add docstrings, fix choose_by_prob

This commit is contained in:
Cezary Pukownik 2019-06-19 13:40:35 +02:00
parent 75e1c6cb90
commit fee8de37cb
3 changed files with 90 additions and 23 deletions

View File

@ -1,42 +1,74 @@
#!python3
#!/usr/bin/env python3
''' This module generates a sample, and create a midi file.
Usage:
>>> ./generate.py [trained_model_path] [output_path]
'''
import settings
import sys
import random
import pickle
import numpy as np
import tensorflow as tf
import pypianoroll as roll
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
from music21 import converter, instrument, note, chord, stream
from keras.layers import Input, Dense, Conv2D
from keras.models import Model
from tensorflow.keras import layers
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector
from keras.models import Model, Sequential
import matplotlib.pyplot as plt
import settings
import random
import pickle
from tqdm import trange, tqdm
import sys
from music21 import converter, instrument, note, chord, stream
def choose_by_prob(list_of_probs):
''' This functions a list of values and assumed
that if the value is bigger it should by returned often
It was crated to give more options to choose than argmax function,
thus is more than one way that you can develop a melody.
Returns a index of choosen value from given list.
'''
sum_prob = np.array(list_of_probs).sum()
prob_normalized = [x/sum_prob for x in list_of_probs]
cumsum = np.array(prob_normalized).cumsum()
prob_cum = cumsum.tolist()
random_x = random.random()
for i, x in enumerate(prob_cum):
if random_x < x:
return i
trained_model_path = sys.argv[1]
output_path = sys.argv[2]
# load and predict
# load model and dictionary that can translate back index_numbers to notes
# this dictionary is generated with model
print('Loading... {}'.format(trained_model_path))
model = pickle.load(open(trained_model_path, 'rb'))
int_to_note = pickle.load(open('{}_dict'.format(trained_model_path), 'rb'))
seed = [random.randint(0,50) for x in range(8)]
# TODO: 16 it should a variable by integrated with model seq_len
# TODO: random.randint(0,50), the range should be a variable of lenght of vocab size
seed = [random.randint(0,250) for x in range(16)]
music = []
print('Generating...')
for i in trange(500):
predicted_vector = model.predict(np.array(seed).reshape(1,8,1))
predicted_index = np.argmax(predicted_vector)
for i in trange(124):
#TODO: 16 it should a variable by integrated with model seq_len
predicted_vector = model.predict(np.array(seed).reshape(1,16,1))
# using best fitted note
# predicted_index = np.argmax(predicted_vector)
# using propability distribution for choosing note
# to prevent looping
predicted_index = choose_by_prob(predicted_vector)
music.append(int_to_note[predicted_index])
seed.append(predicted_index)
seed = seed[1:9]
#TODO: 16 it should a variable by integrated with model seq_len
seed = seed[1:1+16]
print('Saving...')

View File

@ -1,5 +1,17 @@
#!python3
#!/usr/bin/env python3
''' This module contains functions to endocing midi files into data samples
that is prepared for model training.
midi_folder_path - the path to directiory containing midi files
output_path - the output path where will be created samples of data
Usage:
>>> ./midi.py <midi_folder_path> <output_path>
'''
import settings
import pypianoroll as roll
import numpy as np
@ -14,8 +26,22 @@ import music21
midi_folder_path = sys.argv[1]
output_path = sys.argv[2]
seq_len = int(sys.argv[3])
def to_sequence(midi_path):
def to_sequence(midi_path, seq_len):
''' This function is supposed to be used on one midi file in directory loop.
Its encoding midi files, into sequances of given lenth as a train_X,
and the next note as a train_y. Also splitting midi samples into
instrument group.
Use for LSTM neural network.
Parameters:
- midi_path: path to midi file
- seq_len: lenght of sequance before prediction
Returns: Tuple of train_X, train_y directories'''
seq_by_instrument = defaultdict( lambda : [] )
midi_file = music21.converter.parse(midi_path)
stream = music21.instrument.partitionByInstrument(midi_file)
@ -36,9 +62,9 @@ def to_sequence(midi_path):
y_train_by_instrument = defaultdict( lambda : [] )
for instrument, sequence in seq_by_instrument.items():
for i in range(len(sequence)-8) :
X_train_by_instrument[instrument].append(np.array(sequence[i: i + 8])) # <seq lenth
y_train_by_instrument[instrument].append(np.array(sequence[i + 8]))
for i in range(len(sequence)-(seq_len)) :
X_train_by_instrument[instrument].append(np.array(sequence[i:i+seq_len])) # <seq lenth
y_train_by_instrument[instrument].append(np.array(sequence[i+seq_len]))
# TODO: Notes to integers
return X_train_by_instrument, y_train_by_instrument
@ -52,7 +78,13 @@ def main():
for directory, subdirectories, files in os.walk(midi_folder_path):
for midi_file in tqdm(files):
midi_file_path = os.path.join(directory, midi_file)
_X_train, _y_train = to_sequence(midi_file_path)
# some midi files can be corupted, and cannot be parsed
# so we just omit corupted files, and go to the next file.
try:
_X_train, _y_train = to_sequence(midi_file_path, seq_len)
except music21.midi.MidiException:
continue
for (X_key, X_value), (y_key, y_value) in zip(_X_train.items(), _y_train.items()):
train_X[X_key].extend(np.array(X_value))

View File

@ -1,8 +1,9 @@
#!python3
#!/usr/bin/env python3
import tensorflow as tf
import settings
from tensorflow.keras import layers
#from tensorflow.keras import layers
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector, Activation, Bidirectional, Reshape
from keras.models import Model, Sequential
from keras.utils.np_utils import to_categorical
@ -18,7 +19,9 @@ def load_data(samples_path):
# TODO: make transformer class with fit, transform and reverse definitions
def preprocess_samples(train_X, train_y):
vocab = np.unique(train_X)
vocab_X = np.unique(train_X)
vocab_y = np.unique(train_y)
vocab = np.concatenate([vocab_X, vocab_y])
n_vocab = vocab.shape[0]
note_to_int = dict((note, number) for number, note in enumerate(vocab))
int_to_note = dict((number, note) for number, note in enumerate(vocab))