nn - bigrams

This commit is contained in:
adnovac 2022-04-29 09:34:44 +02:00
parent 2a93b66184
commit 27944ca2c8
4 changed files with 18057 additions and 17978 deletions

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ geval
*in.tsv
train_file.txt
model.arpa
model*.bin

File diff suppressed because it is too large Load Diff

166
run.py
View File

@ -1,66 +1,144 @@
#%%
# importy
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import torch
import pandas as pd
import regex as re
import csv
import os
import kenlm
from collections import Counter, defaultdict
from math import log10
import itertools
from os.path import exists
vocab_size = 30000
embed_size = 150
#%%
# funkcje pomocnicze
def clean(text):
text = str(text).lower().strip().replace("", "'").replace('\\n', " ").replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have").replace(",", "").replace("-", "").replace(".", "").replace("'", "".replace("", "").replace(">", ""))
text = str(text).strip().lower()
text = re.sub("|>|<|\.|\\|\"|”|-|,|\*|:|\/", "", text)
text = text.replace('\\n', " ").replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
text = text.replace("'", "")
return text
def get_words_from_line(line):
line = line.rstrip()
yield '<s>'
for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
yield m.group(0).lower()
yield '</s>'
def get_word_lines_from_data(d):
for line in d:
yield get_words_from_line(line)
#%%
class Model(torch.nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(Model, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Embedding(vocabulary_size, embedding_size),
torch.nn.Linear(embedding_size, vocabulary_size),
torch.nn.Softmax()
)
def forward(self, x):
return self.model(x)
#%%
class Trigrams(torch.utils.data.IterableDataset):
def __init__(self, data, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_data(data),
max_tokens = vocabulary_size,
specials = ['<unk>'])
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.data = data
@staticmethod
def look_ahead_iterator(gen):
w1 = None
for item in gen:
if w1 is not None:
yield (w1, item)
w1 = item
def __iter__(self):
return self.look_ahead_iterator(
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_data(self.data))))
#%%
# ładowanie danych treningowych
train_in = pd.read_csv("train/in.tsv.xz", sep='\t', header=None, encoding="UTF-8", on_bad_lines="skip", quoting=csv.QUOTE_NONE, nrows=300000)[[6, 7]]
train_expected = pd.read_csv("train/expected.tsv", sep='\t', header=None, encoding="UTF-8", on_bad_lines="skip", quoting=csv.QUOTE_NONE, nrows=300000)
data = pd.concat([train_in, train_expected], axis=1)
data = data[6] + data[0] + data[7]
data = data.apply(clean)
if not os.path.isfile('train_file.txt'):
with open("train_file.txt", "w+") as f:
for text in data:
f.write(text + "\n")
train_data = pd.concat([train_in, train_expected], axis=1)
train_data = train_data[6] + train_data[0] + train_data[7]
train_data = train_data.apply(clean)
train_dataset = Trigrams(train_data, vocab_size)
#%%
#get_ipython().system('../kenlm/build/bin/lmplz -o 4 < train_file.txt > model.arpa --skip_symbols')
model = kenlm.Model("model.arpa")
# trenowanie/wczytywanie modelu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Model(vocab_size, embed_size).to(device)
if(not exists('model1.bin')):
data = DataLoader(train_dataset, batch_size=200)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
model.train()
step = 0
for x, y in data:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
ypredicted = model(x)
loss = criterion(torch.log(ypredicted), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model1.bin')
else:
model.load_state_dict(torch.load('model1.bin'))
#%%
import nltk
vocab = train_dataset.vocab
def predict(tokens):
ixs = torch.tensor(vocab.forward(tokens)).to(device)
out = model(ixs)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
result = ""
for word, prob in list(zip(top_words, top_probs)):
result += f"{word}:{prob} "
result += f':0.01'
return result
from nltk import word_tokenize
nltk.download('punkt')
most_common = defaultdict(lambda: 0)
for text in data:
words = word_tokenize(text)
if "d" in words:
words.remove("d")
for w in words:
most_common[w] += 1
most_common = Counter(most_common).most_common(8000)
#%%
def predict(path, result_path):
data = pd.read_csv(path, sep='\t', header=None, encoding="UTF-8", on_bad_lines="skip", quoting=csv.QUOTE_NONE)
def predict_file(result_path, data):
with open(result_path, "w+", encoding="UTF-8") as f:
for i, row in data.iterrows():
for row in data:
result = {}
before = word_tokenize(clean(str(row[6])))[-3:]
if(len(before) < 2):
before = word_tokenize(clean(str(row)))[-1:]
if(len(before) < 1):
result = "a:0.2 the:0.2 to:0.2 of:0.1 and:0.1 of:0.1 :0.1"
else:
for w in most_common:
word = w[0]
prob = model.score(" ".join(before + [word]))
result[word] = prob
predictions = dict(Counter(result).most_common(12))
result = ""
for word, prob in predictions.items():
result += f"{word}:{prob} "
result += f':{log10(0.99)}'
result = predict(before)
f.write(result + "\n")
print(result)
#%%
dev_data = pd.read_csv("dev-0/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
dev_data = dev_data.apply(clean)
predict_file("dev-0/out.tsv", dev_data)
predict("dev-0/in.tsv.xz", "dev-0/out.tsv")
predict("test-A/in.tsv.xz", "test-A/out.tsv")
test_data = pd.read_csv("test-A/in.tsv.xz", sep='\t', header=None, quoting=csv.QUOTE_NONE)[6]
test_data = test_data.apply(clean)
predict_file("test-A/out.tsv", test_data)

File diff suppressed because it is too large Load Diff