Move script into main fn

This commit is contained in:
Marcin Kostrzewski 2022-05-05 22:33:34 +02:00
parent 46d7831b98
commit fe63ef269c

View File

@ -12,6 +12,9 @@ default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def hour_to_int(text: str):
return float(text.replace(':', ''))
@ -86,7 +89,7 @@ def test(dataloader, model, loss_fn):
pred = model(X)
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f"Avg loss: {test_loss:>8f} \n")
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
def setup_args():
@ -97,7 +100,7 @@ def setup_args():
return args_parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
def main():
print(f"Using {device} device")
args = setup_args()
@ -129,3 +132,6 @@ print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
if __name__ == "__main__":
main()