challenging-america-word-ga.../lmn.py

83 lines
2.3 KiB
Python

from tqdm import tqdm
from numpy import argmax
def preprocess(corpus):
corpus = corpus.replace('-\n', '')
corpus = corpus.replace('\\n', ' ')
corpus = corpus.replace('\n', ' ')
corpus = corpus.replace('.', ' EOS')
return corpus
def generate_freq(tokens):
tokens_freq = {}
for token in tqdm(tokens):
if token not in tokens_freq:
tokens_freq[token] = 1
else:
tokens_freq[token] += 1
return tokens_freq
def generate_ngrams(tokens, n):
ngrams = []
for i in tqdm(range(len(tokens) - n + 1)):
ngrams.append(tokens[i:i+n])
return ngrams
def generate_distribution(unique_tokens, tokens_freq, bigrams_freq):
n = len(unique_tokens)
distribution = [[] * n] * n
for i in tqdm(n):
denominator = tokens_freq[unique_tokens[i]]
for j in range(n):
try:
numerator = bigrams_freq[unique_tokens[i] + unique_tokens[j]]
except:
numerator = 0
distribution[unique_tokens[i] + unique_tokens[j]] = numerator / denominator
return distribution
with open('train/in.tsv', 'r') as f:
print('Reading corpus...')
corpus = []
for line in tqdm(f):
ctx = line.split('\t')[6:]
corpus.append(ctx[0] + 'BLANK' + ctx[1])
print('Preprocessing corpus...')
corpus = preprocess(' '.join(corpus))
tokens = corpus.split()
unique_tokens = set(sorted(corpus))
print('Generating tokens frequency...')
tokens_freq = generate_freq(tokens)
print('Generating n-grams...')
bigrams = generate_ngrams(tokens, 2)
print('Generating bigrams frequency...')
bigrams_freq = generate_freq(bigrams)
print('Generate distribution...')
distribution = generate_distribution(unique_tokens, tokens_freq, bigrams_freq)
with open('dev-0/in.tsv', 'r') as f:
print('Generating output...')
results = []
for line in tqdm(f):
ctx = line.split('\t')[6:]
last_word = preprocess(ctx[0]).split(' ')[-1]
try:
blank_word = unique_tokens[argmax(distribution[unique_tokens.index(last_word)])]
except:
blank_word = 'NONE'
results.append(blank_word)
with open('dev-0/out.tsv', 'w') as f:
print('Writing output...')
for result in tqdm(results):
if result == 'NONE':
f.write('a:0.6 the:0.2 :0.2')
else:
f.write(f'{result}:0.9 :0.1')