import lzma import matplotlib.pyplot as plt from math import log from collections import OrderedDict from collections import Counter import regex as re from itertools import islice import json import pdb model_v = "1" PREFIX_VALID = 'test-A' prob_4gram = {} with open(f'4_gram_model_{model_v}.tsv', 'r') as f: for line in f: line = line.rstrip() splitted_line = line.split('\t') prob_4gram[tuple(splitted_line[:3])] = json.loads(splitted_line[-1]) prob_3gram = {} # with open(f'3_gram_model_{model_v}.tsv', 'r') as f: # for line in f: # line = line.rstrip() # splitted_line = line.split('\t') # prob_3gram[tuple(splitted_line[:2])] = json.loads(splitted_line[-1]) prob_2gram = {} # with open(f'2_gram_model_{model_v}.tsv', 'r') as f: # for line in f: # line = line.rstrip() # splitted_line = line.split('\t') # prob_2gram[tuple(splitted_line[0])] = json.loads(splitted_line[-1]) vocab = set() with open(f"vocab_{model_v}.txt", 'r') as f: for l in f: vocab.add(l.rstrip()) # probabilities_bi = {} # with open(f'bigram_big_unk_20', 'r') as f: # for line in f: # line = line.rstrip() # splitted_line = line.split('\t') # probabilities_bi[tuple(splitted_line[:2])] = (float(splitted_line[2]), float(splitted_line[3])) def count_probabilities(prob_4gram_x, prob_3gram_x, prob_2gram_x, _chunk_left, _chunk_right): for index, (l, r) in enumerate(zip(_chunk_left, _chunk_right)): if l not in vocab: _chunk_left[index] = "" if r not in vocab: _chunk_right[index] = "" _chunk_left = tuple(_chunk_left) _chunk_right = tuple(_chunk_right) hyps_4 = prob_4gram_x.get(_chunk_left) # if _chunk_left not in prob_3gram_x: # return {} # hyps_3 = prob_3gram_x.get(_chunk_left) # if _chunk_left not in prob_2gram_x: # return {} # hyps_2 = prob_2gram_x.get(_chunk_left) if hyps_4 is None: return {} items = hyps_4.items() return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True)) with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train: for t_line in train: t_line = t_line.decode("utf-8") t_line = t_line.rstrip() t_line = t_line.lower() t_line = t_line.replace("\\\\n", ' ') t_line_splitted_by_tab = t_line.split('\t') words_before = t_line_splitted_by_tab[-2] words_before = re.findall(r'\p{L}+', words_before) words_after = t_line_splitted_by_tab[-1] words_after = re.findall(r'\p{L}+', words_after) chunk_left = words_before[-3:] chunk_right = words_after[0:3] probs_ordered = count_probabilities(prob_4gram, prob_3gram, prob_2gram, chunk_left, chunk_right) # if len(probs_ordered) !=0: # print(probs_ordered) if len(probs_ordered) ==0: print(f"the:0.1 to:0.1 a:0.1 :0.7") continue result_string = '' counter_ = 0 p_sum = 0 for word_, p in probs_ordered.items(): if counter_>30: break re_ = re.search(r'\p{L}+', word_) if re_: word_cleared = re_.group(0) p = p*0.9 p_sum += p result_string += f"{word_cleared}:{str(p)} " else: if result_string == '': result_string = f"the:0.5 a:0.3 " continue counter_+=1 res = 1 - p_sum result_string += f':{res}' print(result_string) a=1