import os import sys import pickle import keras import argparse from model import Seq2SeqModel parser = argparse.ArgumentParser() parser.add_argument('n', help='name for experiment', type=str) parser.add_argument('--b', help='batch_size', type=int) parser.add_argument('--l', help='latent_dim', type=int) parser.add_argument('--e', help='epochs', type=int) parser.add_argument('--r', help='reset, use when you want to reset waights and train from scratch', action='store_true') parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument') args = parser.parse_args() '''HYPER PARAMETERS''' EXPERIMENT_NAME = args.n BATCH_SIZE = args.b LATENT_DIM = args.l EPOCHS = args.e RESET = args.r INSTRUMENT = args.i if BATCH_SIZE == None: BATCH_SIZE = 32 if LATENT_DIM == None: LATENT_DIM = 256 if EPOCHS == None: EPOCHS = 1 if RESET == None: RESET = False ## TODO: raise error if file not found model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb')) tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True) instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()] # make folder for new experiment try: os.mkdir(os.path.join('models',EXPERIMENT_NAME)) except: pass # init models found = False for instrument in instruments: if INSTRUMENT == None or INSTRUMENT == instrument: data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl') model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5') x_train, y_train, _ = pickle.load(open(data_path,'rb')) model = Seq2SeqModel(LATENT_DIM, x_train, y_train) if os.path.isfile(model_path) and not RESET: model.load(model_path) print(f'Training: {instrument}') train_history = model.fit(BATCH_SIZE, EPOCHS, callbacks=[tbCallBack]) model.save(model_path) found = True if not found: raise ValueError(f'Instrument not found. Use one of the {instruments}')