37 lines
966 B
Python
37 lines
966 B
Python
|
import pickle
|
||
|
import sys
|
||
|
import math
|
||
|
import fileinput
|
||
|
|
||
|
model = pickle.load(open("model.pkl", "rb"))
|
||
|
word_index, vocabulary, weights, words_count = model
|
||
|
|
||
|
def predict():
|
||
|
output = []
|
||
|
for line in fileinput.input(openhook=fileinput.hook_encoded("utf-8")):
|
||
|
line = line.rstrip()
|
||
|
fields = line.split('\t')
|
||
|
label = fields[0].strip()
|
||
|
document = fields[0]
|
||
|
terms = document.split(' ')
|
||
|
for term in terms:
|
||
|
if term in words_count:
|
||
|
words_count[term] += 1
|
||
|
else:
|
||
|
words_count[term] = 1
|
||
|
expected = weights[0]
|
||
|
for t in terms:
|
||
|
if t in vocabulary:
|
||
|
expected +=(words_count[t]/len(words_count)*(weights[word_index[t]]))
|
||
|
if expected > 0.9:
|
||
|
output.append(1)
|
||
|
else:
|
||
|
output.append(0)
|
||
|
|
||
|
with open("out.tsv", "w") as out:
|
||
|
for val in output:
|
||
|
out.write(str(val)+"\n")
|
||
|
|
||
|
predict()
|
||
|
|