paranormal-or-skeptic3/predict.py

32 lines
694 B
Python
Raw Normal View History

2020-12-15 16:40:10 +01:00
import pickle
import sys
import torch
from tokenizator import tokenize
def get_x(line, weights, mapping):
terms = tokenize(line)
x = len(weights) * [0.]
x[len(weights) - 1] = 1
for word in terms:
if word in mapping:
x[mapping[word]] += 1
return torch.tensor(x, dtype=torch.float)
def main():
w, word_to_index_mapping = pickle.load(open('model.pkl', 'rb'))
for line in sys.stdin:
line = line.strip()
x = get_x(line, w, word_to_index_mapping)
y = torch.sigmoid(x @ w)
if y > 0.85:
y = torch.tensor([0.85])
elif y < 0.15:
y = torch.tensor([0.15])
print(y.item())
main()