Fix a little

This commit is contained in:
SzamanFL 2021-01-08 18:12:58 +01:00
parent 0db9211817
commit bdc1e902e8
2 changed files with 13 additions and 4 deletions

View File

@ -1,7 +1,7 @@
import torch
class NgramModel(torch.nn.Module):
def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size):
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
super().__init__()
self.drop_prob = drop_prob

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python
print("Imports")
import argparse
import re
import pickle
@ -23,7 +25,6 @@ def read_clear_data(in_file):
return texts
def create_ngrams(string, ngram_len=2):
print("Creating ngrams")
n_grams = []
if len(string.split()) > ngram_len:
for i in range(ngram_len, len(string.split())):
@ -35,8 +36,12 @@ def create_ngrams(string, ngram_len=2):
def get_ngrams(data, ngram_len=2):
print("Creating ngrams")
n_grams = []
counter = 0
for string in data:
n_grams.append(create_ngrams(sring))
n_grams.append(create_ngrams(string))
counter += 1
percentage = round((counter/len(data))*100, 2)
print(f"Status: {percentage}%", end='\r')
n_grams = sum(n_grams, [])
print("Created ngrams")
@ -48,6 +53,7 @@ def segment_data(n_grams):
target = []
for string in n_grams:
# tutaj brac pod uwage jescze follow slowa
source.append(" ".join(string.split()[:-1]))
target.append(" ".join(string.split()[1:]))
@ -62,6 +68,8 @@ def create_vocab(data):
for word in set(" ".join(data).split()):
vocab[counter] = word
counter += 1
percentage = round((counter/len(data))*100, 2)
print(f"Status: {percentage}%", end='\r')
vocab = {t:i for i,t in vocab.items()}
print("Vocab created")
@ -130,7 +138,7 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in_file', '')
parser.add_argument('--in_file')
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
parser.add_argument('--vocab')
@ -159,6 +167,7 @@ def main():
print("Vocab read")
else:
vocab = create_vocab(data)
print(f"Vocab size: {len(vocab)}")
source_int, target_int = segment_with_vocab(vocab, target, source)
print("Saving progress")
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f: