praca-magisterska/project/generate.py

95 lines
3.2 KiB
Python
Raw Normal View History

2019-10-24 14:01:43 +02:00
from midi_processing import MultiTrack, SingleTrack, Stream
from model import Seq2SeqModel, seq_to_numpy
from tqdm import tqdm
import argparse
import os
2019-06-19 13:40:35 +02:00
import pickle
2019-10-24 14:01:43 +02:00
parser = argparse.ArgumentParser()
parser.add_argument('n', help='name for experiment', type=str)
2019-10-27 14:34:02 +01:00
parser.add_argument('s', help='session name', type=str)
2019-10-24 14:01:43 +02:00
parser.add_argument('--i', help='number of midis to generate', type=int)
2019-10-25 11:23:59 +02:00
parser.add_argument('--l', help='latent_dim_of_model', type=int)
parser.add_argument('--m', help="mode {'from_seq', 'from_state}'", type=str)
2019-10-24 14:01:43 +02:00
args = parser.parse_args()
2019-06-19 13:40:35 +02:00
2019-10-24 14:01:43 +02:00
EXPERIMENT_NAME = args.n
2019-10-27 14:34:02 +01:00
SESSION_NAME = args.s
2019-10-24 14:01:43 +02:00
GENERETIONS_COUNT = args.i
2019-10-25 11:23:59 +02:00
LATENT_DIM = args.l
MODE = args.m
2019-10-27 14:34:02 +01:00
if not GENERETIONS_COUNT:
2019-10-24 14:01:43 +02:00
GENERETIONS_COUNT = 1
2019-10-27 14:34:02 +01:00
if not LATENT_DIM:
LATENT_DIM = 256
if not MODE:
MODE = 'from_seq'
2019-10-24 14:01:43 +02:00
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
2019-10-24 14:01:43 +02:00
band = dict()
for key, value in model_workflow.items():
if isinstance(value[0], str):
instrument = value[0]
generator = None
else:
instrument = value[0][1]
generator = value[0][0]
band[instrument] = [None, None, generator]
2019-10-24 14:01:43 +02:00
'''LOAD MODELS'''
for instrument in tqdm(band):
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
model_path = os.path.join('models', EXPERIMENT_NAME, instrument.lower() + '_model.h5')
x_train, y_train, program = pickle.load(open(data_path,'rb'))
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
model.load(model_path)
band[instrument][0] = model
band[instrument][1] = program
2019-10-27 14:34:02 +01:00
for midi_counter in tqdm(range(GENERETIONS_COUNT)):
2019-10-24 14:01:43 +02:00
''' MAKE MULTIINSTRUMENTAL MUSIC !!!'''
notes = dict()
2019-10-24 14:01:43 +02:00
for instrument, (model, program, generator) in band.items():
if generator == None:
notes[instrument] = model.develop(mode=MODE)
else:
input_data = seq_to_numpy(notes[generator],
model.transformer.x_max_seq_length,
model.transformer.x_vocab_size,
model.transformer.x_transform_dict)
notes[instrument] = model.predict(input_data)[:-1]
2019-10-24 14:01:43 +02:00
'''COMPILE TO MIDI'''
generated_midi = MultiTrack()
for instrument, (model, program, generator) in band.items():
if instrument == 'Drums':
is_drums = True
else:
is_drums = False
2019-10-24 14:01:43 +02:00
stream = Stream(first_tick=0, notes=notes[instrument])
track = SingleTrack(name=instrument ,program=program, is_drum=is_drums, stream=stream)
generated_midi.tracks.append(track)
2019-10-24 14:01:43 +02:00
# make folder for new experiment
try:
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME))
except:
pass
2019-10-27 14:34:02 +01:00
try:
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME))
except:
pass
2019-10-27 14:34:02 +01:00
save_path = os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME, f'{EXPERIMENT_NAME}_{midi_counter}_{MODE}_{LATENT_DIM}.mid')
2019-10-24 14:01:43 +02:00
generated_midi.save(save_path)
2019-10-27 14:34:02 +01:00
# print(f'Generated succefuly to {save_path}')