challenging-america-word-ga.../Main.py
2022-04-02 16:10:21 +02:00

105 lines
3.1 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import numpy as np
from tqdm import tqdm
from collections import defaultdict
ALPH = re.compile('[^a-z]')
REPLACE_WITH_SPACE = re.compile(r"(\\+n|[{}\[\]”&:•¦()*0-9;\"«»$\-><^,®¬¿?¡!#+. \t\n])+")
REMOVE = re.compile(r"'s|[\-­]\\n")
def preprocess(l):
l = l.lower()
l = l.replace("", "'")
l = REMOVE.sub('', l)
l = REPLACE_WITH_SPACE.sub(" ", l)
l = l.replace("i'm", "i am")
l = l.replace("won't", "will not")
l = l.replace("n't", " not")
l = l.replace("'ll", " will")
l = l.replace("'", "")
l = l.strip()
return l
def words(l):
l = l.split()
return l
lexicon_array = []
lexicon = {}
with open('words_alpha.txt') as f:
lexicon_array = [word.strip() for word in f]
for w in lexicon_array:
lexicon[w] = len(lexicon)
trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
bigrams = defaultdict(lambda: defaultdict(int))
with open('train/in.tsv') as f, open('train/expected.tsv') as e:
for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022):
if line_no == 4000:
break
line = line.split('\t')
expected = ALPH.sub('', expected.lower())
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
w_list = words(l_ctx) + [expected] + words(r_ctx)
sentence = []
for w in w_list:
i = lexicon.get(w)
if i is not None:
sentence.append(i)
if len(sentence) >= 3:
prev_prev = sentence[0]
prev = sentence[1]
for next in sentence[2:]:
trigrams[prev_prev][next][prev] += 1
bigrams[prev_prev][prev] += 1
prev_prev = prev
prev = next
bigrams[prev_prev][prev] += 1
def max_val(d):
max_elem = 0
max_key = None
for key, elem in d.items():
if elem > max_elem:
max_elem = elem
max_key = key
return max_key
def infer(d):
with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o:
for line in f:
line = line.split('\t')
l_ctx = preprocess(line[6])
r_ctx = preprocess(line[7])
if l_ctx != '' and r_ctx != '':
prev_prev = l_ctx.rsplit(" ", 1)[-1]
next = r_ctx.split(" ", 1)[0]
prev_prev_i = lexicon.get(prev_prev)
next_i = lexicon.get(next)
if prev_prev_i is not None:
if next_i is not None:
options = trigrams[prev_prev_i][next_i]
if len(options) > 0:
prev_i = max_val(options)
prev = lexicon_array[prev_i]
print(prev, file=o)
continue
options = bigrams[prev_prev_i]
if len(options) > 0:
prev_i = max_val(options)
prev = lexicon_array[prev_i]
print(prev, file=o)
continue
print("", file=o)
infer('dev-0')
infer('test-A')