challenging-america-word-ga.../Main.py

130 lines
4.1 KiB
Python
Raw Permalink Normal View History

2022-04-02 16:09:25 +02:00
import re
2022-04-04 18:41:05 +02:00
import math
2022-04-02 16:09:25 +02:00
from tqdm import tqdm
from collections import defaultdict
ALPH = re.compile('[^a-z]')
REPLACE_WITH_SPACE = re.compile(r"(\\+n|[{}\[\]”&:•¦()*0-9;\"«»$\-><^,®¬¿?¡!#+. \t\n])+")
REMOVE = re.compile(r"'s|[\-­]\\n")
def preprocess(l):
l = l.lower()
l = l.replace("", "'")
l = REMOVE.sub('', l)
l = REPLACE_WITH_SPACE.sub(" ", l)
l = l.replace("i'm", "i am")
l = l.replace("won't", "will not")
l = l.replace("n't", " not")
l = l.replace("'ll", " will")
l = l.replace("'", "")
l = l.strip()
return l
def words(l):
l = l.split()
return l
lexicon_array = []
lexicon = {}
with open('words_alpha.txt') as f:
lexicon_array = [word.strip() for word in f]
for w in lexicon_array:
lexicon[w] = len(lexicon)
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
bigrams = defaultdict(lambda: defaultdict(int))
2022-04-04 18:41:05 +02:00
reverse_bigrams = defaultdict(lambda: defaultdict(int))
2022-04-02 16:09:25 +02:00
with open('train/in.tsv') as f, open('train/expected.tsv') as e:
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
2022-04-04 18:41:05 +02:00
# if line_no == 4000:
# break
2022-04-02 16:09:25 +02:00
line = line.split('\t')
expected = ALPH.sub('', expected.lower())
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
w_list = words(l_ctx) + [expected] + words(r_ctx)
sentence = []
for w in w_list:
i = lexicon.get(w)
if i is not None:
sentence.append(i)
if len(sentence) >= 3:
prev_prev = sentence[0]
prev = sentence[1]
for next in sentence[2:]:
trigrams[prev_prev][next][prev] += 1
bigrams[prev_prev][prev] += 1
2022-04-04 18:41:05 +02:00
reverse_bigrams[prev][prev_prev] += 1
2022-04-02 16:09:25 +02:00
prev_prev = prev
prev = next
bigrams[prev_prev][prev] += 1
2022-04-04 18:41:05 +02:00
reverse_bigrams[prev][prev_prev] += 1
2022-04-02 16:09:25 +02:00
def max_val(d):
max_elem = 0
max_key = None
for key, elem in d.items():
if elem > max_elem:
max_elem = elem
max_key = key
return max_key
2022-04-04 18:41:05 +02:00
def words_and_probs(d):
alpha = 0.01
k = 10
s = sum(d.values())
denominator = s + len(lexicon_array) * alpha
denominator_log = math.log10(denominator)
items = list(d.items())
items.sort(key=lambda x: x[1], reverse=True)
top_k = items[:k]
top_k_probs = [(key, math.log10(elem + alpha) - denominator_log) for key, elem in top_k]
strings = [lexicon_array[key] + ":" + str(prob) for key, prob in top_k_probs]
sum_top_k = sum(map(lambda x: x[1], top_k))
smoothed_sum_top_k = sum_top_k + k * alpha
remaining = denominator - smoothed_sum_top_k
remaining_log = math.log10(remaining) - denominator_log
return " ".join(strings) + " :" + str(remaining_log)
2022-04-02 16:09:25 +02:00
def infer(d):
2022-04-04 18:41:05 +02:00
empty = 0
2022-04-02 16:09:25 +02:00
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
2022-04-04 18:41:05 +02:00
for line in tqdm(f, desc=d):
2022-04-02 16:09:25 +02:00
line = line.split('\t')
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
if l_ctx != '' and r_ctx != '':
prev_prev = l_ctx.rsplit(" ", 1)[-1]
next = r_ctx.split(" ", 1)[0]
prev_prev_i = lexicon.get(prev_prev)
next_i = lexicon.get(next)
if prev_prev_i is not None:
if next_i is not None:
options = trigrams[prev_prev_i][next_i]
if len(options) > 0:
2022-04-04 18:41:05 +02:00
print(words_and_probs(options), file=o)
2022-04-02 16:09:25 +02:00
continue
options = bigrams[prev_prev_i]
if len(options) > 0:
2022-04-04 18:41:05 +02:00
print(words_and_probs(options), file=o)
continue
if next_i is not None:
options = reverse_bigrams[next_i]
if len(options) > 0:
print(words_and_probs(options), file=o)
2022-04-02 16:09:25 +02:00
continue
print("", file=o)
2022-04-04 18:41:05 +02:00
empty += 1
print("empty=", empty)
2022-04-02 16:09:25 +02:00
infer('dev-0')
infer('test-A')