Compare commits

...

2 Commits

Author SHA1 Message Date
205338aa25 Add predictions 2022-05-10 23:56:28 +02:00
5bb9042abf Add initial implementation 2022-05-10 23:54:10 +02:00
3 changed files with 10482 additions and 0 deletions

5272
dev-0/out.tsv Normal file

File diff suppressed because it is too large Load Diff

58
run.py Normal file
View File

@ -0,0 +1,58 @@
import lzma
from naivebayes import NaiveBayesTextClassifier
from spacy.lang.en.stop_words import STOP_WORDS as en_stop
def get_data(fname):
with open(fname, 'r', encoding='utf8') as file:
return file.readlines()
def get_data_zipped(fname):
with lzma.open(fname, 'r') as file:
return file.readlines()
def train_bayes(model, x, y, step=15000):
start = 0
end = step
for _ in range(0, len(x), step):
model.train(x[start:end], y[start:end])
if start + step < len(x):
start += step
else:
start = 0
end = min(start + step, len(x))
def write_file(fname, data):
with open(fname, 'wt') as f:
for d in data:
f.write(f'{str(d)}\n')
def main():
train_x = get_data_zipped('train/in.tsv.xz')
train_y = get_data('train/expected.tsv')
# preprocessing
train_y = [int(y) for y in train_y]
test_x = get_data_zipped('test-A/in.tsv.xz')
dev_x = get_data_zipped('dev-0/in.tsv.xz')
model = NaiveBayesTextClassifier(
categories=[0, 1],
stop_words=en_stop
)
train_bayes(model, train_x, train_y)
predicted = model.classify(dev_x)
predicted_2= model.classify(test_x)
write_file('dev-0/out.tsv', predicted)
write_file('test-A/out.tsv', predicted_2)
main()

5152
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff