challenging-america-word-ga.../predict.py
Mikołaj Pokrywka 65d889d652 tetragram
2023-04-12 20:56:08 +02:00

101 lines
2.9 KiB
Python

import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
import json
import pdb
model_v = "4000"
PREFIX_VALID = 'test-A'
probabilities = {}
with open(f'model_{model_v}.tsv', 'r') as f:
for line in f:
line = line.rstrip()
splitted_line = line.split('\t')
probabilities[tuple(splitted_line[:4])] = (float(splitted_line[4]), float(splitted_line[5]))
vocab = set()
with open(f"vocab_{model_v}.txt", 'r') as f:
for l in f:
vocab.add(l.rstrip())
def count_probabilities(_probabilities, _chunk_left, _chunk_right):
for index, (l, r) in enumerate(zip( _chunk_left, _chunk_right)):
if l not in vocab:
_chunk_left[index] = "<UNK>"
if r not in vocab:
_chunk_right[index] = "<UNK>"
results_left = {}
best_ = {}
for tetragram, probses in _probabilities.items():
if tetragram[-1] == "<UNK>":
return best_
if len(results_left) > 2:
break
if list(tetragram[:3]) == _chunk_left:
# for tetragram_2, probses_2 in _probabilities.items():
# if list(tetragram_2[1:]) == _chunk_right:
# best_[tetragram[-1]] = probses[0] * probses_2[1]
if tetragram[-1] not in best_:
best_[tetragram[-1]] = probses[0] * 0.7
items = best_.items()
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
for t_line in train:
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
t_line = t_line.lower()
t_line_splitted_by_tab = t_line.split('\t')
words_before = t_line_splitted_by_tab[-2]
words_before = re.findall(r'\p{L}+', words_before)
words_after = t_line_splitted_by_tab[-1]
words_after = re.findall(r'\p{L}+', words_after)
chunk_left = words_before[-3:]
chunk_right = words_after[0:3]
probs_ordered = count_probabilities(probabilities, chunk_left, chunk_right)
# if len(probs_ordered) !=0:
# print(probs_ordered)
if len(probs_ordered) ==0:
print(f"the:0.1 to:0.1 a:0.1 :0.7")
continue
result_string = ''
counter_ = 0
for word_, p in probs_ordered.items():
if counter_>4:
break
re_ = re.search(r'\p{L}+', word_)
if re_:
word_cleared = re_.group(0)
result_string += f"{word_cleared}:{str(p)} "
else:
if result_string == '':
result_string = f"the:0.5 a:0.3 "
continue
counter_+=1
result_string += ':0.2'
print(result_string)
a=1