from tqdm import tqdm import re import math from math import log, exp from sklearn.datasets import fetch_20newsgroups import gensim import torch import gensim.downloader as api import matplotlib.pyplot as plt from gensim.models.word2vec import Word2Vec corpus = api.load('text8') w2v = Word2Vec(corpus) TRAINING_MODE = False DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') obfuscator = re.compile('[\\[?.,!()\\]*&^%$#@{}|\\\\/~\\- \t\n]+') MAX_SENTENCE_LEN = 128 NUM_CATEGORIES = 2 BATCH_SIZE = 256 def tokenize(txt): return [token.lower() for token in obfuscator.sub(' ', txt).split()] class NetL2(torch.nn.Module): # This model got much better performance def __init__(self): super(NetL2, self).__init__() self.fc1 = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 512) self.fc2 = torch.nn.Linear(512, 2) def forward(self, x): x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN) x = self.fc1(x) x = torch.relu(x) x = self.fc2(x) x = torch.log_softmax(x, dim=1) return x class NetL1(torch.nn.Module): # This model did not learn well enough def __init__(self): super(NetL1, self).__init__() self.fc = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 2) def forward(self, x): x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN) x = self.fc(x) x = torch.log_softmax(x, dim=1) return x model = NetL2().to(DEVICE) def collate(batch: [(str, int)]): inputs = torch.zeros(len(batch), w2v.wv.vector_size, MAX_SENTENCE_LEN) outputs = torch.zeros(len(batch), dtype=torch.long) for i, (sentence, expected) in enumerate(batch): outputs[i] = expected for j, word in enumerate(sentence[:MAX_SENTENCE_LEN]): if word in w2v.wv: vec = w2v.wv[word] inputs[i, :, j] = torch.from_numpy(vec) return inputs, outputs def infer(data_dir): with open(data_dir + '/in.tsv') as fd, open(data_dir + '/out.tsv', 'w+') as ex: for line in tqdm(fd, desc="inferring " + data_dir): comment, _ = line.split('\t') comment = tokenize(comment) comment, _ = collate([(comment, 0)]) comment = comment.to(DEVICE) predicetd = model(comment).argmax(dim=1).item() ex.write(str(predicetd) + '\n') if TRAINING_MODE: DATA = [] with open('train/in.tsv') as fd, open('train/expected.tsv') as ex: k = 0 for line, result in tqdm(zip(fd, ex), desc="preprocessing", total=289579): result = int(result) comment, _ = line.split('\t') DATA.append((tokenize(comment), result)) k+=1 if k == -1: break TEST_DATA = [] with open('dev-0/in.tsv') as fd, open('dev-0/expected.tsv') as ex: k = 0 for line, result in tqdm(zip(fd, ex), desc="test preprocessing", total=5272): result = int(result) comment, _ = line.split('\t') TEST_DATA.append((tokenize(comment), result)) k += 1 if k == -1: break dataloader = torch.utils.data.DataLoader(dataset=DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) test_dataloader = torch.utils.data.DataLoader(dataset=TEST_DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) criterion = torch.nn.NLLLoss() optimizer = torch.optim.Adam(model.parameters()) bar = tqdm(total=len(DATA), desc="training", position=0) test_bar = tqdm(total=len(TEST_DATA), desc="testing", position=0) avg_losses = [] accuracies = [] test_accuracies = [] for epoch in range(1000): avg_loss = 0 bar.reset() accuracy = 0 test_accuracy = 0 total = 0 for in_batch, out_batch in dataloader: in_batch = in_batch.to(DEVICE) outputs = model(in_batch) out_batch = out_batch.to(DEVICE) loss = criterion(outputs, out_batch) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() * BATCH_SIZE bar.update(BATCH_SIZE) accuracy += (outputs.argmax(dim=1) == out_batch).sum().item() total += BATCH_SIZE avg_losses.append(avg_loss / total) accuracies.append(accuracy / total) test_bar.reset() total = 0 for in_batch, out_batch in test_dataloader: in_batch = in_batch.to(DEVICE) outputs = model(in_batch) out_batch = out_batch.to(DEVICE) test_bar.update(BATCH_SIZE) test_accuracy += (outputs.argmax(dim=1) == out_batch).sum().item() total += BATCH_SIZE test_accuracies.append(test_accuracy / total) plt.clf() plt.plot(avg_losses, label='avg loss') plt.plot(accuracies, label='accuracy') plt.plot(test_accuracies, label='test accuracy') print("epoch: "+str(epoch)) print("avg loss: " + str(avg_losses[-1])) print("accuracy: " + str(accuracies[-1])) print("test accuracy: " + str(test_accuracies[-1])) print() plt.legend() plt.pause(0.0001) torch.save(model.state_dict(), 'l2_epoch_' + str(epoch) + ".pth") infer('dev-0') else: model.load_state_dict(torch.load('l2_epoch_0.pth')) model.eval() infer('test-A')