From 6916e0e54b5416ac90b28c818d1f1edadff964db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Parafin=CC=81ski?= Date: Tue, 10 May 2022 23:53:49 +0200 Subject: [PATCH] solution --- run.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..61ce2d6 --- /dev/null +++ b/run.py @@ -0,0 +1,48 @@ +from naivebayes import NaiveBayesTextClassifier +from spacy.lang.en.stop_words import STOP_WORDS as en_stop + +naive_bayes = NaiveBayesTextClassifier( + categories=[0, 1], + stop_words=en_stop +) + +with open('train/in.tsv', 'r', encoding='utf8') as f: + train = f.readlines() + +with open('train/expected.tsv', 'r', encoding='utf8') as f: + expected = f.readlines() + +for i in range(0, len(expected)): + expected[i] = int(expected[i]) + +step = 20000 +start, end = 0, step + +for i in range(0, len(expected), step): + naive_bayes.train(train[start:end], expected[start:end]) + if start + step < len(expected): + start += step + else: + start = 0 + end = min(start + step, len(expected)) + + +with open('dev-0/in.tsv', 'r', encoding='utf8') as f: + dev_0 = f.readlines() + +predicted_dev_0 = naive_bayes.classify(dev_0) + +with open('dev-0/out.tsv', 'wt') as f: + for p in predicted_dev_0: + f.write(str(p) + '\n') +f.close() + +with open('test-A/in.tsv', 'r', encoding='utf8') as f: + test_A = f.readlines() + +predicted_test_A = naive_bayes.classify(test_A) + +with open('test-A/out.tsv', 'wt') as f: + for p in predicted_test_A: + f.write(str(p) + '\n') +f.close() \ No newline at end of file