From 8359ba19e6d469d400200438a4c131489c27612f Mon Sep 17 00:00:00 2001 From: Jan Nowak Date: Mon, 25 Apr 2022 16:58:55 +0200 Subject: [PATCH] kenlm --- run.py | 83 +++++++++++++++++++++++++++++++--------------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/run.py b/run.py index 80a5dfa..d00a082 100644 --- a/run.py +++ b/run.py @@ -5,45 +5,53 @@ from nltk.tokenize import RegexpTokenizer from nltk import trigrams import regex as re import lzma +import kenlm class WordPred: def __init__(self): self.tokenizer = RegexpTokenizer(r"\w+") - self.model = defaultdict(lambda: defaultdict(lambda: 0)) - self.vocab = set() - self.alpha = 0.001 - + # self.model = defaultdict(lambda: defaultdict(lambda: 0)) + self.model = kenlm.Model("model.binary") + self.words = set() + def read_file(self, file): - for line in file: - text = line.split("\t") - yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n"," ").replace("\n","").lower())) - + for line in file: + text = line.split("\t") + yield re.sub(r"[^\w\d'\s]+", '', + re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower())) + def read_file_7(self, file): - for line in file: - text = line.split("\t") - yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n"," ").replace("\n","").lower())) + for line in file: + text = line.split("\t") + yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n", " ").replace("\n", "").lower())) - def read_train_data(self, file_path): - with lzma.open(file_path, mode='rt') as file: - for index, text in enumerate(self.read_file(file)): - tokens = self.tokenizer.tokenize(text) - for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True): - if w1 and w2 and w3: - self.model[(w2, w3)][w1] += 1 - self.vocab.add(w1) - self.vocab.add(w2) - self.vocab.add(w3) - if index == 300000: - break - - for word_pair in self.model: - num_n_grams = float(sum(self.model[word_pair].values())) - for word in self.model[word_pair]: - self.model[word_pair][word] = (self.model[word_pair][word] + self.alpha) / (num_n_grams + self.alpha*len(self.vocab)) + def fill_words(self, file_path, output_file): + with open(output_file, 'w') as out: + with lzma.open(file_path, mode='rt') as file: + for text in self.read_file(file): + for word in text.split(" "): + if word not in self.words: + out.write(word + "\n") + self.words.add(word) - def generate_outputs(self, input_file, output_file): + def read_words(self, file_path): + with open(file_path, 'r') as fin: + for word in fin.readline(): + self.words.add(word.replace("\n","")) + + + def create_train_file(self, file_path, output_path, rows=10000): + with open(output_path, 'w') as outputfile: + with lzma.open(file_path, mode='rt') as file: + for index, text in enumerate(self.read_file(file)): + outputfile.write(text) + if index == rows: + break + outputfile.close() + + def generate_outputs(self, input_file, output_file): with open(output_file, 'w') as outputf: with lzma.open(input_file, mode='rt') as file: for index, text in enumerate(self.read_file_7(file)): @@ -55,9 +63,8 @@ class WordPred: outputf.write(prediction + '\n') def predict_probs(self, word1, word2): - predictions = dict(self.model[word1, word2]) - most_common = dict(Counter(predictions).most_common(6)) - + + total_prob = 0.0 str_prediction = '' @@ -69,13 +76,13 @@ class WordPred: return 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1' if 1 - total_prob >= 0.01: - str_prediction += f":{1-total_prob}" + str_prediction += f":{1 - total_prob}" else: str_prediction += f":0.01" - + return str_prediction -wp = WordPred() -wp.read_train_data('train/in.tsv.xz') -wp.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv') -wp.generate_outputs('test-A/in.tsv.xz', 'test-A/out.tsv') +if __name__ == "__main__": + wp = WordPred() + # wp.create_train_file("train/in.tsv.xz", "train/in.txt") + # wp.fill_words("train/in.tsv.xz", "words.txt") \ No newline at end of file