Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
39c1f3a341 | |||
bb121718aa | |||
9332c1957b | |||
|
d877969ac2 | ||
|
2a4ab01f29 |
10
README.md
10
README.md
@ -1,9 +1 @@
|
|||||||
Challenging America word-gap prediction
|
# Rozwiązanie realizuje wszystkie warianty za dodatkowe punkty
|
||||||
===================================
|
|
||||||
|
|
||||||
Guess a word in a gap.
|
|
||||||
|
|
||||||
Evaluation metric
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
LikelihoodHashed is the metric
|
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -1,11 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
file = sys.argv[1]
|
|
||||||
|
|
||||||
with open(file, encoding='utf-8') as f1, open('out.tsv', 'w', encoding='utf-8') as f2:
|
|
||||||
for line in f1:
|
|
||||||
line = line.split('\t')
|
|
||||||
if line[-1][0].isupper():
|
|
||||||
f2.write('the:0.9 :0.1\n')
|
|
||||||
else:
|
|
||||||
f2.write('the:0.4 a:0.4 :0.2\n')
|
|
9
generations.txt
Normal file
9
generations.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
According to recent news and a half of the north side or the other must be a man of great force of character to the country. He was a member of a family in the United States to the full amount of the principal and interest of this section shall be subject to the
|
||||||
|
|
||||||
|
Recent studies have shown that the present condition of things in which we have been in a very short time after the war and the war was over and that every man who has been in the hands of the United States of the North and the South American States, and those States who had
|
||||||
|
|
||||||
|
Today I was taking a stroll in the park when suddenly and that the said estate has by no Mr. ii him, but he was too young to be a very good reason that the above named de- - . . . They are able six e of a good deal of time and money to the amount of the tax
|
||||||
|
|
||||||
|
The most unbelievable story ever told goes like this to be the most important of these are the only two men who were at the time of the year when they tried an 1 the said sum of money to be paid in case of your South and the West may have to do the work of the committee
|
||||||
|
|
||||||
|
he war between natural and the few who are not in the interest of the said William H. and acres of them, more o the State from the control of the state of New York and New York and New York are the looked for food in the greatest number of the most
|
177
lab5.py
Normal file
177
lab5.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
from tqdm import tqdm
|
||||||
|
from collections import Counter
|
||||||
|
import mmap
|
||||||
|
import pickle
|
||||||
|
from math import prod
|
||||||
|
from copy import deepcopy
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class TetragramModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = None
|
||||||
|
self.ngram_counts = None
|
||||||
|
|
||||||
|
def get_num_lines(self, filename):
|
||||||
|
fp = open(filename, 'r+')
|
||||||
|
buf = mmap.mmap(fp.fileno(), 0)
|
||||||
|
lines = 0
|
||||||
|
while buf.readline():
|
||||||
|
lines += 1
|
||||||
|
fp.close()
|
||||||
|
return lines
|
||||||
|
|
||||||
|
def train(self, filename, vocab_size=5000, load_ngrams=None):
|
||||||
|
def get_vocab(filename, vocab_size):
|
||||||
|
file_vocab = Counter()
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating vocab'):
|
||||||
|
line = ' '.join(line.strip().split('\t')[-2:]).replace(r'\n', ' ').split()
|
||||||
|
line_vocab = Counter(line)
|
||||||
|
file_vocab.update(line_vocab)
|
||||||
|
if len(file_vocab) > vocab_size:
|
||||||
|
file_vocab = [tup[0] for tup in file_vocab.most_common(vocab_size)]
|
||||||
|
else:
|
||||||
|
file_vocab = file_vocab.keys()
|
||||||
|
return file_vocab
|
||||||
|
|
||||||
|
def get_gram_counts(filename):
|
||||||
|
gram_names = ['unigrams', 'bigrams', 'trigrams', 'tetragrams']
|
||||||
|
ngram_counts = {name: Counter() for name in gram_names}
|
||||||
|
with open(filename, encoding='utf-8') as f:
|
||||||
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating n-gram counts'):
|
||||||
|
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
||||||
|
for line_part in line:
|
||||||
|
line_part = [word if word in self.vocab else '<UNK>' for word in line_part.split()]
|
||||||
|
line_unigrams = Counter(line_part)
|
||||||
|
line_bigrams = Counter([tuple(line_part[i: i + 2]) for i in range(len(line_part) - 2 + 1)])
|
||||||
|
line_trigrams = Counter([tuple(line_part[i: i + 3]) for i in range(len(line_part) - 3 + 1)])
|
||||||
|
line_tetragrams = Counter([tuple(line_part[i: i + 4]) for i in range(len(line_part) - 4 + 1)])
|
||||||
|
ngram_counts['unigrams'].update(line_unigrams)
|
||||||
|
ngram_counts['bigrams'].update(line_bigrams)
|
||||||
|
ngram_counts['trigrams'].update(line_trigrams)
|
||||||
|
ngram_counts['tetragrams'].update(line_tetragrams)
|
||||||
|
return ngram_counts
|
||||||
|
|
||||||
|
self.vocab = get_vocab(filename, vocab_size)
|
||||||
|
if load_ngrams:
|
||||||
|
print('Loading n-gram model from file...')
|
||||||
|
with open(load_ngrams, "rb") as f:
|
||||||
|
self.ngram_counts = pickle.load(f)
|
||||||
|
print('Model loaded.')
|
||||||
|
else:
|
||||||
|
self.ngram_counts = get_gram_counts(filename)
|
||||||
|
|
||||||
|
def get_ngram_prob(self, ngram):
|
||||||
|
try:
|
||||||
|
if type(ngram) == str:
|
||||||
|
return self.ngram_counts['unigrams'][ngram] / sum(self.ngram_counts['unigrams'].values())
|
||||||
|
ngram_len = len(ngram)
|
||||||
|
if ngram_len == 2:
|
||||||
|
return self.ngram_counts['bigrams'][ngram] / self.ngram_counts['unigrams'][ngram[0]]
|
||||||
|
elif ngram_len == 3:
|
||||||
|
return self.ngram_counts['trigrams'][ngram] / self.ngram_counts['bigrams'][ngram[:2]]
|
||||||
|
else:
|
||||||
|
return self.ngram_counts['tetragrams'][ngram] / self.ngram_counts['trigrams'][ngram[:3]]
|
||||||
|
except ZeroDivisionError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def predict_gaps(self, filename, lambdas=(0.25, 0.25, 0.25, 0.25), top_k=False, k=10):
|
||||||
|
with open(filename, encoding='utf-8') as f, open('out.tsv', 'w', encoding='utf-8') as out:
|
||||||
|
for line in tqdm(f, total=self.get_num_lines(filename), desc=f'Generating gap predictions'):
|
||||||
|
line = line.strip().replace(r'\n', ' ').split('\t')[-2:]
|
||||||
|
left_context = [word if word in self.vocab else '<UNK>' for word in line[0].split()]
|
||||||
|
right_context = [word if word in self.vocab else '<UNK>' for word in line[1].split()]
|
||||||
|
context_probs = dict()
|
||||||
|
vocab_keys = deepcopy(list(self.vocab))
|
||||||
|
vocab_keys.append('<UNK>')
|
||||||
|
for word in vocab_keys:
|
||||||
|
tetragrams_probs = []
|
||||||
|
for i in range(4):
|
||||||
|
tetragram = tuple(left_context[-4:][1 + i:] + [word] + right_context[:4][:-4 + i])
|
||||||
|
unigram_prob = self.get_ngram_prob(tetragram[-1])
|
||||||
|
bigram_prob = self.get_ngram_prob(tetragram[-2:])
|
||||||
|
trigram_prob = self.get_ngram_prob(tetragram[-3:])
|
||||||
|
tetragram_prob = self.get_ngram_prob(tetragram)
|
||||||
|
interpolated_prob = (unigram_prob * lambdas[0]) + (bigram_prob * lambdas[1]) + (
|
||||||
|
trigram_prob * lambdas[2]) + (tetragram_prob * lambdas[3])
|
||||||
|
tetragrams_probs.append(interpolated_prob)
|
||||||
|
tetragrams_final_prob = prod(tetragrams_probs)
|
||||||
|
context_probs[word] = tetragrams_final_prob
|
||||||
|
if top_k:
|
||||||
|
sorted_top = sorted(context_probs.items(), key=lambda x: x[1], reverse=True)[:k]
|
||||||
|
probs_sum = sum([y for x, y in sorted_top])
|
||||||
|
sorted_top = [(x, y / probs_sum) for x, y in sorted_top]
|
||||||
|
probs_string = ''
|
||||||
|
unk_string = ''
|
||||||
|
for tup in sorted_top:
|
||||||
|
if tup[0] == '<UNK>':
|
||||||
|
unk_string += f':{tup[1]}'
|
||||||
|
else:
|
||||||
|
probs_string += f'{tup[0]}:{tup[1]}\t'
|
||||||
|
probs_string += unk_string
|
||||||
|
probs_string = probs_string.strip()
|
||||||
|
out.write(probs_string + '\n')
|
||||||
|
else:
|
||||||
|
probs_sum = sum(context_probs.values())
|
||||||
|
unk_prob = context_probs.pop('<UNK>') / probs_sum
|
||||||
|
context_probs_normalized = [(unigram, prob / probs_sum) for unigram, prob in context_probs.items()]
|
||||||
|
probs_string = '\t'.join([f'{unigram}:{prob}' for unigram, prob in
|
||||||
|
sorted(context_probs_normalized, key=lambda x: x[1], reverse=True)])
|
||||||
|
if unk_prob > 0:
|
||||||
|
probs_string += f"\t:{unk_prob}"
|
||||||
|
out.write(probs_string + '\n')
|
||||||
|
|
||||||
|
def generate(self, prompt, length, temperature=0.5):
|
||||||
|
# for simplicity we assume that the prompt is always at least a trigram in terms of length
|
||||||
|
generation = prompt.split()
|
||||||
|
prompt = prompt.split()[-3:]
|
||||||
|
prompt = [word if word in self.vocab else '<UNK>' for word in prompt]
|
||||||
|
for i in tqdm(range(length), desc=f'Generating text'):
|
||||||
|
next_tri = [(k, v) for k, v in self.ngram_counts['tetragrams'].items() if list(k[:3]) == prompt and k[-1] != '<UNK>']
|
||||||
|
if next_tri:
|
||||||
|
top3 = sorted(next_tri, key=lambda x: x[1], reverse=True)[:min(3, len(next_tri))]
|
||||||
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
||||||
|
if len(next_tri) < 3:
|
||||||
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
||||||
|
else:
|
||||||
|
generation.append(random.choice(top3)[0][-1])
|
||||||
|
else:
|
||||||
|
generation.append(top3[0][0][-1])
|
||||||
|
else:
|
||||||
|
next_bi = [(k, v) for k, v in self.ngram_counts['trigrams'].items() if list(k[:2]) == prompt[-2:] and k[-1] != '<UNK>']
|
||||||
|
if next_bi:
|
||||||
|
top3 = sorted(next_bi, key=lambda x: x[1], reverse=True)[:min(3, len(next_bi))]
|
||||||
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
||||||
|
if len(next_bi) < 3:
|
||||||
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
||||||
|
else:
|
||||||
|
generation.append(random.choice(top3)[0][-1])
|
||||||
|
else:
|
||||||
|
generation.append(top3[0][0][-1])
|
||||||
|
else:
|
||||||
|
next_uni = [(k, v) for k, v in self.ngram_counts['bigrams'].items() if list(k[:1]) == prompt[-1:] and k[-1] != '<UNK>']
|
||||||
|
if next_uni:
|
||||||
|
top3 = sorted(next_uni, key=lambda x: x[1], reverse=True)[:min(3, len(next_uni))]
|
||||||
|
if random.choice([i for i in range(1, 11)]) > temperature * 10:
|
||||||
|
if len(next_uni) < 3:
|
||||||
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
||||||
|
else:
|
||||||
|
generation.append(random.choice(top3)[0][-1])
|
||||||
|
else:
|
||||||
|
generation.append(top3[0][0][-1])
|
||||||
|
else:
|
||||||
|
generation.append(random.choice(list(self.ngram_counts['unigrams'].keys())))
|
||||||
|
prompt = generation[-3:]
|
||||||
|
print(' '.join(generation))
|
||||||
|
|
||||||
|
|
||||||
|
# model = TetragramModel()
|
||||||
|
# model.train('train.tsv', vocab_size=X)
|
||||||
|
# model.predict_gaps('dev-0/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
|
||||||
|
# model.predict_gaps('test-A/in.tsv', top_k=True, k=5, lambdas=(X, X, X, X))
|
||||||
|
# model.generate('According to recent news', length=50)
|
||||||
|
# model.generate('Recent studies have shown that', length=50)
|
||||||
|
# model.generate('Today I was taking a stroll in the park when suddenly', length=50)
|
||||||
|
# model.generate('The most unbelievable story ever told goes like this', length=50)
|
||||||
|
# model.generate('The war between', length=50)
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user