diff --git a/src/Model.py b/src/Model.py index a2c6010..632ec4b 100644 --- a/src/Model.py +++ b/src/Model.py @@ -1,7 +1,7 @@ import torch class NgramModel(torch.nn.Module): - def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size): + def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001): super().__init__() self.drop_prob = drop_prob diff --git a/src/train.py b/src/train.py index c98f796..93f8116 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +print("Imports") + import argparse import re import pickle @@ -23,7 +25,6 @@ def read_clear_data(in_file): return texts def create_ngrams(string, ngram_len=2): - print("Creating ngrams") n_grams = [] if len(string.split()) > ngram_len: for i in range(ngram_len, len(string.split())): @@ -35,8 +36,12 @@ def create_ngrams(string, ngram_len=2): def get_ngrams(data, ngram_len=2): print("Creating ngrams") n_grams = [] + counter = 0 for string in data: - n_grams.append(create_ngrams(sring)) + n_grams.append(create_ngrams(string)) + counter += 1 + percentage = round((counter/len(data))*100, 2) + print(f"Status: {percentage}%", end='\r') n_grams = sum(n_grams, []) print("Created ngrams") @@ -48,6 +53,7 @@ def segment_data(n_grams): target = [] for string in n_grams: + # tutaj brac pod uwage jescze follow slowa source.append(" ".join(string.split()[:-1])) target.append(" ".join(string.split()[1:])) @@ -62,6 +68,8 @@ def create_vocab(data): for word in set(" ".join(data).split()): vocab[counter] = word counter += 1 + percentage = round((counter/len(data))*100, 2) + print(f"Status: {percentage}%", end='\r') vocab = {t:i for i,t in vocab.items()} print("Vocab created") @@ -130,7 +138,7 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, def main(): parser = argparse.ArgumentParser() - parser.add_argument('--in_file', '') + parser.add_argument('--in_file') parser.add_argument('--ngram_level', default=2, help="Level of ngram") parser.add_argument('--ngrams', help="Path to pickle with ready bigrams") parser.add_argument('--vocab') @@ -159,6 +167,7 @@ def main(): print("Vocab read") else: vocab = create_vocab(data) + print(f"Vocab size: {len(vocab)}") source_int, target_int = segment_with_vocab(vocab, target, source) print("Saving progress") with open(f"vocab-seed_{seed}.pickle", 'wb+') as f: