kenlm
This commit is contained in:
parent
0aa79cba31
commit
8359ba19e6
83
run.py
83
run.py
@ -5,45 +5,53 @@ from nltk.tokenize import RegexpTokenizer
|
|||||||
from nltk import trigrams
|
from nltk import trigrams
|
||||||
import regex as re
|
import regex as re
|
||||||
import lzma
|
import lzma
|
||||||
|
import kenlm
|
||||||
|
|
||||||
|
|
||||||
class WordPred:
|
class WordPred:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tokenizer = RegexpTokenizer(r"\w+")
|
self.tokenizer = RegexpTokenizer(r"\w+")
|
||||||
self.model = defaultdict(lambda: defaultdict(lambda: 0))
|
# self.model = defaultdict(lambda: defaultdict(lambda: 0))
|
||||||
self.vocab = set()
|
self.model = kenlm.Model("model.binary")
|
||||||
self.alpha = 0.001
|
self.words = set()
|
||||||
|
|
||||||
def read_file(self, file):
|
def read_file(self, file):
|
||||||
for line in file:
|
for line in file:
|
||||||
text = line.split("\t")
|
text = line.split("\t")
|
||||||
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n"," ").replace("\n","").lower()))
|
yield re.sub(r"[^\w\d'\s]+", '',
|
||||||
|
re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower()))
|
||||||
|
|
||||||
def read_file_7(self, file):
|
def read_file_7(self, file):
|
||||||
for line in file:
|
for line in file:
|
||||||
text = line.split("\t")
|
text = line.split("\t")
|
||||||
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n"," ").replace("\n","").lower()))
|
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[7].replace("\\n", " ").replace("\n", "").lower()))
|
||||||
|
|
||||||
def read_train_data(self, file_path):
|
def fill_words(self, file_path, output_file):
|
||||||
with lzma.open(file_path, mode='rt') as file:
|
with open(output_file, 'w') as out:
|
||||||
for index, text in enumerate(self.read_file(file)):
|
with lzma.open(file_path, mode='rt') as file:
|
||||||
tokens = self.tokenizer.tokenize(text)
|
for text in self.read_file(file):
|
||||||
for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
|
for word in text.split(" "):
|
||||||
if w1 and w2 and w3:
|
if word not in self.words:
|
||||||
self.model[(w2, w3)][w1] += 1
|
out.write(word + "\n")
|
||||||
self.vocab.add(w1)
|
self.words.add(word)
|
||||||
self.vocab.add(w2)
|
|
||||||
self.vocab.add(w3)
|
|
||||||
if index == 300000:
|
|
||||||
break
|
|
||||||
|
|
||||||
for word_pair in self.model:
|
|
||||||
num_n_grams = float(sum(self.model[word_pair].values()))
|
|
||||||
for word in self.model[word_pair]:
|
|
||||||
self.model[word_pair][word] = (self.model[word_pair][word] + self.alpha) / (num_n_grams + self.alpha*len(self.vocab))
|
|
||||||
|
|
||||||
def generate_outputs(self, input_file, output_file):
|
def read_words(self, file_path):
|
||||||
|
with open(file_path, 'r') as fin:
|
||||||
|
for word in fin.readline():
|
||||||
|
self.words.add(word.replace("\n",""))
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_file(self, file_path, output_path, rows=10000):
|
||||||
|
with open(output_path, 'w') as outputfile:
|
||||||
|
with lzma.open(file_path, mode='rt') as file:
|
||||||
|
for index, text in enumerate(self.read_file(file)):
|
||||||
|
outputfile.write(text)
|
||||||
|
if index == rows:
|
||||||
|
break
|
||||||
|
outputfile.close()
|
||||||
|
|
||||||
|
def generate_outputs(self, input_file, output_file):
|
||||||
with open(output_file, 'w') as outputf:
|
with open(output_file, 'w') as outputf:
|
||||||
with lzma.open(input_file, mode='rt') as file:
|
with lzma.open(input_file, mode='rt') as file:
|
||||||
for index, text in enumerate(self.read_file_7(file)):
|
for index, text in enumerate(self.read_file_7(file)):
|
||||||
@ -55,9 +63,8 @@ class WordPred:
|
|||||||
outputf.write(prediction + '\n')
|
outputf.write(prediction + '\n')
|
||||||
|
|
||||||
def predict_probs(self, word1, word2):
|
def predict_probs(self, word1, word2):
|
||||||
predictions = dict(self.model[word1, word2])
|
|
||||||
most_common = dict(Counter(predictions).most_common(6))
|
|
||||||
|
|
||||||
total_prob = 0.0
|
total_prob = 0.0
|
||||||
str_prediction = ''
|
str_prediction = ''
|
||||||
|
|
||||||
@ -69,13 +76,13 @@ class WordPred:
|
|||||||
return 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
return 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
||||||
|
|
||||||
if 1 - total_prob >= 0.01:
|
if 1 - total_prob >= 0.01:
|
||||||
str_prediction += f":{1-total_prob}"
|
str_prediction += f":{1 - total_prob}"
|
||||||
else:
|
else:
|
||||||
str_prediction += f":0.01"
|
str_prediction += f":0.01"
|
||||||
|
|
||||||
return str_prediction
|
return str_prediction
|
||||||
|
|
||||||
wp = WordPred()
|
if __name__ == "__main__":
|
||||||
wp.read_train_data('train/in.tsv.xz')
|
wp = WordPred()
|
||||||
wp.generate_outputs('dev-0/in.tsv.xz', 'dev-0/out.tsv')
|
# wp.create_train_file("train/in.tsv.xz", "train/in.txt")
|
||||||
wp.generate_outputs('test-A/in.tsv.xz', 'test-A/out.tsv')
|
# wp.fill_words("train/in.tsv.xz", "words.txt")
|
Loading…
Reference in New Issue
Block a user