2019-10-24 14:01:43 +02:00
|
|
|
import os
|
2019-06-19 15:48:39 +02:00
|
|
|
import sys
|
|
|
|
import pickle
|
2019-10-24 14:01:43 +02:00
|
|
|
import keras
|
|
|
|
import argparse
|
2019-10-27 14:34:02 +01:00
|
|
|
import warnings
|
2019-10-24 14:01:43 +02:00
|
|
|
from model import Seq2SeqModel
|
2019-10-27 14:34:02 +01:00
|
|
|
from extract import make_folder_if_not_exist
|
2019-06-19 15:48:39 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
# TODO:
|
|
|
|
# FIXME:
|
2019-06-19 15:48:39 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
def parse_argv():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('n', help='name for experiment', type=str)
|
|
|
|
parser.add_argument('--b', help='batch_size', type=int)
|
|
|
|
parser.add_argument('--l', help='latent_dim', type=int)
|
|
|
|
parser.add_argument('--e', help='epochs', type=int)
|
|
|
|
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
|
|
|
|
parser.add_argument('-r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
2019-05-28 12:40:26 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
def load_workflow():
|
|
|
|
workflow_path = os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl')
|
|
|
|
if os.path.isfile(workflow_path):
|
|
|
|
model_workflow = pickle.load(open(workflow_path,'rb'))
|
|
|
|
else:
|
|
|
|
raise FileNotFoundError(f'There is no workflow.pkl file in trainig_sets/{EXPERIMENT_NAME}/ folder')
|
|
|
|
return model_workflow
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
def train_models(model_workflow):
|
|
|
|
|
|
|
|
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
|
|
|
|
# make_folder_if_not_exist(os.mkdir(os.path.join('models',EXPERIMENT_NAME)))
|
|
|
|
|
|
|
|
found = False
|
|
|
|
for instrument in instruments:
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
if INSTRUMENT == None or INSTRUMENT == instrument:
|
|
|
|
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
|
|
|
|
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
x_train, y_train, _ = pickle.load(open(data_path,'rb'))
|
|
|
|
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
|
|
|
|
if os.path.isfile(model_path) and not RESET:
|
|
|
|
model.load(model_path)
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
print(f'Training: {instrument}')
|
|
|
|
model.fit(BATCH_SIZE, EPOCHS, callbacks=[])
|
|
|
|
model.save(model_path)
|
|
|
|
found = True
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
if not found:
|
|
|
|
raise ValueError(f'Instrument not found. Use one of the {instruments}')
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2019-06-01 17:05:38 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
args = parse_argv()
|
|
|
|
|
|
|
|
EXPERIMENT_NAME = args.n
|
|
|
|
BATCH_SIZE = args.b
|
|
|
|
LATENT_DIM = args.l
|
|
|
|
EPOCHS = args.e
|
|
|
|
RESET = args.r
|
|
|
|
INSTRUMENT = args.i
|
2019-05-30 11:23:34 +02:00
|
|
|
|
2019-10-27 14:34:02 +01:00
|
|
|
# default settings if not args passed
|
|
|
|
if not BATCH_SIZE:
|
|
|
|
BATCH_SIZE = 32
|
|
|
|
if not LATENT_DIM:
|
|
|
|
LATENT_DIM = 256
|
|
|
|
if not EPOCHS:
|
|
|
|
EPOCHS = 1
|
|
|
|
if not RESET:
|
|
|
|
RESET = False
|
|
|
|
|
|
|
|
train_models(load_workflow())
|