nn Bigram
This commit is contained in:
parent
c1e6d53513
commit
f9a0b05308
@ -6,7 +6,7 @@ import pandas as pd
|
||||
from os.path import exists
|
||||
|
||||
from utils import read_csv, clean_text, get_words_from_line
|
||||
from nn import Trigrams, Model
|
||||
from nn import Bigrams, Model
|
||||
|
||||
data = read_csv("train/in.tsv.xz")
|
||||
train_words = read_csv("train/expected.tsv")
|
||||
@ -19,7 +19,7 @@ train_data = train_data.apply(clean_text)
|
||||
vocab_size = 30000
|
||||
embed_size = 150
|
||||
|
||||
train_dataset = Trigrams(train_data, vocab_size)
|
||||
train_dataset = Bigrams(train_data, vocab_size)
|
||||
|
||||
##################################################################################
|
||||
|
2
nn.py
2
nn.py
@ -4,7 +4,7 @@ from torchtext.vocab import build_vocab_from_iterator
|
||||
import itertools
|
||||
|
||||
|
||||
class Trigrams(torch.utils.data.IterableDataset):
|
||||
class Bigrams(torch.utils.data.IterableDataset):
|
||||
def __init__(self, data, vocabulary_size):
|
||||
self.vocab = build_vocab_from_iterator(
|
||||
get_word_lines_from_data(data),
|
||||
|
Loading…
Reference in New Issue
Block a user