delete empty samples - optimized

This commit is contained in:
Cezary Pukownik 2019-05-30 13:36:15 +02:00
parent 56e7d72a64
commit 4217811ae2

View File

@ -12,7 +12,6 @@ import pickle
from tqdm import tqdm from tqdm import tqdm
from tqdm import trange from tqdm import trange
from collections import defaultdict from collections import defaultdict
import bz2 import bz2
import pickle import pickle
@ -29,10 +28,13 @@ def to_samples(midi_file_path, midi_res=settings.midi_resolution):
for track in roll.Multitrack(midi_file_path).tracks: for track in roll.Multitrack(midi_file_path).tracks:
if not track.is_drum: if not track.is_drum:
key = track.program + 1 key = track.program + 1
# this makes pack of samples of N x 96 x 128 shape # this makes pack of samples of N x 96 x 128 shape
number_of_beats = floor(track.pianoroll.shape[0] / midi_res) number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
track_pianoroll = track.pianoroll[: number_of_beats * midi_res] track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128) track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
# save collected pack of data to dictionary with samples packs for every instrument
samples_by_instrument[track.program + 1] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0) samples_by_instrument[track.program + 1] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
else: else:
# TODO: add code for drums samples # TODO: add code for drums samples
@ -46,15 +48,13 @@ def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=
# TODO: this function is running too slow. # TODO: this function is running too slow.
def delete_empty_samples(sample_pack): def delete_empty_samples(sample_pack):
print('Deleting empty samples...') non_empty_arrays = []
temp_sample_pack = sample_pack for sample in sample_pack:
index_manipulator = 1 if sample.sum() != 0:
for index, sample in enumerate(sample_pack): non_empty_arrays.append(sample)
if sample.sum() == 0:
temp_sample_pack = np.delete(temp_sample_pack, index-index_manipulator, axis=0) return np.array(non_empty_arrays)
index_manipulator = index_manipulator + 1
print('Deleted {} empty samples'.format(index_manipulator-1))
return temp_sample_pack
def main(): def main():
print('Exporting...') print('Exporting...')
@ -66,7 +66,6 @@ def main():
sample_pack = np.empty((0,settings.midi_resolution,128)) sample_pack = np.empty((0,settings.midi_resolution,128))
for midi_file in tqdm(os.listdir(settings.midi_dir)): for midi_file in tqdm(os.listdir(settings.midi_dir)):
print(midi_file)
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file) midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
midi_samples = to_samples(midi_file_path) midi_samples = to_samples(midi_file_path)
if midi_samples is None: if midi_samples is None:
@ -74,11 +73,9 @@ def main():
# this is for intrument separation # this is for intrument separation
for key, value in midi_samples.items(): for key, value in midi_samples.items():
value = delete_empty_samples(value)
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0) samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
# TODO: Delete empty samples - optimize
# sample_pack = delete_empty_samples(sample_pack)
# save as compressed pickle (sample-dictionary) # save as compressed pickle (sample-dictionary)
# sfile = bz2.BZ2File('data/samples.pickle', 'w') # sfile = bz2.BZ2File('data/samples.pickle', 'w')
# pickle.dump(dict(samples_pack_by_instrument), sfile) # pickle.dump(dict(samples_pack_by_instrument), sfile)
@ -88,7 +85,7 @@ def main():
for key, value in tqdm(samples_pack_by_instrument.items()): for key, value in tqdm(samples_pack_by_instrument.items()):
np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value) np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value)
# Give a preview of what samples looks like # Give a preview of what samples looks like
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20)) fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
for idx, ax in enumerate(axes.ravel()): for idx, ax in enumerate(axes.ravel()):
n = np.random.randint(0, value.shape[0]) n = np.random.randint(0, value.shape[0])