diff --git a/train.py b/train.py index e07529b..623a4c0 100644 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ X_test = torch.FloatTensor(X_test) y_train = torch.LongTensor(y_train) y_test = torch.LongTensor(y_test) -#### Model +### Model class ANN_Model(nn.Module): def __init__(self,input_features=82,hidden1=20,hidden2=20,out_features=3):