nn Bigram

This commit is contained in:
Dominik Strzałko 2022-05-09 10:51:57 +02:00
parent c1e6d53513
commit f9a0b05308
2 changed files with 3 additions and 3 deletions

View File

@ -6,7 +6,7 @@ import pandas as pd
from os.path import exists from os.path import exists
from utils import read_csv, clean_text, get_words_from_line from utils import read_csv, clean_text, get_words_from_line
from nn import Trigrams, Model from nn import Bigrams, Model
data = read_csv("train/in.tsv.xz") data = read_csv("train/in.tsv.xz")
train_words = read_csv("train/expected.tsv") train_words = read_csv("train/expected.tsv")
@ -19,7 +19,7 @@ train_data = train_data.apply(clean_text)
vocab_size = 30000 vocab_size = 30000
embed_size = 150 embed_size = 150
train_dataset = Trigrams(train_data, vocab_size) train_dataset = Bigrams(train_data, vocab_size)
################################################################################## ##################################################################################

2
nn.py
View File

@ -4,7 +4,7 @@ from torchtext.vocab import build_vocab_from_iterator
import itertools import itertools
class Trigrams(torch.utils.data.IterableDataset): class Bigrams(torch.utils.data.IterableDataset):
def __init__(self, data, vocabulary_size): def __init__(self, data, vocabulary_size):
self.vocab = build_vocab_from_iterator( self.vocab = build_vocab_from_iterator(
get_word_lines_from_data(data), get_word_lines_from_data(data),