168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
from tqdm import tqdm
|
|
import re
|
|
import math
|
|
from math import log, exp
|
|
from sklearn.datasets import fetch_20newsgroups
|
|
import gensim
|
|
import torch
|
|
import gensim.downloader as api
|
|
import matplotlib.pyplot as plt
|
|
from gensim.models.word2vec import Word2Vec
|
|
|
|
corpus = api.load('text8')
|
|
w2v = Word2Vec(corpus)
|
|
TRAINING_MODE = False
|
|
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
|
|
obfuscator = re.compile('[\\[?.,!()\\]*&^%$#@{}|\\\\/~\\- \t\n]+')
|
|
|
|
MAX_SENTENCE_LEN = 128
|
|
NUM_CATEGORIES = 2
|
|
BATCH_SIZE = 256
|
|
|
|
|
|
def tokenize(txt):
|
|
return [token.lower() for token in obfuscator.sub(' ', txt).split()]
|
|
|
|
|
|
class NetL2(torch.nn.Module): # This model got much better performance
|
|
|
|
def __init__(self):
|
|
super(NetL2, self).__init__()
|
|
self.fc1 = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 512)
|
|
self.fc2 = torch.nn.Linear(512, 2)
|
|
|
|
def forward(self, x):
|
|
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
|
|
x = self.fc1(x)
|
|
x = torch.relu(x)
|
|
x = self.fc2(x)
|
|
x = torch.log_softmax(x, dim=1)
|
|
return x
|
|
|
|
|
|
class NetL1(torch.nn.Module): # This model did not learn well enough
|
|
|
|
def __init__(self):
|
|
super(NetL1, self).__init__()
|
|
self.fc = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 2)
|
|
|
|
def forward(self, x):
|
|
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
|
|
x = self.fc(x)
|
|
x = torch.log_softmax(x, dim=1)
|
|
return x
|
|
|
|
|
|
model = NetL2().to(DEVICE)
|
|
|
|
|
|
def collate(batch: [(str, int)]):
|
|
inputs = torch.zeros(len(batch), w2v.wv.vector_size, MAX_SENTENCE_LEN)
|
|
outputs = torch.zeros(len(batch), dtype=torch.long)
|
|
for i, (sentence, expected) in enumerate(batch):
|
|
outputs[i] = expected
|
|
for j, word in enumerate(sentence[:MAX_SENTENCE_LEN]):
|
|
if word in w2v.wv:
|
|
vec = w2v.wv[word]
|
|
inputs[i, :, j] = torch.from_numpy(vec)
|
|
return inputs, outputs
|
|
|
|
|
|
def infer(data_dir):
|
|
with open(data_dir + '/in.tsv') as fd, open(data_dir + '/out.tsv', 'w+') as ex:
|
|
for line in tqdm(fd, desc="inferring " + data_dir):
|
|
comment, _ = line.split('\t')
|
|
comment = tokenize(comment)
|
|
comment, _ = collate([(comment, 0)])
|
|
comment = comment.to(DEVICE)
|
|
predicetd = model(comment).argmax(dim=1).item()
|
|
ex.write(str(predicetd) + '\n')
|
|
|
|
|
|
if TRAINING_MODE:
|
|
|
|
DATA = []
|
|
|
|
with open('train/in.tsv') as fd, open('train/expected.tsv') as ex:
|
|
k = 0
|
|
for line, result in tqdm(zip(fd, ex), desc="preprocessing", total=289579):
|
|
result = int(result)
|
|
comment, _ = line.split('\t')
|
|
DATA.append((tokenize(comment), result))
|
|
k+=1
|
|
if k == -1:
|
|
break
|
|
|
|
TEST_DATA = []
|
|
|
|
with open('dev-0/in.tsv') as fd, open('dev-0/expected.tsv') as ex:
|
|
k = 0
|
|
for line, result in tqdm(zip(fd, ex), desc="test preprocessing", total=5272):
|
|
result = int(result)
|
|
comment, _ = line.split('\t')
|
|
TEST_DATA.append((tokenize(comment), result))
|
|
k += 1
|
|
if k == -1:
|
|
break
|
|
|
|
dataloader = torch.utils.data.DataLoader(dataset=DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
|
|
drop_last=True)
|
|
test_dataloader = torch.utils.data.DataLoader(dataset=TEST_DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
|
|
drop_last=True)
|
|
|
|
criterion = torch.nn.NLLLoss()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
bar = tqdm(total=len(DATA), desc="training", position=0)
|
|
test_bar = tqdm(total=len(TEST_DATA), desc="testing", position=0)
|
|
avg_losses = []
|
|
accuracies = []
|
|
test_accuracies = []
|
|
for epoch in range(1000):
|
|
avg_loss = 0
|
|
bar.reset()
|
|
accuracy = 0
|
|
test_accuracy = 0
|
|
total = 0
|
|
for in_batch, out_batch in dataloader:
|
|
in_batch = in_batch.to(DEVICE)
|
|
outputs = model(in_batch)
|
|
out_batch = out_batch.to(DEVICE)
|
|
loss = criterion(outputs, out_batch)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
avg_loss += loss.item() * BATCH_SIZE
|
|
bar.update(BATCH_SIZE)
|
|
accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
|
|
total += BATCH_SIZE
|
|
avg_losses.append(avg_loss / total)
|
|
accuracies.append(accuracy / total)
|
|
test_bar.reset()
|
|
total = 0
|
|
for in_batch, out_batch in test_dataloader:
|
|
in_batch = in_batch.to(DEVICE)
|
|
outputs = model(in_batch)
|
|
out_batch = out_batch.to(DEVICE)
|
|
test_bar.update(BATCH_SIZE)
|
|
test_accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
|
|
total += BATCH_SIZE
|
|
test_accuracies.append(test_accuracy / total)
|
|
plt.clf()
|
|
plt.plot(avg_losses, label='avg loss')
|
|
plt.plot(accuracies, label='accuracy')
|
|
plt.plot(test_accuracies, label='test accuracy')
|
|
print("epoch: "+str(epoch))
|
|
print("avg loss: " + str(avg_losses[-1]))
|
|
print("accuracy: " + str(accuracies[-1]))
|
|
print("test accuracy: " + str(test_accuracies[-1]))
|
|
print()
|
|
plt.legend()
|
|
plt.pause(0.0001)
|
|
torch.save(model.state_dict(), 'l2_epoch_' + str(epoch) + ".pth")
|
|
infer('dev-0')
|
|
else:
|
|
model.load_state_dict(torch.load('l2_epoch_0.pth'))
|
|
model.eval()
|
|
infer('test-A')
|