from tqdm import tqdm import regex as re from nltk.tokenize import word_tokenize from english_words import get_english_words_set import kenlm from math import log10 import pickle path = 'kenlm_model.binary' model = kenlm.Model(path) with open('V.pickle', 'rb') as handle: V_counter = pickle.load(handle) def clean_string(text): text = text.lower() text = re.sub(r" -\\*\\n", "", text) text = re.sub(r"\\n", " ", text) text = text.strip() return text def predict_probs(w1, w3): best_scores = [] pred_str = "" # for word in get_english_words_set(['web2'], lower=True): for word in V_counter: text = ' '.join([w1, word, w3]) text_score = model.score(text, bos=False, eos=False) if len(best_scores) < 5: best_scores.append((word, text_score)) else: worst_score = best_scores[-1] if worst_score[1] < text_score: best_scores[-1] = (word, text_score) best_scores = sorted(best_scores, key=lambda tup: tup[1], reverse=True) for word, prob in best_scores: pred_str += f'{word}:{prob} ' pred_str += f':{log10(0.99)}' return pred_str def get_word_predictions(w1, w2,): for word in get_english_words_set(['web2'], lower=True): sentence = w1 + ' ' + word + ' ' + w2 text_score = model.score(sentence, bos=False, eos=False) yield((word, text_score)) def argmax(w1,w2): # get top 10 predictions from predict_line top_10 = sorted(list(get_word_predictions(w1,w2)), key=lambda x: -x[1])[:4] output_line = " ".join(["{}:{:.8f}".format(w, p) for w, p in top_10]) return output_line def run_predictions(source_folder): print(f"Run predictions on {source_folder} data...") with open(f"{source_folder}/in.tsv", encoding="utf8", mode="rt") as file: train_data = file.readlines() with open(f"{source_folder}/out.tsv", "w", encoding="utf-8") as output_file: for line in tqdm(train_data): line = line.split("\t") l1 = clean_string(line[-2]) l2 = clean_string(line[-1]) if not l1 or not l2: out_line = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1" else: w1 = word_tokenize(l1)[-1] w2 = word_tokenize(l2)[0] out_line = predict_probs(w1, w2) output_file.write(out_line + "\n") run_predictions("dev-0") run_predictions("test-A")