Add yeild file read.

This commit is contained in:
Jan Nowak 2022-04-04 18:31:33 +02:00
parent ca72f4ea4a
commit 827021f1a3

64
run.py
View File

@ -3,6 +3,8 @@ import csv
from collections import Counter, defaultdict from collections import Counter, defaultdict
from nltk.tokenize import RegexpTokenizer from nltk.tokenize import RegexpTokenizer
from nltk import trigrams from nltk import trigrams
import regex as re
import lzma
class WordPred: class WordPred:
@ -10,34 +12,42 @@ class WordPred:
def __init__(self): def __init__(self):
self.tokenizer = RegexpTokenizer(r"\w+") self.tokenizer = RegexpTokenizer(r"\w+")
self.model = defaultdict(lambda: defaultdict(lambda: 0)) self.model = defaultdict(lambda: defaultdict(lambda: 0))
def read_train_data(self, file):
data = pd.read_csv(file, compression='xz', sep="\t", error_bad_lines=False, index_col=0, header=None)
for row in data[:140000].itertuples():
if len(row)<8:
continue
text = str(row[6]) + ' ' + str(row[7])
tokens = self.tokenizer.tokenize(text)
for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
if w1 and w2 and w3:
self.model[(w2, w3)][w1] += 1
for word_pair in self.model: def read_file(self, file):
num_n_grams = float(sum(self.model[word_pair].values())) for line in file:
for word in self.model[word_pair]: text = line.split("\t")
self.model[word_pair][word] /= num_n_grams yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n"," ").replace("\n","").lower()))
def read_file_7(self, file):
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n"," ").replace("\n","").lower()))
def generate_outputs(self, input_file, output_file): def read_train_data(self, file_path):
data = pd.read_csv(input_file, compression='xz', sep='\t', error_bad_lines=False, index_col=0, header=None, quoting=csv.QUOTE_NONE) with lzma.open(file_path, mode='rt') as file:
with open(output_file, 'w') as f: for index, text in enumerate(self.read_file(file)):
for row in data.iterrows():
text = str(row[7])
tokens = self.tokenizer.tokenize(text) tokens = self.tokenizer.tokenize(text)
if len(tokens) < 4: for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1' if w1 and w2 and w3:
else: self.model[(w2, w3)][w1] += 1
prediction = word_gap_prediction.predict_probs(tokens[0], tokens[1]) if index == 1000000:
f.write(prediction + '\n') break
for word_pair in self.model:
num_n_grams = float(sum(self.model[word_pair].values()))
for word in self.model[word_pair]:
self.model[word_pair][word] /= num_n_grams
def generate_outputs(self, input_file, output_file):
with open(output_file, 'w') as outputf:
with lzma.open(input_file, mode='rt') as file:
for index, text in enumerate(self.read_file_7(file)):
tokens = self.tokenizer.tokenize(text)
if len(tokens) < 4:
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
else:
prediction = word_gap_prediction.predict_probs(tokens[0], tokens[1])
outputf.write(prediction + '\n')
def predict_probs(self, word1, word2): def predict_probs(self, word1, word2):
predictions = dict(self.model[word1, word2]) predictions = dict(self.model[word1, word2])
@ -61,6 +71,6 @@ class WordPred:
return str_prediction return str_prediction
word_gap_prediction = WordPred() word_gap_prediction = WordPred()
word_gap_prediction.read_train_data('./train/in.tsv.xz') word_gap_prediction.read_train_data('train/in.tsv.xz')
# word_gap_prediction.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv') word_gap_prediction.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv')
# word_gap_prediction.generate_outputs('test-A/in.tsv.xz', 'test-A/out.tsv') # word_gap_prediction.generate_outputs('test-A/in.tsv.xz', 'test-A/out.tsv')