challenging-america-word-ga.../predict.py

141 lines
3.6 KiB
Python
Raw Normal View History

2023-04-12 20:56:08 +02:00
import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
import json
import pdb
2023-04-13 21:33:24 +02:00
model_v = "1"
2023-04-12 20:56:08 +02:00
PREFIX_VALID = 'test-A'
2023-04-13 21:33:24 +02:00
prob_4gram = {}
with open(f'4_gram_model_{model_v}.tsv', 'r') as f:
2023-04-12 20:56:08 +02:00
for line in f:
line = line.rstrip()
splitted_line = line.split('\t')
2023-04-13 21:33:24 +02:00
prob_4gram[tuple(splitted_line[:3])] = json.loads(splitted_line[-1])
prob_3gram = {}
# with open(f'3_gram_model_{model_v}.tsv', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# prob_3gram[tuple(splitted_line[:2])] = json.loads(splitted_line[-1])
prob_2gram = {}
# with open(f'2_gram_model_{model_v}.tsv', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# prob_2gram[tuple(splitted_line[0])] = json.loads(splitted_line[-1])
2023-04-12 20:56:08 +02:00
vocab = set()
with open(f"vocab_{model_v}.txt", 'r') as f:
for l in f:
vocab.add(l.rstrip())
2023-04-13 21:33:24 +02:00
# probabilities_bi = {}
# with open(f'bigram_big_unk_20', 'r') as f:
# for line in f:
# line = line.rstrip()
# splitted_line = line.split('\t')
# probabilities_bi[tuple(splitted_line[:2])] = (float(splitted_line[2]), float(splitted_line[3]))
def count_probabilities(prob_4gram_x, prob_3gram_x, prob_2gram_x, _chunk_left, _chunk_right):
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
for index, (l, r) in enumerate(zip(_chunk_left, _chunk_right)):
2023-04-12 20:56:08 +02:00
if l not in vocab:
_chunk_left[index] = "<UNK>"
if r not in vocab:
_chunk_right[index] = "<UNK>"
2023-04-13 21:33:24 +02:00
_chunk_left = tuple(_chunk_left)
_chunk_right = tuple(_chunk_right)
hyps_4 = prob_4gram_x.get(_chunk_left)
# if _chunk_left not in prob_3gram_x:
# return {}
# hyps_3 = prob_3gram_x.get(_chunk_left)
# if _chunk_left not in prob_2gram_x:
# return {}
# hyps_2 = prob_2gram_x.get(_chunk_left)
if hyps_4 is None:
return {}
items = hyps_4.items()
2023-04-12 20:56:08 +02:00
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
for t_line in train:
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
t_line = t_line.lower()
2023-04-13 21:33:24 +02:00
t_line = t_line.replace("\\\\n", ' ')
2023-04-12 20:56:08 +02:00
t_line_splitted_by_tab = t_line.split('\t')
words_before = t_line_splitted_by_tab[-2]
words_before = re.findall(r'\p{L}+', words_before)
words_after = t_line_splitted_by_tab[-1]
words_after = re.findall(r'\p{L}+', words_after)
chunk_left = words_before[-3:]
chunk_right = words_after[0:3]
2023-04-13 21:33:24 +02:00
probs_ordered = count_probabilities(prob_4gram, prob_3gram, prob_2gram, chunk_left, chunk_right)
2023-04-12 20:56:08 +02:00
# if len(probs_ordered) !=0:
# print(probs_ordered)
if len(probs_ordered) ==0:
print(f"the:0.1 to:0.1 a:0.1 :0.7")
continue
result_string = ''
counter_ = 0
2023-04-13 21:33:24 +02:00
p_sum = 0
2023-04-12 20:56:08 +02:00
for word_, p in probs_ordered.items():
2023-04-13 21:33:24 +02:00
if counter_>30:
2023-04-12 20:56:08 +02:00
break
re_ = re.search(r'\p{L}+', word_)
if re_:
word_cleared = re_.group(0)
2023-04-13 21:33:24 +02:00
p = p*0.9
p_sum += p
2023-04-12 20:56:08 +02:00
result_string += f"{word_cleared}:{str(p)} "
else:
if result_string == '':
result_string = f"the:0.5 a:0.3 "
continue
counter_+=1
2023-04-13 21:33:24 +02:00
res = 1 - p_sum
result_string += f':{res}'
2023-04-12 20:56:08 +02:00
print(result_string)
a=1