code clarity
This commit is contained in:
parent
eaaad76a2a
commit
4d9ee930a1
@ -5,54 +5,70 @@ import pickle
|
|||||||
|
|
||||||
from midi_processing import extract_data, analyze_data
|
from midi_processing import extract_data, analyze_data
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
def make_folder_if_not_exist(path):
|
||||||
parser.add_argument('midi_pack', help='folder name for midi pack in midi_packs folder', type=str)
|
try:
|
||||||
parser.add_argument('name', help='name for experiment', type=str)
|
os.mkdir(path)
|
||||||
parser.add_argument('--b', help='lengh of sequence in bars', type=int)
|
except:
|
||||||
parser.add_argument('-a', help='analize data', action='store_true')
|
pass
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
'''SETTINGS'''
|
def parse_argv():
|
||||||
MIDI_PACK_NAME = args.midi_pack
|
parser = argparse.ArgumentParser()
|
||||||
EXPERIMENT_NAME = args.name
|
parser.add_argument('midi_pack', help='folder name for midi pack in midi_packs folder', type=str)
|
||||||
BARS_IN_SEQ = args.b
|
parser.add_argument('--n', help='name for experiment', type=str)
|
||||||
|
parser.add_argument('--b', help='lengh of sequence in bars', type=int)
|
||||||
|
parser.add_argument('-a', help='analize data', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
midi_folder_path = os.path.join('midi_packs', MIDI_PACK_NAME)
|
def ask_for_workflow():
|
||||||
|
'''MODEL WORKFLOW DIALOG'''
|
||||||
|
number_of_instruments = int(input('Please specify number of instruments\n'))
|
||||||
|
model_workflow = dict()
|
||||||
|
for i in range(number_of_instruments):
|
||||||
|
input_string = input('Please specify a workflow step <Instrument> [<Second Instrument>] <mode> {m : melody, a : arrangment}\n')
|
||||||
|
tokens = input_string.split()
|
||||||
|
if tokens[-1] == 'm':
|
||||||
|
model_workflow[i] = (tokens[0], 'melody')
|
||||||
|
elif tokens[-1] == 'a':
|
||||||
|
model_workflow[i] = ((tokens[1], tokens[0]), 'arrangment')
|
||||||
|
else:
|
||||||
|
raise ValueError("The step definitiom must end with {'m', 'a'}");
|
||||||
|
|
||||||
|
make_folder_if_not_exist(os.path.join('training_sets', EXPERIMENT_NAME))
|
||||||
|
pickle.dump(model_workflow, open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'wb'))
|
||||||
|
|
||||||
|
return model_workflow
|
||||||
|
|
||||||
# analyze data set for intresting intruments
|
def extract_from_folder(model_workflow):
|
||||||
if args.a:
|
for key, (instrument, how) in model_workflow.items():
|
||||||
analyze_data(midi_folder_path)
|
if how == 'melody':
|
||||||
sys.exit()
|
instrument_name = instrument
|
||||||
|
else:
|
||||||
|
instrument_name = instrument[1]
|
||||||
|
|
||||||
|
make_folder_if_not_exist(os.path.join('training_sets', EXPERIMENT_NAME))
|
||||||
|
save_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument_name.lower() + '_data.pkl')
|
||||||
|
|
||||||
|
x_train, y_train, program = extract_data(midi_folder_path=os.path.join('midi_packs', MIDI_PACK_NAME),
|
||||||
|
how=how,
|
||||||
|
instrument=instrument,
|
||||||
|
bar_in_seq=BARS_IN_SEQ)
|
||||||
|
|
||||||
|
pickle.dump((x_train, y_train, program), open(save_path,'wb'))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_argv()
|
||||||
|
|
||||||
'''MODEL WORKFLOW DIALOG'''
|
MIDI_PACK_NAME = args.midi_pack
|
||||||
number_of_instruments = int(input('Please specify number of instruments\n'))
|
EXPERIMENT_NAME = args.n
|
||||||
model_workflow = dict()
|
BARS_IN_SEQ = args.b
|
||||||
input_list = []
|
if not EXPERIMENT_NAME:
|
||||||
for i in range(number_of_instruments):
|
EXPERIMENT_NAME = MIDI_PACK_NAME
|
||||||
input_string = input('Please specify a workflow step\n')
|
if not BARS_IN_SEQ:
|
||||||
tokens = input_string.split()
|
BARS_IN_SEQ = 4
|
||||||
if tokens[-1] == 'melody':
|
ANALIZE = args.a
|
||||||
model_workflow[i] = (tokens[0], tokens[1])
|
|
||||||
|
if ANALIZE:
|
||||||
|
analyze_data(os.path.join('midi_packs', MIDI_PACK_NAME))
|
||||||
else:
|
else:
|
||||||
model_workflow[i] = ((tokens[1], tokens[0]), tokens[2])
|
extract_from_folder(ask_for_workflow())
|
||||||
|
|
||||||
# make folder for new experiment if no exist
|
|
||||||
try:
|
|
||||||
os.mkdir(os.path.join('training_sets', EXPERIMENT_NAME))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# extract process
|
|
||||||
for key, (instrument, how) in model_workflow.items():
|
|
||||||
if how == 'melody':
|
|
||||||
instrument_name = instrument
|
|
||||||
else:
|
|
||||||
instrument_name = instrument[1]
|
|
||||||
|
|
||||||
save_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument_name.lower() + '_data.pkl')
|
|
||||||
x_train, y_train, program = extract_data(midi_folder_path=midi_folder_path, how=how,
|
|
||||||
instrument=instrument, bar_in_seq=BARS_IN_SEQ)
|
|
||||||
pickle.dump((x_train, y_train, program), open(save_path,'wb'))
|
|
||||||
|
|
||||||
pickle.dump(model_workflow, open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'wb'))
|
|
@ -7,19 +7,27 @@ import pickle
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('n', help='name for experiment', type=str)
|
parser.add_argument('n', help='name for experiment', type=str)
|
||||||
|
parser.add_argument('s', help='session name', type=str)
|
||||||
parser.add_argument('--i', help='number of midis to generate', type=int)
|
parser.add_argument('--i', help='number of midis to generate', type=int)
|
||||||
parser.add_argument('--l', help='latent_dim_of_model', type=int)
|
parser.add_argument('--l', help='latent_dim_of_model', type=int)
|
||||||
parser.add_argument('--m', help="mode {'from_seq', 'from_state}'", type=str)
|
parser.add_argument('--m', help="mode {'from_seq', 'from_state}'", type=str)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
EXPERIMENT_NAME = args.n
|
EXPERIMENT_NAME = args.n
|
||||||
|
SESSION_NAME = args.s
|
||||||
GENERETIONS_COUNT = args.i
|
GENERETIONS_COUNT = args.i
|
||||||
LATENT_DIM = args.l
|
LATENT_DIM = args.l
|
||||||
MODE = args.m
|
MODE = args.m
|
||||||
|
|
||||||
if GENERETIONS_COUNT == None:
|
if not GENERETIONS_COUNT:
|
||||||
GENERETIONS_COUNT = 1
|
GENERETIONS_COUNT = 1
|
||||||
|
|
||||||
|
if not LATENT_DIM:
|
||||||
|
LATENT_DIM = 256
|
||||||
|
|
||||||
|
if not MODE:
|
||||||
|
MODE = 'from_seq'
|
||||||
|
|
||||||
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
|
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
|
||||||
|
|
||||||
band = dict()
|
band = dict()
|
||||||
@ -45,7 +53,7 @@ for instrument in tqdm(band):
|
|||||||
band[instrument][0] = model
|
band[instrument][0] = model
|
||||||
band[instrument][1] = program
|
band[instrument][1] = program
|
||||||
|
|
||||||
for midi_counter in range(GENERETIONS_COUNT):
|
for midi_counter in tqdm(range(GENERETIONS_COUNT)):
|
||||||
''' MAKE MULTIINSTRUMENTAL MUSIC !!!'''
|
''' MAKE MULTIINSTRUMENTAL MUSIC !!!'''
|
||||||
notes = dict()
|
notes = dict()
|
||||||
|
|
||||||
@ -76,7 +84,11 @@ for midi_counter in range(GENERETIONS_COUNT):
|
|||||||
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME))
|
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
save_path = os.path.join('generated_music', EXPERIMENT_NAME, f'{EXPERIMENT_NAME}_{midi_counter}_{MODE}_{LATENT_DIM}.mid')
|
save_path = os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME, f'{EXPERIMENT_NAME}_{midi_counter}_{MODE}_{LATENT_DIM}.mid')
|
||||||
generated_midi.save(save_path)
|
generated_midi.save(save_path)
|
||||||
print(f'Generated succefuly to {save_path}')
|
# print(f'Generated succefuly to {save_path}')
|
||||||
|
@ -10,8 +10,6 @@ from random import randint
|
|||||||
import pretty_midi as pm
|
import pretty_midi as pm
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Stream class is no logner needed <- remore from code and make just SingleTrack.notes instead on SingleTrack.stream.notes
|
# TODO: Stream class is no logner needed <- remore from code and make just SingleTrack.notes instead on SingleTrack.stream.notes
|
||||||
class Stream():
|
class Stream():
|
||||||
|
|
||||||
@ -486,7 +484,7 @@ def round_to_sixteenth_note(x, base=0.25):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
return base * round(x/base)
|
return base * round(x/base)
|
||||||
|
|
||||||
def parse_pretty_midi_instrument(instrument, resolution, time_to_tick, key_offset):
|
def parse_pretty_midi_instrument(instrument, resolution, time_to_tick, key_offset):
|
||||||
''' arguments: a prettyMidi instrument object
|
''' arguments: a prettyMidi instrument object
|
||||||
return: a custom SingleTrack object
|
return: a custom SingleTrack object
|
||||||
|
@ -52,7 +52,6 @@ class Seq2SeqTransformer():
|
|||||||
self.x_vocab_size = len(self.x_vocab)
|
self.x_vocab_size = len(self.x_vocab)
|
||||||
self.y_vocab_size = len(self.y_vocab)
|
self.y_vocab_size = len(self.y_vocab)
|
||||||
|
|
||||||
|
|
||||||
self.x_transform_dict = dict(
|
self.x_transform_dict = dict(
|
||||||
[(char, i) for i, char in enumerate(self.x_vocab)])
|
[(char, i) for i, char in enumerate(self.x_vocab)])
|
||||||
self.y_transform_dict = dict(
|
self.y_transform_dict = dict(
|
||||||
|
113
project/train.py
113
project/train.py
@ -3,64 +3,77 @@ import sys
|
|||||||
import pickle
|
import pickle
|
||||||
import keras
|
import keras
|
||||||
import argparse
|
import argparse
|
||||||
|
import warnings
|
||||||
from model import Seq2SeqModel
|
from model import Seq2SeqModel
|
||||||
|
from extract import make_folder_if_not_exist
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
# TODO:
|
||||||
parser.add_argument('n', help='name for experiment', type=str)
|
# FIXME:
|
||||||
parser.add_argument('--b', help='batch_size', type=int)
|
|
||||||
parser.add_argument('--l', help='latent_dim', type=int)
|
|
||||||
parser.add_argument('--e', help='epochs', type=int)
|
|
||||||
parser.add_argument('--r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
|
|
||||||
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
def parse_argv():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('n', help='name for experiment', type=str)
|
||||||
|
parser.add_argument('--b', help='batch_size', type=int)
|
||||||
|
parser.add_argument('--l', help='latent_dim', type=int)
|
||||||
|
parser.add_argument('--e', help='epochs', type=int)
|
||||||
|
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
|
||||||
|
parser.add_argument('-r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
'''HYPER PARAMETERS'''
|
def load_workflow():
|
||||||
EXPERIMENT_NAME = args.n
|
workflow_path = os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl')
|
||||||
BATCH_SIZE = args.b
|
if os.path.isfile(workflow_path):
|
||||||
LATENT_DIM = args.l
|
model_workflow = pickle.load(open(workflow_path,'rb'))
|
||||||
EPOCHS = args.e
|
else:
|
||||||
RESET = args.r
|
raise FileNotFoundError(f'There is no workflow.pkl file in trainig_sets/{EXPERIMENT_NAME}/ folder')
|
||||||
INSTRUMENT = args.i
|
return model_workflow
|
||||||
|
|
||||||
if BATCH_SIZE == None:
|
def train_models(model_workflow):
|
||||||
BATCH_SIZE = 32
|
|
||||||
if LATENT_DIM == None:
|
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
|
||||||
LATENT_DIM = 256
|
# make_folder_if_not_exist(os.mkdir(os.path.join('models',EXPERIMENT_NAME)))
|
||||||
if EPOCHS == None:
|
|
||||||
EPOCHS = 1
|
found = False
|
||||||
if RESET == None:
|
for instrument in instruments:
|
||||||
RESET = False
|
|
||||||
|
|
||||||
## TODO: raise error if file not found
|
if INSTRUMENT == None or INSTRUMENT == instrument:
|
||||||
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
|
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
|
||||||
tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)
|
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
|
||||||
|
|
||||||
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
|
x_train, y_train, _ = pickle.load(open(data_path,'rb'))
|
||||||
|
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
|
||||||
|
if os.path.isfile(model_path) and not RESET:
|
||||||
|
model.load(model_path)
|
||||||
|
|
||||||
# make folder for new experiment
|
print(f'Training: {instrument}')
|
||||||
try:
|
model.fit(BATCH_SIZE, EPOCHS, callbacks=[])
|
||||||
os.mkdir(os.path.join('models',EXPERIMENT_NAME))
|
model.save(model_path)
|
||||||
except:
|
found = True
|
||||||
pass
|
|
||||||
|
|
||||||
# init models
|
if not found:
|
||||||
found = False
|
raise ValueError(f'Instrument not found. Use one of the {instruments}')
|
||||||
for instrument in instruments:
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
if INSTRUMENT == None or INSTRUMENT == instrument:
|
warnings.filterwarnings("ignore")
|
||||||
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
|
args = parse_argv()
|
||||||
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
|
|
||||||
|
EXPERIMENT_NAME = args.n
|
||||||
|
BATCH_SIZE = args.b
|
||||||
|
LATENT_DIM = args.l
|
||||||
|
EPOCHS = args.e
|
||||||
|
RESET = args.r
|
||||||
|
INSTRUMENT = args.i
|
||||||
|
|
||||||
x_train, y_train, _ = pickle.load(open(data_path,'rb'))
|
# default settings if not args passed
|
||||||
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
|
if not BATCH_SIZE:
|
||||||
if os.path.isfile(model_path) and not RESET:
|
BATCH_SIZE = 32
|
||||||
model.load(model_path)
|
if not LATENT_DIM:
|
||||||
|
LATENT_DIM = 256
|
||||||
print(f'Training: {instrument}')
|
if not EPOCHS:
|
||||||
train_history = model.fit(BATCH_SIZE, EPOCHS, callbacks=[tbCallBack])
|
EPOCHS = 1
|
||||||
model.save(model_path)
|
if not RESET:
|
||||||
found = True
|
RESET = False
|
||||||
|
|
||||||
if not found:
|
train_models(load_workflow())
|
||||||
raise ValueError(f'Instrument not found. Use one of the {instruments}')
|
|
Loading…
Reference in New Issue
Block a user