47 lines
1.8 KiB
Python
47 lines
1.8 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
from sklearn.naive_bayes import MultinomialNB
|
|
#from sklearn.metrics import accuracy_score
|
|
from sklearn.feature_extraction.text import CountVectorizer
|
|
import lzma
|
|
import re
|
|
|
|
X_train_raw = lzma.open("train/in.tsv.xz", mode='rt', encoding='utf-8').readlines()
|
|
y_train_raw = open('train/expected.tsv').readlines()
|
|
X_dev0_raw = open("dev-0/in.tsv", "r").readlines()
|
|
y_expected_dev0_raw = open("dev-0/expected.tsv", "r").readlines()
|
|
X_dev1_raw = open("dev-1/in.tsv", "r").readlines()
|
|
y_expected_dev1_raw = open("dev-1/expected.tsv", "r").readlines()
|
|
X_test_raw = open("test-A/in.tsv", "r").readlines()
|
|
|
|
X_dev0_cleaned = [re.sub('\t(not-)?for-humans\t(not-)?contaminated\n', '', line) for line in X_dev0_raw]
|
|
X_dev1_cleaned = [re.sub('\t(not-)?for-humans\t(not-)?contaminated\n', '', line) for line in X_dev1_raw]
|
|
X_test_cleaned = [re.sub('\t(not-)?for-humans\t(not-)?contaminated\n', '', line) for line in X_test_raw]
|
|
|
|
count_vect = CountVectorizer()
|
|
X_train_counts = count_vect.fit_transform(X_train_raw)
|
|
X_dev0_counts = count_vect.transform(X_dev0_cleaned)
|
|
X_dev1_counts = count_vect.transform(X_dev1_cleaned)
|
|
X_test_counts = count_vect.transform(X_test_cleaned)
|
|
|
|
clf2 = MultinomialNB().fit(X_train_counts, y_train_raw)
|
|
|
|
y_predicted_dev0_MNB = clf2.predict(X_dev0_counts)
|
|
y_predicted_dev1_MNB = clf2.predict(X_dev1_counts)
|
|
y_predicted_test_MNB = clf2.predict(X_test_counts)
|
|
|
|
# accuracy_dev0_MNB = accuracy_score(y_expected_dev0_raw, y_predicted_dev0_MNB)
|
|
# print(f"Accuracy dev0: {accuracy_dev0_MNB}")
|
|
# accuracy_dev1_MNB = accuracy_score(y_expected_dev1_raw, y_predicted_dev1_MNB)
|
|
# print(f"Accuracy dev1: {accuracy_dev1_MNB}")
|
|
|
|
open("dev-0/out.tsv", mode='w').writelines(y_predicted_dev0_MNB)
|
|
open("dev-1/out.tsv", mode='w').writelines(y_predicted_dev1_MNB)
|
|
open("test-A/out.tsv", mode='w').writelines(y_predicted_test_MNB)
|
|
|
|
|
|
|
|
|
|
|