99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
import torch.optim as optim
|
|
from string import punctuation
|
|
from unidecode import unidecode
|
|
import pickle
|
|
|
|
PREDICT_DIR = "test-A"
|
|
TRAINED_MODEL_FILE = 'model.bin'
|
|
INDEX_FILE = "index.pkl"
|
|
train_lines = open('train/in.tsv', 'r', encoding='utf-8').readlines()
|
|
result_lines = open('train/expected.tsv', 'r', encoding='utf-8').readlines()
|
|
input_file = open(PREDICT_DIR+'/in.tsv', 'r')
|
|
out_file = open(PREDICT_DIR+'/out.tsv', 'w')
|
|
LABEL_TO_INDEX = {0: 0, 1: 1}
|
|
|
|
|
|
class Classifier(nn.Module):
|
|
def __init__(self, num_labels, vocab_size):
|
|
super().__init__()
|
|
self.linear = nn.Linear(vocab_size, num_labels)
|
|
|
|
def forward(self, bow_vec):
|
|
return F.log_softmax(self.linear(bow_vec), dim=1)
|
|
|
|
|
|
data = []
|
|
temp = []
|
|
i = 0
|
|
lines = len(train_lines)
|
|
for i in range(lines):
|
|
temp.append(unidecode(train_lines[i].lower()))
|
|
for punct in punctuation:
|
|
if punct in temp[i]:
|
|
temp[i].replace(punct, " ")
|
|
only_alpha = list(filter(lambda w: w.isalpha(), temp[i].split()))
|
|
data.append((only_alpha, int(result_lines[i])))
|
|
|
|
word_indexes = {}
|
|
for sent, _ in data:
|
|
for word in sent:
|
|
if word not in word_indexes:
|
|
word_indexes[word] = len(word_indexes)
|
|
|
|
VOCAB_SIZE = len(word_indexes)
|
|
NUM_LABELS = 2
|
|
|
|
|
|
model = Classifier(NUM_LABELS, VOCAB_SIZE)
|
|
model.to(torch.device('cuda:0'))
|
|
loss_function = nn.NLLLoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
|
|
|
for epoch in range(3):
|
|
i = 1
|
|
data_len = len(data)
|
|
for instance, label in data:
|
|
model.zero_grad()
|
|
vec = torch.zeros(len(word_indexes))
|
|
for word in instance:
|
|
if word in word_indexes:
|
|
vec[word_indexes[word]] += 1
|
|
target = torch.LongTensor([LABEL_TO_INDEX[label]])
|
|
log_probs = model
|
|
loss = loss_function(log_probs, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
i = i + 1
|
|
if i % 10000 == 0:
|
|
print(str(epoch) + "/" + str(3) + " " + str(i/data_len))
|
|
|
|
torch.save(model.state_dict(), TRAINED_MODEL_FILE)
|
|
f = open(INDEX_FILE, "wb")
|
|
pickle.dump(word_indexes,f)
|
|
f.close()
|
|
|
|
state_dict = torch.load(TRAINED_MODEL_FILE)
|
|
linear_weight = state_dict['linear.weight']
|
|
linear_bias = state_dict['linear.bias']
|
|
model = Classifier(len(linear_bias), len(linear_weight[0]))
|
|
model.load_state_dict(torch.load(TRAINED_MODEL_FILE))
|
|
model.eval()
|
|
|
|
with open(INDEX_FILE, 'rb') as file:
|
|
word_indexes = pickle.load(file)
|
|
|
|
with torch.no_grad():
|
|
for line in input_file.readlines():
|
|
log_probs = model
|
|
if log_probs[0][0] > log_probs[0][1]:
|
|
out_file.write("0\n")
|
|
else:
|
|
out_file.write("1\n")
|
|
|
|
input_file.close()
|
|
out_file.close()
|
|
|