diff --git a/model.py b/model.py index fed8835..12e8a95 100644 --- a/model.py +++ b/model.py @@ -6,7 +6,7 @@ import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder import torch.nn.functional as F -import os +import sys device = ( @@ -31,6 +31,9 @@ class Model(nn.Module): return x def main(): + epochs = sys.argv[1] + print(epochs) + forest_train = pd.read_csv('forest_train.csv') forest_val = pd.read_csv('forest_val.csv') @@ -60,7 +63,6 @@ def main(): val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64) # Training loop - epochs = os.getenv("EPOCHS") for epoch in range(epochs): model.train() # Set model to training mode running_loss = 0.0