challenging-america-word-ga.../train.py

176 lines
4.2 KiB
Python
Raw Normal View History

2023-04-12 20:56:08 +02:00
import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
import json
import tqdm
2023-04-13 21:33:24 +02:00
ignore_rare = 15000 #7500 perpex511.51 9000 perpex=505 15000 perpex503
model_v = '1'
2023-04-12 20:56:08 +02:00
def freq_list(g, top=None):
c = Counter(g)
if top is None:
items = c.items()
else:
items = c.most_common(top)
return OrderedDict(sorted(items, key=lambda t: -t[1]))
def get_words(t):
for m in re.finditer(r'[\p{L}0-9-\*]+', t):
yield m.group(0)
def ngrams(iter, size, w_counter):
ngram = []
for item in iter:
if w_counter[item] < ignore_rare:
ngram.append('<UNK>')
else:
ngram.append(item)
if len(ngram) == size:
yield tuple(ngram)
ngram = ngram[1:]
PREFIX_TRAIN = 'train'
words = []
counter_lines = 0
with lzma.open(f'{PREFIX_TRAIN}/in.tsv.xz', 'r') as train, open(f'{PREFIX_TRAIN}/expected.tsv', 'r') as expected:
for t_line, e_line in zip(train, expected):
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
e_line = e_line.rstrip()
t_line_splitted_by_tab = t_line.split('\t')
t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
t_line_cleared = t_line_cleared.lower()
2023-04-13 21:33:24 +02:00
t_line_cleared = t_line_cleared.replace("\\\\n", ' ')
2023-04-12 20:56:08 +02:00
words += re.findall(r'\p{L}+', t_line_cleared)
# t_line_splitted = t_line_cleared.split()
# words += t_line_splitted
if counter_lines % 100000 == 0:
print(counter_lines)
counter_lines+=1
2023-04-13 21:33:24 +02:00
if counter_lines > 130000: # 50000 12gb ram
2023-04-12 20:56:08 +02:00
break
words_c = Counter(words)
with open(f'vocab_{model_v}.txt', 'w') as f:
for word, amount in words_c.items():
if amount < ignore_rare:
continue
f.write(word + '\n')
2023-04-13 21:33:24 +02:00
with open(f'vocab_{model_v}.txt', 'w') as f:
for word, amount in words_c.items():
if amount < ignore_rare:
continue
f.write(word + '\n')
def create_model(grams4, trigrams):
model = {}
for gram4, amount4 in grams4.items():
trigram = gram4[:-1]
last_word = gram4[-1]
if last_word == "<UNK>":
continue
probibility = amount4 / trigrams[trigram]
if trigram in model:
model[trigram][last_word] = probibility
continue
model[trigram] = {last_word: probibility}
return model
def create_bigram_model(bigram_x, word_c_x):
model = {}
for gram4, amount4 in bigram_x.items():
word_key = gram4[0]
last_word = gram4[1]
if last_word == "<UNK>" or word_key=="<UNK>":
continue
try:
probibility = amount4 / word_c_x[word_key]
except:
print(gram4)
print(word_key)
print(last_word)
raise Exception
if word_key in model:
model[word_key][last_word] = probibility
continue
model[word_key] = {last_word: probibility}
return model
2023-04-12 20:56:08 +02:00
trigrams_ = ngrams(words, 3, words_c)
tetragrams_ = ngrams(words, 4, words_c)
2023-04-13 21:33:24 +02:00
trigram_c = Counter(trigrams_)
trigrams_ = ''
tetragrams_c = Counter(tetragrams_)
tetragrams_ = ''
model = create_model(tetragrams_c, trigram_c)
with open(f'4_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write("\t".join(trigram) + "\t" + json.dumps(hyps) + '\n')
# ========= Trigram
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
model=""
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
trigrams_ = ngrams(words, 3, words_c)
bigrams_ = ngrams(words, 2, words_c)
2023-04-12 20:56:08 +02:00
trigram_c = Counter(trigrams_)
2023-04-13 21:33:24 +02:00
trigrams_ = ''
bigram_c = Counter(bigrams_)
bigrams_ = ''
model = create_model(trigram_c, bigram_c)
trigram_c = ""
with open(f'3_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write("\t".join(trigram) + "\t" + json.dumps(hyps) + '\n')
model = ""
# ========= Bigram
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
model=""
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
bigrams_ = ngrams(words, 2, words_c)
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
bigram_c = Counter(bigrams_)
bigrams_ = ''
model = create_bigram_model(bigram_c, words_c)
2023-04-12 20:56:08 +02:00
2023-04-13 21:33:24 +02:00
with open(f'2_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write(trigram + "\t" + json.dumps(hyps) + '\n')
model = ""