diff --git a/train_model.py b/train_model.py index bdf091a..21bf880 100644 --- a/train_model.py +++ b/train_model.py @@ -109,3 +109,6 @@ for t in range(epochs): train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model, loss_fn) print("Done!") + +torch.save(model.state_dict(), './model_out') +print("Model saved in ./model_out file.")