Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
e60a4b57db | ||
|
c98d437e47 |
5272
dev-0/out.tsv
Normal file
5272
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
46
main.py
Normal file
46
main.py
Normal file
@ -0,0 +1,46 @@
|
||||
import lzma
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.naive_bayes import BernoulliNB
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
|
||||
def readInput(dir):
|
||||
X = []
|
||||
if 'xz' in dir:
|
||||
with lzma.open(dir) as f:
|
||||
for line in f:
|
||||
X.append(line.decode('utf-8'))
|
||||
else:
|
||||
with open(dir) as f:
|
||||
for line in f:
|
||||
X. append(line.replace('\n',''))
|
||||
return X
|
||||
|
||||
def writeOutput(output, dir):
|
||||
with open(dir, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# train
|
||||
trainX = readInput('train/in.tsv.xz')
|
||||
trainY = readInput('train/expected.tsv')
|
||||
vectorizer = TfidfVectorizer()
|
||||
trainX = vectorizer.fit_transform(trainX)
|
||||
trainY = np.array(trainY)
|
||||
bernoulli = BernoulliNB()
|
||||
bernoulli.fit(trainX, trainY)
|
||||
|
||||
# dev-0
|
||||
devX = readInput('dev-0/in.tsv.xz')
|
||||
devX = vectorizer.transform(devX)
|
||||
devPredicted = bernoulli.predict(devX)
|
||||
writeOutput(devPredicted, 'dev-0/out.tsv')
|
||||
|
||||
# test-A
|
||||
testX = readInput('test-A/in.tsv.xz')
|
||||
testX = vectorizer.transform(testX)
|
||||
testPredicted = bernoulli.predict(testX)
|
||||
writeOutput(testPredicted, 'test-A/out.tsv')
|
5152
test-A/out.tsv
Normal file
5152
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user