GuessRedditDateSumo/predict.py

40 lines
1.1 KiB
Python
Raw Normal View History

2020-05-05 14:17:49 +02:00
import pickle
import numpy as np
from sklearn.decomposition import PCA
from linear_regression import create_dictionary
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
def predict():
input_file = open("l_regression.pkl",'rb')
l_regression = pickle.load(input_file)
input_file = open("tfidf_model.pkl",'rb')
tfidf = pickle.load(input_file)
dev0 = create_dictionary("dev-0/in.tsv")
testA = create_dictionary("test-A/in.tsv")
dev0_vector = tfidf.fit_transform(dev0)
testA_vector = tfidf.fit_transform(testA)
#print(testA_vector)
2020-05-05 14:52:12 +02:00
pca = TruncatedSVD(n_components=300)
2020-05-05 14:17:49 +02:00
dev0_pca = pca.fit_transform(dev0_vector)
testA_pca = pca.fit_transform(testA_vector)
output= open("dev-0/out.tsv","w+",encoding="UTF-8")
y_dev = l_regression.predict(dev0_pca)
print(y_dev)
foo = np.array(y_dev)
print(foo)
np.savetxt(output,foo)
output = open("test-A/out.tsv", "w+", encoding="UTF-8")
y_test = l_regression.predict(testA_pca)
foo = np.array(y_test)
np.savetxt(output,foo)
#print(y_test)
# dev0_vectorizer =
predict()