2019-05-29 10:36:34 +02:00
#!/usr/bin/env python3
import settings
import pypianoroll as roll
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
from math import floor
import sys
2019-05-30 11:23:34 +02:00
import pickle
from tqdm import tqdm
from tqdm import trange
from collections import defaultdict
import bz2
import pickle
2019-05-29 10:36:34 +02:00
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
2019-05-30 11:23:34 +02:00
# TODO: add transpositions of every sample to every possible key transposition
# np.roll(sample, pitch_interval, axis=1) for transposition
# np.roll(sample, time_steps, axis=0) for time shifting
all_X_samples = []
all_y_samples = []
for track in roll.Multitrack(midi_file_path).tracks:
if not track.is_drum:
# TODO: this makes rollable samples and dataset of y_train for prdiction
# the idea is to predict next N timesteps from prevous M timesteps
m_timesteps = 96
n_next_notes = 4
track_timesteps = track.pianoroll.shape[0] - (m_timesteps + n_next_notes)
X_track_samples = []
y_track_samples = []
for i in range(track_timesteps):
X = track.pianoroll[i : i + m_timesteps].reshape(96,128)
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(4,128)
# TODO: add code for drums samples
return all_X_samples, all_y_samples
def to_samples_by_instrument(midi_file_path, midi_res=settings.midi_resolution):
# add transpositions of every sample to every possible key transposition
# np.roll(sample, pitch_interval, axis=1) for transposition
# np.roll(sample, time_steps, axis=0) for time shifting
2019-05-29 10:36:34 +02:00
2019-05-30 11:23:34 +02:00
# TODO: make rollable samples with train_Y set
2019-05-29 10:36:34 +02:00
2019-05-30 11:23:34 +02:00
fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
samples_by_instrument = defaultdict(fill_empty_array)
all_beats = np.empty((0, 96, 128))
2019-05-29 10:36:34 +02:00
for track in roll.Multitrack(midi_file_path).tracks:
if not track.is_drum:
2019-05-30 11:23:34 +02:00
key = track.program + 1
# TODO: this makes pack of samples of N x 96 x 128 shape
# number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
# track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
# track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
# TODO: this makes rollable samples and dataset of y_train for prdiction
# the idea is to predict next n notes from prevous m timesteps
m_timesteps = 96
n_next_notes = 4
for i, value in tqdm(enumerate(track.pianoroll[:-(m_timesteps + n_next_notes)])):
X = track.pianoroll[i : i + m_timesteps].reshape(1,96,128)
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(1,1,128)
samples_by_instrument[key][0] = np.concatenate([X, samples_by_instrument[ key ][0]], axis=0)
samples_by_instrument[key][1] = np.concatenate([y, samples_by_instrument[ key ][1]], axis=0)
# samples_by_instrument[track.program + 1][0] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
2019-05-29 10:36:34 +02:00
2019-05-30 11:23:34 +02:00
# TODO: add code for drums samples
2019-05-29 10:36:34 +02:00
2019-05-30 11:23:34 +02:00
return samples_by_instrument
2019-05-29 10:36:34 +02:00
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
tracks = [roll.Track(samples, program=program)]
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
2019-05-30 11:23:34 +02:00
roll.write(return_midi, output_path)
2019-05-29 10:36:34 +02:00
2019-05-29 10:50:00 +02:00
# todo: this function is running too slow.
2019-05-29 10:36:34 +02:00
def delete_empty_samples(sample_pack):
print('Deleting empty samples...')
temp_sample_pack = sample_pack
index_manipulator = 1
for index, sample in enumerate(sample_pack):
if sample.sum() == 0:
temp_sample_pack = np.delete(temp_sample_pack, index-index_manipulator, axis=0)
index_manipulator = index_manipulator + 1
print('Deleted {} empty samples'.format(index_manipulator-1))
return temp_sample_pack
def main():
2019-05-30 11:23:34 +02:00
# from collections import defaultdict
# fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
# samples_pack_by_instrument = defaultdict(fill_empty_array)
# sample_pack = np.empty((0,settings.midi_resolution,128))
X_train = []
y_train = []
for midi_file in tqdm(os.listdir(settings.midi_dir)):
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
X, y = to_samples(midi_file_path)
# if midi_samples is None:
# continue
# this is for intrument separation
# for key, value in midi_samples.items():
# samples_pack_by_instrument[key][0] = np.concatenate((samples_pack_by_instrument[key][0], value[0]), axis=0)
# samples_pack_by_instrument[key][1] = np.concatenate((samples_pack_by_instrument[key][1], value[1]), axis=0)
# TODO: Delete empty samples
# sample_pack = delete_empty_samples(sample_pack)
# save as compressed pickle (sample-dictionary)
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
# pickle.dump(dict(samples_pack_by_instrument), sfile)
# this is for intrument separation
# print('Saving...')
# for key, value in tqdm(samples_pack_by_instrument.items()):
# np.savez_compressed('data/samples/X_{}.npz'.format(settings.midi_program[key][0]), value)
# np.savez_compressed('data/samples/y_{}.npz'.format(settings.midi_program[key][1]), value)
# this if for one big list
np_X_train = np.array(X_train)
np_y_train = np.array(y_train)
print(np_X_train.shape, np_y_train.shape)
np.savez_compressed('data/samples/X_{}.npz'.format(1), np_X_train)
np.savez_compressed('data/samples/y_{}.npz'.format(1), np_y_train)
# Give a preview of what samples looks like
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
# for idx, ax in enumerate(axes.ravel()):
# n = np.random.randint(0, value[0].shape[0])
# sample = value[n]
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
print('Exported {} samples'.format(np_X_train.shape[0]))
2019-05-29 10:36:34 +02:00
if __name__ == '__main__':