import re import numpy as np from tqdm import tqdm from collections import defaultdict ALPH = re.compile('[^a-z]') REPLACE_WITH_SPACE = re.compile(r"(\\+n|[{}\[\]”&:•¦()*0-9;\"«»$\-><^,®¬¿?¡!#+. \t\n])+") REMOVE = re.compile(r"'s|[\-­]\\n") def preprocess(l): l = l.lower() l = l.replace("’", "'") l = REMOVE.sub('', l) l = REPLACE_WITH_SPACE.sub(" ", l) l = l.replace("i'm", "i am") l = l.replace("won't", "will not") l = l.replace("n't", " not") l = l.replace("'ll", " will") l = l.replace("'", "") l = l.strip() return l def words(l): l = l.split() return l lexicon_array = [] lexicon = {} with open('words_alpha.txt') as f: lexicon_array = [word.strip() for word in f] for w in lexicon_array: lexicon[w] = len(lexicon) trigrams = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) bigrams = defaultdict(lambda: defaultdict(int)) with open('train/in.tsv') as f, open('train/expected.tsv') as e: for line_no, (line, expected) in tqdm(enumerate(zip(f, e)), total=432022): if line_no == 4000: break line = line.split('\t') expected = ALPH.sub('', expected.lower()) l_ctx = preprocess(line[6]) r_ctx = preprocess(line[7]) w_list = words(l_ctx) + [expected] + words(r_ctx) sentence = [] for w in w_list: i = lexicon.get(w) if i is not None: sentence.append(i) if len(sentence) >= 3: prev_prev = sentence[0] prev = sentence[1] for next in sentence[2:]: trigrams[prev_prev][next][prev] += 1 bigrams[prev_prev][prev] += 1 prev_prev = prev prev = next bigrams[prev_prev][prev] += 1 def max_val(d): max_elem = 0 max_key = None for key, elem in d.items(): if elem > max_elem: max_elem = elem max_key = key return max_key def infer(d): with open(d + '/in.tsv') as f, open(d + '/out.tsv', "w+") as o: for line in f: line = line.split('\t') l_ctx = preprocess(line[6]) r_ctx = preprocess(line[7]) if l_ctx != '' and r_ctx != '': prev_prev = l_ctx.rsplit(" ", 1)[-1] next = r_ctx.split(" ", 1)[0] prev_prev_i = lexicon.get(prev_prev) next_i = lexicon.get(next) if prev_prev_i is not None: if next_i is not None: options = trigrams[prev_prev_i][next_i] if len(options) > 0: prev_i = max_val(options) prev = lexicon_array[prev_i] print(prev, file=o) continue options = bigrams[prev_prev_i] if len(options) > 0: prev_i = max_val(options) prev = lexicon_array[prev_i] print(prev, file=o) continue print("", file=o) infer('dev-0') infer('test-A')