Naive bayes with sklearn
This commit is contained in:
parent
9cb2fb2612
commit
36f9bbfa9f
5452
dev-0/out.tsv
Normal file
5452
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
52
main.py
Normal file
52
main.py
Normal file
@ -0,0 +1,52 @@
|
||||
import pandas as pd
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
|
||||
def main():
|
||||
clf = MultinomialNB()
|
||||
vectorizer = TfidfVectorizer()
|
||||
x_train, y_train = get_training_data(vectorizer)
|
||||
clf.fit(x_train, y_train)
|
||||
|
||||
x_dev = get_dev_data(vectorizer)
|
||||
Y_dev_predicted = clf.predict(x_dev)
|
||||
save_to_tsv(Y_dev_predicted, 'dev-0/out.tsv')
|
||||
|
||||
x_test = get_test_data(vectorizer)
|
||||
Y_test_predicted = clf.predict(x_test)
|
||||
save_to_tsv(Y_test_predicted, 'test-A/out.tsv')
|
||||
|
||||
|
||||
def save_to_tsv(data, path):
|
||||
pd.DataFrame(data).to_csv(path, sep='\t', index=False, header=False)
|
||||
|
||||
|
||||
def get_training_data(vectorizer):
|
||||
train_dataset = pd.read_csv('train/train.tsv/train.tsv', sep='\t', header=None, error_bad_lines=False)
|
||||
|
||||
y_train = train_dataset[0]
|
||||
X_train = train_dataset[1]
|
||||
x_train = [str(item) for item in X_train.to_numpy()]
|
||||
|
||||
x_train = vectorizer.fit_transform(x_train)
|
||||
|
||||
return x_train, y_train
|
||||
|
||||
|
||||
def get_dev_data(vectorizer):
|
||||
dev_dataset = pd.read_csv('dev-0/in.tsv', sep='\t', header=None, error_bad_lines=False)
|
||||
X_dev = [str(item) for item in dev_dataset.to_numpy()]
|
||||
|
||||
return vectorizer.transform(X_dev)
|
||||
|
||||
|
||||
def get_test_data(vectorizer):
|
||||
test_dataset = pd.read_csv('test-A/in.tsv', sep='\t', header=None, error_bad_lines=False)
|
||||
X_test = [str(item) for item in test_dataset.to_numpy()]
|
||||
|
||||
return vectorizer.transform(X_test)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
5445
test-A/out.tsv
Normal file
5445
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
98132
train/train.tsv/train.tsv
Normal file
98132
train/train.tsv/train.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user