import multiprocessing as mp import nltk from tqdm import tqdm from functools import partial import kenlm import regex as re from tqdm import tqdm from collections import Counter from english_words import get_english_words_set words = get_english_words_set(['web2'], lower=True, alpha=True) path = 'model_5.binary' language_model = kenlm.Model(path) def clean(text): text = text.replace('-\\n', '').replace('\\n', ' ').replace('\\t', ' ').replace('','s') while ' ' in text: text = text.replace(' ',' ') return re.sub(r'\p{P}', '', text) def generate_file(input_path, expected_path, output_path): with open(input_path) as input_file, open(expected_path) as expected_file, open(output_path, 'w', encoding='utf-8') as output_file: for line, word in zip(input_file, expected_file): columns = line.split('\t') prefix = clean(columns[6]) suffix = clean(columns[7]) train_line = f"{prefix.strip()} {word.strip()} {suffix.strip()}" output_file.write(train_line) #generate_file('train/in.tsv', 'train/expected.tsv', 'train/train.txt') def predict(prefix): scores = {} for word in words: candidate = f"{prefix} {word}".strip() score = language_model.score(candidate, bos=False, eos=False) score_step_lower = language_model.score(f"{prefix.strip()}", bos=False, eos=False) scores[word] = score - score_step_lower highest_probs = Counter(scores).most_common(10) output = '' probs = 0 for word, logprob in highest_probs: prob = 10 ** logprob probs += prob output += f"{word}:{prob} " output += f":{1 - probs}" return output def parse_line(line): columns = line.split('\t') prefix = clean(columns[6]) prefix = nltk.tokenize.word_tokenize(prefix) prefix_input = prefix[-4] + " " + prefix[-3] + " " + prefix[-2] + " " + prefix[-1] result = predict(prefix_input) return result def parse(input_path, output_path='out.tsv'): with open(input_path) as f: lines = f.readlines() with open(output_path, 'w', encoding="utf-8") as output_file: pool = mp.Pool() results = list(tqdm(pool.imap(parse_line, lines), total=len(lines))) for result in results: output_file.write(result + '\n') parse('test-A/in.tsv', output_path="test-A/out.tsv")