naive_bayes/Logistic.py
2021-05-25 11:17:30 +02:00

168 lines
5.5 KiB
Python

from tqdm import tqdm
import re
import math
from math import log, exp
from sklearn.datasets import fetch_20newsgroups
import gensim
import torch
import gensim.downloader as api
import matplotlib.pyplot as plt
from gensim.models.word2vec import Word2Vec
corpus = api.load('text8')
w2v = Word2Vec(corpus)
TRAINING_MODE = False
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
obfuscator = re.compile('[\\[?.,!()\\]*&^%$#@{}|\\\\/~\\- \t\n]+')
MAX_SENTENCE_LEN = 128
NUM_CATEGORIES = 2
BATCH_SIZE = 256
def tokenize(txt):
return [token.lower() for token in obfuscator.sub(' ', txt).split()]
class NetL2(torch.nn.Module): # This model got much better performance
def __init__(self):
super(NetL2, self).__init__()
self.fc1 = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 512)
self.fc2 = torch.nn.Linear(512, 2)
def forward(self, x):
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.log_softmax(x, dim=1)
return x
class NetL1(torch.nn.Module): # This model did not learn well enough
def __init__(self):
super(NetL1, self).__init__()
self.fc = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 2)
def forward(self, x):
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
x = self.fc(x)
x = torch.log_softmax(x, dim=1)
return x
model = NetL2().to(DEVICE)
def collate(batch: [(str, int)]):
inputs = torch.zeros(len(batch), w2v.wv.vector_size, MAX_SENTENCE_LEN)
outputs = torch.zeros(len(batch), dtype=torch.long)
for i, (sentence, expected) in enumerate(batch):
outputs[i] = expected
for j, word in enumerate(sentence[:MAX_SENTENCE_LEN]):
if word in w2v.wv:
vec = w2v.wv[word]
inputs[i, :, j] = torch.from_numpy(vec)
return inputs, outputs
def infer(data_dir):
with open(data_dir + '/in.tsv') as fd, open(data_dir + '/out.tsv', 'w+') as ex:
for line in tqdm(fd, desc="inferring " + data_dir):
comment, _ = line.split('\t')
comment = tokenize(comment)
comment, _ = collate([(comment, 0)])
comment = comment.to(DEVICE)
predicetd = model(comment).argmax(dim=1).item()
ex.write(str(predicetd) + '\n')
if TRAINING_MODE:
DATA = []
with open('train/in.tsv') as fd, open('train/expected.tsv') as ex:
k = 0
for line, result in tqdm(zip(fd, ex), desc="preprocessing", total=289579):
result = int(result)
comment, _ = line.split('\t')
DATA.append((tokenize(comment), result))
k+=1
if k == -1:
break
TEST_DATA = []
with open('dev-0/in.tsv') as fd, open('dev-0/expected.tsv') as ex:
k = 0
for line, result in tqdm(zip(fd, ex), desc="test preprocessing", total=5272):
result = int(result)
comment, _ = line.split('\t')
TEST_DATA.append((tokenize(comment), result))
k += 1
if k == -1:
break
dataloader = torch.utils.data.DataLoader(dataset=DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
drop_last=True)
test_dataloader = torch.utils.data.DataLoader(dataset=TEST_DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
drop_last=True)
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())
bar = tqdm(total=len(DATA), desc="training", position=0)
test_bar = tqdm(total=len(TEST_DATA), desc="testing", position=0)
avg_losses = []
accuracies = []
test_accuracies = []
for epoch in range(1000):
avg_loss = 0
bar.reset()
accuracy = 0
test_accuracy = 0
total = 0
for in_batch, out_batch in dataloader:
in_batch = in_batch.to(DEVICE)
outputs = model(in_batch)
out_batch = out_batch.to(DEVICE)
loss = criterion(outputs, out_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * BATCH_SIZE
bar.update(BATCH_SIZE)
accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
total += BATCH_SIZE
avg_losses.append(avg_loss / total)
accuracies.append(accuracy / total)
test_bar.reset()
total = 0
for in_batch, out_batch in test_dataloader:
in_batch = in_batch.to(DEVICE)
outputs = model(in_batch)
out_batch = out_batch.to(DEVICE)
test_bar.update(BATCH_SIZE)
test_accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
total += BATCH_SIZE
test_accuracies.append(test_accuracy / total)
plt.clf()
plt.plot(avg_losses, label='avg loss')
plt.plot(accuracies, label='accuracy')
plt.plot(test_accuracies, label='test accuracy')
print("epoch: "+str(epoch))
print("avg loss: " + str(avg_losses[-1]))
print("accuracy: " + str(accuracies[-1]))
print("test accuracy: " + str(test_accuracies[-1]))
print()
plt.legend()
plt.pause(0.0001)
torch.save(model.state_dict(), 'l2_epoch_' + str(epoch) + ".pth")
infer('dev-0')
else:
model.load_state_dict(torch.load('l2_epoch_0.pth'))
model.eval()
infer('test-A')