434749
This commit is contained in:
parent
7f1d42626a
commit
e0e7302e8d
45
Main.py
45
Main.py
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
import numpy as np
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@ -36,11 +36,12 @@ for w in lexicon_array:
|
|||||||
|
|
||||||
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
|
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
|
||||||
bigrams = defaultdict(lambda: defaultdict(int))
|
bigrams = defaultdict(lambda: defaultdict(int))
|
||||||
|
reverse_bigrams = defaultdict(lambda: defaultdict(int))
|
||||||
|
|
||||||
with open('train/in.tsv') as f, open('train/expected.tsv') as e:
|
with open('train/in.tsv') as f, open('train/expected.tsv') as e:
|
||||||
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
|
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
|
||||||
if line_no == 4000:
|
# if line_no == 4000:
|
||||||
break
|
# break
|
||||||
line = line.split('\t')
|
line = line.split('\t')
|
||||||
expected = ALPH.sub('', expected.lower())
|
expected = ALPH.sub('', expected.lower())
|
||||||
l_ctx = preprocess(line[6])
|
l_ctx = preprocess(line[6])
|
||||||
@ -57,9 +58,11 @@ with open('train/in.tsv') as f, open('train/expected.tsv') as e:
|
|||||||
for next in sentence[2:]:
|
for next in sentence[2:]:
|
||||||
trigrams[prev_prev][next][prev] += 1
|
trigrams[prev_prev][next][prev] += 1
|
||||||
bigrams[prev_prev][prev] += 1
|
bigrams[prev_prev][prev] += 1
|
||||||
|
reverse_bigrams[prev][prev_prev] += 1
|
||||||
prev_prev = prev
|
prev_prev = prev
|
||||||
prev = next
|
prev = next
|
||||||
bigrams[prev_prev][prev] += 1
|
bigrams[prev_prev][prev] += 1
|
||||||
|
reverse_bigrams[prev][prev_prev] += 1
|
||||||
|
|
||||||
|
|
||||||
def max_val(d):
|
def max_val(d):
|
||||||
@ -72,9 +75,28 @@ def max_val(d):
|
|||||||
return max_key
|
return max_key
|
||||||
|
|
||||||
|
|
||||||
|
def words_and_probs(d):
|
||||||
|
alpha = 0.01
|
||||||
|
k = 10
|
||||||
|
s = sum(d.values())
|
||||||
|
denominator = s + len(lexicon_array) * alpha
|
||||||
|
denominator_log = math.log10(denominator)
|
||||||
|
items = list(d.items())
|
||||||
|
items.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
top_k = items[:k]
|
||||||
|
top_k_probs = [(key, math.log10(elem + alpha) - denominator_log) for key, elem in top_k]
|
||||||
|
strings = [lexicon_array[key] + ":" + str(prob) for key, prob in top_k_probs]
|
||||||
|
sum_top_k = sum(map(lambda x: x[1], top_k))
|
||||||
|
smoothed_sum_top_k = sum_top_k + k * alpha
|
||||||
|
remaining = denominator - smoothed_sum_top_k
|
||||||
|
remaining_log = math.log10(remaining) - denominator_log
|
||||||
|
return " ".join(strings) + " :" + str(remaining_log)
|
||||||
|
|
||||||
|
|
||||||
def infer(d):
|
def infer(d):
|
||||||
|
empty = 0
|
||||||
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
|
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
|
||||||
for line in f:
|
for line in tqdm(f, desc=d):
|
||||||
line = line.split('\t')
|
line = line.split('\t')
|
||||||
l_ctx = preprocess(line[6])
|
l_ctx = preprocess(line[6])
|
||||||
r_ctx = preprocess(line[7])
|
r_ctx = preprocess(line[7])
|
||||||
@ -87,17 +109,20 @@ def infer(d):
|
|||||||
if next_i is not None:
|
if next_i is not None:
|
||||||
options = trigrams[prev_prev_i][next_i]
|
options = trigrams[prev_prev_i][next_i]
|
||||||
if len(options) > 0:
|
if len(options) > 0:
|
||||||
prev_i = max_val(options)
|
print(words_and_probs(options), file=o)
|
||||||
prev = lexicon_array[prev_i]
|
|
||||||
print(prev, file=o)
|
|
||||||
continue
|
continue
|
||||||
options = bigrams[prev_prev_i]
|
options = bigrams[prev_prev_i]
|
||||||
if len(options) > 0:
|
if len(options) > 0:
|
||||||
prev_i = max_val(options)
|
print(words_and_probs(options), file=o)
|
||||||
prev = lexicon_array[prev_i]
|
continue
|
||||||
print(prev, file=o)
|
if next_i is not None:
|
||||||
|
options = reverse_bigrams[next_i]
|
||||||
|
if len(options) > 0:
|
||||||
|
print(words_and_probs(options), file=o)
|
||||||
continue
|
continue
|
||||||
print("", file=o)
|
print("", file=o)
|
||||||
|
empty += 1
|
||||||
|
print("empty=", empty)
|
||||||
|
|
||||||
|
|
||||||
infer('dev-0')
|
infer('dev-0')
|
||||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user