petite-difference-challenge2/main.py
2021-02-23 20:27:19 +01:00

99 lines
2.7 KiB
Python

import torch
import torch.nn.functional as F
from torch import nn
import torch.optim as optim
from string import punctuation
from unidecode import unidecode
import pickle
PREDICT_DIR = "test-A"
TRAINED_MODEL_FILE = 'model.bin'
INDEX_FILE = "index.pkl"
train_lines = open('train/in.tsv', 'r', encoding='utf-8').readlines()
result_lines = open('train/expected.tsv', 'r', encoding='utf-8').readlines()
input_file = open(PREDICT_DIR+'/in.tsv', 'r')
out_file = open(PREDICT_DIR+'/out.tsv', 'w')
LABEL_TO_INDEX = {0: 0, 1: 1}
class Classifier(nn.Module):
def __init__(self, num_labels, vocab_size):
super().__init__()
self.linear = nn.Linear(vocab_size, num_labels)
def forward(self, bow_vec):
return F.log_softmax(self.linear(bow_vec), dim=1)
data = []
temp = []
i = 0
lines = len(train_lines)
for i in range(lines):
temp.append(unidecode(train_lines[i].lower()))
for punct in punctuation:
if punct in temp[i]:
temp[i].replace(punct, " ")
only_alpha = list(filter(lambda w: w.isalpha(), temp[i].split()))
data.append((only_alpha, int(result_lines[i])))
word_indexes = {}
for sent, _ in data:
for word in sent:
if word not in word_indexes:
word_indexes[word] = len(word_indexes)
VOCAB_SIZE = len(word_indexes)
NUM_LABELS = 2
model = Classifier(NUM_LABELS, VOCAB_SIZE)
model.to(torch.device('cuda:0'))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
i = 1
data_len = len(data)
for instance, label in data:
model.zero_grad()
vec = torch.zeros(len(word_indexes))
for word in instance:
if word in word_indexes:
vec[word_indexes[word]] += 1
target = torch.LongTensor([LABEL_TO_INDEX[label]])
log_probs = model
loss = loss_function(log_probs, target)
loss.backward()
optimizer.step()
i = i + 1
if i % 10000 == 0:
print(str(epoch) + "/" + str(3) + " " + str(i/data_len))
torch.save(model.state_dict(), TRAINED_MODEL_FILE)
f = open(INDEX_FILE, "wb")
pickle.dump(word_indexes,f)
f.close()
state_dict = torch.load(TRAINED_MODEL_FILE)
linear_weight = state_dict['linear.weight']
linear_bias = state_dict['linear.bias']
model = Classifier(len(linear_bias), len(linear_weight[0]))
model.load_state_dict(torch.load(TRAINED_MODEL_FILE))
model.eval()
with open(INDEX_FILE, 'rb') as file:
word_indexes = pickle.load(file)
with torch.no_grad():
for line in input_file.readlines():
log_probs = model
if log_probs[0][0] > log_probs[0][1]:
out_file.write("0\n")
else:
out_file.write("1\n")
input_file.close()
out_file.close()