From 03f4d0b47a6c23f413a846a9e23d4fb67ad2a7dd Mon Sep 17 00:00:00 2001 From: Alicja Szulecka <73056579+AliSzu@users.noreply.github.com> Date: Mon, 29 Apr 2024 21:27:45 +0200 Subject: [PATCH] Update model.py --- model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index fed8835..12e8a95 100644 --- a/model.py +++ b/model.py @@ -6,7 +6,7 @@ import pandas as pd from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder import torch.nn.functional as F -import os +import sys device = ( @@ -31,6 +31,9 @@ class Model(nn.Module): return x def main(): + epochs = sys.argv[1] + print(epochs) + forest_train = pd.read_csv('forest_train.csv') forest_val = pd.read_csv('forest_val.csv') @@ -60,7 +63,6 @@ def main(): val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64) # Training loop - epochs = os.getenv("EPOCHS") for epoch in range(epochs): model.train() # Set model to training mode running_loss = 0.0