From 4b683a4656acc3b0582a21a161b1fce0a0381eb3 Mon Sep 17 00:00:00 2001 From: Wojciech Lidwin Date: Sun, 14 Apr 2024 16:56:32 +0200 Subject: [PATCH] add run.py --- run.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..039cf33 --- /dev/null +++ b/run.py @@ -0,0 +1,78 @@ +from nltk.tokenize import word_tokenize +from nltk import trigrams +from collections import defaultdict, Counter +import pandas as pd +import csv + + +class TextCompletionModel: + def __init__(self, smoothing_factor): + self.language_model = defaultdict(lambda: defaultdict(float)) + self.smoothing = smoothing_factor + self.dictionary = set() + self.fallback_prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1" + + @staticmethod + def clean_text(input_text): + return input_text.lower().replace("-\\n", "").replace("\\n", " ").replace("\xad", "").replace("\\\\n", " ").replace("\\\\", " ") + + def data(self, file_path, num_rows=90000): + data_frame = pd.read_csv(file_path, sep="\t", header=None, quoting=csv.QUOTE_NONE, nrows=num_rows) + return data_frame + + def train(self, content_data, tags_data): + content_data = content_data.reset_index(drop=True) + tags_data = tags_data.reset_index(drop=True) + + combined_data = pd.concat([content_data[[6, 7]], tags_data], axis=1) + combined_data['composed'] = combined_data[6].astype(str) + tags_data[0].astype(str) + combined_data[7].astype( + str) + + for line in combined_data['composed']: + tokens = word_tokenize(self.clean_text(line)) + for word1, word2, word3 in trigrams(tokens, pad_right=True, pad_left=True): + if word1 and word2 and word3: + self.language_model[(word2, word3)][word1] += 1 + self.language_model[(word1, word2)][word3] += 1 + self.dictionary.update([word1, word2, word3]) + + self.adjust_probabilities() + + def adjust_probabilities(self): + for pair in self.language_model: + total_count = sum(self.language_model[pair].values()) + self.smoothing * len(self.dictionary) + for token in self.language_model[pair]: + self.language_model[pair][token] = (self.language_model[pair][token] + self.smoothing) / total_count + + def predict(self, context): + if len(context) < 3: + return self.fallback_prediction + + possible_outcomes = dict(self.language_model[(context[0], context[1])]) + if not possible_outcomes: + return self.fallback_prediction + + formatted_prediction = ' '.join( + f"{term}:{round(prob, 2)}" for term, prob in Counter(possible_outcomes).most_common(6)) + return formatted_prediction.strip() + + def output_results(self, source_file, target_file): + data = self.data(source_file) + with open(target_file, "w", encoding="utf-8") as output: + for text in data[7]: + tokens = word_tokenize(self.clean_text(text)) + prediction = self.predict(tokens) + output.write(prediction + "\n") + + +# Example usage: +model = TextCompletionModel(smoothing_factor=0.00002) +input_data = model.data("train/in.tsv.xz") +expected_data = model.data("train/expected.tsv") +print('0') +model.train(input_data, expected_data) +print('1') +model.output_results("dev-0/in.tsv.xz", "dev-0/out.tsv") +print('2') +model.output_results("test-A/in.tsv.xz", "test-A/out.tsv") +print('3')