add run.py
This commit is contained in:
parent
708733b58d
commit
4b683a4656
78
run.py
Normal file
78
run.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from nltk import trigrams
|
||||||
|
from collections import defaultdict, Counter
|
||||||
|
import pandas as pd
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionModel:
|
||||||
|
def __init__(self, smoothing_factor):
|
||||||
|
self.language_model = defaultdict(lambda: defaultdict(float))
|
||||||
|
self.smoothing = smoothing_factor
|
||||||
|
self.dictionary = set()
|
||||||
|
self.fallback_prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_text(input_text):
|
||||||
|
return input_text.lower().replace("-\\n", "").replace("\\n", " ").replace("\xad", "").replace("\\\\n", " ").replace("\\\\", " ")
|
||||||
|
|
||||||
|
def data(self, file_path, num_rows=90000):
|
||||||
|
data_frame = pd.read_csv(file_path, sep="\t", header=None, quoting=csv.QUOTE_NONE, nrows=num_rows)
|
||||||
|
return data_frame
|
||||||
|
|
||||||
|
def train(self, content_data, tags_data):
|
||||||
|
content_data = content_data.reset_index(drop=True)
|
||||||
|
tags_data = tags_data.reset_index(drop=True)
|
||||||
|
|
||||||
|
combined_data = pd.concat([content_data[[6, 7]], tags_data], axis=1)
|
||||||
|
combined_data['composed'] = combined_data[6].astype(str) + tags_data[0].astype(str) + combined_data[7].astype(
|
||||||
|
str)
|
||||||
|
|
||||||
|
for line in combined_data['composed']:
|
||||||
|
tokens = word_tokenize(self.clean_text(line))
|
||||||
|
for word1, word2, word3 in trigrams(tokens, pad_right=True, pad_left=True):
|
||||||
|
if word1 and word2 and word3:
|
||||||
|
self.language_model[(word2, word3)][word1] += 1
|
||||||
|
self.language_model[(word1, word2)][word3] += 1
|
||||||
|
self.dictionary.update([word1, word2, word3])
|
||||||
|
|
||||||
|
self.adjust_probabilities()
|
||||||
|
|
||||||
|
def adjust_probabilities(self):
|
||||||
|
for pair in self.language_model:
|
||||||
|
total_count = sum(self.language_model[pair].values()) + self.smoothing * len(self.dictionary)
|
||||||
|
for token in self.language_model[pair]:
|
||||||
|
self.language_model[pair][token] = (self.language_model[pair][token] + self.smoothing) / total_count
|
||||||
|
|
||||||
|
def predict(self, context):
|
||||||
|
if len(context) < 3:
|
||||||
|
return self.fallback_prediction
|
||||||
|
|
||||||
|
possible_outcomes = dict(self.language_model[(context[0], context[1])])
|
||||||
|
if not possible_outcomes:
|
||||||
|
return self.fallback_prediction
|
||||||
|
|
||||||
|
formatted_prediction = ' '.join(
|
||||||
|
f"{term}:{round(prob, 2)}" for term, prob in Counter(possible_outcomes).most_common(6))
|
||||||
|
return formatted_prediction.strip()
|
||||||
|
|
||||||
|
def output_results(self, source_file, target_file):
|
||||||
|
data = self.data(source_file)
|
||||||
|
with open(target_file, "w", encoding="utf-8") as output:
|
||||||
|
for text in data[7]:
|
||||||
|
tokens = word_tokenize(self.clean_text(text))
|
||||||
|
prediction = self.predict(tokens)
|
||||||
|
output.write(prediction + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
model = TextCompletionModel(smoothing_factor=0.00002)
|
||||||
|
input_data = model.data("train/in.tsv.xz")
|
||||||
|
expected_data = model.data("train/expected.tsv")
|
||||||
|
print('0')
|
||||||
|
model.train(input_data, expected_data)
|
||||||
|
print('1')
|
||||||
|
model.output_results("dev-0/in.tsv.xz", "dev-0/out.tsv")
|
||||||
|
print('2')
|
||||||
|
model.output_results("test-A/in.tsv.xz", "test-A/out.tsv")
|
||||||
|
print('3')
|
Loading…
Reference in New Issue
Block a user