2019-06-19 13:40:35 +02:00
|
|
|
#!python3
|
2019-05-29 10:36:34 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
2019-06-19 13:40:35 +02:00
|
|
|
''' This module contains functions to endocing midi files into data samples
|
|
|
|
that is prepared for model training.
|
|
|
|
|
|
|
|
midi_folder_path - the path to directiory containing midi files
|
|
|
|
output_path - the output path where will be created samples of data
|
|
|
|
|
|
|
|
Usage:
|
2019-06-19 15:48:39 +02:00
|
|
|
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>
|
2019-06-19 13:40:35 +02:00
|
|
|
'''
|
|
|
|
|
2019-05-29 10:36:34 +02:00
|
|
|
import settings
|
|
|
|
import pypianoroll as roll
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
from tqdm import tqdm
|
|
|
|
from math import floor
|
|
|
|
import sys
|
2019-05-30 11:23:34 +02:00
|
|
|
from collections import defaultdict
|
|
|
|
import pickle
|
2019-06-01 17:05:38 +02:00
|
|
|
from music21 import converter, instrument, note, chord, stream
|
|
|
|
import music21
|
2019-05-29 10:36:34 +02:00
|
|
|
|
2019-06-19 15:48:39 +02:00
|
|
|
class MidiParseError(Exception):
|
|
|
|
"""Error that is raised then midi file cannot be parsed"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def parse_argv(argv):
|
|
|
|
'''This function is parsing given arguments when running a midi script.
|
|
|
|
Returns a tuple consinting of midi_folder_path, output_path, seq_len'''
|
|
|
|
try:
|
|
|
|
midi_folder_path = argv[1]
|
|
|
|
output_path = argv[2]
|
|
|
|
seq_len = int(argv[3])
|
|
|
|
return midi_folder_path, output_path, seq_len
|
|
|
|
except IndexError:
|
|
|
|
raise AttributeError('You propably didnt pass parameters to run midi.py script.\
|
|
|
|
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>')
|
2019-06-19 13:40:35 +02:00
|
|
|
|
|
|
|
def to_sequence(midi_path, seq_len):
|
|
|
|
''' This function is supposed to be used on one midi file in directory loop.
|
|
|
|
Its encoding midi files, into sequances of given lenth as a train_X,
|
|
|
|
and the next note as a train_y. Also splitting midi samples into
|
|
|
|
instrument group.
|
2019-05-30 20:47:47 +02:00
|
|
|
|
2019-06-19 13:40:35 +02:00
|
|
|
Use for LSTM neural network.
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
- midi_path: path to midi file
|
|
|
|
- seq_len: lenght of sequance before prediction
|
|
|
|
|
2019-06-19 15:48:39 +02:00
|
|
|
Returns: Tuple of train_X, train_y dictionaries consisinting of samples of song grouped by instruments
|
|
|
|
'''
|
2019-06-19 13:40:35 +02:00
|
|
|
|
2019-06-01 17:05:38 +02:00
|
|
|
seq_by_instrument = defaultdict( lambda : [] )
|
2019-06-19 15:48:39 +02:00
|
|
|
|
|
|
|
try:
|
|
|
|
midi_file = music21.converter.parse(midi_path)
|
|
|
|
except music21.midi.MidiException:
|
|
|
|
raise MidiParseError
|
2019-06-01 17:05:38 +02:00
|
|
|
stream = music21.instrument.partitionByInstrument(midi_file)
|
|
|
|
for part in stream:
|
|
|
|
for event in part:
|
|
|
|
if part.partName != None:
|
|
|
|
if isinstance(event, music21.note.Note):
|
2019-06-19 15:48:39 +02:00
|
|
|
to_export_event = '{};{}'.format(str(event.pitch), float(event.quarterLength))
|
2019-06-01 17:05:38 +02:00
|
|
|
seq_by_instrument[part.partName].append(to_export_event)
|
|
|
|
elif isinstance(event, music21.chord.Chord):
|
2019-06-19 15:48:39 +02:00
|
|
|
to_export_event = '{};{}'.format(' '.join(str(note) for note in event.pitches), float(event.quarterLength))
|
2019-06-01 17:05:38 +02:00
|
|
|
seq_by_instrument[part.partName].append(to_export_event)
|
|
|
|
|
|
|
|
X_train_by_instrument = defaultdict( lambda : [] )
|
|
|
|
y_train_by_instrument = defaultdict( lambda : [] )
|
|
|
|
|
|
|
|
for instrument, sequence in seq_by_instrument.items():
|
2019-06-19 13:40:35 +02:00
|
|
|
for i in range(len(sequence)-(seq_len)) :
|
|
|
|
X_train_by_instrument[instrument].append(np.array(sequence[i:i+seq_len])) # <seq lenth
|
|
|
|
y_train_by_instrument[instrument].append(np.array(sequence[i+seq_len]))
|
2019-06-01 17:05:38 +02:00
|
|
|
|
|
|
|
return X_train_by_instrument, y_train_by_instrument
|
2019-05-30 13:36:15 +02:00
|
|
|
|
2019-06-19 15:48:39 +02:00
|
|
|
def colect_samples(midi_folder_path, seq_len):
|
|
|
|
'''This function is looping throuth given directories and
|
|
|
|
collecting samples from midi files.
|
|
|
|
|
|
|
|
Parameters: midi_folder_path - a path to directory with midi files
|
|
|
|
seq_len - a lenth of train_X sample that tells
|
|
|
|
how many notes is given do LSTM to predict the next note.
|
|
|
|
|
|
|
|
Returns: Tuple of train_X, train_y dictionaries consisinting
|
|
|
|
of samples of all songs in directory grouped by instruments.
|
|
|
|
'''
|
|
|
|
|
|
|
|
print('Collecting samples...')
|
2019-06-01 17:05:38 +02:00
|
|
|
train_X = defaultdict( lambda : [] )
|
|
|
|
train_y = defaultdict( lambda : [] )
|
2019-05-31 10:25:16 +02:00
|
|
|
|
2019-05-30 20:47:47 +02:00
|
|
|
for directory, subdirectories, files in os.walk(midi_folder_path):
|
|
|
|
for midi_file in tqdm(files):
|
|
|
|
midi_file_path = os.path.join(directory, midi_file)
|
2019-06-19 13:40:35 +02:00
|
|
|
try:
|
|
|
|
_X_train, _y_train = to_sequence(midi_file_path, seq_len)
|
2019-06-19 15:48:39 +02:00
|
|
|
except MidiParseError:
|
2019-06-19 13:40:35 +02:00
|
|
|
continue
|
2019-06-01 17:05:38 +02:00
|
|
|
for (X_key, X_value), (y_key, y_value) in zip(_X_train.items(), _y_train.items()):
|
|
|
|
train_X[X_key].extend(np.array(X_value))
|
|
|
|
train_y[y_key].extend(np.array(y_value))
|
2019-05-30 11:23:34 +02:00
|
|
|
|
2019-06-19 15:48:39 +02:00
|
|
|
return train_X, train_y
|
|
|
|
|
|
|
|
def save_samples(output_path, samples):
|
|
|
|
'''This function save samples to npz packages, splitted by instrument.'''
|
|
|
|
|
2019-05-30 11:23:34 +02:00
|
|
|
print('Saving...')
|
2019-06-19 15:48:39 +02:00
|
|
|
|
2019-05-31 10:25:16 +02:00
|
|
|
if not os.path.exists(output_path):
|
|
|
|
os.makedirs(output_path)
|
2019-06-19 15:48:39 +02:00
|
|
|
|
|
|
|
train_X, train_y = samples
|
2019-06-01 17:05:38 +02:00
|
|
|
for (X_key, X_value), (y_key, y_value) in tqdm(zip(train_X.items(), train_y.items())):
|
|
|
|
if X_key == y_key:
|
|
|
|
np.savez_compressed('{}/{}.npz'.format(output_path, X_key), np.array(X_value), np.array(y_value))
|
2019-05-30 11:23:34 +02:00
|
|
|
|
2019-06-19 15:48:39 +02:00
|
|
|
def main():
|
|
|
|
midi_folder_path, output_path, seq_len = parse_argv(sys.argv)
|
|
|
|
save_samples(output_path, colect_samples(midi_folder_path, seq_len))
|
2019-05-30 11:23:34 +02:00
|
|
|
print('Done!')
|
2019-05-29 10:36:34 +02:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|