praca-magisterska/project/generate.py

81 lines
2.7 KiB
Python
Raw Normal View History

2019-10-24 14:01:43 +02:00
from midi_processing import MultiTrack, SingleTrack, Stream
from model import Seq2SeqModel, seq_to_numpy
from tqdm import tqdm
import argparse
import os
2019-06-19 13:40:35 +02:00
import pickle
2019-10-24 14:01:43 +02:00
parser = argparse.ArgumentParser()
parser.add_argument('n', help='name for experiment', type=str)
parser.add_argument('--i', help='number of midis to generate', type=int)
args = parser.parse_args()
2019-06-19 13:40:35 +02:00
2019-10-24 14:01:43 +02:00
EXPERIMENT_NAME = args.n
GENERETIONS_COUNT = args.i
LATENT_DIM = 256
MODE = 'from_seq'
2019-10-24 14:01:43 +02:00
if GENERETIONS_COUNT == None:
GENERETIONS_COUNT = 1
2019-10-24 14:01:43 +02:00
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
2019-10-24 14:01:43 +02:00
band = dict()
for key, value in model_workflow.items():
if isinstance(value[0], str):
instrument = value[0]
generator = None
else:
instrument = value[0][1]
generator = value[0][0]
band[instrument] = [None, None, generator]
2019-10-24 14:01:43 +02:00
'''LOAD MODELS'''
for instrument in tqdm(band):
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
model_path = os.path.join('models', EXPERIMENT_NAME, instrument.lower() + '_model.h5')
x_train, y_train, program = pickle.load(open(data_path,'rb'))
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
model.load(model_path)
band[instrument][0] = model
band[instrument][1] = program
2019-10-24 14:01:43 +02:00
for midi_counter in range(GENERETIONS_COUNT):
''' MAKE MULTIINSTRUMENTAL MUSIC !!!'''
notes = dict()
2019-10-24 14:01:43 +02:00
for instrument, (model, program, generator) in band.items():
if generator == None:
notes[instrument] = model.develop(mode=MODE)
else:
input_data = seq_to_numpy(notes[generator],
model.transformer.x_max_seq_length,
model.transformer.x_vocab_size,
model.transformer.x_transform_dict)
notes[instrument] = model.predict(input_data)[:-1]
2019-10-24 14:01:43 +02:00
'''COMPILE TO MIDI'''
generated_midi = MultiTrack()
for instrument, (model, program, generator) in band.items():
if instrument == 'Drums':
is_drums = True
else:
is_drums = False
2019-10-24 14:01:43 +02:00
stream = Stream(first_tick=0, notes=notes[instrument])
track = SingleTrack(name=instrument ,program=program, is_drum=is_drums, stream=stream)
generated_midi.tracks.append(track)
2019-10-24 14:01:43 +02:00
# make folder for new experiment
try:
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME))
except:
pass
2019-10-24 14:01:43 +02:00
save_path = os.path.join('generated_music', EXPERIMENT_NAME, f'{EXPERIMENT_NAME}_{midi_counter}_{MODE}_{LATENT_DIM}.mid')
generated_midi.save(save_path)
print(f'Generated succefuly to {save_path}')