2022-04-25 01:19:11 +02:00

156 lines
2.5 KiB

#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pandas as pd
from utils import *
# In[2]:
data = get_csv("train/in.tsv.xz")
# In[3]:
train_labels = get_csv("train/expected.tsv")
# In[4]:
train_data = data[[6,7]]
# In[5]:
train_data = pd.concat([train_data, train_labels], axis=1)
# In[6]:
train_data[607] = train_data[6] + train_data[0] + train_data[7]
# In[7]:
train_data[607] = train_data[607].apply(clean_text)
# In[8]:
# In[15]:
with open("tmp", "w+") as f:
for t in train_data[607]:
f.write(t + "\n")
# In[10]:
KENLM_BUILD_PATH = "../kenlm/build/"
get_ipython().system('$KENLM_BUILD_PATH/bin/lmplz -o 4 < tmp >')
# In[11]:
get_ipython().system('rm tmp')
# In[16]:
import kenlm
model = kenlm.Model("./")
# In[23]:
get_ipython().system('pip install english_words')
# In[24]:
from english_words import english_words_alpha_set
from math import log10
def predict(before, after):
result = ''
prob = 0.0
best = []
for word in english_words_alpha_set:
text = ' '.join([before, word, after])
text_score = model.score(text, bos=False, eos=False)
if len(best) < 12:
best.append((word, text_score))
is_better = False
worst_score = None
for score in best:
if not worst_score:
worst_score = score
if worst_score[1] > score[1]:
worst_score = score
if worst_score[1] < text_score:
best.append((word, text_score))
probs = sorted(best, key=lambda tup: tup[1], reverse=True)
pred_str = ''
for word, prob in probs:
pred_str += f'{word}:{prob} '
pred_str += f':{log10(0.99)}'
return pred_str
# In[27]:
from nltk import trigrams, word_tokenize
def make_prediction(path, result_path):
pdata = get_csv(path)
with open(result_path, 'w', encoding='utf-8') as file_out:
for _, row in pdata.iterrows():
before, after = word_tokenize(clean_text(str(row[6]))), word_tokenize(clean_text(str(row[7])))
if len(before) < 2 or len(after) < 2:
pred = prediction
pred = predict(before[-1], after[0])
file_out.write(pred + '\n')
# In[28]:
make_prediction("dev-0/in.tsv.xz", "dev-0/out.tsv")
# In[29]:
make_prediction("test-A/in.tsv.xz", "test-A/out.tsv")
# In[ ]: