From 5bb9042abf1f5a83631cb789954f8185f48c2ef1 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Tue, 10 May 2022 23:53:20 +0200 Subject: [PATCH] Add initial implementation --- run.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..9a5a4df --- /dev/null +++ b/run.py @@ -0,0 +1,58 @@ +import lzma +from naivebayes import NaiveBayesTextClassifier +from spacy.lang.en.stop_words import STOP_WORDS as en_stop + + +def get_data(fname): + with open(fname, 'r', encoding='utf8') as file: + return file.readlines() + + +def get_data_zipped(fname): + with lzma.open(fname, 'r') as file: + return file.readlines() + + +def train_bayes(model, x, y, step=15000): + start = 0 + end = step + + for _ in range(0, len(x), step): + model.train(x[start:end], y[start:end]) + if start + step < len(x): + start += step + else: + start = 0 + end = min(start + step, len(x)) + + +def write_file(fname, data): + with open(fname, 'wt') as f: + for d in data: + f.write(f'{str(d)}\n') + +def main(): + train_x = get_data_zipped('train/in.tsv.xz') + train_y = get_data('train/expected.tsv') + # preprocessing + train_y = [int(y) for y in train_y] + + test_x = get_data_zipped('test-A/in.tsv.xz') + + dev_x = get_data_zipped('dev-0/in.tsv.xz') + + model = NaiveBayesTextClassifier( + categories=[0, 1], + stop_words=en_stop + ) + + train_bayes(model, train_x, train_y) + + predicted = model.classify(dev_x) + predicted_2= model.classify(test_x) + + write_file('dev-0/out.tsv', predicted) + write_file('test-A/out.tsv', predicted_2) + + +main() \ No newline at end of file