praca-magisterska/project/train.py

95 lines
3.4 KiB
Python
Raw Permalink Normal View History

2019-10-24 14:01:43 +02:00
import os
2019-06-19 15:48:39 +02:00
import sys
import pickle
2019-10-24 14:01:43 +02:00
import keras
import argparse
2019-10-27 14:34:02 +01:00
import warnings
import pandas as pd
2019-10-24 14:01:43 +02:00
from model import Seq2SeqModel
2019-10-27 14:34:02 +01:00
from extract import make_folder_if_not_exist
2019-06-19 15:48:39 +02:00
2019-10-27 14:34:02 +01:00
# TODO:
# FIXME:
2019-06-19 15:48:39 +02:00
2019-10-27 14:34:02 +01:00
def parse_argv():
parser = argparse.ArgumentParser()
parser.add_argument('n', help='name for experiment', type=str)
parser.add_argument('--b', help='batch_size', type=int)
parser.add_argument('--l', help='latent_dim', type=int)
parser.add_argument('--e', help='epochs', type=int)
2019-10-30 16:21:36 +01:00
parser.add_argument('--ed', help='encoder dropout', type=float)
parser.add_argument('--dd', help='decoder dropout', type=float)
2019-10-27 14:34:02 +01:00
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
parser.add_argument('-r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
args = parser.parse_args()
return args
2019-05-28 12:40:26 +02:00
2019-10-27 14:34:02 +01:00
def load_workflow():
workflow_path = os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl')
if os.path.isfile(workflow_path):
model_workflow = pickle.load(open(workflow_path,'rb'))
else:
raise FileNotFoundError(f'There is no workflow.pkl file in trainig_sets/{EXPERIMENT_NAME}/ folder')
return model_workflow
2019-10-27 14:34:02 +01:00
def train_models(model_workflow):
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
# make_folder_if_not_exist(os.mkdir(os.path.join('models',EXPERIMENT_NAME)))
found = False
for instrument in instruments:
2019-10-30 16:21:36 +01:00
if not INSTRUMENT or INSTRUMENT == instrument:
2019-10-27 14:34:02 +01:00
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
history_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_history.csv')
2019-10-30 16:21:36 +01:00
x_train, y_train, _, bars_in_seq = pickle.load(open(data_path,'rb'))
2019-10-27 14:34:02 +01:00
if os.path.isfile(model_path) and not RESET:
2019-10-30 16:21:36 +01:00
model = Seq2SeqModel(x_train, y_train)
2019-10-27 14:34:02 +01:00
model.load(model_path)
2019-10-30 16:21:36 +01:00
else:
model = Seq2SeqModel(x_train, y_train, LATENT_DIM, ENCODER_DROPOUT, DECODER_DROPOUT, bars_in_seq)
2019-10-27 14:34:02 +01:00
print(f'Training: {instrument}')
history = model.fit(BATCH_SIZE, EPOCHS, callbacks=[])
2019-10-30 16:21:36 +01:00
make_folder_if_not_exist(os.path.join('models', EXPERIMENT_NAME))
pd.DataFrame(history.history).to_csv(history_path, mode='a', header=False)
2019-10-27 14:34:02 +01:00
model.save(model_path)
found = True
2019-10-27 14:34:02 +01:00
if not found:
raise ValueError(f'Instrument not found. Use one of the {instruments}')
if __name__ == '__main__':
2019-10-27 14:34:02 +01:00
warnings.filterwarnings("ignore")
args = parse_argv()
EXPERIMENT_NAME = args.n
BATCH_SIZE = args.b
LATENT_DIM = args.l
EPOCHS = args.e
RESET = args.r
INSTRUMENT = args.i
2019-10-30 16:21:36 +01:00
ENCODER_DROPOUT = args.ed
DECODER_DROPOUT = args.dd
2019-10-27 14:34:02 +01:00
# default settings if not args passed
if not BATCH_SIZE:
BATCH_SIZE = 32
if not LATENT_DIM:
LATENT_DIM = 256
if not EPOCHS:
EPOCHS = 1
if not RESET:
RESET = False
2019-10-30 16:21:36 +01:00
if not ENCODER_DROPOUT:
ENCODER_DROPOUT = 0.0
if not DECODER_DROPOUT:
DECODER_DROPOUT = 0.0
2019-10-27 14:34:02 +01:00
train_models(load_workflow())