logistic regression
This commit is contained in:
parent
d27ecb30e5
commit
b167482d61
167
Logistic.py
Normal file
167
Logistic.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
from math import log, exp
|
||||||
|
from sklearn.datasets import fetch_20newsgroups
|
||||||
|
import gensim
|
||||||
|
import torch
|
||||||
|
import gensim.downloader as api
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from gensim.models.word2vec import Word2Vec
|
||||||
|
|
||||||
|
corpus = api.load('text8')
|
||||||
|
w2v = Word2Vec(corpus)
|
||||||
|
TRAINING_MODE = False
|
||||||
|
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
obfuscator = re.compile('[\\[?.,!()\\]*&^%$#@{}|\\\\/~\\- \t\n]+')
|
||||||
|
|
||||||
|
MAX_SENTENCE_LEN = 128
|
||||||
|
NUM_CATEGORIES = 2
|
||||||
|
BATCH_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(txt):
|
||||||
|
return [token.lower() for token in obfuscator.sub(' ', txt).split()]
|
||||||
|
|
||||||
|
|
||||||
|
class NetL2(torch.nn.Module): # This model got much better performance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NetL2, self).__init__()
|
||||||
|
self.fc1 = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 512)
|
||||||
|
self.fc2 = torch.nn.Linear(512, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = torch.log_softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NetL1(torch.nn.Module): # This model did not learn well enough
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NetL1, self).__init__()
|
||||||
|
self.fc = torch.nn.Linear(w2v.wv.vector_size * MAX_SENTENCE_LEN, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(-1, w2v.wv.vector_size * MAX_SENTENCE_LEN)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = torch.log_softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
model = NetL2().to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
def collate(batch: [(str, int)]):
|
||||||
|
inputs = torch.zeros(len(batch), w2v.wv.vector_size, MAX_SENTENCE_LEN)
|
||||||
|
outputs = torch.zeros(len(batch), dtype=torch.long)
|
||||||
|
for i, (sentence, expected) in enumerate(batch):
|
||||||
|
outputs[i] = expected
|
||||||
|
for j, word in enumerate(sentence[:MAX_SENTENCE_LEN]):
|
||||||
|
if word in w2v.wv:
|
||||||
|
vec = w2v.wv[word]
|
||||||
|
inputs[i, :, j] = torch.from_numpy(vec)
|
||||||
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
|
def infer(data_dir):
|
||||||
|
with open(data_dir + '/in.tsv') as fd, open(data_dir + '/out.tsv', 'w+') as ex:
|
||||||
|
for line in tqdm(fd, desc="inferring " + data_dir):
|
||||||
|
comment, _ = line.split('\t')
|
||||||
|
comment = tokenize(comment)
|
||||||
|
comment, _ = collate([(comment, 0)])
|
||||||
|
comment = comment.to(DEVICE)
|
||||||
|
predicetd = model(comment).argmax(dim=1).item()
|
||||||
|
ex.write(str(predicetd) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if TRAINING_MODE:
|
||||||
|
|
||||||
|
DATA = []
|
||||||
|
|
||||||
|
with open('train/in.tsv') as fd, open('train/expected.tsv') as ex:
|
||||||
|
k = 0
|
||||||
|
for line, result in tqdm(zip(fd, ex), desc="preprocessing", total=289579):
|
||||||
|
result = int(result)
|
||||||
|
comment, _ = line.split('\t')
|
||||||
|
DATA.append((tokenize(comment), result))
|
||||||
|
k+=1
|
||||||
|
if k == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
TEST_DATA = []
|
||||||
|
|
||||||
|
with open('dev-0/in.tsv') as fd, open('dev-0/expected.tsv') as ex:
|
||||||
|
k = 0
|
||||||
|
for line, result in tqdm(zip(fd, ex), desc="test preprocessing", total=5272):
|
||||||
|
result = int(result)
|
||||||
|
comment, _ = line.split('\t')
|
||||||
|
TEST_DATA.append((tokenize(comment), result))
|
||||||
|
k += 1
|
||||||
|
if k == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset=DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
test_dataloader = torch.utils.data.DataLoader(dataset=TEST_DATA, collate_fn=collate, batch_size=BATCH_SIZE, shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
|
||||||
|
criterion = torch.nn.NLLLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
bar = tqdm(total=len(DATA), desc="training", position=0)
|
||||||
|
test_bar = tqdm(total=len(TEST_DATA), desc="testing", position=0)
|
||||||
|
avg_losses = []
|
||||||
|
accuracies = []
|
||||||
|
test_accuracies = []
|
||||||
|
for epoch in range(1000):
|
||||||
|
avg_loss = 0
|
||||||
|
bar.reset()
|
||||||
|
accuracy = 0
|
||||||
|
test_accuracy = 0
|
||||||
|
total = 0
|
||||||
|
for in_batch, out_batch in dataloader:
|
||||||
|
in_batch = in_batch.to(DEVICE)
|
||||||
|
outputs = model(in_batch)
|
||||||
|
out_batch = out_batch.to(DEVICE)
|
||||||
|
loss = criterion(outputs, out_batch)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
avg_loss += loss.item() * BATCH_SIZE
|
||||||
|
bar.update(BATCH_SIZE)
|
||||||
|
accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
|
||||||
|
total += BATCH_SIZE
|
||||||
|
avg_losses.append(avg_loss / total)
|
||||||
|
accuracies.append(accuracy / total)
|
||||||
|
test_bar.reset()
|
||||||
|
total = 0
|
||||||
|
for in_batch, out_batch in test_dataloader:
|
||||||
|
in_batch = in_batch.to(DEVICE)
|
||||||
|
outputs = model(in_batch)
|
||||||
|
out_batch = out_batch.to(DEVICE)
|
||||||
|
test_bar.update(BATCH_SIZE)
|
||||||
|
test_accuracy += (outputs.argmax(dim=1) == out_batch).sum().item()
|
||||||
|
total += BATCH_SIZE
|
||||||
|
test_accuracies.append(test_accuracy / total)
|
||||||
|
plt.clf()
|
||||||
|
plt.plot(avg_losses, label='avg loss')
|
||||||
|
plt.plot(accuracies, label='accuracy')
|
||||||
|
plt.plot(test_accuracies, label='test accuracy')
|
||||||
|
print("epoch: "+str(epoch))
|
||||||
|
print("avg loss: " + str(avg_losses[-1]))
|
||||||
|
print("accuracy: " + str(accuracies[-1]))
|
||||||
|
print("test accuracy: " + str(test_accuracies[-1]))
|
||||||
|
print()
|
||||||
|
plt.legend()
|
||||||
|
plt.pause(0.0001)
|
||||||
|
torch.save(model.state_dict(), 'l2_epoch_' + str(epoch) + ".pth")
|
||||||
|
infer('dev-0')
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load('l2_epoch_0.pth'))
|
||||||
|
model.eval()
|
||||||
|
infer('test-A')
|
16
README.md
16
README.md
@ -7,6 +7,22 @@ Classify a reddit as either from Skeptic subreddit or one of the
|
|||||||
|
|
||||||
Output label is the probability of a paranormal subreddit.
|
Output label is the probability of a paranormal subreddit.
|
||||||
|
|
||||||
|
# Pytorch logistic regression
|
||||||
|
|
||||||
|
The code can be found in Logistic.py
|
||||||
|
Trained models end with .pth extension.
|
||||||
|
Geval results:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ ./geval -t dev-0
|
||||||
|
Likelihood 0.0000
|
||||||
|
Accuracy 0.7043
|
||||||
|
F1.0 0.4950
|
||||||
|
Precision 0.6257
|
||||||
|
Recall 0.4094
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
Sources
|
Sources
|
||||||
-------
|
-------
|
||||||
|
|
||||||
|
1376
dev-0/out.tsv
1376
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
BIN
l1_epoch_6.pth
Normal file
BIN
l1_epoch_6.pth
Normal file
Binary file not shown.
44
l1_epochs.txt
Normal file
44
l1_epochs.txt
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1562.58it/s]epoch: 0
|
||||||
|
avg loss: 0.6662371538425887
|
||||||
|
accuracy: 0.6769724323932329
|
||||||
|
test accuracy: 0.6496585735963581
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:04<00:00, 1485.70it/s]epoch: 1
|
||||||
|
avg loss: 0.6602801650684239
|
||||||
|
accuracy: 0.6871561819054558
|
||||||
|
test accuracy: 0.6574355083459787
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1509.15it/s]epoch: 2
|
||||||
|
avg loss: 0.6611704620506444
|
||||||
|
accuracy: 0.6899844256662258
|
||||||
|
test accuracy: 0.6566767830045523
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1471.36it/s]epoch: 3
|
||||||
|
avg loss: 0.6616791084902397
|
||||||
|
accuracy: 0.6911412775097642
|
||||||
|
test accuracy: 0.6638846737481032
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1579.47it/s]epoch: 4
|
||||||
|
avg loss: 0.6610813111163456
|
||||||
|
accuracy: 0.6913588347221311
|
||||||
|
test accuracy: 0.6403641881638846
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1561.11it/s]epoch: 5
|
||||||
|
avg loss: 0.6612948830510013
|
||||||
|
accuracy: 0.6919113609757613
|
||||||
|
test accuracy: 0.6553490136570561
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1496.72it/s]epoch: 6
|
||||||
|
avg loss: 0.662789237758215
|
||||||
|
accuracy: 0.6916558175834574
|
||||||
|
test accuracy: 0.6688163884673748 <--- this is the best we managed to get
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1531.05it/s]epoch: 7
|
||||||
|
avg loss: 0.6635299078017594
|
||||||
|
accuracy: 0.6916730840288833
|
||||||
|
test accuracy: 0.6525037936267072
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1539.94it/s]epoch: 8
|
||||||
|
avg loss: 0.6665739091241887
|
||||||
|
accuracy: 0.6917179767869908
|
||||||
|
test accuracy: 0.6676783004552352
|
BIN
l2_epoch_0.pth
Normal file
BIN
l2_epoch_0.pth
Normal file
Binary file not shown.
9
l2_epochs.txt
Normal file
9
l2_epochs.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1496.45it/s]epoch: 0
|
||||||
|
avg loss: 0.5669572093742932
|
||||||
|
accuracy: 0.705209252052117
|
||||||
|
test accuracy: 0.7043400606980273
|
||||||
|
|
||||||
|
testing: 97%|█████████▋| 5120/5272 [00:03<00:00, 1500.00it/s]epoch: 1
|
||||||
|
avg loss: 0.5200198413747087
|
||||||
|
accuracy: 0.7267515945562351
|
||||||
|
test accuracy: 0.6817147192716236 <-- starts over-fitting
|
2338
test-A/out.tsv
2338
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user