paranormal-or-skeptic-ISI-p.../classifier.py

42 lines
1.4 KiB
Python
Raw Normal View History

2022-05-07 21:32:19 +02:00
import lzma
from naivebayes import NaiveBayesTextClassifier
import nltk
from nltk.corpus import stopwords
nltk.download("stopwords")
# Read train files
with lzma.open("train/in.tsv.xz", "rt", encoding="utf-8") as train_file:
x_train = [x.strip().lower() for x in train_file.readlines()]
with open("train/expected.tsv", "r", encoding="utf-8") as train_file:
y_train = [int(x.strip()) for x in train_file.readlines()]
nbc = NaiveBayesTextClassifier(
categories=[0, 1],
stop_words=stopwords.words("english"),
min_df=1
)
step = 15000
for i in range(0, len(x_train), step):
nbc.train(x_train[i:min(i+step, len(x_train))], y_train[i:min(i+step, len(x_train))])
# Read dev files
with lzma.open("dev-0/in.tsv.xz", "rt", encoding="utf-8") as dev_file:
x_dev = [x.strip().lower() for x in dev_file.readlines()]
# Read test file
2022-05-07 21:40:37 +02:00
with lzma.open("test-A/in.tsv.xz", "rt", encoding="utf-8") as test_file:
2022-05-07 21:32:19 +02:00
x_test = [x.strip().lower() for x in test_file.readlines()]
# Predict dev
pred_dev = [str(x) + "\n" for x in nbc.classify(x_dev)]
with open("dev-0/out.tsv", "w", encoding="utf-8") as dev_out_file:
dev_out_file.writelines(pred_dev)
# Predict dev
pred_test = [str(x) + "\n" for x in nbc.classify(x_test)]
with open("test-A/out.tsv", "w", encoding="utf-8") as test_out_file:
test_out_file.writelines(pred_test)