x1/train.py
wojciechbatruszewicz 91decc353d Fix train
2023-06-27 15:19:55 +02:00

82 lines
2.7 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from sklearn.preprocessing import LabelBinarizer
import numpy as np
import argparse
class MyNeuralNetwork(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(MyNeuralNetwork, self).__init__(*args, **kwargs)
self.fc1 = nn.Linear(7, 12)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(7, 12)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(12, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
def prepare_df_for_nn(df):
id_column_name_list = [column for column in df.columns.to_list() if 'id' in column.lower()]
if len(id_column_name_list) == 0:
pass
else:
df.drop(id_column_name_list[0], inplace=True, axis=1)
encoder = LabelBinarizer()
df.reset_index(inplace=True)
for column in df.columns:
if str(df[column].dtype).lower() == 'object':
encoded_column = encoder.fit_transform(df[column])
df[column] = pd.Series(encoded_column.flatten(), dtype=pd.Int16Dtype)
return df
def load_data(path):
df = pd.read_csv(path)
train_dataset = prepare_df_for_nn(df)
x = train_dataset.iloc[:, :-1].values.astype(float)
y = train_dataset.iloc[:, -1].values.astype(float)
x_tensor = torch.tensor(x, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
dataset = TensorDataset(x_tensor, y_tensor)
return dataset
def train(epochs, dataloader_train):
model: MyNeuralNetwork = MyNeuralNetwork()
criterion: nn.BCELoss = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
for inputs, labels in dataloader_train:
outputs = model(inputs)
labels = labels.reshape((labels.shape[0], 1))
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
return model
def main():
parser = argparse.ArgumentParser(description='A test program.')
parser.add_argument("--epochs", help="Prints the supplied argument.", default='10')
args = parser.parse_args()
config = vars(args)
epochs = int(config["epochs"])
train_dataset = load_data("gender_classification_train.csv")
batch_size = 32
dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
model = train(epochs, dataloader_train)
torch.save(model.state_dict(), 'model.pt')
if __name__ == "__main__":
main()