added argv, fix no-dir-exist bug, loop through all files, dirs in input

This commit is contained in:
Cezary Pukownik 2019-05-30 20:47:47 +02:00
parent b36407e3c6
commit b296eb9e6c

View File

@ -15,6 +15,9 @@ from collections import defaultdict
import bz2
import pickle
midi_folder_path = sys.argv[1]
output_path = sys.argv[2]
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
# add transpositions of every sample to every possible key transposition
@ -45,36 +48,34 @@ def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=
tracks = [roll.Track(samples, program=program)]
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
roll.write(return_midi, output_path)
return return_midi
# TODO: this function is running too slow.
def delete_empty_samples(sample_pack):
non_empty_arrays = []
for sample in sample_pack:
if sample.sum() != 0:
non_empty_arrays.append(sample)
return np.array(non_empty_arrays)
def main():
print('Exporting...')
from collections import defaultdict
fill_empty_array = lambda : np.empty((0, 96, 128))
samples_pack_by_instrument = defaultdict(fill_empty_array)
sample_pack = np.empty((0,settings.midi_resolution,128))
for midi_file in tqdm(os.listdir(settings.midi_dir)):
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
midi_samples = to_samples(midi_file_path)
if midi_samples is None:
continue
# this is for intrument separation
for key, value in midi_samples.items():
value = delete_empty_samples(value)
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
for directory, subdirectories, files in os.walk(midi_folder_path):
for midi_file in tqdm(files):
midi_file_path = os.path.join(directory, midi_file)
try:
midi_samples = to_samples(midi_file_path)
except:
pass
if midi_samples is None:
continue
# this is for intrument separation
for key, value in midi_samples.items():
value = delete_empty_samples(value)
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
# save as compressed pickle (sample-dictionary)
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
@ -83,15 +84,17 @@ def main():
# this is for intrument separation
print('Saving...')
for key, value in tqdm(samples_pack_by_instrument.items()):
np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value)
if not os.path.exists(output_path):
os.makedirs(output_path)
np.savez_compressed('{}/{}.npz'.format(output_path, settings.midi_program[key]), value)
# Give a preview of what samples looks like
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
for idx, ax in enumerate(axes.ravel()):
n = np.random.randint(0, value.shape[0])
sample = value[n]
ax.imshow(sample, cmap = plt.get_cmap('gray'))
plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
# # Give a preview of what samples looks like
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
# for idx, ax in enumerate(axes.ravel()):
# n = np.random.randint(0, value.shape[0])
# sample = value[n]
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
print('Done!')