This commit is contained in:
Mikołaj Pokrywka 2023-06-08 12:42:04 +02:00
parent c7d96f1597
commit c0894d950a
3 changed files with 211 additions and 10672 deletions

File diff suppressed because it is too large Load Diff

43
hf.py Normal file
View File

@ -0,0 +1,43 @@
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
import sys
import regex as re
import sys
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
for line in sys.stdin:
input_text = line.split('\t')[-2].rstrip()
input_ids = tokenizer.encode(input_text, return_tensors='pt')
with torch.no_grad():
outputs = model(input_ids)
token_logits = outputs.logits[:, -1, : ]
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
top = torch.topk(probs, 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = [tokenizer.decode(x) for x in top_indices]
string_to_print = ''
sum_probs = 0
for w, p in zip(top_words, top_probs):
if '<unk>' in w:
continue
if re.search(r'\p{L}+', w):
string_to_print += f"{w}:{p} "
sum_probs += p
if string_to_print == '':
print(f"the:0.2 a:0.3 :0.5")
continue
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
print(string_to_print)

153
run.py
View File

@ -1,153 +0,0 @@
import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
def freq_list(g, top=None):
c = Counter(g)
if top is None:
items = c.items()
else:
items = c.most_common(top)
return OrderedDict(sorted(items, key=lambda t: -t[1]))
def get_words(t):
for m in re.finditer(r'[\p{L}0-9-\*]+', t):
yield m.group(0)
def ngrams(iter, size):
ngram = []
for item in iter:
ngram.append(item)
if len(ngram) == size:
yield tuple(ngram)
ngram = ngram[1:]
PREFIX_TRAIN = 'train'
words = []
counter_lines = 0
with lzma.open(f'{PREFIX_TRAIN}/in.tsv.xz', 'r') as train, open(f'{PREFIX_TRAIN}/expected.tsv', 'r') as expected:
for t_line, e_line in zip(train, expected):
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
e_line = e_line.rstrip()
t_line_splitted_by_tab = t_line.split('\t')
t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
words += t_line_cleared.split()
counter_lines+=1
if counter_lines > 90000:
break
# lzmaFile = lzma.open('dev-0/in.tsv.xz', 'rb')
# content = lzmaFile.read().decode("utf-8")
# words = get_words(trainset)
ngrams_ = ngrams(words, 2)
def create_probabilities_bigrams(w_c, b_c):
probabilities_bigrams = {}
for bigram, bigram_amount in b_c.items():
if bigram_amount <=2:
continue
p_word_before = bigram_amount / w_c[bigram[0]]
p_word_after = bigram_amount / w_c[bigram[1]]
probabilities_bigrams[bigram] = (p_word_before, p_word_after)
return probabilities_bigrams
words_c = Counter(words)
word_=''
bigram_c = Counter(ngrams_)
ngrams_=''
probabilities = create_probabilities_bigrams(words_c, bigram_c)
items = probabilities.items()
probabilities = OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
items=''
# sorted_by_freq = freq_list(ngrams)
PREFIX_VALID = 'test-A'
def count_probabilities(w_b, w_a, probs, w_c, b_c):
results_before = {}
results_after = {}
for bigram, probses in probs.items():
if len(results_before) > 20 or len(results_after) > 20:
break
if w_b == bigram[0]:
results_before[bigram] = probses[0]
if w_a == bigram[1]:
results_after[bigram] = probses[1]
a=1
best_ = {}
for bigram, probses in results_before.items():
for bigram_2, probses_2 in results_after.items():
best_[bigram[1]] = probses * probses_2
for bigram, probses in results_after.items():
for bigram_2, probses_2 in results_before.items():
if bigram[0] in best_:
if probses * probses_2 < probses_2:
continue
best_[bigram[0]] = probses * probses_2
items = best_.items()
return OrderedDict(sorted(items, key=lambda t:t[1], reverse=True))
with lzma.open(f'{PREFIX_VALID}/in.tsv.xz', 'r') as train:
for t_line in train:
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
t_line = t_line.replace('\\n', ' ')
t_line_splitted_by_tab = t_line.split('\t')
words_pre = t_line_splitted_by_tab[-2].split()
words_po = t_line_splitted_by_tab[-1].split()
w_pre = words_pre[-1]
w_po = words_po[0]
probs_ordered = count_probabilities(w_pre, w_po,probabilities, words_c, bigram_c)
if len(probs_ordered) ==0:
print(f"the:0.5 a:0.3 :0.2")
continue
result_string = ''
counter_ = 0
for word_, p in probs_ordered.items():
if counter_>4:
break
re_ = re.search(r'\p{L}+', word_)
if re_:
word_cleared = re_.group(0)
result_string += f"{word_cleared}:{str(p)} "
else:
if result_string == '':
result_string = f"the:0.5 a:0.3 "
continue
counter_+=1
result_string += ':0.1'
print(result_string)
a=1