Update model.py
This commit is contained in:
parent
ca24c39ada
commit
03f4d0b47a
6
model.py
6
model.py
@ -6,7 +6,7 @@ import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
device = (
|
||||
@ -31,6 +31,9 @@ class Model(nn.Module):
|
||||
return x
|
||||
|
||||
def main():
|
||||
epochs = sys.argv[1]
|
||||
print(epochs)
|
||||
|
||||
forest_train = pd.read_csv('forest_train.csv')
|
||||
forest_val = pd.read_csv('forest_val.csv')
|
||||
|
||||
@ -60,7 +63,6 @@ def main():
|
||||
val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64)
|
||||
|
||||
# Training loop
|
||||
epochs = os.getenv("EPOCHS")
|
||||
for epoch in range(epochs):
|
||||
model.train() # Set model to training mode
|
||||
running_loss = 0.0
|
||||
|
Loading…
Reference in New Issue
Block a user