83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
from tqdm import tqdm
|
|
from numpy import argmax
|
|
|
|
def preprocess(corpus):
|
|
corpus = corpus.replace('-\n', '')
|
|
corpus = corpus.replace('\\n', ' ')
|
|
corpus = corpus.replace('\n', ' ')
|
|
corpus = corpus.replace('.', ' EOS')
|
|
|
|
return corpus
|
|
|
|
def generate_freq(tokens):
|
|
tokens_freq = {}
|
|
for token in tqdm(tokens):
|
|
if token not in tokens_freq:
|
|
tokens_freq[token] = 1
|
|
else:
|
|
tokens_freq[token] += 1
|
|
|
|
return tokens_freq
|
|
|
|
def generate_ngrams(tokens, n):
|
|
ngrams = []
|
|
for i in tqdm(range(len(tokens) - n + 1)):
|
|
ngrams.append(tokens[i:i+n])
|
|
|
|
return ngrams
|
|
|
|
def generate_distribution(unique_tokens, tokens_freq, bigrams_freq):
|
|
n = len(unique_tokens)
|
|
distribution = [[] * n] * n
|
|
for i in tqdm(n):
|
|
denominator = tokens_freq[unique_tokens[i]]
|
|
for j in range(n):
|
|
try:
|
|
numerator = bigrams_freq[unique_tokens[i] + unique_tokens[j]]
|
|
except:
|
|
numerator = 0
|
|
distribution[unique_tokens[i] + unique_tokens[j]] = numerator / denominator
|
|
|
|
return distribution
|
|
|
|
with open('train/in.tsv', 'r') as f:
|
|
print('Reading corpus...')
|
|
corpus = []
|
|
for line in tqdm(f):
|
|
ctx = line.split('\t')[6:]
|
|
corpus.append(ctx[0] + 'BLANK' + ctx[1])
|
|
|
|
print('Preprocessing corpus...')
|
|
corpus = preprocess(' '.join(corpus))
|
|
|
|
tokens = corpus.split()
|
|
unique_tokens = set(sorted(corpus))
|
|
print('Generating tokens frequency...')
|
|
tokens_freq = generate_freq(tokens)
|
|
print('Generating n-grams...')
|
|
bigrams = generate_ngrams(tokens, 2)
|
|
print('Generating bigrams frequency...')
|
|
bigrams_freq = generate_freq(bigrams)
|
|
print('Generate distribution...')
|
|
distribution = generate_distribution(unique_tokens, tokens_freq, bigrams_freq)
|
|
|
|
|
|
with open('dev-0/in.tsv', 'r') as f:
|
|
print('Generating output...')
|
|
results = []
|
|
for line in tqdm(f):
|
|
ctx = line.split('\t')[6:]
|
|
last_word = preprocess(ctx[0]).split(' ')[-1]
|
|
try:
|
|
blank_word = unique_tokens[argmax(distribution[unique_tokens.index(last_word)])]
|
|
except:
|
|
blank_word = 'NONE'
|
|
results.append(blank_word)
|
|
|
|
with open('dev-0/out.tsv', 'w') as f:
|
|
print('Writing output...')
|
|
for result in tqdm(results):
|
|
if result == 'NONE':
|
|
f.write('a:0.6 the:0.2 :0.2')
|
|
else:
|
|
f.write(f'{result}:0.9 :0.1') |