#!/usr/bin/env python3 import settings import pypianoroll as roll import matplotlib.pyplot as plt import numpy as np import os from tqdm import tqdm from math import floor import sys import pickle from tqdm import tqdm from tqdm import trange from collections import defaultdict import bz2 import pickle midi_folder_path = sys.argv[1] output_path = sys.argv[2] def to_samples(midi_file_path, midi_res=settings.midi_resolution): # add transpositions of every sample to every possible key transposition # np.roll(sample, pitch_interval, axis=1) for transposition # np.roll(sample, time_steps, axis=0) for time shifting fill_empty_array = lambda : np.empty((0, 96, 128)) samples_by_instrument = defaultdict(fill_empty_array) all_beats = np.empty((0, 96, 128)) for track in roll.Multitrack(midi_file_path).tracks: if not track.is_drum: key = settings.midi_group[track.program + 1] else: key = 'Drums' # this makes pack of samples of N x 96 x 128 shape number_of_beats = floor(track.pianoroll.shape[0] / midi_res) track_pianoroll = track.pianoroll[: number_of_beats * midi_res] track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128) # save collected pack of data to dictionary with samples packs for every instrument samples_by_instrument[key] = np.concatenate([track_beats, samples_by_instrument[key]], axis=0) return samples_by_instrument def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, is_drum=False, beat_resolution=settings.beat_resolution): tracks = [roll.Track(samples, program=program, is_drum=is_drum)] return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution) roll.write(return_midi, output_path) return return_midi def delete_empty_samples(sample_pack): non_empty_arrays = [] for sample in sample_pack: if sample.sum() != 0: non_empty_arrays.append(sample) return np.array(non_empty_arrays) def main(): print('Exporting...') from collections import defaultdict fill_empty_array = lambda : np.empty((0, 96, 128)) samples_pack_by_instrument = defaultdict(fill_empty_array) sample_pack = np.empty((0,settings.midi_resolution,128)) for directory, subdirectories, files in os.walk(midi_folder_path): for midi_file in tqdm(files): midi_file_path = os.path.join(directory, midi_file) try: midi_samples = to_samples(midi_file_path) except: pass if midi_samples is None: continue # this is for intrument separation for key, value in midi_samples.items(): value = delete_empty_samples(value) samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0) # save as compressed pickle (sample-dictionary) # sfile = bz2.BZ2File('data/samples.pickle', 'w') # pickle.dump(dict(samples_pack_by_instrument), sfile) # this is for intrument separation print('Saving...') for key, value in tqdm(samples_pack_by_instrument.items()): if not os.path.exists(output_path): os.makedirs(output_path) np.savez_compressed('{}/{}.npz'.format(output_path, key), value) # # Give a preview of what samples looks like # fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20)) # for idx, ax in enumerate(axes.ravel()): # n = np.random.randint(0, value.shape[0]) # sample = value[n] # ax.imshow(sample, cmap = plt.get_cmap('gray')) # plt.savefig('data/samples/{}.png'.format(settings.midi_program[key])) print('Done!') if __name__ == '__main__': main()