log_reg_um/log_reg.py

73 lines
2.1 KiB
Python
Raw Permalink Normal View History

2021-05-26 00:07:15 +02:00
import gensim.downloader as gensim
import numpy as np
import pandas as pd
import torch
from nltk.tokenize import word_tokenize
class NeuralNetworkModel(torch.nn.Module):
def __init__(self):
super(NeuralNetworkModel, self).__init__()
self.l01 = torch.nn.Linear(300, 500)
self.l02 = torch.nn.Linear(500, 1)
def forward(self, x):
x = self.l01(x)
x = torch.relu(x)
x = self.l02(x)
x = torch.sigmoid(x)
return x
def doc2vec(doc):
return np.mean([word2vec[word] for word in doc if word in word2vec] or [np.zeros(300)], axis=0)
2021-05-30 22:43:03 +02:00
x_train = pd.read_table('in-train.tsv.xz', compression='xz', sep='\t', header=None, error_bad_lines=False, quoting=3)
y_train = pd.read_table('expected-train.tsv', sep='\t', header=None, quoting=3)
x_dev = pd.read_table('in-dev.tsv.xz', compression='xz', sep='\t', header=None, quoting=3)
2021-05-26 00:07:15 +02:00
y_train = y_train[0]
x_train = x_train[0].str.lower()
x_train = [word_tokenize(x) for x in x_train]
x_dev = x_dev[0].str.lower()
x_dev = [word_tokenize(x) for x in x_dev]
word2vec = gensim.load('word2vec-google-news-300')
x_train = [doc2vec(doc) for doc in x_train]
x_dev = [doc2vec(doc) for doc in x_dev]
model = NeuralNetworkModel()
BATCH_SIZE = 1024
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(5):
model.train()
for i in range(0, y_train.shape[0], BATCH_SIZE):
X = x_train[i:i + BATCH_SIZE]
X = torch.tensor(X)
y = y_train[i:i + BATCH_SIZE]
y = torch.tensor(y.astype(np.float32).to_numpy()).reshape(-1, 1)
optimizer.zero_grad()
outputs = model(X.float())
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
y_dev = []
y_test = []
model.eval()
with torch.no_grad():
for i in range(0, len(x_dev), BATCH_SIZE):
X = x_dev[i:i + BATCH_SIZE]
X = torch.tensor(X)
outputs = model(X.float())
y = (outputs > 0.5)
y_dev.extend(y)
y_dev = np.asarray(y_dev, dtype=np.int32)
Y_dev = pd.DataFrame({'label': y_dev})
2021-05-30 22:43:03 +02:00
Y_dev.to_csv(r'dev-out.tsv', sep='\t', index=False, header=False)