praca-magisterska/project/midi.py

134 lines
5.1 KiB
Python

#!python3
#!/usr/bin/env python3
''' This module contains functions to endocing midi files into data samples
that is prepared for model training.
midi_folder_path - the path to directiory containing midi files
output_path - the output path where will be created samples of data
Usage:
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>
'''
import settings
import pypianoroll as roll
import numpy as np
import os
from tqdm import tqdm
from math import floor
import sys
from collections import defaultdict
import pickle
from music21 import converter, instrument, note, chord, stream
import music21
class MidiParseError(Exception):
"""Error that is raised then midi file cannot be parsed"""
pass
def parse_argv(argv):
'''This function is parsing given arguments when running a midi script.
Returns a tuple consinting of midi_folder_path, output_path, seq_len'''
try:
midi_folder_path = argv[1]
output_path = argv[2]
seq_len = int(argv[3])
return midi_folder_path, output_path, seq_len
except IndexError:
raise AttributeError('You propably didnt pass parameters to run midi.py script.\
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>')
def to_sequence(midi_path, seq_len):
''' This function is supposed to be used on one midi file in directory loop.
Its encoding midi files, into sequances of given lenth as a train_X,
and the next note as a train_y. Also splitting midi samples into
instrument group.
Use for LSTM neural network.
Parameters:
- midi_path: path to midi file
- seq_len: lenght of sequance before prediction
Returns: Tuple of train_X, train_y dictionaries consisinting of samples of song grouped by instruments
'''
seq_by_instrument = defaultdict( lambda : [] )
try:
midi_file = music21.converter.parse(midi_path)
except music21.midi.MidiException:
raise MidiParseError
stream = music21.instrument.partitionByInstrument(midi_file)
for part in stream:
for event in part:
if part.partName != None:
if isinstance(event, music21.note.Note):
to_export_event = '{};{}'.format(str(event.pitch), float(event.quarterLength))
seq_by_instrument[part.partName].append(to_export_event)
elif isinstance(event, music21.chord.Chord):
to_export_event = '{};{}'.format(' '.join(str(note) for note in event.pitches), float(event.quarterLength))
seq_by_instrument[part.partName].append(to_export_event)
X_train_by_instrument = defaultdict( lambda : [] )
y_train_by_instrument = defaultdict( lambda : [] )
for instrument, sequence in seq_by_instrument.items():
for i in range(len(sequence)-(seq_len)) :
X_train_by_instrument[instrument].append(np.array(sequence[i:i+seq_len])) # <seq lenth
y_train_by_instrument[instrument].append(np.array(sequence[i+seq_len]))
return X_train_by_instrument, y_train_by_instrument
def colect_samples(midi_folder_path, seq_len):
'''This function is looping throuth given directories and
collecting samples from midi files.
Parameters: midi_folder_path - a path to directory with midi files
seq_len - a lenth of train_X sample that tells
how many notes is given do LSTM to predict the next note.
Returns: Tuple of train_X, train_y dictionaries consisinting
of samples of all songs in directory grouped by instruments.
'''
print('Collecting samples...')
train_X = defaultdict( lambda : [] )
train_y = defaultdict( lambda : [] )
for directory, subdirectories, files in os.walk(midi_folder_path):
for midi_file in tqdm(files):
midi_file_path = os.path.join(directory, midi_file)
try:
_X_train, _y_train = to_sequence(midi_file_path, seq_len)
except MidiParseError:
continue
for (X_key, X_value), (y_key, y_value) in zip(_X_train.items(), _y_train.items()):
train_X[X_key].extend(np.array(X_value))
train_y[y_key].extend(np.array(y_value))
return train_X, train_y
def save_samples(output_path, samples):
'''This function save samples to npz packages, splitted by instrument.'''
print('Saving...')
if not os.path.exists(output_path):
os.makedirs(output_path)
train_X, train_y = samples
for (X_key, X_value), (y_key, y_value) in tqdm(zip(train_X.items(), train_y.items())):
if X_key == y_key:
np.savez_compressed('{}/{}.npz'.format(output_path, X_key), np.array(X_value), np.array(y_value))
def main():
midi_folder_path, output_path, seq_len = parse_argv(sys.argv)
save_samples(output_path, colect_samples(midi_folder_path, seq_len))
print('Done!')
if __name__ == '__main__':
main()