nn Bigram
This commit is contained in:
parent
c1e6d53513
commit
f9a0b05308
@ -6,7 +6,7 @@ import pandas as pd
|
|||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
|
||||||
from utils import read_csv, clean_text, get_words_from_line
|
from utils import read_csv, clean_text, get_words_from_line
|
||||||
from nn import Trigrams, Model
|
from nn import Bigrams, Model
|
||||||
|
|
||||||
data = read_csv("train/in.tsv.xz")
|
data = read_csv("train/in.tsv.xz")
|
||||||
train_words = read_csv("train/expected.tsv")
|
train_words = read_csv("train/expected.tsv")
|
||||||
@ -19,7 +19,7 @@ train_data = train_data.apply(clean_text)
|
|||||||
vocab_size = 30000
|
vocab_size = 30000
|
||||||
embed_size = 150
|
embed_size = 150
|
||||||
|
|
||||||
train_dataset = Trigrams(train_data, vocab_size)
|
train_dataset = Bigrams(train_data, vocab_size)
|
||||||
|
|
||||||
##################################################################################
|
##################################################################################
|
||||||
|
|
2
nn.py
2
nn.py
@ -4,7 +4,7 @@ from torchtext.vocab import build_vocab_from_iterator
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
class Trigrams(torch.utils.data.IterableDataset):
|
class Bigrams(torch.utils.data.IterableDataset):
|
||||||
def __init__(self, data, vocabulary_size):
|
def __init__(self, data, vocabulary_size):
|
||||||
self.vocab = build_vocab_from_iterator(
|
self.vocab = build_vocab_from_iterator(
|
||||||
get_word_lines_from_data(data),
|
get_word_lines_from_data(data),
|
||||||
|
Loading…
Reference in New Issue
Block a user