ium_470618/train-eval/learning.py

78 lines
2.3 KiB
Python
Executable File

#!/usr/bin/python3
import numpy as np
import torch
from torch import nn
import pandas as pd
# import subprocess
import sys
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder
class Model(nn.Module):
def __init__(self, input_dim):
super(Model, self).__init__()
self.layer1 = nn.Linear(input_dim, 50)
self.layer2 = nn.Linear(50, 20)
self.layer3 = nn.Linear(20, 2)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.layer3(x))
return x
def print_(loss):
print ("The loss calculated: ", loss)
if __name__ == "__main__":
if sys.argv[1]=='default':
alpha = 0.003 #learning rate
epochs = 1000
else:
pass
#TODO split args string to make hyperparameters work
df = pd.read_csv("output.csv")
df = df.dropna() #drop NA values
columns_to_normalize=['Age','Fare'] #NORMALIZATION
for colname in columns_to_normalize:
df[colname]=(df[colname]-df[colname].min())/(df[colname].max()-df[colname].min())
X = df[['Pclass', 'Sex', 'Age','SibSp', 'Fare']] #only reasonable numerical data
Y = df[['Survived']]
X.loc[:,('Sex')].replace(['female', 'male'], [0,1], inplace=True) #categorical data transformed to
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, shuffle=True) #split the date into train and test sets
testing_data = pd.concat([X_test, Y_test], axis=1)
testing_data.to_csv('testing_data.csv', sep=',')
Y_train = np.ravel(Y_train)
encoder = LabelEncoder()
encoder.fit(Y_train)
Y_train = encoder.transform(Y_train)
Xt = torch.tensor(X_train.values, dtype = torch.float32)
Yt = torch.tensor(Y_train, dtype=torch.long)
model = Model(Xt.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=alpha)
loss_fn = nn.CrossEntropyLoss()
#TRAINING LOOP
for epoch in range(1, epochs+1):
print("Epoch #", epoch)
y_pred = model(Xt)
loss = loss_fn(y_pred, Yt)
print_(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model.pt')