12 KiB
12 KiB
import numpy as np
import gensim
import torch
import pandas as pd
from gensim.models import Word2Vec
from gensim import downloader
from sklearn.feature_extraction.text import TfidfVectorizer
BATCH_SIZE = 10
EPOCHS = 100
FEAUTERES = 200
class NeuralNetworkModel(torch.nn.Module):
def __init__(self):
super(NeuralNetworkModel, self).__init__()
self.fc1 = torch.nn.Linear(FEAUTERES,500)
self.fc2 = torch.nn.Linear(500,1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.sigmoid(x)
return x
word2vec = downloader.load("glove-twitter-200")
def readData(fileName):
with open(f'{fileName}/in.tsv', 'r', encoding='utf8') as f:
X = np.array([x.strip().lower() for x in f.readlines()])
with open(f'{fileName}/expected.tsv', 'r', encoding='utf8') as f:
y = np.array([int(x.strip()) for x in f.readlines()])
return X,y
X_file,y_file = readData('dev-0')
x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]
or [np.zeros(FEAUTERES)], axis=0) for doc in X_file]
def train_model(X_file,y_file):
model = NeuralNetworkModel()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)
for epoch in range(EPOCHS):
print(epoch)
loss_score = 0
acc_score = 0
items_total = 0
for i in range(0, y_file.shape[0], BATCH_SIZE):
x = X_file[i:i+BATCH_SIZE]
x = torch.tensor(np.array(x).astype(np.float32))
y = y_file[i:i+BATCH_SIZE]
y = torch.tensor(y.astype(np.float32)).reshape(-1, 1)
y_pred = model(x)
acc_score += torch.sum((y_pred > 0.5) == y).item()
items_total += y.shape[0]
optimizer.zero_grad()
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
loss_score += loss.item() * y.shape[0]
print((loss_score / items_total), (acc_score / items_total))
return model
def predict(model,x_file):
y_dev = []
with torch.no_grad():
for i in range(0, len(x_file), BATCH_SIZE):
x = x_file[i:i+BATCH_SIZE]
x = torch.tensor(np.array(x).astype(np.float32))
outputs = model(x)
y = (outputs > 0.5)
y_dev.extend(y)
return y_dev
def wrtieToFile(fileName,y_file):
y_out = []
for y in y_file:
y_out.append(int(str(y[0]).split('(')[1].split(')')[0]=='True'))
with open(f'{fileName}/out.tsv','w',encoding='utf8') as f:
for y in y_out:
f.write(f'{y}\n')
model = train_model(x_train_w2v,y_file)
0 0.6414709375416563 0.6464339908952959 1 0.6118579905971953 0.6589529590288316 2 0.5930351529140393 0.677731411229135 3 0.5807589731138194 0.6936646433990895 4 0.5711128521026628 0.7031487101669196 5 0.5637358135638451 0.7065629742033384 6 0.5573145605239321 0.710546282245827 7 0.5521481898931252 0.715288315629742 8 0.5475104518053836 0.7181335356600911 9 0.5430893454028008 0.7202200303490136 10 0.5395108298066443 0.7236342943854325 11 0.5361589408495177 0.7257207890743551 12 0.53314527610885 0.7270485584218513 13 0.5298747769267226 0.7297040971168437 14 0.5269876997833096 0.7319802731411229 15 0.5245049590914763 0.7336874051593323 16 0.5220209190930057 0.7363429438543247 17 0.5203242429527871 0.7365326251896813 18 0.5182899421417297 0.737670713201821 19 0.5155506848000069 0.7401365705614568 20 0.5131794015095429 0.7403262518968133 21 0.5113656374375719 0.7412746585735963 22 0.5092821710139558 0.7420333839150227 23 0.5067137854063547 0.7441198786039454 24 0.5047900934558085 0.745257966616085 25 0.5025694217866397 0.7488619119878603 26 0.5007175219885451 0.7486722306525038 27 0.4981631609315847 0.747154779969651 28 0.4961598192105615 0.7498103186646434 29 0.49438970515077685 0.7507587253414264 30 0.49240998727621366 0.7507587253414264 31 0.4907134136267018 0.7520864946889226 32 0.48826086573438415 0.7541729893778453 33 0.4871711270185541 0.7560698027314112 34 0.48422483688330614 0.756638846737481 35 0.48217912709371546 0.7604324734446131 36 0.48009182657475535 0.761380880121396 37 0.4778907883217013 0.7632776934749621 38 0.47551582766660067 0.7621396054628224 39 0.47324845619635353 0.7646054628224582 40 0.47138607904755925 0.7653641881638846 41 0.4684638544374424 0.7659332321699545 42 0.4662012148575012 0.7685887708649469 43 0.46414706633568986 0.7693474962063733 44 0.4620490156040613 0.7702959028831563 45 0.46027336999977486 0.7706752655538694 46 0.4574687189093264 0.7746585735963581 47 0.45456105805311653 0.7748482549317147 48 0.45308226045385025 0.7769347496206374 49 0.44969080617490237 0.7792109256449166 50 0.4477136310823092 0.77902124430956 51 0.44523295281067693 0.7841426403641881 52 0.44300158465442235 0.7835735963581184 53 0.44147631555388656 0.7852807283763278 54 0.43824701448718767 0.78850531107739 55 0.437326367692923 0.7936267071320182 56 0.43404240863558824 0.7936267071320182 57 0.43146262304328825 0.7959028831562974 58 0.429094969041996 0.7938163884673748 59 0.42631421763059857 0.7977996965098634 60 0.4239590879280985 0.798937784522003 61 0.4216488930983229 0.8014036418816388 62 0.41922062316595693 0.8033004552352049 63 0.417561381201688 0.8053869499241275 64 0.4144452941633637 0.8051972685887708 65 0.41305530049212064 0.8080424886191199 66 0.410686616688311 0.8072837632776935 67 0.4076426998430889 0.8114567526555387 68 0.4061218895193342 0.811267071320182 69 0.4029337710281198 0.8139226100151745 70 0.40099998707395496 0.8143019726858877 71 0.39854915830701004 0.8133535660091047 72 0.39473064304845285 0.8201820940819423 73 0.3931978788616896 0.8198027314112292 74 0.3905544553760422 0.8218892261001517 75 0.3894510168316513 0.8211305007587253 76 0.38586248252229916 0.8247344461305007 77 0.3851398667786977 0.8256828528072838 78 0.38457902678046857 0.8247344461305007 79 0.3803209278461197 0.8272003034901366 80 0.37845283393127693 0.8287177541729894 81 0.37618811287505943 0.8294764795144158 82 0.37400476449368486 0.8323216995447648 83 0.3726042910890261 0.8332701062215478 84 0.36963997851373215 0.8338391502276176 85 0.3680792153446917 0.8363050075872535 86 0.36542417398160704 0.8361153262518968 87 0.36405448698366627 0.8376327769347496 88 0.3595154614517061 0.8423748103186647 89 0.35860147739566967 0.8419954476479514 90 0.3578952589836848 0.8404779969650986 91 0.35602253879814755 0.8414264036418816 92 0.3523210818386087 0.8446509863429439 93 0.34952340000598764 0.8480652503793626 94 0.3513405356550524 0.845030349013657 95 0.349314306160274 0.8493930197268589 96 0.34516190266595626 0.8492033383915023 97 0.34279035948137776 0.8524279210925645 98 0.34358633996576793 0.8518588770864947 99 0.3396215371445545 0.8528072837632777
y_dev=predict(model,x_train_w2v)
wrtieToFile("dev-0",y_dev)
with open(f'test-A/in.tsv', 'r', encoding='utf8') as f:
X = np.array([x.strip().lower() for x in f.readlines()])
x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]
or [np.zeros(FEAUTERES)], axis=0) for doc in X]
y_dev=predict(model,x_train_w2v)
wrtieToFile("test-A",y_dev)