85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
import multiprocessing as mp
|
|
import nltk
|
|
from tqdm import tqdm
|
|
from functools import partial
|
|
import kenlm
|
|
import regex as re
|
|
from tqdm import tqdm
|
|
from collections import Counter
|
|
|
|
from english_words import get_english_words_set
|
|
words = get_english_words_set(['web2'], lower=True, alpha=True)
|
|
|
|
path = 'model_5.binary'
|
|
|
|
language_model = kenlm.Model(path)
|
|
|
|
def clean(text):
|
|
text = text.replace('-\\n', '').replace('\\n', ' ').replace('\\t', ' ').replace('<s>','s')
|
|
while ' ' in text:
|
|
text = text.replace(' ',' ')
|
|
|
|
return re.sub(r'\p{P}', '', text)
|
|
|
|
def generate_file(input_path, expected_path, output_path):
|
|
with open(input_path) as input_file, open(expected_path) as expected_file, open(output_path, 'w', encoding='utf-8') as output_file:
|
|
for line, word in zip(input_file, expected_file):
|
|
columns = line.split('\t')
|
|
prefix = clean(columns[6])
|
|
suffix = clean(columns[7])
|
|
|
|
train_line = f"{prefix.strip()} {word.strip()} {suffix.strip()}"
|
|
|
|
output_file.write(train_line)
|
|
|
|
#generate_file('train/in.tsv', 'train/expected.tsv', 'train/train.txt')
|
|
|
|
def predict(prefix):
|
|
scores = {}
|
|
|
|
for word in words:
|
|
candidate = f"{prefix} {word}".strip()
|
|
score = language_model.score(candidate, bos=False, eos=False)
|
|
|
|
score_step_lower = language_model.score(f"{prefix.strip()}", bos=False, eos=False)
|
|
|
|
scores[word] = score - score_step_lower
|
|
|
|
highest_probs = Counter(scores).most_common(10)
|
|
|
|
output = ''
|
|
probs = 0
|
|
for word, logprob in highest_probs:
|
|
prob = 10 ** logprob
|
|
probs += prob
|
|
output += f"{word}:{prob} "
|
|
output += f":{1 - probs}"
|
|
|
|
return output
|
|
|
|
def parse_line(line):
|
|
columns = line.split('\t')
|
|
prefix = clean(columns[6])
|
|
|
|
prefix = nltk.tokenize.word_tokenize(prefix)
|
|
|
|
prefix_input = prefix[-4] + " " + prefix[-3] + " " + prefix[-2] + " " + prefix[-1]
|
|
|
|
result = predict(prefix_input)
|
|
|
|
return result
|
|
|
|
def parse(input_path, output_path='out.tsv'):
|
|
with open(input_path) as f:
|
|
lines = f.readlines()
|
|
|
|
with open(output_path, 'w', encoding="utf-8") as output_file:
|
|
pool = mp.Pool()
|
|
|
|
results = list(tqdm(pool.imap(parse_line, lines), total=len(lines)))
|
|
|
|
for result in results:
|
|
output_file.write(result + '\n')
|
|
|
|
parse('test-A/in.tsv', output_path="test-A/out.tsv")
|