Compare commits
No commits in common. "master" and "zad_9" have entirely different histories.
@ -1,9 +0,0 @@
|
||||
Challenging America word-gap prediction
|
||||
===================================
|
||||
|
||||
Guess a word in a gap.
|
||||
|
||||
Evaluation metric
|
||||
-----------------
|
||||
|
||||
LikelihoodHashed is the metric
|
@ -1 +0,0 @@
|
||||
--metric PerplexityHashed --precision 2 --in-header in-header.tsv --out-header out-header.tsv
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
FileId Year LeftContext RightContext
|
|
@ -1 +0,0 @@
|
||||
Word
|
|
153
run.py
153
run.py
@ -1,153 +0,0 @@
|
||||
import lzma
|
||||
import matplotlib.pyplot as plt
|
||||
from math import log
|
||||
from collections import OrderedDict
|
||||
from collections import Counter
|
||||
import regex as re
|
||||
from itertools import islice
|
||||
|
||||
def freq_list(g, top=None):
|
||||
c = Counter(g)
|
||||
|
||||
if top is None:
|
||||
items = c.items()
|
||||
else:
|
||||
items = c.most_common(top)
|
||||
|
||||
return OrderedDict(sorted(items, key=lambda t: -t[1]))
|
||||
|
||||
def get_words(t):
|
||||
for m in re.finditer(r'[\p{L}0-9-\*]+', t):
|
||||
yield m.group(0)
|
||||
|
||||
def ngrams(iter, size):
|
||||
ngram = []
|
||||
for item in iter:
|
||||
ngram.append(item)
|
||||
if len(ngram) == size:
|
||||
yield tuple(ngram)
|
||||
ngram = ngram[1:]
|
||||
|
||||
PREFIX_TRAIN = 'train'
|
||||
words = []
|
||||
|
||||
counter_lines = 0
|
||||
with lzma.open(f'{PREFIX_TRAIN}/in.tsv.xz', 'r') as train, open(f'{PREFIX_TRAIN}/expected.tsv', 'r') as expected:
|
||||
for t_line, e_line in zip(train, expected):
|
||||
t_line = t_line.decode("utf-8")
|
||||
|
||||
t_line = t_line.rstrip()
|
||||
e_line = e_line.rstrip()
|
||||
|
||||
t_line_splitted_by_tab = t_line.split('\t')
|
||||
|
||||
t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||
|
||||
words += t_line_cleared.split()
|
||||
|
||||
counter_lines+=1
|
||||
if counter_lines > 90000:
|
||||
break
|
||||
|
||||
# lzmaFile = lzma.open('dev-0/in.tsv.xz', 'rb')
|
||||
|
||||
# content = lzmaFile.read().decode("utf-8")
|
||||
# words = get_words(trainset)
|
||||
|
||||
ngrams_ = ngrams(words, 2)
|
||||
|
||||
|
||||
def create_probabilities_bigrams(w_c, b_c):
|
||||
probabilities_bigrams = {}
|
||||
for bigram, bigram_amount in b_c.items():
|
||||
if bigram_amount <=2:
|
||||
continue
|
||||
p_word_before = bigram_amount / w_c[bigram[0]]
|
||||
p_word_after = bigram_amount / w_c[bigram[1]]
|
||||
probabilities_bigrams[bigram] = (p_word_before, p_word_after)
|
||||
|
||||
return probabilities_bigrams
|
||||
|
||||
words_c = Counter(words)
|
||||
word_=''
|
||||
bigram_c = Counter(ngrams_)
|
||||
ngrams_=''
|
||||
probabilities = create_probabilities_bigrams(words_c, bigram_c)
|
||||
|
||||
|
||||
items = probabilities.items()
|
||||
probabilities = OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
|
||||
items=''
|
||||
# sorted_by_freq = freq_list(ngrams)
|
||||
|
||||
PREFIX_VALID = 'test-A'
|
||||
|
||||
def count_probabilities(w_b, w_a, probs, w_c, b_c):
|
||||
results_before = {}
|
||||
results_after = {}
|
||||
for bigram, probses in probs.items():
|
||||
if len(results_before) > 20 or len(results_after) > 20:
|
||||
break
|
||||
if w_b == bigram[0]:
|
||||
results_before[bigram] = probses[0]
|
||||
if w_a == bigram[1]:
|
||||
results_after[bigram] = probses[1]
|
||||
a=1
|
||||
best_ = {}
|
||||
|
||||
for bigram, probses in results_before.items():
|
||||
for bigram_2, probses_2 in results_after.items():
|
||||
best_[bigram[1]] = probses * probses_2
|
||||
|
||||
for bigram, probses in results_after.items():
|
||||
for bigram_2, probses_2 in results_before.items():
|
||||
if bigram[0] in best_:
|
||||
if probses * probses_2 < probses_2:
|
||||
continue
|
||||
best_[bigram[0]] = probses * probses_2
|
||||
|
||||
items = best_.items()
|
||||
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
|
||||
|
||||
|
||||
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
|
||||
for t_line in train:
|
||||
t_line = t_line.decode("utf-8")
|
||||
|
||||
t_line = t_line.rstrip()
|
||||
t_line = t_line.replace('\\n', ' ')
|
||||
|
||||
|
||||
t_line_splitted_by_tab = t_line.split('\t')
|
||||
|
||||
|
||||
words_pre = t_line_splitted_by_tab[-2].split()
|
||||
|
||||
words_po = t_line_splitted_by_tab[-1].split()
|
||||
|
||||
w_pre = words_pre[-1]
|
||||
w_po = words_po[0]
|
||||
|
||||
probs_ordered = count_probabilities(w_pre, w_po,probabilities, words_c, bigram_c)
|
||||
if len(probs_ordered) ==0:
|
||||
print(f"the:0.5 a:0.3 :0.2")
|
||||
continue
|
||||
result_string = ''
|
||||
counter_ = 0
|
||||
for word_, p in probs_ordered.items():
|
||||
if counter_>4:
|
||||
break
|
||||
re_ = re.search(r'\p{L}+', word_)
|
||||
if re_:
|
||||
word_cleared = re_.group(0)
|
||||
result_string += f"{word_cleared}:{str(p)} "
|
||||
|
||||
else:
|
||||
if result_string == '':
|
||||
result_string = f"the:0.5 a:0.3 "
|
||||
continue
|
||||
|
||||
counter_+=1
|
||||
result_string += ':0.1'
|
||||
print(result_string)
|
||||
a=1
|
16
scripts.py
Normal file
16
scripts.py
Normal file
@ -0,0 +1,16 @@
|
||||
import regex as re
|
||||
import string
|
||||
|
||||
|
||||
def get_words_from_line(line):
|
||||
line = line.rstrip()
|
||||
# line = line.lower()
|
||||
line = line.strip()
|
||||
line = line.translate(str.maketrans('', '', string.punctuation))
|
||||
# yield '<s>'
|
||||
for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
|
||||
yield m.group(0).lower()
|
||||
# yield '</s>'
|
||||
|
||||
vocab_size = 60000
|
||||
learning_rate=0.0001
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
70
utils.py
Normal file
70
utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
import regex as re
|
||||
import string
|
||||
from torch import nn
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torch.utils.data import IterableDataset
|
||||
import itertools
|
||||
import lzma
|
||||
import regex as re
|
||||
import pickle
|
||||
import scripts
|
||||
import string
|
||||
|
||||
|
||||
def get_words_from_line(line):
|
||||
line = line.rstrip()
|
||||
line = line.lower()
|
||||
line = line.strip()
|
||||
line = line.translate(str.maketrans('', '', string.punctuation))
|
||||
yield '<s>'
|
||||
for m in re.finditer(r'\p{L}+', line):
|
||||
yield m.group(0)
|
||||
yield '</s>'
|
||||
|
||||
vocab_size = 32000
|
||||
learning_rate=0.0001
|
||||
embed_size = 100
|
||||
device = 'cuda'
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
def __init__(self, vocabulary_size, embedding_size):
|
||||
super(LanguageModel, self).__init__()
|
||||
self.embedings = nn.Embedding(vocabulary_size, embedding_size)
|
||||
self.linear = nn.Linear(embedding_size*3, vocabulary_size)
|
||||
|
||||
self.linear_first_layer = nn.Linear(embedding_size*5, embedding_size*3)
|
||||
self.relu = nn.ReLU()
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
# self.model = nn.Sequential(
|
||||
# nn.Embedding(vocabulary_size, embedding_size),
|
||||
# nn.Linear(embedding_size, vocabulary_size),
|
||||
# nn.Softmax()
|
||||
# )
|
||||
|
||||
def forward(self, x_in):
|
||||
# emb_1 = self.embedings(x[0])
|
||||
# emb_2 = self.embedings(x[1])
|
||||
|
||||
|
||||
|
||||
embeddings = [self.embedings(x) for x in x_in]
|
||||
|
||||
first = embeddings[0]
|
||||
to_sum = embeddings[1:6]
|
||||
to_concat = embeddings[6:]
|
||||
|
||||
for t in to_sum:
|
||||
first = torch.add(first, t)
|
||||
|
||||
to_concat.insert(0, first)
|
||||
|
||||
first_layer = self.linear_first_layer(torch.cat(to_concat, dim=1))
|
||||
after_relu = self.relu(first_layer)
|
||||
concated = self.linear(after_relu)
|
||||
|
||||
y = self.softmax(concated)
|
||||
|
||||
return y
|
29
x_create_vocab.py
Normal file
29
x_create_vocab.py
Normal file
@ -0,0 +1,29 @@
|
||||
from itertools import islice
|
||||
import regex as re
|
||||
import sys
|
||||
from torchtext.vocab import build_vocab_from_iterator
|
||||
import lzma
|
||||
import utils
|
||||
import torch
|
||||
def get_word_lines_from_file(file_name):
|
||||
counter=0
|
||||
with lzma.open(file_name, 'r') as fh:
|
||||
for line in fh:
|
||||
counter+=1
|
||||
# if counter == 4000:
|
||||
# break
|
||||
line = line.decode("utf-8")
|
||||
yield utils.get_words_from_line(line)
|
||||
|
||||
|
||||
vocab_size = utils.vocab_size
|
||||
|
||||
vocab = build_vocab_from_iterator(
|
||||
get_word_lines_from_file('train/in.tsv.xz'),
|
||||
max_tokens = vocab_size,
|
||||
specials = ['<unk>', '<empty>'])
|
||||
|
||||
|
||||
import pickle
|
||||
with open("vocab.pickle", 'wb') as handle:
|
||||
pickle.dump(vocab, handle)
|
348
x_train.py
Normal file
348
x_train.py
Normal file
@ -0,0 +1,348 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import copy
|
||||
from torch.utils.data import IterableDataset
|
||||
import itertools
|
||||
import lzma
|
||||
import regex as re
|
||||
import pickle
|
||||
import scripts
|
||||
import string
|
||||
import pdb
|
||||
import utils
|
||||
|
||||
def divide_chunks(l, n):
|
||||
|
||||
# looping till length l
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
with open("vocab.pickle", 'rb') as handle:
|
||||
vocab = pickle.load( handle)
|
||||
vocab.set_default_index(vocab['<unk>'])
|
||||
|
||||
|
||||
|
||||
def look_ahead_iterator(gen):
|
||||
seq = []
|
||||
counter = 0
|
||||
for item in gen:
|
||||
seq.append(item)
|
||||
if counter % 11 == 0 and counter !=0:
|
||||
if len(seq) == 11:
|
||||
yield seq
|
||||
seq = []
|
||||
counter+=1
|
||||
|
||||
def get_word_lines_from_file(file_name):
|
||||
counter=0
|
||||
with lzma.open(file_name, 'r') as fh:
|
||||
for line in fh:
|
||||
counter+=1
|
||||
# if counter == 100000:
|
||||
# break
|
||||
line = line.decode("utf-8")
|
||||
yield scripts.get_words_from_line(line)
|
||||
|
||||
|
||||
|
||||
class Grams_10(IterableDataset):
|
||||
def load_vocab(self):
|
||||
with open("vocab.pickle", 'rb') as handle:
|
||||
vocab = pickle.load( handle)
|
||||
return vocab
|
||||
|
||||
def __init__(self, text_file, vocab):
|
||||
self.vocab = vocab
|
||||
self.vocab.set_default_index(self.vocab['<unk>'])
|
||||
self.text_file = text_file
|
||||
|
||||
def __iter__(self):
|
||||
return look_ahead_iterator(
|
||||
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
|
||||
|
||||
vocab_size = scripts.vocab_size
|
||||
|
||||
train_dataset = Grams_10('train/in.tsv.xz', vocab)
|
||||
BATCH_SIZE = 2048
|
||||
|
||||
train_data = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
|
||||
PREFIX_TRAIN = 'train'
|
||||
PREFIX_VALID = 'dev-0'
|
||||
BATCHES = []
|
||||
# def read_train_file(folder_prefix, vocab):
|
||||
# dataset_x = []
|
||||
# dataset_y = []
|
||||
# counter_lines = 0
|
||||
# seq_len = 10
|
||||
# with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||||
# for t_line, e_line in zip(train, expected):
|
||||
# t_line = t_line.decode("utf-8")
|
||||
# t_line = t_line.rstrip()
|
||||
# e_line = e_line.rstrip()
|
||||
# t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||
# t_line_splitted_by_tab = t_line.split('\t')
|
||||
# # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||
|
||||
|
||||
# whole_line = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||
|
||||
# whole_line_splitted = list(scripts.get_words_from_line(whole_line))
|
||||
|
||||
# whole_lines_splitted = divide_chunks(whole_line_splitted, 11)
|
||||
|
||||
# for chunk_line in whole_line_splitted:
|
||||
|
||||
|
||||
# left_context_splitted = chunk_line[0:10]
|
||||
|
||||
# seq_x = []
|
||||
# for i in range(seq_len):
|
||||
# index = -1 - i
|
||||
# if len(left_context_splitted) < i + 1:
|
||||
# seq_x.insert(0, '<empty>')
|
||||
# else:
|
||||
# seq_x.insert(0, left_context_splitted[-1 -i])
|
||||
|
||||
# left_vocabed = [vocab[t] for t in seq_x]
|
||||
|
||||
|
||||
# dataset_x.append(left_vocabed )
|
||||
# dataset_y.append([vocab[chunk_line[10]]])
|
||||
|
||||
# counter_lines+=1
|
||||
# # if counter_lines > 20000:
|
||||
# # break
|
||||
# return dataset_x, dataset_y
|
||||
|
||||
def read_dev_file(folder_prefix, vocab):
|
||||
dataset_x = []
|
||||
dataset_y = []
|
||||
counter_lines = 0
|
||||
seq_len = 10
|
||||
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||||
for t_line, e_line in zip(train, expected):
|
||||
t_line = t_line.decode("utf-8")
|
||||
t_line = t_line.rstrip()
|
||||
e_line = e_line.rstrip()
|
||||
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||
t_line_splitted_by_tab = t_line.split('\t')
|
||||
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||
|
||||
left_context = t_line_splitted_by_tab[-2]
|
||||
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||||
|
||||
|
||||
seq_x = []
|
||||
for i in range(seq_len):
|
||||
index = -1 - i
|
||||
if len(left_context_splitted) < i + 1:
|
||||
seq_x.insert(0, '<empty>')
|
||||
else:
|
||||
seq_x.insert(0, left_context_splitted[-1 -i])
|
||||
|
||||
left_vocabed = [vocab[t] for t in seq_x]
|
||||
|
||||
|
||||
dataset_x.append(left_vocabed )
|
||||
dataset_y.append([vocab[e_line]])
|
||||
|
||||
counter_lines+=1
|
||||
# if counter_lines > 20000:
|
||||
# break
|
||||
return dataset_x, dataset_y
|
||||
|
||||
def read_test_file(folder_prefix, vocab):
|
||||
dataset_x = []
|
||||
dataset_y = []
|
||||
counter_lines = 0
|
||||
seq_len = 10
|
||||
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train:
|
||||
for t_line in train:
|
||||
t_line = t_line.decode("utf-8")
|
||||
t_line = t_line.rstrip()
|
||||
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||
t_line_splitted_by_tab = t_line.split('\t')
|
||||
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||
|
||||
left_context = t_line_splitted_by_tab[-2]
|
||||
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||||
|
||||
|
||||
seq_x = []
|
||||
for i in range(seq_len):
|
||||
index = -1 - i
|
||||
if len(left_context_splitted) < i + 1:
|
||||
seq_x.insert(0, '<empty>')
|
||||
else:
|
||||
seq_x.insert(0, left_context_splitted[-1 -i])
|
||||
|
||||
left_vocabed = [vocab[t] for t in seq_x]
|
||||
|
||||
|
||||
dataset_x.append(left_vocabed )
|
||||
|
||||
counter_lines+=1
|
||||
# if counter_lines > 20000:
|
||||
# break
|
||||
return dataset_x
|
||||
|
||||
|
||||
|
||||
# train_set_x, train_set_y = read_file(PREFIX_TRAIN, vocab)
|
||||
dev_set_x, dev_set_y = read_dev_file(PREFIX_VALID, vocab)
|
||||
|
||||
test_set_x = read_test_file('test-A', vocab)
|
||||
|
||||
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||||
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||||
|
||||
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||||
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||||
|
||||
|
||||
dev_data_x = DataLoader(dev_set_x, batch_size=1)
|
||||
dev_data_y = DataLoader(dev_set_y, batch_size=1)
|
||||
|
||||
|
||||
|
||||
test_set_x = DataLoader(test_set_x, batch_size=1)
|
||||
# pdb.set_trace()
|
||||
device = utils.device
|
||||
|
||||
model = utils.LanguageModel(scripts.vocab_size, utils.embed_size).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=utils.learning_rate)
|
||||
criterion = torch.nn.NLLLoss()
|
||||
model.train()
|
||||
|
||||
step = 0
|
||||
last_best_acc = -1
|
||||
epochs = 3
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
for batch in train_data:
|
||||
x = batch[:10]
|
||||
y = [batch[10]]
|
||||
|
||||
x = [i.to(device) for i in x]
|
||||
y = y[0].to(device)
|
||||
optimizer.zero_grad()
|
||||
ypredicted = model(x)
|
||||
# pdb.set_trace()
|
||||
loss = criterion(torch.log(ypredicted), y)
|
||||
if step % 10000 == 0:
|
||||
print('Step: ', step, loss)
|
||||
# torch.save(model.state_dict(), f'model1_{step}.bin')
|
||||
step += 1
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# evaluation
|
||||
model.eval()
|
||||
y_predeicted = []
|
||||
top_50_true = 0
|
||||
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||||
# pdb.set_trace()
|
||||
d_x = [i.to(device) for i in d_x]
|
||||
# d_y = d_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
ypredicted = model(d_x)
|
||||
|
||||
top = torch.topk(ypredicted[0], 64)
|
||||
top_indices = top.indices.tolist()
|
||||
if d_y[0] in top_indices:
|
||||
top_50_true+=1
|
||||
my_acc = top_50_true/len(dev_data_y)
|
||||
print('My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||||
if my_acc > last_best_acc:
|
||||
print('NEW BEST -- My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||||
last_best_acc = my_acc
|
||||
best_model = copy.deepcopy(model)
|
||||
torch.save(model.state_dict(), f'model_last_best_.bin')
|
||||
if epoch % 15 == 0:
|
||||
print('Epoch: ', epoch, step, loss)
|
||||
# torch.save(model.state_dict(), f'model_epoch_{epoch}_.bin')
|
||||
|
||||
|
||||
|
||||
# inference
|
||||
print('inference')
|
||||
inference_result = []
|
||||
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||||
# pdb.set_trace()
|
||||
d_x = [i.to(device) for i in d_x]
|
||||
# d_y = d_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
ypredicted = model(d_x)
|
||||
|
||||
top = torch.topk(ypredicted[0], 10)
|
||||
top_indices = top.indices.tolist()
|
||||
top_probs = top.values.tolist()
|
||||
top_words = vocab.lookup_tokens(top_indices)
|
||||
|
||||
string_to_print = ''
|
||||
|
||||
sum_probs = 0
|
||||
for w, p in zip(top_words, top_probs):
|
||||
# print(top_words)
|
||||
if '<unk>' in w:
|
||||
continue
|
||||
string_to_print += f"{w}:{p} "
|
||||
sum_probs += p
|
||||
|
||||
if string_to_print == '':
|
||||
inference_result.append("the:0.2 a:0.3 :0.5")
|
||||
continue
|
||||
unknow_prob = 1 - sum_probs
|
||||
string_to_print += f":{unknow_prob}"
|
||||
|
||||
inference_result.append(string_to_print)
|
||||
|
||||
with open('dev-0/out.tsv', 'w') as f:
|
||||
for line in inference_result:
|
||||
f.write(line+'\n')
|
||||
|
||||
|
||||
print('inference test')
|
||||
inference_result = []
|
||||
for d_x in test_set_x:
|
||||
# pdb.set_trace()
|
||||
d_x = [i.to(device) for i in d_x]
|
||||
# d_y = d_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
ypredicted = model(d_x)
|
||||
|
||||
top = torch.topk(ypredicted[0], 64)
|
||||
top_indices = top.indices.tolist()
|
||||
top_probs = top.values.tolist()
|
||||
top_words = vocab.lookup_tokens(top_indices)
|
||||
|
||||
string_to_print = ''
|
||||
|
||||
sum_probs = 0
|
||||
for w, p in zip(top_words, top_probs):
|
||||
# print(top_words)
|
||||
if '<unk>' in w:
|
||||
continue
|
||||
string_to_print += f"{w}:{p} "
|
||||
sum_probs += p
|
||||
|
||||
if string_to_print == '':
|
||||
inference_result.append("the:0.2 a:0.3 :0.5")
|
||||
continue
|
||||
unknow_prob = 1 - sum_probs
|
||||
string_to_print += f":{unknow_prob}"
|
||||
|
||||
inference_result.append(string_to_print)
|
||||
|
||||
with open('test-A/out.tsv', 'w') as f:
|
||||
for line in inference_result:
|
||||
f.write(line+'\n')
|
||||
print('All done')
|
Loading…
Reference in New Issue
Block a user