From 4217811ae23b0280d0ff0273524be1890e4081e0 Mon Sep 17 00:00:00 2001 From: Cezary Pukownik Date: Thu, 30 May 2019 13:36:15 +0200 Subject: [PATCH] delete empty samples - optimized --- project/midi.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/project/midi.py b/project/midi.py index b79378c..7033932 100644 --- a/project/midi.py +++ b/project/midi.py @@ -12,7 +12,6 @@ import pickle from tqdm import tqdm from tqdm import trange from collections import defaultdict - import bz2 import pickle @@ -29,10 +28,13 @@ def to_samples(midi_file_path, midi_res=settings.midi_resolution): for track in roll.Multitrack(midi_file_path).tracks: if not track.is_drum: key = track.program + 1 + # this makes pack of samples of N x 96 x 128 shape number_of_beats = floor(track.pianoroll.shape[0] / midi_res) track_pianoroll = track.pianoroll[: number_of_beats * midi_res] track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128) + + # save collected pack of data to dictionary with samples packs for every instrument samples_by_instrument[track.program + 1] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0) else: # TODO: add code for drums samples @@ -46,15 +48,13 @@ def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo= # TODO: this function is running too slow. def delete_empty_samples(sample_pack): - print('Deleting empty samples...') - temp_sample_pack = sample_pack - index_manipulator = 1 - for index, sample in enumerate(sample_pack): - if sample.sum() == 0: - temp_sample_pack = np.delete(temp_sample_pack, index-index_manipulator, axis=0) - index_manipulator = index_manipulator + 1 - print('Deleted {} empty samples'.format(index_manipulator-1)) - return temp_sample_pack + non_empty_arrays = [] + for sample in sample_pack: + if sample.sum() != 0: + non_empty_arrays.append(sample) + + return np.array(non_empty_arrays) + def main(): print('Exporting...') @@ -66,7 +66,6 @@ def main(): sample_pack = np.empty((0,settings.midi_resolution,128)) for midi_file in tqdm(os.listdir(settings.midi_dir)): - print(midi_file) midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file) midi_samples = to_samples(midi_file_path) if midi_samples is None: @@ -74,11 +73,9 @@ def main(): # this is for intrument separation for key, value in midi_samples.items(): + value = delete_empty_samples(value) samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0) - # TODO: Delete empty samples - optimize - # sample_pack = delete_empty_samples(sample_pack) - # save as compressed pickle (sample-dictionary) # sfile = bz2.BZ2File('data/samples.pickle', 'w') # pickle.dump(dict(samples_pack_by_instrument), sfile) @@ -88,7 +85,7 @@ def main(): for key, value in tqdm(samples_pack_by_instrument.items()): np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value) - # Give a preview of what samples looks like + # Give a preview of what samples looks like fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20)) for idx, ax in enumerate(axes.ravel()): n = np.random.randint(0, value.shape[0])