aitech-eks-pub/wyk/pytorch_regression/analyzer_classification.py

42 lines
1.0 KiB
Python
Raw Normal View History

2021-05-05 13:35:25 +02:00
import regex as re
from sklearn.feature_extraction.text import HashingVectorizer
import torch
token_root_len = 7
class Analyzer(object):
def __init__(self):
self.token_pat = re.compile(r'(?:\p{L}|\d)+')
def __call__(self, doc):
return [tok[0:token_root_len] for tok in self.token_pat.findall(doc)]
# hiperparametr - liczba bitów hasza
vector_length = 2**18
vectorizer = HashingVectorizer(n_features=vector_length, analyzer=Analyzer())
def vectorize_text(content):
# musimy przekonwertować macierz sklearn => macierz numpy => tensor pytorcha
return (torch.from_numpy(vectorizer.fit_transform([content]).toarray()))[0]
def vectorize_batch(contents):
# musimy przekonwertować macierz sklearn => macierz numpy => tensor pytorcha
return (torch.from_numpy(vectorizer.fit_transform(contents).toarray()))
def process_line(line):
fields = line.strip('\n').split('\t')
label, content = fields
# normalizujemy lata do wartości (-1,1)
y = float(label)
return (content, torch.tensor(y))