100 lines
2.5 KiB
Python
100 lines
2.5 KiB
Python
from tqdm import tqdm
|
|
import regex as re
|
|
from english_words import get_english_words_set
|
|
import kenlm
|
|
import pickle
|
|
import math
|
|
import numpy as np
|
|
|
|
path = 'kenlm_model.binary'
|
|
model = kenlm.Model(path)
|
|
|
|
CONTRACTIONS = {
|
|
"I'm": "I am",
|
|
"you're": "you are",
|
|
"he's": "he is",
|
|
"she's": "she is",
|
|
"it's": "it is",
|
|
"we're": "we are",
|
|
"they're": "they are",
|
|
"aren't": "are not",
|
|
"don't": "do not",
|
|
"doesn't": "does not",
|
|
"weren't": "were not",
|
|
"'ll": " will",
|
|
}
|
|
|
|
|
|
def formalize_text(text):
|
|
# Replace contractions using regular expressions
|
|
pattern = re.compile(r'\b(' + '|'.join(CONTRACTIONS.keys()) + r')\b')
|
|
text = pattern.sub(lambda x: CONTRACTIONS[x.group()], text)
|
|
|
|
# Remove hyphens at the end of lines and replace newlines with spaces
|
|
text = text.replace('-\n', '')
|
|
text = text.replace('\n', ' ')
|
|
|
|
return text
|
|
|
|
|
|
def clean_string(text):
|
|
text = formalize_text(text)
|
|
text = re.sub(r" -\\*\\n", "", text)
|
|
text = re.sub(r"\\n", " ", text)
|
|
text = text.strip()
|
|
return text
|
|
|
|
|
|
def p(text):
|
|
return 1 / (1 + math.exp(-(model.score(text, bos=False, eos=False))))
|
|
|
|
|
|
def perplexity(text):
|
|
return model.perplexity(text)
|
|
|
|
|
|
def predict_probs_w1w2wi(w1, w2):
|
|
best_scores = []
|
|
pred_str = ""
|
|
for word in V_counter:
|
|
w1w2 = ' '.join([w2, word])
|
|
w1w2w3 = ' '.join([w1, w2, word])
|
|
|
|
text_score = 0.1 * p(word) + 0.3 * p(w1w2) + 0.6 * p(w1w2w3)
|
|
|
|
if len(best_scores) < 5:
|
|
best_scores.append((word, text_score))
|
|
else:
|
|
worst_score = best_scores[-1]
|
|
if worst_score[1] < text_score:
|
|
best_scores[-1] = (word, text_score)
|
|
best_scores = sorted(best_scores, key=lambda tup: tup[1], reverse=True)
|
|
|
|
for word, prob in best_scores:
|
|
pred_str += f'{word}:{prob} '
|
|
pred_str += f':{1 - sum([p for _, p in best_scores])}'
|
|
return pred_str
|
|
|
|
|
|
def run_predictions(source_folder):
|
|
print(f"Run predictions on {source_folder} data...")
|
|
|
|
with open(f"{source_folder}/in.tsv", encoding="utf8", mode="rt") as file:
|
|
train_data = file.readlines()
|
|
|
|
with open(f"{source_folder}/out.tsv", "w", encoding="utf-8") as output_file:
|
|
for line in tqdm(train_data):
|
|
line = line.split("\t")
|
|
|
|
w1, w2 = clean_string(line[-2]).split()[-2:]
|
|
out_line = predict_probs_w1w2wi(w1, w2)
|
|
|
|
output_file.write(out_line + "\n")
|
|
|
|
|
|
with open('V_3000.pickle', 'rb') as handle:
|
|
V_counter = pickle.load(handle)
|
|
|
|
run_predictions("../dev-0")
|
|
# run_predictions("../test-A")
|