Move script into main fn
This commit is contained in:
parent
46d7831b98
commit
fe63ef269c
@ -12,6 +12,9 @@ default_batch_size = 64
|
||||
default_epochs = 5
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def hour_to_int(text: str):
|
||||
return float(text.replace(':', ''))
|
||||
|
||||
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
|
||||
pred = model(X)
|
||||
test_loss += loss_fn(pred, y).item()
|
||||
test_loss /= num_batches
|
||||
print(f"Avg loss: {test_loss:>8f} \n")
|
||||
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
|
||||
|
||||
|
||||
def setup_args():
|
||||
@ -97,35 +100,38 @@ def setup_args():
|
||||
return args_parser.parse_args()
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
def main():
|
||||
print(f"Using {device} device")
|
||||
|
||||
args = setup_args()
|
||||
args = setup_args()
|
||||
|
||||
batch_size = args.batchSize
|
||||
batch_size = args.batchSize
|
||||
|
||||
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
||||
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
||||
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
||||
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
||||
|
||||
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
||||
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
||||
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
||||
|
||||
for i, (data, labels) in enumerate(train_dataloader):
|
||||
print(data.shape, labels.shape)
|
||||
print(data, labels)
|
||||
break
|
||||
for i, (data, labels) in enumerate(train_dataloader):
|
||||
print(data.shape, labels.shape)
|
||||
print(data, labels)
|
||||
break
|
||||
|
||||
model = MLP()
|
||||
print(model)
|
||||
model = MLP()
|
||||
print(model)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
epochs = args.epochs
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
test(test_dataloader, model, loss_fn)
|
||||
print("Done!")
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
epochs = args.epochs
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
test(test_dataloader, model, loss_fn)
|
||||
print("Done!")
|
||||
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user