p|s with torch
This commit is contained in:
parent
ecfafbf86c
commit
6577971e50
5272
dev-0/in.tsv
Normal file
5272
dev-0/in.tsv
Normal file
File diff suppressed because one or more lines are too long
5272
dev-0/out.tsv
Normal file
5272
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
117
run.py
Normal file
117
run.py
Normal file
@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
import gensim
|
||||
import re
|
||||
import torch
|
||||
import pandas as pd
|
||||
from gensim.models import Word2Vec
|
||||
from gensim import downloader
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
|
||||
BATCH_SIZE = 64
|
||||
EPOCHS = 50
|
||||
FEATURES = 200
|
||||
|
||||
class NeuralNetworkModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NeuralNetworkModel, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(FEATURES, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
def train_model(X_train, y_train):
|
||||
model = NeuralNetworkModel()
|
||||
|
||||
criterion = torch.nn.BCELoss()
|
||||
optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
|
||||
print(epoch)
|
||||
loss_score = 0
|
||||
acc_score = 0
|
||||
items_total = 0
|
||||
|
||||
for i in range(0, y_train.shape[0], BATCH_SIZE):
|
||||
x = X_train[i:i+BATCH_SIZE]
|
||||
x = torch.tensor(np.array(x).astype(np.float32))
|
||||
y = y_train[i:i+BATCH_SIZE]
|
||||
y = torch.tensor(y.astype(np.float32)).reshape(-1, 1)
|
||||
y_pred = model(x)
|
||||
acc_score += torch.sum((y_pred > 0.5) == y).item()
|
||||
items_total += y.shape[0]
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_score += loss.item() * y.shape[0]
|
||||
|
||||
print((loss_score / items_total), (acc_score / items_total))
|
||||
|
||||
return model
|
||||
|
||||
def predict(model, x_test):
|
||||
y_dev = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(x_test), BATCH_SIZE):
|
||||
x = x_test[i:i+BATCH_SIZE]
|
||||
x = torch.tensor(np.array(x).astype(np.float32))
|
||||
outputs = model(x)
|
||||
y = (outputs > 0.5)
|
||||
y_dev.extend(y)
|
||||
|
||||
return y_dev
|
||||
|
||||
def word_to_vec(word):
|
||||
word2vec = downloader.load("glove-twitter-200")
|
||||
return [np.mean([word2vec[word.lower()] for word in doc.split() \
|
||||
if word.lower() in word2vec] \
|
||||
or [np.zeros(FEATURES)], axis=0) for doc in word]
|
||||
|
||||
|
||||
|
||||
def load_data(path):
|
||||
#return pd.read_csv(path, sep='\t', header=None)
|
||||
with open(path, 'r', encoding='utf8') as f:
|
||||
return f.readlines()
|
||||
|
||||
|
||||
def write_res(data, path):
|
||||
with open(path, 'w') as f:
|
||||
for line in data:
|
||||
f.write(f'{line}\n')
|
||||
print(f"Data written {path}/out.tsv")
|
||||
|
||||
|
||||
def main():
|
||||
x_train = [re.sub(r'\t[0-9]+\n', '', i) for i in load_data('train/in.tsv')]
|
||||
y_train = [re.sub(r'\n', '', i) for i in load_data('train/expected.tsv')]
|
||||
|
||||
x_train_word2vec = word_to_vec(x_train)
|
||||
|
||||
y_train = np.array(y_train)
|
||||
|
||||
model = train_model(x_train_word2vec, y_train)
|
||||
|
||||
for path in ['dev-0', 'test-A']:
|
||||
x = [re.sub(r'\t[0-9]+\n', '', i) for i in load_data(f'{path}/in.tsv')]
|
||||
x_word2vec = word_to_vec(x)
|
||||
y = predict(model, x_word2vec)
|
||||
result = ['1' if bool(i) else '0' for i in y]
|
||||
write_res(result, f'{path}/out.tsv')
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
5152
test-A/in.tsv
Normal file
5152
test-A/in.tsv
Normal file
File diff suppressed because one or more lines are too long
5152
test-A/out.tsv
Normal file
5152
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
289579
train/in.tsv
Normal file
289579
train/in.tsv
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user