Move script into main fn
This commit is contained in:
parent
46d7831b98
commit
fe63ef269c
@ -12,6 +12,9 @@ default_batch_size = 64
|
|||||||
default_epochs = 5
|
default_epochs = 5
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def hour_to_int(text: str):
|
def hour_to_int(text: str):
|
||||||
return float(text.replace(':', ''))
|
return float(text.replace(':', ''))
|
||||||
|
|
||||||
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
|
|||||||
pred = model(X)
|
pred = model(X)
|
||||||
test_loss += loss_fn(pred, y).item()
|
test_loss += loss_fn(pred, y).item()
|
||||||
test_loss /= num_batches
|
test_loss /= num_batches
|
||||||
print(f"Avg loss: {test_loss:>8f} \n")
|
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
||||||
|
|
||||||
|
|
||||||
def setup_args():
|
def setup_args():
|
||||||
@ -97,35 +100,38 @@ def setup_args():
|
|||||||
return args_parser.parse_args()
|
return args_parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
def main():
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
args = setup_args()
|
args = setup_args()
|
||||||
|
|
||||||
batch_size = args.batchSize
|
batch_size = args.batchSize
|
||||||
|
|
||||||
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
||||||
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
||||||
|
|
||||||
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
||||||
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
||||||
|
|
||||||
for i, (data, labels) in enumerate(train_dataloader):
|
for i, (data, labels) in enumerate(train_dataloader):
|
||||||
print(data.shape, labels.shape)
|
print(data.shape, labels.shape)
|
||||||
print(data, labels)
|
print(data, labels)
|
||||||
break
|
break
|
||||||
|
|
||||||
model = MLP()
|
model = MLP()
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
epochs = args.epochs
|
epochs = args.epochs
|
||||||
for t in range(epochs):
|
for t in range(epochs):
|
||||||
print(f"Epoch {t + 1}\n-------------------------------")
|
print(f"Epoch {t + 1}\n-------------------------------")
|
||||||
train(train_dataloader, model, loss_fn, optimizer)
|
train(train_dataloader, model, loss_fn, optimizer)
|
||||||
test(test_dataloader, model, loss_fn)
|
test(test_dataloader, model, loss_fn)
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|
||||||
torch.save(model.state_dict(), './model_out')
|
torch.save(model.state_dict(), './model_out')
|
||||||
print("Model saved in ./model_out file.")
|
print("Model saved in ./model_out file.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user