praca-magisterska/project/train.py
Cezary Pukownik [Impakt S.A.] 86b0540feb init commit
2019-10-24 14:01:43 +02:00

66 lines
2.2 KiB
Python

import os
import sys
import pickle
import keras
import argparse
from model import Seq2SeqModel
parser = argparse.ArgumentParser()
parser.add_argument('n', help='name for experiment', type=str)
parser.add_argument('--b', help='batch_size', type=int)
parser.add_argument('--l', help='latent_dim', type=int)
parser.add_argument('--e', help='epochs', type=int)
parser.add_argument('--r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
args = parser.parse_args()
'''HYPER PARAMETERS'''
EXPERIMENT_NAME = args.n
BATCH_SIZE = args.b
LATENT_DIM = args.l
EPOCHS = args.e
RESET = args.r
INSTRUMENT = args.i
if BATCH_SIZE == None:
BATCH_SIZE = 32
if LATENT_DIM == None:
LATENT_DIM = 256
if EPOCHS == None:
EPOCHS = 1
if RESET == None:
RESET = False
## TODO: raise error if file not found
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
# make folder for new experiment
try:
os.mkdir(os.path.join('models',EXPERIMENT_NAME))
except:
pass
# init models
found = False
for instrument in instruments:
if INSTRUMENT == None or INSTRUMENT == instrument:
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
x_train, y_train, _ = pickle.load(open(data_path,'rb'))
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
if os.path.isfile(model_path) and not RESET:
model.load(model_path)
print(f'Training: {instrument}')
train_history = model.fit(BATCH_SIZE, EPOCHS, callbacks=[tbCallBack])
model.save(model_path)
found = True
if not found:
raise ValueError(f'Instrument not found. Use one of the {instruments}')