Update run.py

This commit is contained in:
Iwona Christop 2022-04-25 01:30:11 +02:00
parent 5621e4ce9a
commit 61a9a4632a
2 changed files with 48 additions and 0 deletions

BIN
._main.ipynb Executable file

Binary file not shown.

48
run.py
View File

@ -0,0 +1,48 @@
import lzma
# import re
from sklearn.feature_extraction.text import CountVectorizer
import csv
from sklearn.naive_bayes import GaussianNB
# def get_str_cleaned(str_dirty):
# punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\\\]^_`{|}~'
# new_str = str_dirty.lower()
# new_str = re.sub(' +', ' ', new_str)
# for char in punctuation:
# new_str = new_str.replace(char, '')
# new_str = new_str.replace('\n', '')
# return new_str
# with open('train/expected.tsv') as f:
# trainY = list(csv.reader(f))
trainX = []
trainY = []
testX = []
testY = []
with lzma.open('train/in.tsv.xz') as f:
for line in f:
# X_train.append(get_str_cleaned(line.decode('utf-8')))
trainX.append(line.decode('utf-8'))
with open('train/expected.tsv') as f:
for line in f:
trainY.append(line)
vectorizer = CountVectorizer()
trainX = vectorizer.fit_transform(trainX)
model = GaussianNB()
model.fit(trainX, trainY)
with open('dev-0/in.tsv') as f:
for line in f:
testX.append(line.decode('utf-8'))
# testX = list(csv.reader(f))
predictedY = model.predict(testX)
print(predictedY)
# with open('dev-0/expected.tsv') as f:
# expectedY = list(csv.reader(f))