challenging-america-word-ga.../run.py

85 lines
2.2 KiB
Python

import multiprocessing as mp
import nltk
from tqdm import tqdm
from functools import partial
import kenlm
import regex as re
from tqdm import tqdm
from collections import Counter
from english_words import get_english_words_set
words = get_english_words_set(['web2'], lower=True, alpha=True)
path = 'model_5.binary'
language_model = kenlm.Model(path)
def clean(text):
text = text.replace('-\\n', '').replace('\\n', ' ').replace('\\t', ' ').replace('<s>','s')
while ' ' in text:
text = text.replace(' ',' ')
return re.sub(r'\p{P}', '', text)
def generate_file(input_path, expected_path, output_path):
with open(input_path) as input_file, open(expected_path) as expected_file, open(output_path, 'w', encoding='utf-8') as output_file:
for line, word in zip(input_file, expected_file):
columns = line.split('\t')
prefix = clean(columns[6])
suffix = clean(columns[7])
train_line = f"{prefix.strip()} {word.strip()} {suffix.strip()}"
output_file.write(train_line)
#generate_file('train/in.tsv', 'train/expected.tsv', 'train/train.txt')
def predict(prefix):
scores = {}
for word in words:
candidate = f"{prefix} {word}".strip()
score = language_model.score(candidate, bos=False, eos=False)
score_step_lower = language_model.score(f"{prefix.strip()}", bos=False, eos=False)
scores[word] = score - score_step_lower
highest_probs = Counter(scores).most_common(10)
output = ''
probs = 0
for word, logprob in highest_probs:
prob = 10 ** logprob
probs += prob
output += f"{word}:{prob} "
output += f":{1 - probs}"
return output
def parse_line(line):
columns = line.split('\t')
prefix = clean(columns[6])
prefix = nltk.tokenize.word_tokenize(prefix)
prefix_input = prefix[-4] + " " + prefix[-3] + " " + prefix[-2] + " " + prefix[-1]
result = predict(prefix_input)
return result
def parse(input_path, output_path='out.tsv'):
with open(input_path) as f:
lines = f.readlines()
with open(output_path, 'w', encoding="utf-8") as output_file:
pool = mp.Pool()
results = list(tqdm(pool.imap(parse_line, lines), total=len(lines)))
for result in results:
output_file.write(result + '\n')
parse('test-A/in.tsv', output_path="test-A/out.tsv")