challenging-america-word-ga.../Main.py
2022-04-04 18:41:05 +02:00

130 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import math
from tqdm import tqdm
from collections import defaultdict
ALPH = re.compile('[^a-z]')
REPLACE_WITH_SPACE = re.compile(r"(\\+n|[{}\[\]”&:•¦()*0-9;\"«»$\-><^,®¬¿?¡!#+. \t\n])+")
REMOVE = re.compile(r"'s|[\-­]\\n")
def preprocess(l):
l = l.lower()
l = l.replace("", "'")
l = REMOVE.sub('', l)
l = REPLACE_WITH_SPACE.sub(" ", l)
l = l.replace("i'm", "i am")
l = l.replace("won't", "will not")
l = l.replace("n't", " not")
l = l.replace("'ll", " will")
l = l.replace("'", "")
l = l.strip()
return l
def words(l):
l = l.split()
return l
lexicon_array = []
lexicon = {}
with open('words_alpha.txt') as f:
lexicon_array = [word.strip() for word in f]
for w in lexicon_array:
lexicon[w] = len(lexicon)
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
bigrams = defaultdict(lambda: defaultdict(int))
reverse_bigrams = defaultdict(lambda: defaultdict(int))
with open('train/in.tsv') as f, open('train/expected.tsv') as e:
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
# if line_no == 4000:
# break
line = line.split('\t')
expected = ALPH.sub('', expected.lower())
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
w_list = words(l_ctx) + [expected] + words(r_ctx)
sentence = []
for w in w_list:
i = lexicon.get(w)
if i is not None:
sentence.append(i)
if len(sentence) >= 3:
prev_prev = sentence[0]
prev = sentence[1]
for next in sentence[2:]:
trigrams[prev_prev][next][prev] += 1
bigrams[prev_prev][prev] += 1
reverse_bigrams[prev][prev_prev] += 1
prev_prev = prev
prev = next
bigrams[prev_prev][prev] += 1
reverse_bigrams[prev][prev_prev] += 1
def max_val(d):
max_elem = 0
max_key = None
for key, elem in d.items():
if elem > max_elem:
max_elem = elem
max_key = key
return max_key
def words_and_probs(d):
alpha = 0.01
k = 10
s = sum(d.values())
denominator = s + len(lexicon_array) * alpha
denominator_log = math.log10(denominator)
items = list(d.items())
items.sort(key=lambda x: x[1], reverse=True)
top_k = items[:k]
top_k_probs = [(key, math.log10(elem + alpha) - denominator_log) for key, elem in top_k]
strings = [lexicon_array[key] + ":" + str(prob) for key, prob in top_k_probs]
sum_top_k = sum(map(lambda x: x[1], top_k))
smoothed_sum_top_k = sum_top_k + k * alpha
remaining = denominator - smoothed_sum_top_k
remaining_log = math.log10(remaining) - denominator_log
return " ".join(strings) + " :" + str(remaining_log)
def infer(d):
empty = 0
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
for line in tqdm(f, desc=d):
line = line.split('\t')
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
if l_ctx != '' and r_ctx != '':
prev_prev = l_ctx.rsplit(" ", 1)[-1]
next = r_ctx.split(" ", 1)[0]
prev_prev_i = lexicon.get(prev_prev)
next_i = lexicon.get(next)
if prev_prev_i is not None:
if next_i is not None:
options = trigrams[prev_prev_i][next_i]
if len(options) > 0:
print(words_and_probs(options), file=o)
continue
options = bigrams[prev_prev_i]
if len(options) > 0:
print(words_and_probs(options), file=o)
continue
if next_i is not None:
options = reverse_bigrams[next_i]
if len(options) > 0:
print(words_and_probs(options), file=o)
continue
print("", file=o)
empty += 1
print("empty=", empty)
infer('dev-0')
infer('test-A')