This commit is contained in:
Alagris 2022-04-04 18:41:05 +02:00
parent 7f1d42626a
commit e0e7302e8d
3 changed files with 17968 additions and 17943 deletions

45
Main.py
View File

@ -1,5 +1,5 @@
import re import re
import numpy as np import math
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
@ -36,11 +36,12 @@ for w in lexicon_array:
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
bigrams = defaultdict(lambda: defaultdict(int)) bigrams = defaultdict(lambda: defaultdict(int))
reverse_bigrams = defaultdict(lambda: defaultdict(int))
with open('train/in.tsv') as f, open('train/expected.tsv') as e: with open('train/in.tsv') as f, open('train/expected.tsv') as e:
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022): for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
if line_no == 4000: # if line_no == 4000:
break # break
line = line.split('\t') line = line.split('\t')
expected = ALPH.sub('', expected.lower()) expected = ALPH.sub('', expected.lower())
l_ctx = preprocess(line[6]) l_ctx = preprocess(line[6])
@ -57,9 +58,11 @@ with open('train/in.tsv') as f, open('train/expected.tsv') as e:
for next in sentence[2:]: for next in sentence[2:]:
trigrams[prev_prev][next][prev] += 1 trigrams[prev_prev][next][prev] += 1
bigrams[prev_prev][prev] += 1 bigrams[prev_prev][prev] += 1
reverse_bigrams[prev][prev_prev] += 1
prev_prev = prev prev_prev = prev
prev = next prev = next
bigrams[prev_prev][prev] += 1 bigrams[prev_prev][prev] += 1
reverse_bigrams[prev][prev_prev] += 1
def max_val(d): def max_val(d):
@ -72,9 +75,28 @@ def max_val(d):
return max_key return max_key
def words_and_probs(d):
alpha = 0.01
k = 10
s = sum(d.values())
denominator = s + len(lexicon_array) * alpha
denominator_log = math.log10(denominator)
items = list(d.items())
items.sort(key=lambda x: x[1], reverse=True)
top_k = items[:k]
top_k_probs = [(key, math.log10(elem + alpha) - denominator_log) for key, elem in top_k]
strings = [lexicon_array[key] + ":" + str(prob) for key, prob in top_k_probs]
sum_top_k = sum(map(lambda x: x[1], top_k))
smoothed_sum_top_k = sum_top_k + k * alpha
remaining = denominator - smoothed_sum_top_k
remaining_log = math.log10(remaining) - denominator_log
return " ".join(strings) + " :" + str(remaining_log)
def infer(d): def infer(d):
empty = 0
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o: with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
for line in f: for line in tqdm(f, desc=d):
line = line.split('\t') line = line.split('\t')
l_ctx = preprocess(line[6]) l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7]) r_ctx = preprocess(line[7])
@ -87,17 +109,20 @@ def infer(d):
if next_i is not None: if next_i is not None:
options = trigrams[prev_prev_i][next_i] options = trigrams[prev_prev_i][next_i]
if len(options) > 0: if len(options) > 0:
prev_i = max_val(options) print(words_and_probs(options), file=o)
prev = lexicon_array[prev_i]
print(prev, file=o)
continue continue
options = bigrams[prev_prev_i] options = bigrams[prev_prev_i]
if len(options) > 0: if len(options) > 0:
prev_i = max_val(options) print(words_and_probs(options), file=o)
prev = lexicon_array[prev_i] continue
print(prev, file=o) if next_i is not None:
options = reverse_bigrams[next_i]
if len(options) > 0:
print(words_and_probs(options), file=o)
continue continue
print("", file=o) print("", file=o)
empty += 1
print("empty=", empty)
infer('dev-0') infer('dev-0')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff