ium_444498/neutral_network.py

134 lines
3.7 KiB
Python
Raw Normal View History

from ast import arg
import numpy as np
import pandas as pd
import torch
import argparse
from torch import nn
from torch.utils.data import DataLoader, Dataset
2022-05-09 17:07:25 +02:00
from sacred import Experiment
from sacred.observers import FileStorageObserver, MongoObserver
default_batch_size = 64
default_epochs = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
class AtpDataset(Dataset):
def __init__(self, file_name):
2022-05-08 23:38:16 +02:00
df = pd.read_csv(file_name, usecols=["AvgL", "AvgW"])
df = df.dropna()
# Loser avg and Winner avg
2022-05-08 23:38:16 +02:00
x = df.iloc[:, 1].values
y = df.iloc[:, 0].values
self.x_train = torch.from_numpy(x)
self.y_train = torch.from_numpy(y)
self.x_train.type(torch.LongTensor)
def __len__(self):
return len(self.y_train)
def __getitem__(self, idx):
return self.x_train[idx].float(), self.y_train[idx].float()
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, x):
x = x.view(x.size(0), -1)
return self.layers(x)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f"Avg loss (using {loss_fn}): {test_loss:>8f} \n")
return test_loss
def setup_args():
2022-05-08 23:38:16 +02:00
args_parser = argparse.ArgumentParser(prefix_chars="-")
args_parser.add_argument("-b", "--batchSize", type=int, default=default_batch_size)
args_parser.add_argument("-e", "--epochs", type=int, default=default_epochs)
return args_parser.parse_args()
2022-05-09 17:07:25 +02:00
def main(batch_size, epochs):
print(f"Using {device} device")
plant_test = AtpDataset("atp_test.csv")
plant_train = AtpDataset("atp_train.csv")
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
2022-05-09 17:07:25 +02:00
model = MLP()
print(model)
2022-05-09 17:07:25 +02:00
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
2022-05-09 17:07:25 +02:00
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Finish!")
2022-05-09 17:07:25 +02:00
torch.save(model.state_dict(), "./model.zip")
print("Model saved in ./model.zip file.")
2022-05-09 17:07:25 +02:00
def setup_experiment():
ex = Experiment('Simple Experiment')
ex.observers.append(FileStorageObserver('sacred_runs'))
2022-05-09 18:28:17 +02:00
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
2022-05-09 17:07:25 +02:00
return ex
2022-05-09 17:07:25 +02:00
ex = setup_experiment()
2022-05-09 17:07:25 +02:00
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
2022-05-09 17:07:25 +02:00
@ex.automain
def run(batch_size, epochs):
main(batch_size, epochs)
ex.add_artifact('model.zip')