141 lines
3.6 KiB
Python
141 lines
3.6 KiB
Python
import lzma
|
|
import matplotlib.pyplot as plt
|
|
from math import log
|
|
from collections import OrderedDict
|
|
from collections import Counter
|
|
import regex as re
|
|
from itertools import islice
|
|
import json
|
|
import pdb
|
|
|
|
model_v = "1"
|
|
PREFIX_VALID = 'test-A'
|
|
|
|
|
|
prob_4gram = {}
|
|
with open(f'4_gram_model_{model_v}.tsv', 'r') as f:
|
|
for line in f:
|
|
line = line.rstrip()
|
|
splitted_line = line.split('\t')
|
|
|
|
prob_4gram[tuple(splitted_line[:3])] = json.loads(splitted_line[-1])
|
|
prob_3gram = {}
|
|
# with open(f'3_gram_model_{model_v}.tsv', 'r') as f:
|
|
# for line in f:
|
|
# line = line.rstrip()
|
|
# splitted_line = line.split('\t')
|
|
|
|
# prob_3gram[tuple(splitted_line[:2])] = json.loads(splitted_line[-1])
|
|
|
|
prob_2gram = {}
|
|
# with open(f'2_gram_model_{model_v}.tsv', 'r') as f:
|
|
# for line in f:
|
|
# line = line.rstrip()
|
|
# splitted_line = line.split('\t')
|
|
|
|
# prob_2gram[tuple(splitted_line[0])] = json.loads(splitted_line[-1])
|
|
|
|
|
|
|
|
vocab = set()
|
|
with open(f"vocab_{model_v}.txt", 'r') as f:
|
|
for l in f:
|
|
vocab.add(l.rstrip())
|
|
|
|
# probabilities_bi = {}
|
|
# with open(f'bigram_big_unk_20', 'r') as f:
|
|
|
|
# for line in f:
|
|
# line = line.rstrip()
|
|
# splitted_line = line.split('\t')
|
|
|
|
# probabilities_bi[tuple(splitted_line[:2])] = (float(splitted_line[2]), float(splitted_line[3]))
|
|
|
|
|
|
def count_probabilities(prob_4gram_x, prob_3gram_x, prob_2gram_x, _chunk_left, _chunk_right):
|
|
|
|
for index, (l, r) in enumerate(zip(_chunk_left, _chunk_right)):
|
|
if l not in vocab:
|
|
_chunk_left[index] = "<UNK>"
|
|
if r not in vocab:
|
|
_chunk_right[index] = "<UNK>"
|
|
_chunk_left = tuple(_chunk_left)
|
|
_chunk_right = tuple(_chunk_right)
|
|
|
|
|
|
|
|
hyps_4 = prob_4gram_x.get(_chunk_left)
|
|
|
|
# if _chunk_left not in prob_3gram_x:
|
|
# return {}
|
|
# hyps_3 = prob_3gram_x.get(_chunk_left)
|
|
|
|
# if _chunk_left not in prob_2gram_x:
|
|
# return {}
|
|
# hyps_2 = prob_2gram_x.get(_chunk_left)
|
|
|
|
if hyps_4 is None:
|
|
return {}
|
|
|
|
items = hyps_4.items()
|
|
|
|
|
|
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
|
|
|
|
|
|
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
|
|
for t_line in train:
|
|
t_line = t_line.decode("utf-8")
|
|
|
|
t_line = t_line.rstrip()
|
|
|
|
t_line = t_line.lower()
|
|
|
|
t_line = t_line.replace("\\\\n", ' ')
|
|
|
|
t_line_splitted_by_tab = t_line.split('\t')
|
|
|
|
words_before = t_line_splitted_by_tab[-2]
|
|
|
|
words_before = re.findall(r'\p{L}+', words_before)
|
|
|
|
words_after = t_line_splitted_by_tab[-1]
|
|
words_after = re.findall(r'\p{L}+', words_after)
|
|
|
|
chunk_left = words_before[-3:]
|
|
chunk_right = words_after[0:3]
|
|
|
|
probs_ordered = count_probabilities(prob_4gram, prob_3gram, prob_2gram, chunk_left, chunk_right)
|
|
|
|
# if len(probs_ordered) !=0:
|
|
# print(probs_ordered)
|
|
|
|
if len(probs_ordered) ==0:
|
|
print(f"the:0.1 to:0.1 a:0.1 :0.7")
|
|
continue
|
|
result_string = ''
|
|
counter_ = 0
|
|
p_sum = 0
|
|
|
|
for word_, p in probs_ordered.items():
|
|
|
|
|
|
if counter_>30:
|
|
break
|
|
re_ = re.search(r'\p{L}+', word_)
|
|
if re_:
|
|
word_cleared = re_.group(0)
|
|
p = p*0.9
|
|
p_sum += p
|
|
result_string += f"{word_cleared}:{str(p)} "
|
|
else:
|
|
if result_string == '':
|
|
result_string = f"the:0.5 a:0.3 "
|
|
continue
|
|
|
|
counter_+=1
|
|
res = 1 - p_sum
|
|
result_string += f':{res}'
|
|
print(result_string)
|
|
a=1
|