challenging-america-word-ga.../run.py

99 lines
3.7 KiB
Python
Raw Normal View History

2022-04-04 17:54:10 +02:00
import pandas as pd
import csv
from collections import Counter, defaultdict
from nltk.tokenize import RegexpTokenizer
from nltk import trigrams
2022-04-04 18:31:33 +02:00
import regex as re
import lzma
2022-04-25 16:58:55 +02:00
import kenlm
2022-04-25 23:18:15 +02:00
from math import log10
from english_words import english_words_set
2022-04-04 17:54:10 +02:00
class WordPred:
2022-04-04 17:54:10 +02:00
def __init__(self):
self.tokenizer = RegexpTokenizer(r"\w+")
2022-04-25 16:58:55 +02:00
# self.model = defaultdict(lambda: defaultdict(lambda: 0))
self.model = kenlm.Model("model.binary")
self.words = set()
2022-04-04 18:31:33 +02:00
def read_file(self, file):
2022-04-25 16:58:55 +02:00
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '',
re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower()))
2022-04-04 18:31:33 +02:00
def read_file_7(self, file):
2022-04-25 16:58:55 +02:00
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n", " ").replace("\n", "").lower()))
def fill_words(self, file_path, output_file):
with open(output_file, 'w') as out:
with lzma.open(file_path, mode='rt') as file:
for text in self.read_file(file):
2022-04-25 23:18:15 +02:00
for mword in text.split(" "):
if mword not in self.words:
out.write(mword + "\n")
self.words.add(mword)
2022-04-25 16:58:55 +02:00
def read_words(self, file_path):
with open(file_path, 'r') as fin:
2022-04-25 23:18:15 +02:00
for word in fin.readlines():
word = word.replace("\n", "")
if word:
self.words.add(word)
2022-04-25 16:58:55 +02:00
def create_train_file(self, file_path, output_path, rows=10000):
with open(output_path, 'w') as outputfile:
with lzma.open(file_path, mode='rt') as file:
for index, text in enumerate(self.read_file(file)):
outputfile.write(text)
if index == rows:
break
outputfile.close()
def generate_outputs(self, input_file, output_file):
2022-04-04 18:31:33 +02:00
with open(output_file, 'w') as outputf:
with lzma.open(input_file, mode='rt') as file:
for index, text in enumerate(self.read_file_7(file)):
tokens = self.tokenizer.tokenize(text)
if len(tokens) < 4:
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
else:
2022-04-04 18:43:08 +02:00
prediction = wp.predict_probs(tokens[0], tokens[1])
2022-04-04 18:31:33 +02:00
outputf.write(prediction + '\n')
2022-04-03 22:59:04 +02:00
2022-04-04 17:54:10 +02:00
def predict_probs(self, word1, word2):
2022-04-25 23:18:15 +02:00
preds = []
for word in english_words_set:
sentence = word1 + ' ' + word + ' ' + word2
words_score = self.model.score(sentence, bos=False, eos=False)
if len(preds) < 12:
preds.append((word, words_score))
else:
min_score = preds[0]
for score in preds:
if min_score[1] > score[1]:
min_score = score
if min_score[1] < words_score:
preds.remove(min_score)
preds.append((word, words_score))
probs = sorted(preds, key=lambda sc: sc[1], reverse=True)
2022-04-04 17:54:10 +02:00
str_prediction = ''
2022-04-25 23:18:15 +02:00
for word, prob in probs:
2022-04-04 17:54:10 +02:00
str_prediction += f'{word}:{prob} '
2022-04-25 23:18:15 +02:00
str_prediction += f':{log10(0.99)}'
2022-04-25 16:58:55 +02:00
2022-04-04 17:54:10 +02:00
return str_prediction
2022-04-25 16:58:55 +02:00
if __name__ == "__main__":
wp = WordPred()
# wp.create_train_file("train/in.tsv.xz", "train/in.txt")
2022-04-25 23:18:15 +02:00
# wp.fill_words("train/in.tsv.xz", "words.txt")
# wp.read_words("words.txt")
wp.generate_outputs("dev-0/in.tsv.xz", "dev-0/out3.tsv")
wp.generate_outputs("test-A/in.tsv.xz", "test-A/out3.tsv")