Update model.py
This commit is contained in:
parent
ca24c39ada
commit
03f4d0b47a
6
model.py
6
model.py
@ -6,7 +6,7 @@ import pandas as pd
|
|||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import os
|
import sys
|
||||||
|
|
||||||
|
|
||||||
device = (
|
device = (
|
||||||
@ -31,6 +31,9 @@ class Model(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
epochs = sys.argv[1]
|
||||||
|
print(epochs)
|
||||||
|
|
||||||
forest_train = pd.read_csv('forest_train.csv')
|
forest_train = pd.read_csv('forest_train.csv')
|
||||||
forest_val = pd.read_csv('forest_val.csv')
|
forest_val = pd.read_csv('forest_val.csv')
|
||||||
|
|
||||||
@ -60,7 +63,6 @@ def main():
|
|||||||
val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64)
|
val_loader = DataLoader(list(zip(X_val, y_val)), batch_size=64)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
epochs = os.getenv("EPOCHS")
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train() # Set model to training mode
|
model.train() # Set model to training mode
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
Loading…
Reference in New Issue
Block a user