challenging-america-word-ga.../predict.py
2023-04-13 21:33:24 +02:00

141 lines
3.6 KiB
Python

import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
import json
import pdb
model_v = "1"
PREFIX_VALID = 'test-A'
prob_4gram = {}
with open(f'4_gram_model_{model_v}.tsv', 'r') as f:
for line in f:
line = line.rstrip()
splitted_line = line.split('\t')
prob_4gram[tuple(splitted_line[:3])] = json.loads(splitted_line[-1])
prob_3gram = {}
# with open(f'3_gram_model_{model_v}.tsv', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# prob_3gram[tuple(splitted_line[:2])] = json.loads(splitted_line[-1])
prob_2gram = {}
# with open(f'2_gram_model_{model_v}.tsv', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# prob_2gram[tuple(splitted_line[0])] = json.loads(splitted_line[-1])
vocab = set()
with open(f"vocab_{model_v}.txt", 'r') as f:
for l in f:
vocab.add(l.rstrip())
# probabilities_bi = {}
# with open(f'bigram_big_unk_20', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# probabilities_bi[tuple(splitted_line[:2])] = (float(splitted_line[2]), float(splitted_line[3]))
def count_probabilities(prob_4gram_x, prob_3gram_x, prob_2gram_x, _chunk_left, _chunk_right):
for index, (l, r) in enumerate(zip(_chunk_left, _chunk_right)):
if l not in vocab:
_chunk_left[index] = "<UNK>"
if r not in vocab:
_chunk_right[index] = "<UNK>"
_chunk_left = tuple(_chunk_left)
_chunk_right = tuple(_chunk_right)
hyps_4 = prob_4gram_x.get(_chunk_left)
# if _chunk_left not in prob_3gram_x:
# return {}
# hyps_3 = prob_3gram_x.get(_chunk_left)
# if _chunk_left not in prob_2gram_x:
# return {}
# hyps_2 = prob_2gram_x.get(_chunk_left)
if hyps_4 is None:
return {}
items = hyps_4.items()
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
for t_line in train:
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
t_line = t_line.lower()
t_line = t_line.replace("\\\\n", ' ')
t_line_splitted_by_tab = t_line.split('\t')
words_before = t_line_splitted_by_tab[-2]
words_before = re.findall(r'\p{L}+', words_before)
words_after = t_line_splitted_by_tab[-1]
words_after = re.findall(r'\p{L}+', words_after)
chunk_left = words_before[-3:]
chunk_right = words_after[0:3]
probs_ordered = count_probabilities(prob_4gram, prob_3gram, prob_2gram, chunk_left, chunk_right)
# if len(probs_ordered) !=0:
# print(probs_ordered)
if len(probs_ordered) ==0:
print(f"the:0.1 to:0.1 a:0.1 :0.7")
continue
result_string = ''
counter_ = 0
p_sum = 0
for word_, p in probs_ordered.items():
if counter_>30:
break
re_ = re.search(r'\p{L}+', word_)
if re_:
word_cleared = re_.group(0)
p = p*0.9
p_sum += p
result_string += f"{word_cleared}:{str(p)} "
else:
if result_string == '':
result_string = f"the:0.5 a:0.3 "
continue
counter_+=1
res = 1 - p_sum
result_string += f':{res}'
print(result_string)
a=1