challenging-america-word-ga.../train.py
2023-04-13 21:33:24 +02:00

176 lines
4.2 KiB
Python

import lzma
import matplotlib.pyplot as plt
from math import log
from collections import OrderedDict
from collections import Counter
import regex as re
from itertools import islice
import json
import tqdm
ignore_rare = 15000 #7500 perpex511.51 9000 perpex=505 15000 perpex503
model_v = '1'
def freq_list(g, top=None):
c = Counter(g)
if top is None:
items = c.items()
else:
items = c.most_common(top)
return OrderedDict(sorted(items, key=lambda t: -t[1]))
def get_words(t):
for m in re.finditer(r'[\p{L}0-9-\*]+', t):
yield m.group(0)
def ngrams(iter, size, w_counter):
ngram = []
for item in iter:
if w_counter[item] < ignore_rare:
ngram.append('<UNK>')
else:
ngram.append(item)
if len(ngram) == size:
yield tuple(ngram)
ngram = ngram[1:]
PREFIX_TRAIN = 'train'
words = []
counter_lines = 0
with lzma.open(f'{PREFIX_TRAIN}/in.tsv.xz', 'r') as train, open(f'{PREFIX_TRAIN}/expected.tsv', 'r') as expected:
for t_line, e_line in zip(train, expected):
t_line = t_line.decode("utf-8")
t_line = t_line.rstrip()
e_line = e_line.rstrip()
t_line_splitted_by_tab = t_line.split('\t')
t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
t_line_cleared = t_line_cleared.lower()
t_line_cleared = t_line_cleared.replace("\\\\n", ' ')
words += re.findall(r'\p{L}+', t_line_cleared)
# t_line_splitted = t_line_cleared.split()
# words += t_line_splitted
if counter_lines % 100000 == 0:
print(counter_lines)
counter_lines+=1
if counter_lines > 130000: # 50000 12gb ram
break
words_c = Counter(words)
with open(f'vocab_{model_v}.txt', 'w') as f:
for word, amount in words_c.items():
if amount < ignore_rare:
continue
f.write(word + '\n')
with open(f'vocab_{model_v}.txt', 'w') as f:
for word, amount in words_c.items():
if amount < ignore_rare:
continue
f.write(word + '\n')
def create_model(grams4, trigrams):
model = {}
for gram4, amount4 in grams4.items():
trigram = gram4[:-1]
last_word = gram4[-1]
if last_word == "<UNK>":
continue
probibility = amount4 / trigrams[trigram]
if trigram in model:
model[trigram][last_word] = probibility
continue
model[trigram] = {last_word: probibility}
return model
def create_bigram_model(bigram_x, word_c_x):
model = {}
for gram4, amount4 in bigram_x.items():
word_key = gram4[0]
last_word = gram4[1]
if last_word == "<UNK>" or word_key=="<UNK>":
continue
try:
probibility = amount4 / word_c_x[word_key]
except:
print(gram4)
print(word_key)
print(last_word)
raise Exception
if word_key in model:
model[word_key][last_word] = probibility
continue
model[word_key] = {last_word: probibility}
return model
trigrams_ = ngrams(words, 3, words_c)
tetragrams_ = ngrams(words, 4, words_c)
trigram_c = Counter(trigrams_)
trigrams_ = ''
tetragrams_c = Counter(tetragrams_)
tetragrams_ = ''
model = create_model(tetragrams_c, trigram_c)
with open(f'4_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write("\t".join(trigram) + "\t" + json.dumps(hyps) + '\n')
# ========= Trigram
model=""
trigrams_ = ngrams(words, 3, words_c)
bigrams_ = ngrams(words, 2, words_c)
trigram_c = Counter(trigrams_)
trigrams_ = ''
bigram_c = Counter(bigrams_)
bigrams_ = ''
model = create_model(trigram_c, bigram_c)
trigram_c = ""
with open(f'3_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write("\t".join(trigram) + "\t" + json.dumps(hyps) + '\n')
model = ""
# ========= Bigram
model=""
bigrams_ = ngrams(words, 2, words_c)
bigram_c = Counter(bigrams_)
bigrams_ = ''
model = create_bigram_model(bigram_c, words_c)
with open(f'2_gram_model_{model_v}.tsv', 'w') as f:
for trigram, hyps in model.items():
f.write(trigram + "\t" + json.dumps(hyps) + '\n')
model = ""