From 594a64e45f10d7d134cc7584450046b5ebec7040 Mon Sep 17 00:00:00 2001 From: Norbert Litkowski Date: Mon, 25 Apr 2022 01:19:11 +0200 Subject: [PATCH] add exported python file --- run.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..062a755 --- /dev/null +++ b/run.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[1]: + + +import pandas as pd +from utils import * + + +# In[2]: + + +data = get_csv("train/in.tsv.xz") + + +# In[3]: + + +train_labels = get_csv("train/expected.tsv") + + +# In[4]: + + +train_data = data[[6,7]] + + +# In[5]: + + +train_data = pd.concat([train_data, train_labels], axis=1) + + +# In[6]: + + +train_data[607] = train_data[6] + train_data[0] + train_data[7] + + +# In[7]: + + +train_data[607] = train_data[607].apply(clean_text) + + +# In[8]: + + +train_data[607] + + +# In[15]: + + +with open("tmp", "w+") as f: + for t in train_data[607]: + f.write(t + "\n") + + +# In[10]: + + +KENLM_BUILD_PATH = "../kenlm/build/" +get_ipython().system('$KENLM_BUILD_PATH/bin/lmplz -o 4 < tmp > model.arpa') + + +# In[11]: + + +get_ipython().system('rm tmp') + + +# In[16]: + + +import kenlm +model = kenlm.Model("./model.arpa") + + +# In[23]: + + +get_ipython().system('pip install english_words') + + +# In[24]: + + +from english_words import english_words_alpha_set +from math import log10 + +def predict(before, after): + result = '' + prob = 0.0 + best = [] + for word in english_words_alpha_set: + text = ' '.join([before, word, after]) + text_score = model.score(text, bos=False, eos=False) + if len(best) < 12: + best.append((word, text_score)) + else: + is_better = False + worst_score = None + for score in best: + if not worst_score: + worst_score = score + else: + if worst_score[1] > score[1]: + worst_score = score + if worst_score[1] < text_score: + best.remove(worst_score) + best.append((word, text_score)) + probs = sorted(best, key=lambda tup: tup[1], reverse=True) + pred_str = '' + for word, prob in probs: + pred_str += f'{word}:{prob} ' + pred_str += f':{log10(0.99)}' + return pred_str + + +# In[27]: + + +from nltk import trigrams, word_tokenize + +def make_prediction(path, result_path): + pdata = get_csv(path) + with open(result_path, 'w', encoding='utf-8') as file_out: + for _, row in pdata.iterrows(): + before, after = word_tokenize(clean_text(str(row[6]))), word_tokenize(clean_text(str(row[7]))) + if len(before) < 2 or len(after) < 2: + pred = prediction + else: + pred = predict(before[-1], after[0]) + file_out.write(pred + '\n') + + +# In[28]: + + +make_prediction("dev-0/in.tsv.xz", "dev-0/out.tsv") + + +# In[29]: + + +make_prediction("test-A/in.tsv.xz", "test-A/out.tsv") + + +# In[ ]: + + + +