Add yeild file read.

This commit is contained in:
Jan Nowak 2022-04-04 18:31:33 +02:00
parent ca72f4ea4a
commit 827021f1a3

64
run.py
View File

@ -3,6 +3,8 @@ import csv
from collections import Counter, defaultdict
from nltk.tokenize import RegexpTokenizer
from nltk import trigrams
import regex as re
import lzma
class WordPred:
@ -10,34 +12,42 @@ class WordPred:
def __init__(self):
self.tokenizer = RegexpTokenizer(r"\w+")
self.model = defaultdict(lambda: defaultdict(lambda: 0))
def read_train_data(self, file):
data = pd.read_csv(file, compression='xz', sep="\t", error_bad_lines=False, index_col=0, header=None)
for row in data[:140000].itertuples():
if len(row)<8:
continue
text = str(row[6]) + ' ' + str(row[7])
tokens = self.tokenizer.tokenize(text)
for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
if w1 and w2 and w3:
self.model[(w2, w3)][w1] += 1
for word_pair in self.model:
num_n_grams = float(sum(self.model[word_pair].values()))
for word in self.model[word_pair]:
self.model[word_pair][word] /= num_n_grams
def read_file(self, file):
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n"," ").replace("\n","").lower()))
def read_file_7(self, file):
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n"," ").replace("\n","").lower()))
def generate_outputs(self, input_file, output_file):
data = pd.read_csv(input_file, compression='xz', sep='\t', error_bad_lines=False, index_col=0, header=None, quoting=csv.QUOTE_NONE)
with open(output_file, 'w') as f:
for row in data.iterrows():
text = str(row[7])
def read_train_data(self, file_path):
with lzma.open(file_path, mode='rt') as file:
for index, text in enumerate(self.read_file(file)):
tokens = self.tokenizer.tokenize(text)
if len(tokens) < 4:
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
else:
prediction = word_gap_prediction.predict_probs(tokens[0], tokens[1])
f.write(prediction + '\n')
for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
if w1 and w2 and w3:
self.model[(w2, w3)][w1] += 1
if index == 1000000:
break
for word_pair in self.model:
num_n_grams = float(sum(self.model[word_pair].values()))
for word in self.model[word_pair]:
self.model[word_pair][word] /= num_n_grams
def generate_outputs(self, input_file, output_file):
with open(output_file, 'w') as outputf:
with lzma.open(input_file, mode='rt') as file:
for index, text in enumerate(self.read_file_7(file)):
tokens = self.tokenizer.tokenize(text)
if len(tokens) < 4:
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
else:
prediction = word_gap_prediction.predict_probs(tokens[0], tokens[1])
outputf.write(prediction + '\n')
def predict_probs(self, word1, word2):
predictions = dict(self.model[word1, word2])
@ -61,6 +71,6 @@ class WordPred:
return str_prediction
word_gap_prediction = WordPred()
word_gap_prediction.read_train_data('./train/in.tsv.xz')
# word_gap_prediction.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv')
word_gap_prediction.read_train_data('train/in.tsv.xz')
word_gap_prediction.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv')
# word_gap_prediction.generate_outputs('test-A/in.tsv.xz', 'test-A/out.tsv')