challenging-america-word-ga.../run.py
2022-04-29 09:34:44 +02:00

144 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#%%
# importy
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import torch
import pandas as pd
import regex as re
import csv
import itertools
from os.path import exists
vocab_size = 30000
embed_size = 150
#%%
# funkcje pomocnicze
def clean(text):
text = str(text).strip().lower()
text = re.sub("|>|<|\.|\\|\"|”|-|,|\*|:|\/", "", text)
text = text.replace('\\n', " ").replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
text = text.replace("'", "")
return text
def get_words_from_line(line):
line = line.rstrip()
yield '<s>'
for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
yield m.group(0).lower()
yield '</s>'
def get_word_lines_from_data(d):
for line in d:
yield get_words_from_line(line)
#%%
class Model(torch.nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(Model, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Embedding(vocabulary_size, embedding_size),
torch.nn.Linear(embedding_size, vocabulary_size),
torch.nn.Softmax()
)
def forward(self, x):
return self.model(x)
#%%
class Trigrams(torch.utils.data.IterableDataset):
def __init__(self, data, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_data(data),
max_tokens = vocabulary_size,
specials = ['<unk>'])
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.data = data
@staticmethod
def look_ahead_iterator(gen):
w1 = None
for item in gen:
if w1 is not None:
yield (w1, item)
w1 = item
def __iter__(self):
return self.look_ahead_iterator(
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_data(self.data))))
#%%
# ładowanie danych treningowych
train_in = pd.read_csv("train/in.tsv.xz", sep='\t', header=None, encoding="UTF-8", on_bad_lines="skip", quoting=csv.QUOTE_NONE, nrows=300000)[[6, 7]]
train_expected = pd.read_csv("train/expected.tsv", sep='\t', header=None, encoding="UTF-8", on_bad_lines="skip", quoting=csv.QUOTE_NONE, nrows=300000)
train_data = pd.concat([train_in, train_expected], axis=1)
train_data = train_data[6] + train_data[0] + train_data[7]
train_data = train_data.apply(clean)
train_dataset = Trigrams(train_data, vocab_size)
#%%
# trenowanie/wczytywanie modelu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(vocab_size, embed_size).to(device)
if(not exists('model1.bin')):
data = DataLoader(train_dataset, batch_size=200)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
model.train()
step = 0
for x, y in data:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
ypredicted = model(x)
loss = criterion(torch.log(ypredicted), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model1.bin')
else:
model.load_state_dict(torch.load('model1.bin'))
#%%
vocab = train_dataset.vocab
def predict(tokens):
ixs = torch.tensor(vocab.forward(tokens)).to(device)
out = model(ixs)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
result = ""
for word, prob in list(zip(top_words, top_probs)):
result += f"{word}:{prob} "
result += f':0.01'
return result
from nltk import word_tokenize
def predict_file(result_path, data):
with open(result_path, "w+", encoding="UTF-8") as f:
for row in data:
result = {}
before = word_tokenize(clean(str(row)))[-1:]
if(len(before) < 1):
result = "a:0.2 the:0.2 to:0.2 of:0.1 and:0.1 of:0.1 :0.1"
else:
result = predict(before)
f.write(result + "\n")
print(result)
#%%
dev_data = pd.read_csv("dev-0/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
dev_data = dev_data.apply(clean)
predict_file("dev-0/out.tsv", dev_data)
test_data = pd.read_csv("test-A/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
test_data = test_data.apply(clean)
predict_file("test-A/out.tsv", test_data)