praca-magisterska/project/train.py

40 lines
1.2 KiB
Python
Raw Normal View History

2019-05-29 10:37:29 +02:00
#!/usr/bin/env python3
2019-05-28 12:40:26 +02:00
import tensorflow as tf
import settings
from tensorflow.keras import layers
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector, Activation, Bidirectional, Reshape
2019-05-28 12:40:26 +02:00
from keras.models import Model, Sequential
import numpy as np
import sys
2019-05-28 12:40:26 +02:00
import pickle
train_data_path = sys.argv[1]
save_model_path = sys.argv[2]
epochs = int(sys.argv[3])
model = Sequential()
model.add(LSTM(128,input_shape=(96, 128),return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(512, return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(128))
model.add(Dense(128))
model.add(Dropout(0.3))
model.add(Dense(128*96))
model.add(Activation('softmax'))
model.add(Reshape((96, 128)))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# load training data
print('Traing Samples: {}'.format(train_data_path))
train_X = np.load(train_data_path)['arr_0']
# model training
model.fit(train_X, train_X, epochs=epochs, batch_size=32)
# save trained model
pickle_path = '{}.pickle'.format(save_model_path)
pickle.dump(model, open(pickle_path,'wb'))
print("Model save to {}".format(pickle_path))