2019-05-29 10:36:34 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
import settings
|
|
|
|
import pypianoroll as roll
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
from tqdm import tqdm
|
|
|
|
from math import floor
|
|
|
|
import sys
|
2019-05-30 11:23:34 +02:00
|
|
|
import pickle
|
|
|
|
from tqdm import tqdm
|
|
|
|
from tqdm import trange
|
|
|
|
from collections import defaultdict
|
|
|
|
import bz2
|
|
|
|
import pickle
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
|
|
|
|
|
2019-05-30 11:23:34 +02:00
|
|
|
# add transpositions of every sample to every possible key transposition
|
|
|
|
# np.roll(sample, pitch_interval, axis=1) for transposition
|
|
|
|
# np.roll(sample, time_steps, axis=0) for time shifting
|
2019-05-29 10:36:34 +02:00
|
|
|
|
2019-05-30 12:36:59 +02:00
|
|
|
fill_empty_array = lambda : np.empty((0, 96, 128))
|
2019-05-30 11:23:34 +02:00
|
|
|
samples_by_instrument = defaultdict(fill_empty_array)
|
|
|
|
all_beats = np.empty((0, 96, 128))
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
for track in roll.Multitrack(midi_file_path).tracks:
|
|
|
|
if not track.is_drum:
|
2019-05-30 11:23:34 +02:00
|
|
|
key = track.program + 1
|
2019-05-30 13:36:15 +02:00
|
|
|
|
2019-05-30 12:36:59 +02:00
|
|
|
# this makes pack of samples of N x 96 x 128 shape
|
|
|
|
number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
|
|
|
|
track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
|
|
|
|
track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
|
2019-05-30 13:36:15 +02:00
|
|
|
|
|
|
|
# save collected pack of data to dictionary with samples packs for every instrument
|
2019-05-30 12:36:59 +02:00
|
|
|
samples_by_instrument[track.program + 1] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
|
2019-05-29 10:36:34 +02:00
|
|
|
else:
|
2019-05-30 11:23:34 +02:00
|
|
|
# TODO: add code for drums samples
|
2019-05-29 10:36:34 +02:00
|
|
|
pass
|
2019-05-30 11:23:34 +02:00
|
|
|
return samples_by_instrument
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
|
|
|
|
tracks = [roll.Track(samples, program=program)]
|
|
|
|
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
|
2019-05-30 11:23:34 +02:00
|
|
|
roll.write(return_midi, output_path)
|
2019-05-29 10:36:34 +02:00
|
|
|
|
2019-05-30 12:36:59 +02:00
|
|
|
# TODO: this function is running too slow.
|
2019-05-29 10:36:34 +02:00
|
|
|
def delete_empty_samples(sample_pack):
|
2019-05-30 13:36:15 +02:00
|
|
|
non_empty_arrays = []
|
|
|
|
for sample in sample_pack:
|
|
|
|
if sample.sum() != 0:
|
|
|
|
non_empty_arrays.append(sample)
|
|
|
|
|
|
|
|
return np.array(non_empty_arrays)
|
|
|
|
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
def main():
|
2019-05-30 11:23:34 +02:00
|
|
|
print('Exporting...')
|
|
|
|
|
2019-05-30 12:36:59 +02:00
|
|
|
from collections import defaultdict
|
|
|
|
fill_empty_array = lambda : np.empty((0, 96, 128))
|
|
|
|
samples_pack_by_instrument = defaultdict(fill_empty_array)
|
2019-05-30 11:23:34 +02:00
|
|
|
|
2019-05-30 12:36:59 +02:00
|
|
|
sample_pack = np.empty((0,settings.midi_resolution,128))
|
2019-05-30 11:23:34 +02:00
|
|
|
|
|
|
|
for midi_file in tqdm(os.listdir(settings.midi_dir)):
|
|
|
|
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
2019-05-30 12:36:59 +02:00
|
|
|
midi_samples = to_samples(midi_file_path)
|
|
|
|
if midi_samples is None:
|
|
|
|
continue
|
|
|
|
|
2019-05-30 11:23:34 +02:00
|
|
|
# this is for intrument separation
|
2019-05-30 12:36:59 +02:00
|
|
|
for key, value in midi_samples.items():
|
2019-05-30 13:36:15 +02:00
|
|
|
value = delete_empty_samples(value)
|
2019-05-30 12:36:59 +02:00
|
|
|
samples_pack_by_instrument[key] = np.concatenate((samples_pack_by_instrument[key], value), axis=0)
|
2019-05-30 11:23:34 +02:00
|
|
|
|
|
|
|
# save as compressed pickle (sample-dictionary)
|
|
|
|
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
|
|
|
|
# pickle.dump(dict(samples_pack_by_instrument), sfile)
|
|
|
|
|
|
|
|
# this is for intrument separation
|
|
|
|
print('Saving...')
|
2019-05-30 12:36:59 +02:00
|
|
|
for key, value in tqdm(samples_pack_by_instrument.items()):
|
|
|
|
np.savez_compressed('data/samples/{}.npz'.format(settings.midi_program[key]), value)
|
|
|
|
|
2019-05-30 13:36:15 +02:00
|
|
|
# Give a preview of what samples looks like
|
2019-05-30 12:36:59 +02:00
|
|
|
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
|
|
|
for idx, ax in enumerate(axes.ravel()):
|
|
|
|
n = np.random.randint(0, value.shape[0])
|
|
|
|
sample = value[n]
|
|
|
|
ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
|
|
|
plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
|
2019-05-30 11:23:34 +02:00
|
|
|
|
|
|
|
print('Done!')
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|