Update model.py

This commit is contained in:
Alicja Szulecka 2024-04-29 21:27:45 +02:00
parent ca24c39ada
commit 03f4d0b47a
1 changed files with 4 additions and 2 deletions

View File

@ -6,7 +6,7 @@ import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch.nn.functional as F
import os
import sys
device = (
@ -31,6 +31,9 @@ class Model(nn.Module):
return x
def main():
epochs = sys.argv[1]
print(epochs)
forest_train = pd.read_csv('forest_train.csv')
forest_val = pd.read_csv('forest_val.csv')
@ -60,7 +63,6 @@ def main():
val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64)
# Training loop
epochs = os.getenv("EPOCHS")
for epoch in range(epochs):
model.train() # Set model to training mode
running_loss = 0.0