Fix a little
This commit is contained in:
parent
0db9211817
commit
bdc1e902e8
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class NgramModel(torch.nn.Module):
|
class NgramModel(torch.nn.Module):
|
||||||
def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size):
|
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
15
src/train.py
15
src/train.py
@ -1,5 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
print("Imports")
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import pickle
|
import pickle
|
||||||
@ -23,7 +25,6 @@ def read_clear_data(in_file):
|
|||||||
return texts
|
return texts
|
||||||
|
|
||||||
def create_ngrams(string, ngram_len=2):
|
def create_ngrams(string, ngram_len=2):
|
||||||
print("Creating ngrams")
|
|
||||||
n_grams = []
|
n_grams = []
|
||||||
if len(string.split()) > ngram_len:
|
if len(string.split()) > ngram_len:
|
||||||
for i in range(ngram_len, len(string.split())):
|
for i in range(ngram_len, len(string.split())):
|
||||||
@ -35,8 +36,12 @@ def create_ngrams(string, ngram_len=2):
|
|||||||
def get_ngrams(data, ngram_len=2):
|
def get_ngrams(data, ngram_len=2):
|
||||||
print("Creating ngrams")
|
print("Creating ngrams")
|
||||||
n_grams = []
|
n_grams = []
|
||||||
|
counter = 0
|
||||||
for string in data:
|
for string in data:
|
||||||
n_grams.append(create_ngrams(sring))
|
n_grams.append(create_ngrams(string))
|
||||||
|
counter += 1
|
||||||
|
percentage = round((counter/len(data))*100, 2)
|
||||||
|
print(f"Status: {percentage}%", end='\r')
|
||||||
|
|
||||||
n_grams = sum(n_grams, [])
|
n_grams = sum(n_grams, [])
|
||||||
print("Created ngrams")
|
print("Created ngrams")
|
||||||
@ -48,6 +53,7 @@ def segment_data(n_grams):
|
|||||||
target = []
|
target = []
|
||||||
|
|
||||||
for string in n_grams:
|
for string in n_grams:
|
||||||
|
# tutaj brac pod uwage jescze follow slowa
|
||||||
source.append(" ".join(string.split()[:-1]))
|
source.append(" ".join(string.split()[:-1]))
|
||||||
target.append(" ".join(string.split()[1:]))
|
target.append(" ".join(string.split()[1:]))
|
||||||
|
|
||||||
@ -62,6 +68,8 @@ def create_vocab(data):
|
|||||||
for word in set(" ".join(data).split()):
|
for word in set(" ".join(data).split()):
|
||||||
vocab[counter] = word
|
vocab[counter] = word
|
||||||
counter += 1
|
counter += 1
|
||||||
|
percentage = round((counter/len(data))*100, 2)
|
||||||
|
print(f"Status: {percentage}%", end='\r')
|
||||||
|
|
||||||
vocab = {t:i for i,t in vocab.items()}
|
vocab = {t:i for i,t in vocab.items()}
|
||||||
print("Vocab created")
|
print("Vocab created")
|
||||||
@ -130,7 +138,7 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--in_file', '')
|
parser.add_argument('--in_file')
|
||||||
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
||||||
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
||||||
parser.add_argument('--vocab')
|
parser.add_argument('--vocab')
|
||||||
@ -159,6 +167,7 @@ def main():
|
|||||||
print("Vocab read")
|
print("Vocab read")
|
||||||
else:
|
else:
|
||||||
vocab = create_vocab(data)
|
vocab = create_vocab(data)
|
||||||
|
print(f"Vocab size: {len(vocab)}")
|
||||||
source_int, target_int = segment_with_vocab(vocab, target, source)
|
source_int, target_int = segment_with_vocab(vocab, target, source)
|
||||||
print("Saving progress")
|
print("Saving progress")
|
||||||
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user