praca-magisterska/project/midi.py

163 lines
6.8 KiB
Python
Raw Normal View History

2019-05-29 10:36:34 +02:00
#!/usr/bin/env python3
import settings
import pypianoroll as roll
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
from math import floor
import sys
import pickle
from tqdm import tqdm
from tqdm import trange
from collections import defaultdict
import bz2
import pickle
2019-05-29 10:36:34 +02:00
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
# TODO: add transpositions of every sample to every possible key transposition
# np.roll(sample, pitch_interval, axis=1) for transposition
# np.roll(sample, time_steps, axis=0) for time shifting
all_X_samples = []
all_y_samples = []
for track in roll.Multitrack(midi_file_path).tracks:
if not track.is_drum:
# TODO: this makes rollable samples and dataset of y_train for prdiction
# the idea is to predict next N timesteps from prevous M timesteps
m_timesteps = 96
n_next_notes = 4
track_timesteps = track.pianoroll.shape[0] - (m_timesteps + n_next_notes)
X_track_samples = []
y_track_samples = []
for i in range(track_timesteps):
X = track.pianoroll[i : i + m_timesteps].reshape(96,128)
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(4,128)
X_track_samples.append(X)
y_track_samples.append(y)
all_X_samples.extend(X_track_samples)
all_y_samples.extend(y_track_samples)
else:
# TODO: add code for drums samples
pass
return all_X_samples, all_y_samples
def to_samples_by_instrument(midi_file_path, midi_res=settings.midi_resolution):
# add transpositions of every sample to every possible key transposition
# np.roll(sample, pitch_interval, axis=1) for transposition
# np.roll(sample, time_steps, axis=0) for time shifting
2019-05-29 10:36:34 +02:00
# TODO: make rollable samples with train_Y set
2019-05-29 10:36:34 +02:00
fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
samples_by_instrument = defaultdict(fill_empty_array)
all_beats = np.empty((0, 96, 128))
2019-05-29 10:36:34 +02:00
for track in roll.Multitrack(midi_file_path).tracks:
if not track.is_drum:
key = track.program + 1
# TODO: this makes pack of samples of N x 96 x 128 shape
# number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
# track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
# track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
# TODO: this makes rollable samples and dataset of y_train for prdiction
# the idea is to predict next n notes from prevous m timesteps
m_timesteps = 96
n_next_notes = 4
for i, value in tqdm(enumerate(track.pianoroll[:-(m_timesteps + n_next_notes)])):
X = track.pianoroll[i : i + m_timesteps].reshape(1,96,128)
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(1,1,128)
samples_by_instrument[key][0] = np.concatenate([X, samples_by_instrument[ key ][0]], axis=0)
samples_by_instrument[key][1] = np.concatenate([y, samples_by_instrument[ key ][1]], axis=0)
# samples_by_instrument[track.program + 1][0] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
2019-05-29 10:36:34 +02:00
else:
# TODO: add code for drums samples
2019-05-29 10:36:34 +02:00
pass
return samples_by_instrument
2019-05-29 10:36:34 +02:00
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
tracks = [roll.Track(samples, program=program)]
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
roll.write(return_midi, output_path)
2019-05-29 10:36:34 +02:00
# todo: this function is running too slow.
2019-05-29 10:36:34 +02:00
def delete_empty_samples(sample_pack):
print('Deleting empty samples...')
temp_sample_pack = sample_pack
index_manipulator = 1
for index, sample in enumerate(sample_pack):
if sample.sum() == 0:
temp_sample_pack = np.delete(temp_sample_pack, index-index_manipulator, axis=0)
index_manipulator = index_manipulator + 1
print('Deleted {} empty samples'.format(index_manipulator-1))
return temp_sample_pack
def main():
print('Exporting...')
# from collections import defaultdict
# fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
# samples_pack_by_instrument = defaultdict(fill_empty_array)
# sample_pack = np.empty((0,settings.midi_resolution,128))
X_train = []
y_train = []
for midi_file in tqdm(os.listdir(settings.midi_dir)):
print(midi_file)
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
X, y = to_samples(midi_file_path)
# if midi_samples is None:
# continue
X_train.extend(X)
y_train.extend(y)
# this is for intrument separation
# for key, value in midi_samples.items():
# samples_pack_by_instrument[key][0] = np.concatenate((samples_pack_by_instrument[key][0], value[0]), axis=0)
# samples_pack_by_instrument[key][1] = np.concatenate((samples_pack_by_instrument[key][1], value[1]), axis=0)
# TODO: Delete empty samples
# sample_pack = delete_empty_samples(sample_pack)
# save as compressed pickle (sample-dictionary)
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
# pickle.dump(dict(samples_pack_by_instrument), sfile)
# this is for intrument separation
# print('Saving...')
# for key, value in tqdm(samples_pack_by_instrument.items()):
# np.savez_compressed('data/samples/X_{}.npz'.format(settings.midi_program[key][0]), value)
# np.savez_compressed('data/samples/y_{}.npz'.format(settings.midi_program[key][1]), value)
# this if for one big list
print('Saving...')
np_X_train = np.array(X_train)
np_y_train = np.array(y_train)
print(np_X_train.shape, np_y_train.shape)
np.savez_compressed('data/samples/X_{}.npz'.format(1), np_X_train)
np.savez_compressed('data/samples/y_{}.npz'.format(1), np_y_train)
# Give a preview of what samples looks like
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
# for idx, ax in enumerate(axes.ravel()):
# n = np.random.randint(0, value[0].shape[0])
# sample = value[n]
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
print('Exported {} samples'.format(np_X_train.shape[0]))
print('Done!')
2019-05-29 10:36:34 +02:00
if __name__ == '__main__':
main()