Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
e60a4b57db | ||
|
c98d437e47 |
5272
dev-0/out.tsv
Normal file
5272
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
46
main.py
Normal file
46
main.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import lzma
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.naive_bayes import BernoulliNB
|
||||||
|
import numpy as np
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
def readInput(dir):
|
||||||
|
X = []
|
||||||
|
if 'xz' in dir:
|
||||||
|
with lzma.open(dir) as f:
|
||||||
|
for line in f:
|
||||||
|
X.append(line.decode('utf-8'))
|
||||||
|
else:
|
||||||
|
with open(dir) as f:
|
||||||
|
for line in f:
|
||||||
|
X. append(line.replace('\n',''))
|
||||||
|
return X
|
||||||
|
|
||||||
|
def writeOutput(output, dir):
|
||||||
|
with open(dir, 'w', newline='') as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerows(output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# train
|
||||||
|
trainX = readInput('train/in.tsv.xz')
|
||||||
|
trainY = readInput('train/expected.tsv')
|
||||||
|
vectorizer = TfidfVectorizer()
|
||||||
|
trainX = vectorizer.fit_transform(trainX)
|
||||||
|
trainY = np.array(trainY)
|
||||||
|
bernoulli = BernoulliNB()
|
||||||
|
bernoulli.fit(trainX, trainY)
|
||||||
|
|
||||||
|
# dev-0
|
||||||
|
devX = readInput('dev-0/in.tsv.xz')
|
||||||
|
devX = vectorizer.transform(devX)
|
||||||
|
devPredicted = bernoulli.predict(devX)
|
||||||
|
writeOutput(devPredicted, 'dev-0/out.tsv')
|
||||||
|
|
||||||
|
# test-A
|
||||||
|
testX = readInput('test-A/in.tsv.xz')
|
||||||
|
testX = vectorizer.transform(testX)
|
||||||
|
testPredicted = bernoulli.predict(testX)
|
||||||
|
writeOutput(testPredicted, 'test-A/out.tsv')
|
5152
test-A/out.tsv
Normal file
5152
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user