added argv, fix no-dir-exist bug, loop through all files, dirs in input
This commit is contained in:
parent
b36407e3c6
commit
b296eb9e6c
@ -15,6 +15,9 @@ from collections import defaultdict
|
||||
import bz2
|
||||
import pickle
|
||||
|
||||
midi_folder_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
|
||||
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
|
||||
|
||||
# add transpositions of every sample to every possible key transposition
|
||||
@ -45,36 +48,34 @@ def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=
|
||||
tracks = [roll.Track(samples, program=program)]
|
||||
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
|
||||
roll.write(return_midi, output_path)
|
||||
return return_midi
|
||||
|
||||
# TODO: this function is running too slow.
|
||||
def delete_empty_samples(sample_pack):
|
||||
non_empty_arrays = []
|
||||
for sample in sample_pack:
|
||||
if sample.sum() != 0:
|
||||
non_empty_arrays.append(sample)
|
||||
|
||||
return np.array(non_empty_arrays)
|
||||
|
||||
|
||||
def main():
|
||||
print('Exporting...')
|
||||
|
||||
from collections import defaultdict
|
||||
fill_empty_array = lambda : np.empty((0, 96, 128))
|
||||
samples_pack_by_instrument = defaultdict(fill_empty_array)
|
||||
|
||||
sample_pack = np.empty((0,settings.midi_resolution,128))
|
||||
|
||||
for midi_file in tqdm(os.listdir(settings.midi_dir)):
|
||||
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
||||
midi_samples = to_samples(midi_file_path)
|
||||
if midi_samples is None:
|
||||
continue
|
||||
|
||||
# this is for intrument separation
|
||||
for key, value in midi_samples.items():
|
||||
value = delete_empty_samples(value)
|
||||
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
|
||||
for directory, subdirectories, files in os.walk(midi_folder_path):
|
||||
for midi_file in tqdm(files):
|
||||
midi_file_path = os.path.join(directory, midi_file)
|
||||
try:
|
||||
midi_samples = to_samples(midi_file_path)
|
||||
except:
|
||||
pass
|
||||
if midi_samples is None:
|
||||
continue
|
||||
# this is for intrument separation
|
||||
for key, value in midi_samples.items():
|
||||
value = delete_empty_samples(value)
|
||||
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
|
||||
|
||||
# save as compressed pickle (sample-dictionary)
|
||||
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
|
||||
@ -83,15 +84,17 @@ def main():
|
||||
# this is for intrument separation
|
||||
print('Saving...')
|
||||
for key, value in tqdm(samples_pack_by_instrument.items()):
|
||||
np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value)
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
np.savez_compressed('{}/{}.npz'.format(output_path, settings.midi_program[key]), value)
|
||||
|
||||
# Give a preview of what samples looks like
|
||||
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
||||
for idx, ax in enumerate(axes.ravel()):
|
||||
n = np.random.randint(0, value.shape[0])
|
||||
sample = value[n]
|
||||
ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
||||
plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
|
||||
# # Give a preview of what samples looks like
|
||||
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
||||
# for idx, ax in enumerate(axes.ravel()):
|
||||
# n = np.random.randint(0, value.shape[0])
|
||||
# sample = value[n]
|
||||
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
||||
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user