Move script into main fn

This commit is contained in:
Marcin Kostrzewski 2022-05-05 22:33:34 +02:00
parent 46d7831b98
commit fe63ef269c

View File

@ -12,6 +12,9 @@ default_batch_size = 64
default_epochs = 5 default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def hour_to_int(text: str): def hour_to_int(text: str):
return float(text.replace(':', '')) return float(text.replace(':', ''))
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
pred = model(X) pred = model(X)
test_loss += loss_fn(pred, y).item() test_loss += loss_fn(pred, y).item()
test_loss /= num_batches test_loss /= num_batches
print(f"Avg loss: {test_loss:>8f} \n") print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
def setup_args(): def setup_args():
@ -97,35 +100,38 @@ def setup_args():
return args_parser.parse_args() return args_parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu" def main():
print(f"Using {device} device") print(f"Using {device} device")
args = setup_args() args = setup_args()
batch_size = args.batchSize batch_size = args.batchSize
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
train_dataloader = DataLoader(plant_train, batch_size=batch_size) train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size) test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader): for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape) print(data.shape, labels.shape)
print(data, labels) print(data, labels)
break break
model = MLP() model = MLP()
print(model) print(model)
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = args.epochs epochs = args.epochs
for t in range(epochs): for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------") print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer) train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn) test(test_dataloader, model, loss_fn)
print("Done!") print("Done!")
torch.save(model.state_dict(), './model_out') torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.") print("Model saved in ./model_out file.")
if __name__ == "__main__":
main()