diff --git a/train.py b/train.py index 4a5f6a8..d827367 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,7 @@ def my_main(batch_size, learning_rate, epochs): model_scripted = torch.jit.script(model) # Export to TorchScript model_scripted.save('model_scripted.pt') # Save + exint.add_artifact('model_scripted.pt') exint.run()