diff --git a/model.py b/model.py index 12e8a95..4bd3e73 100644 --- a/model.py +++ b/model.py @@ -31,7 +31,7 @@ class Model(nn.Module): return x def main(): - epochs = sys.argv[1] + epochs = int(sys.argv[1]) print(epochs) forest_train = pd.read_csv('forest_train.csv')