Fix a little
This commit is contained in:
parent
0db9211817
commit
bdc1e902e8
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
class NgramModel(torch.nn.Module):
|
||||
def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size):
|
||||
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
|
||||
super().__init__()
|
||||
|
||||
self.drop_prob = drop_prob
|
||||
|
15
src/train.py
15
src/train.py
@ -1,5 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
print("Imports")
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import pickle
|
||||
@ -23,7 +25,6 @@ def read_clear_data(in_file):
|
||||
return texts
|
||||
|
||||
def create_ngrams(string, ngram_len=2):
|
||||
print("Creating ngrams")
|
||||
n_grams = []
|
||||
if len(string.split()) > ngram_len:
|
||||
for i in range(ngram_len, len(string.split())):
|
||||
@ -35,8 +36,12 @@ def create_ngrams(string, ngram_len=2):
|
||||
def get_ngrams(data, ngram_len=2):
|
||||
print("Creating ngrams")
|
||||
n_grams = []
|
||||
counter = 0
|
||||
for string in data:
|
||||
n_grams.append(create_ngrams(sring))
|
||||
n_grams.append(create_ngrams(string))
|
||||
counter += 1
|
||||
percentage = round((counter/len(data))*100, 2)
|
||||
print(f"Status: {percentage}%", end='\r')
|
||||
|
||||
n_grams = sum(n_grams, [])
|
||||
print("Created ngrams")
|
||||
@ -48,6 +53,7 @@ def segment_data(n_grams):
|
||||
target = []
|
||||
|
||||
for string in n_grams:
|
||||
# tutaj brac pod uwage jescze follow slowa
|
||||
source.append(" ".join(string.split()[:-1]))
|
||||
target.append(" ".join(string.split()[1:]))
|
||||
|
||||
@ -62,6 +68,8 @@ def create_vocab(data):
|
||||
for word in set(" ".join(data).split()):
|
||||
vocab[counter] = word
|
||||
counter += 1
|
||||
percentage = round((counter/len(data))*100, 2)
|
||||
print(f"Status: {percentage}%", end='\r')
|
||||
|
||||
vocab = {t:i for i,t in vocab.items()}
|
||||
print("Vocab created")
|
||||
@ -130,7 +138,7 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--in_file', '')
|
||||
parser.add_argument('--in_file')
|
||||
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
||||
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
||||
parser.add_argument('--vocab')
|
||||
@ -159,6 +167,7 @@ def main():
|
||||
print("Vocab read")
|
||||
else:
|
||||
vocab = create_vocab(data)
|
||||
print(f"Vocab size: {len(vocab)}")
|
||||
source_int, target_int = segment_with_vocab(vocab, target, source)
|
||||
print("Saving progress")
|
||||
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
||||
|
Loading…
Reference in New Issue
Block a user