guess-reddit-date-sumo/train_lrtfidf.py

30 lines
837 B
Python

#!/usr/bin/python3
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LinearRegression
import numpy as np
import csv
import pandas as pd
import pickle
def train():
train = pd.read_csv("train/in_new.tsv", delimiter="\t", header=None, names=["text"], quoting=csv.QUOTE_NONE)
text = train["text"][:2000000]
y = pd.read_csv("train/expected.tsv", header=None)
y = y[:2000000]
print(y)
vect = TfidfVectorizer(stop_words='english', ngram_range=(1, 1))
x = vect.fit_transform(text)
pca = TruncatedSVD(n_components=120)
x_pca = pca.fit_transform(x)
reg = LinearRegression()
reg.fit(x_pca,y)
pickle.dump(reg, open("clf.model", "wb"))
pickle.dump(vect, open("vectorizer.model", "wb"))
train()