From fe63ef269c0ee4dfe0d7d44ae6d92d4dc68c07fa Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Thu, 5 May 2022 22:33:34 +0200 Subject: [PATCH] Move script into main fn --- train_model.py | 56 ++++++++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/train_model.py b/train_model.py index 5fe4978..ecc27ef 100644 --- a/train_model.py +++ b/train_model.py @@ -12,6 +12,9 @@ default_batch_size = 64 default_epochs = 5 +device = "cuda" if torch.cuda.is_available() else "cpu" + + def hour_to_int(text: str): return float(text.replace(':', '')) @@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn): pred = model(X) test_loss += loss_fn(pred, y).item() test_loss /= num_batches - print(f"Avg loss: {test_loss:>8f} \n") + print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n") def setup_args(): @@ -97,35 +100,38 @@ def setup_args(): return args_parser.parse_args() -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using {device} device") +def main(): + print(f"Using {device} device") -args = setup_args() + args = setup_args() -batch_size = args.batchSize + batch_size = args.batchSize -plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') -plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') + plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') + plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') -train_dataloader = DataLoader(plant_train, batch_size=batch_size) -test_dataloader = DataLoader(plant_test, batch_size=batch_size) + train_dataloader = DataLoader(plant_train, batch_size=batch_size) + test_dataloader = DataLoader(plant_test, batch_size=batch_size) -for i, (data, labels) in enumerate(train_dataloader): - print(data.shape, labels.shape) - print(data, labels) - break + for i, (data, labels) in enumerate(train_dataloader): + print(data.shape, labels.shape) + print(data, labels) + break -model = MLP() -print(model) + model = MLP() + print(model) -loss_fn = nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) -epochs = args.epochs -for t in range(epochs): - print(f"Epoch {t + 1}\n-------------------------------") - train(train_dataloader, model, loss_fn, optimizer) - test(test_dataloader, model, loss_fn) -print("Done!") + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + epochs = args.epochs + for t in range(epochs): + print(f"Epoch {t + 1}\n-------------------------------") + train(train_dataloader, model, loss_fn, optimizer) + test(test_dataloader, model, loss_fn) + print("Done!") -torch.save(model.state_dict(), './model_out') -print("Model saved in ./model_out file.") + torch.save(model.state_dict(), './model_out') + print("Model saved in ./model_out file.") + +if __name__ == "__main__": + main() \ No newline at end of file