Merge branch 'seq2seq_model' of s444337/praca-magisterska into master
seq2seq model
This commit is contained in:
commit
76db0a4b9d
BIN
project/__pycache__/midi_processing.cpython-36.pyc
Normal file
BIN
project/__pycache__/midi_processing.cpython-36.pyc
Normal file
Binary file not shown.
BIN
project/__pycache__/model.cpython-36.pyc
Normal file
BIN
project/__pycache__/model.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 7.8 KiB |
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 7.8 KiB |
74
project/extract.py
Normal file
74
project/extract.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from midi_processing import extract_data, analyze_data
|
||||||
|
|
||||||
|
def make_folder_if_not_exist(path):
|
||||||
|
try:
|
||||||
|
os.mkdir(path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def parse_argv():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('midi_pack', help='folder name for midi pack in midi_packs folder', type=str)
|
||||||
|
parser.add_argument('--n', help='name for experiment', type=str)
|
||||||
|
parser.add_argument('--b', help='lengh of sequence in bars', type=int)
|
||||||
|
parser.add_argument('-a', help='analize data', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def ask_for_workflow():
|
||||||
|
'''MODEL WORKFLOW DIALOG'''
|
||||||
|
number_of_instruments = int(input('Please specify number of instruments\n'))
|
||||||
|
model_workflow = dict()
|
||||||
|
for i in range(number_of_instruments):
|
||||||
|
input_string = input('Please specify a workflow step <Instrument> [<Second Instrument>] <mode> {m : melody, a : arrangment}\n')
|
||||||
|
tokens = input_string.split()
|
||||||
|
if tokens[-1] == 'm':
|
||||||
|
model_workflow[i] = (tokens[0], 'melody')
|
||||||
|
elif tokens[-1] == 'a':
|
||||||
|
model_workflow[i] = ((tokens[1], tokens[0]), 'arrangment')
|
||||||
|
else:
|
||||||
|
raise ValueError("The step definitiom must end with {'m', 'a'}");
|
||||||
|
|
||||||
|
make_folder_if_not_exist(os.path.join('training_sets', EXPERIMENT_NAME))
|
||||||
|
pickle.dump(model_workflow, open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'wb'))
|
||||||
|
|
||||||
|
return model_workflow
|
||||||
|
|
||||||
|
def extract_from_folder(model_workflow):
|
||||||
|
for key, (instrument, how) in model_workflow.items():
|
||||||
|
if how == 'melody':
|
||||||
|
instrument_name = instrument
|
||||||
|
else:
|
||||||
|
instrument_name = instrument[1]
|
||||||
|
|
||||||
|
make_folder_if_not_exist(os.path.join('training_sets', EXPERIMENT_NAME))
|
||||||
|
save_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument_name.lower() + '_data.pkl')
|
||||||
|
|
||||||
|
x_train, y_train, program = extract_data(midi_folder_path=os.path.join('midi_packs', MIDI_PACK_NAME),
|
||||||
|
how=how,
|
||||||
|
instrument=instrument,
|
||||||
|
bar_in_seq=BARS_IN_SEQ)
|
||||||
|
|
||||||
|
pickle.dump((x_train, y_train, program), open(save_path,'wb'))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_argv()
|
||||||
|
|
||||||
|
MIDI_PACK_NAME = args.midi_pack
|
||||||
|
EXPERIMENT_NAME = args.n
|
||||||
|
BARS_IN_SEQ = args.b
|
||||||
|
if not EXPERIMENT_NAME:
|
||||||
|
EXPERIMENT_NAME = MIDI_PACK_NAME
|
||||||
|
if not BARS_IN_SEQ:
|
||||||
|
BARS_IN_SEQ = 4
|
||||||
|
ANALIZE = args.a
|
||||||
|
|
||||||
|
if ANALIZE:
|
||||||
|
analyze_data(os.path.join('midi_packs', MIDI_PACK_NAME))
|
||||||
|
else:
|
||||||
|
extract_from_folder(ask_for_workflow())
|
@ -1,96 +1,94 @@
|
|||||||
#!python3
|
from midi_processing import MultiTrack, SingleTrack, Stream
|
||||||
#!/usr/bin/env python3
|
from model import Seq2SeqModel, seq_to_numpy
|
||||||
''' This module generates a sample, and create a midi file.
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
Usage:
|
import os
|
||||||
>>> ./generate.py [trained_model_path] [output_path]
|
|
||||||
|
|
||||||
'''
|
|
||||||
import settings
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
parser = argparse.ArgumentParser()
|
||||||
import tensorflow as tf
|
parser.add_argument('n', help='name for experiment', type=str)
|
||||||
import pypianoroll as roll
|
parser.add_argument('s', help='session name', type=str)
|
||||||
import matplotlib.pyplot as plt
|
parser.add_argument('--i', help='number of midis to generate', type=int)
|
||||||
from tqdm import trange, tqdm
|
parser.add_argument('--l', help='latent_dim_of_model', type=int)
|
||||||
from music21 import converter, instrument, note, chord, stream
|
parser.add_argument('--m', help="mode {'from_seq', 'from_state}'", type=str)
|
||||||
from keras.layers import Input, Dense, Conv2D
|
args = parser.parse_args()
|
||||||
from keras.models import Model
|
|
||||||
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector
|
|
||||||
from keras.models import Model, Sequential
|
|
||||||
|
|
||||||
|
EXPERIMENT_NAME = args.n
|
||||||
|
SESSION_NAME = args.s
|
||||||
|
GENERETIONS_COUNT = args.i
|
||||||
|
LATENT_DIM = args.l
|
||||||
|
MODE = args.m
|
||||||
|
|
||||||
def choose_by_prob(list_of_probs):
|
if not GENERETIONS_COUNT:
|
||||||
''' This functions a list of values and assumed
|
GENERETIONS_COUNT = 1
|
||||||
that if the value is bigger it should by returned often
|
|
||||||
|
|
||||||
It was crated to give more options to choose than argmax function,
|
if not LATENT_DIM:
|
||||||
thus is more than one way that you can develop a melody.
|
LATENT_DIM = 256
|
||||||
|
|
||||||
Returns a index of choosen value from given list.
|
if not MODE:
|
||||||
'''
|
MODE = 'from_seq'
|
||||||
sum_prob = np.array(list_of_probs).sum()
|
|
||||||
prob_normalized = [x/sum_prob for x in list_of_probs]
|
|
||||||
cumsum = np.array(prob_normalized).cumsum()
|
|
||||||
prob_cum = cumsum.tolist()
|
|
||||||
random_x = random.random()
|
|
||||||
for i, x in enumerate(prob_cum):
|
|
||||||
if random_x < x:
|
|
||||||
return i
|
|
||||||
|
|
||||||
trained_model_path = sys.argv[1]
|
model_workflow = pickle.load(open(os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl'),'rb'))
|
||||||
output_path = sys.argv[2]
|
|
||||||
|
|
||||||
# load model and dictionary that can translate back index_numbers to notes
|
band = dict()
|
||||||
# this dictionary is generated with model
|
for key, value in model_workflow.items():
|
||||||
print('Loading... {}'.format(trained_model_path))
|
if isinstance(value[0], str):
|
||||||
model = pickle.load(open(trained_model_path, 'rb'))
|
instrument = value[0]
|
||||||
int_to_note, n_vocab, seq_len = pickle.load(open('{}_dict'.format(trained_model_path), 'rb'))
|
generator = None
|
||||||
|
|
||||||
seed = [random.randint(0,n_vocab) for x in range(seq_len)]
|
|
||||||
music = []
|
|
||||||
|
|
||||||
print('Generating...')
|
|
||||||
for i in trange(124):
|
|
||||||
predicted_vector = model.predict(np.array(seed).reshape(1,seq_len,1))
|
|
||||||
# using best fitted note
|
|
||||||
# predicted_index = np.argmax(predicted_vector)
|
|
||||||
# using propability distribution for choosing note
|
|
||||||
# to prevent looping
|
|
||||||
predicted_index = choose_by_prob(predicted_vector)
|
|
||||||
music.append(int_to_note[predicted_index])
|
|
||||||
seed.append(predicted_index)
|
|
||||||
seed = seed[1:1+seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
print('Saving...')
|
|
||||||
offset = 0
|
|
||||||
output_notes = []
|
|
||||||
for _event in tqdm(music):
|
|
||||||
event, note_len = _event.split(';')
|
|
||||||
if (' ' in event) or event.isdigit():
|
|
||||||
notes_in_chord = event.split(' ')
|
|
||||||
notes = []
|
|
||||||
for current_note in notes_in_chord:
|
|
||||||
new_note = note.Note(current_note)
|
|
||||||
new_note.storedInstrument = instrument.Piano()
|
|
||||||
notes.append(new_note)
|
|
||||||
new_chord = chord.Chord(notes)
|
|
||||||
new_chord.offset = offset
|
|
||||||
output_notes.append(new_chord)
|
|
||||||
else:
|
else:
|
||||||
new_note = note.Note(event)
|
instrument = value[0][1]
|
||||||
new_note.offset = offset
|
generator = value[0][0]
|
||||||
new_note.storedInstrument = instrument.Piano()
|
|
||||||
output_notes.append(new_note)
|
|
||||||
|
|
||||||
offset += float(note_len)
|
band[instrument] = [None, None, generator]
|
||||||
|
|
||||||
midi_stream = stream.Stream(output_notes)
|
'''LOAD MODELS'''
|
||||||
|
for instrument in tqdm(band):
|
||||||
|
|
||||||
midi_stream.write('midi', fp='{}.mid'.format(output_path))
|
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
|
||||||
|
model_path = os.path.join('models', EXPERIMENT_NAME, instrument.lower() + '_model.h5')
|
||||||
|
|
||||||
print('Done!')
|
x_train, y_train, program = pickle.load(open(data_path,'rb'))
|
||||||
|
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
|
||||||
|
model.load(model_path)
|
||||||
|
band[instrument][0] = model
|
||||||
|
band[instrument][1] = program
|
||||||
|
|
||||||
|
for midi_counter in tqdm(range(GENERETIONS_COUNT)):
|
||||||
|
''' MAKE MULTIINSTRUMENTAL MUSIC !!!'''
|
||||||
|
notes = dict()
|
||||||
|
|
||||||
|
for instrument, (model, program, generator) in band.items():
|
||||||
|
if generator == None:
|
||||||
|
notes[instrument] = model.develop(mode=MODE)
|
||||||
|
else:
|
||||||
|
input_data = seq_to_numpy(notes[generator],
|
||||||
|
model.transformer.x_max_seq_length,
|
||||||
|
model.transformer.x_vocab_size,
|
||||||
|
model.transformer.x_transform_dict)
|
||||||
|
notes[instrument] = model.predict(input_data)[:-1]
|
||||||
|
|
||||||
|
'''COMPILE TO MIDI'''
|
||||||
|
generated_midi = MultiTrack()
|
||||||
|
for instrument, (model, program, generator) in band.items():
|
||||||
|
if instrument == 'Drums':
|
||||||
|
is_drums = True
|
||||||
|
else:
|
||||||
|
is_drums = False
|
||||||
|
|
||||||
|
stream = Stream(first_tick=0, notes=notes[instrument])
|
||||||
|
track = SingleTrack(name=instrument ,program=program, is_drum=is_drums, stream=stream)
|
||||||
|
generated_midi.tracks.append(track)
|
||||||
|
|
||||||
|
# make folder for new experiment
|
||||||
|
try:
|
||||||
|
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
os.mkdir(os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
save_path = os.path.join('generated_music', EXPERIMENT_NAME, SESSION_NAME, f'{EXPERIMENT_NAME}_{midi_counter}_{MODE}_{LATENT_DIM}.mid')
|
||||||
|
generated_midi.save(save_path)
|
||||||
|
# print(f'Generated succefuly to {save_path}')
|
||||||
|
133
project/midi.py
133
project/midi.py
@ -1,133 +0,0 @@
|
|||||||
#!python3
|
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
''' This module contains functions to endocing midi files into data samples
|
|
||||||
that is prepared for model training.
|
|
||||||
|
|
||||||
midi_folder_path - the path to directiory containing midi files
|
|
||||||
output_path - the output path where will be created samples of data
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>
|
|
||||||
'''
|
|
||||||
|
|
||||||
import settings
|
|
||||||
import pypianoroll as roll
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
from tqdm import tqdm
|
|
||||||
from math import floor
|
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
|
||||||
import pickle
|
|
||||||
from music21 import converter, instrument, note, chord, stream
|
|
||||||
import music21
|
|
||||||
|
|
||||||
class MidiParseError(Exception):
|
|
||||||
"""Error that is raised then midi file cannot be parsed"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def parse_argv(argv):
|
|
||||||
'''This function is parsing given arguments when running a midi script.
|
|
||||||
Returns a tuple consinting of midi_folder_path, output_path, seq_len'''
|
|
||||||
try:
|
|
||||||
midi_folder_path = argv[1]
|
|
||||||
output_path = argv[2]
|
|
||||||
seq_len = int(argv[3])
|
|
||||||
return midi_folder_path, output_path, seq_len
|
|
||||||
except IndexError:
|
|
||||||
raise AttributeError('You propably didnt pass parameters to run midi.py script.\
|
|
||||||
>>> ./midi.py <midi_folder_path> <output_path> <sequence_lenth>')
|
|
||||||
|
|
||||||
def to_sequence(midi_path, seq_len):
|
|
||||||
''' This function is supposed to be used on one midi file in directory loop.
|
|
||||||
Its encoding midi files, into sequances of given lenth as a train_X,
|
|
||||||
and the next note as a train_y. Also splitting midi samples into
|
|
||||||
instrument group.
|
|
||||||
|
|
||||||
Use for LSTM neural network.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- midi_path: path to midi file
|
|
||||||
- seq_len: lenght of sequance before prediction
|
|
||||||
|
|
||||||
Returns: Tuple of train_X, train_y dictionaries consisinting of samples of song grouped by instruments
|
|
||||||
'''
|
|
||||||
|
|
||||||
seq_by_instrument = defaultdict( lambda : [] )
|
|
||||||
|
|
||||||
try:
|
|
||||||
midi_file = music21.converter.parse(midi_path)
|
|
||||||
except music21.midi.MidiException:
|
|
||||||
raise MidiParseError
|
|
||||||
stream = music21.instrument.partitionByInstrument(midi_file)
|
|
||||||
for part in stream:
|
|
||||||
for event in part:
|
|
||||||
if part.partName != None:
|
|
||||||
if isinstance(event, music21.note.Note):
|
|
||||||
to_export_event = '{};{}'.format(str(event.pitch), float(event.quarterLength))
|
|
||||||
seq_by_instrument[part.partName].append(to_export_event)
|
|
||||||
elif isinstance(event, music21.chord.Chord):
|
|
||||||
to_export_event = '{};{}'.format(' '.join(str(note) for note in event.pitches), float(event.quarterLength))
|
|
||||||
seq_by_instrument[part.partName].append(to_export_event)
|
|
||||||
|
|
||||||
X_train_by_instrument = defaultdict( lambda : [] )
|
|
||||||
y_train_by_instrument = defaultdict( lambda : [] )
|
|
||||||
|
|
||||||
for instrument, sequence in seq_by_instrument.items():
|
|
||||||
for i in range(len(sequence)-(seq_len)) :
|
|
||||||
X_train_by_instrument[instrument].append(np.array(sequence[i:i+seq_len])) # <seq lenth
|
|
||||||
y_train_by_instrument[instrument].append(np.array(sequence[i+seq_len]))
|
|
||||||
|
|
||||||
return X_train_by_instrument, y_train_by_instrument
|
|
||||||
|
|
||||||
def colect_samples(midi_folder_path, seq_len):
|
|
||||||
'''This function is looping throuth given directories and
|
|
||||||
collecting samples from midi files.
|
|
||||||
|
|
||||||
Parameters: midi_folder_path - a path to directory with midi files
|
|
||||||
seq_len - a lenth of train_X sample that tells
|
|
||||||
how many notes is given do LSTM to predict the next note.
|
|
||||||
|
|
||||||
Returns: Tuple of train_X, train_y dictionaries consisinting
|
|
||||||
of samples of all songs in directory grouped by instruments.
|
|
||||||
'''
|
|
||||||
|
|
||||||
print('Collecting samples...')
|
|
||||||
train_X = defaultdict( lambda : [] )
|
|
||||||
train_y = defaultdict( lambda : [] )
|
|
||||||
|
|
||||||
for directory, subdirectories, files in os.walk(midi_folder_path):
|
|
||||||
for midi_file in tqdm(files):
|
|
||||||
midi_file_path = os.path.join(directory, midi_file)
|
|
||||||
try:
|
|
||||||
_X_train, _y_train = to_sequence(midi_file_path, seq_len)
|
|
||||||
except MidiParseError:
|
|
||||||
continue
|
|
||||||
for (X_key, X_value), (y_key, y_value) in zip(_X_train.items(), _y_train.items()):
|
|
||||||
train_X[X_key].extend(np.array(X_value))
|
|
||||||
train_y[y_key].extend(np.array(y_value))
|
|
||||||
|
|
||||||
return train_X, train_y
|
|
||||||
|
|
||||||
def save_samples(output_path, samples):
|
|
||||||
'''This function save samples to npz packages, splitted by instrument.'''
|
|
||||||
|
|
||||||
print('Saving...')
|
|
||||||
|
|
||||||
if not os.path.exists(output_path):
|
|
||||||
os.makedirs(output_path)
|
|
||||||
|
|
||||||
train_X, train_y = samples
|
|
||||||
for (X_key, X_value), (y_key, y_value) in tqdm(zip(train_X.items(), train_y.items())):
|
|
||||||
if X_key == y_key:
|
|
||||||
np.savez_compressed('{}/{}.npz'.format(output_path, X_key), np.array(X_value), np.array(y_value))
|
|
||||||
|
|
||||||
def main():
|
|
||||||
midi_folder_path, output_path, seq_len = parse_argv(sys.argv)
|
|
||||||
save_samples(output_path, colect_samples(midi_folder_path, seq_len))
|
|
||||||
print('Done!')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,324 +1,16 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf-8
|
|
||||||
|
|
||||||
# In[1]:
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import pickle
|
import pickle
|
||||||
|
import operator
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections import Counter
|
||||||
|
from random import randint
|
||||||
|
|
||||||
import pretty_midi as pm
|
import pretty_midi as pm
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# TODO: Stream class is no logner needed <- remore from code and make just SingleTrack.notes instead on SingleTrack.stream.notes
|
||||||
# In[98]:
|
|
||||||
|
|
||||||
|
|
||||||
TODO = '''
|
|
||||||
TODO: put methods of data extraction for seq2seq arangment model to multitrack class [DONE]
|
|
||||||
TODO: make functions for data extraction for seq2seq model for riff/melody generation [DONE]
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# In[367]:
|
|
||||||
|
|
||||||
|
|
||||||
# '''return a dictionary with tracks indexes grouped by instrument class'''
|
|
||||||
# tracks = file.tracks
|
|
||||||
# names = [track.name for track in tracks]
|
|
||||||
# uniqe_instruemnts = set(names)
|
|
||||||
# tracks_by_instrument = dict()
|
|
||||||
# for key in uniqe_instruemnts:
|
|
||||||
# tracks_by_instrument[key] = []
|
|
||||||
|
|
||||||
# for i, track in enumerate(tracks):
|
|
||||||
# tracks_by_instrument[track.name].append(i)
|
|
||||||
|
|
||||||
# tracks_by_instrument
|
|
||||||
|
|
||||||
|
|
||||||
# In[368]:
|
|
||||||
|
|
||||||
|
|
||||||
# def get_posible_pairs(instrument_x, instrument_y):
|
|
||||||
# '''it takes two lists, and return a list of tuples with every posible 2-element combination
|
|
||||||
# parameters:
|
|
||||||
# -----------
|
|
||||||
# instrument_x, instrument_y : string {'Guitar','Bass','Drums'}
|
|
||||||
# a string that represent a instrument class you want to look for in midi file.
|
|
||||||
|
|
||||||
# returns:
|
|
||||||
# ----------
|
|
||||||
# pairs: list of tuples
|
|
||||||
# a list of posible 2-element combination of two lists
|
|
||||||
# '''
|
|
||||||
# x_indexes = tracks_by_instrument[instrument_x]
|
|
||||||
# y_indexes = tracks_by_instrument[instrument_y]
|
|
||||||
# pairs = []
|
|
||||||
# # pairs = [(x,y) for x in x_indexes for y in y_indexes]
|
|
||||||
|
|
||||||
# for x in x_indexes:
|
|
||||||
# for y in y_indexes:
|
|
||||||
# pairs.append((x,y))
|
|
||||||
|
|
||||||
# return pairs
|
|
||||||
|
|
||||||
|
|
||||||
# In[369]:
|
|
||||||
|
|
||||||
|
|
||||||
# def get_common_bars_for_every_possible_pair(pairs)
|
|
||||||
# ''' for every possible pair of given instrument classes
|
|
||||||
# returns common bars from multitrack'''
|
|
||||||
# x_bars = []
|
|
||||||
# y_bars = []
|
|
||||||
# for x_track_index, y_track_index in pairs:
|
|
||||||
# _x_bars, _y_bars = get_common_bars(file.tracks[x_track_index], file.tracks[y_track_index])
|
|
||||||
# x_bars.extend(_x_bars)
|
|
||||||
# y_bars.extend(_y_bars)
|
|
||||||
|
|
||||||
# return x_bars, y_bars
|
|
||||||
|
|
||||||
|
|
||||||
# In[370]:
|
|
||||||
|
|
||||||
|
|
||||||
# def get_data_seq2seq_arrangment(self, bars_in_seq):
|
|
||||||
# ## This is the end of extracting data from midis to seq2seq arranging network.
|
|
||||||
# '''this method is returning a sequances of given lenth by rolling this lists of x and y for arrangemt generation'''
|
|
||||||
# x_seq = []
|
|
||||||
# y_seq = []
|
|
||||||
|
|
||||||
# for i in range(len(x_bars) - bars_in_seq + 1):
|
|
||||||
# x_seq_to_add = [note for bar in x_bars[i:i+bars_in_seq] for note in bar ]
|
|
||||||
# y_seq_to_add = [note for bar in y_bars[i:i+bars_in_seq] for note in bar ]
|
|
||||||
# x_seq.append(x_seq_to_add)
|
|
||||||
# y_seq.append(y_seq_to_add)
|
|
||||||
|
|
||||||
# len(x_seq), len(y_seq)
|
|
||||||
# # get_bar_len(y_seq[0])
|
|
||||||
|
|
||||||
|
|
||||||
# In[371]:
|
|
||||||
|
|
||||||
|
|
||||||
# def get_track_by_instrument(self):
|
|
||||||
# '''return a dictionary with tracks indexes grouped by instrument class'''
|
|
||||||
# tracks = self.tracks
|
|
||||||
# names = [track.name for track in tracks]
|
|
||||||
# uniqe_instruemnts = set(names)
|
|
||||||
# tracks_by_instrument = dict()
|
|
||||||
# for key in uniqe_instruemnts:
|
|
||||||
# tracks_by_instrument[key] = []
|
|
||||||
|
|
||||||
# for i, track in enumerate(tracks):
|
|
||||||
# tracks_by_instrument[track.name].append(i)
|
|
||||||
|
|
||||||
# return tracks_by_instrument
|
|
||||||
|
|
||||||
|
|
||||||
# In[372]:
|
|
||||||
|
|
||||||
|
|
||||||
# def get_data_seq2seq_melody(self,instrument_class, x_seq_len=4)
|
|
||||||
# '''return a list of bars with content for every track with given instrument class for melody generaiton'''
|
|
||||||
|
|
||||||
# instrument_tracks = tracks_by_instrument[instrument_class]
|
|
||||||
|
|
||||||
# for track_index in instrument_tracks:
|
|
||||||
# # make below as function: get_bars_with_content
|
|
||||||
# bars = file.tracks[track_index].stream_to_bars()
|
|
||||||
# bars_indexes_with_content = get_bar_indexes_with_content(bars)
|
|
||||||
# bars_with_content = [bars[i] for i in get_bar_indexes_with_content(bars)]
|
|
||||||
|
|
||||||
# # make below as function: get_sequances_from_bars (for seq2seq melody generator)
|
|
||||||
# x_seq = []
|
|
||||||
# y_bar = []
|
|
||||||
# for i in range(len(bars_with_content)-seq_len-1):
|
|
||||||
# _x_seq = bars_with_content[i:i+seq_len]
|
|
||||||
# _y_bar = bars_with_content[i+seq_len]
|
|
||||||
# x_seq.append(_x_seq)
|
|
||||||
# y_bar.append(_y_bar)
|
|
||||||
|
|
||||||
|
|
||||||
# len(x_seq), len(y_bar)
|
|
||||||
# # print( ' x:' ,x_seq[1],'\n', 'y: ', y_bar[1],'\n', 'seq: ',bars_with_content[1:6])
|
|
||||||
|
|
||||||
|
|
||||||
# In[15]:
|
|
||||||
|
|
||||||
|
|
||||||
def get_bar_indexes_with_content(bars):
|
|
||||||
'''this method is looking for non-empty bars in the tracks bars
|
|
||||||
the empty bar consist of only rest notes.
|
|
||||||
returns: a set of bars indexes with notes
|
|
||||||
'''
|
|
||||||
bars_indexes_with_content = set()
|
|
||||||
for i, bar in enumerate(bars):
|
|
||||||
if bar_has_content(bar):
|
|
||||||
bars_indexes_with_content.add(i)
|
|
||||||
|
|
||||||
return bars_indexes_with_content
|
|
||||||
|
|
||||||
|
|
||||||
# In[4]:
|
|
||||||
|
|
||||||
|
|
||||||
def get_bars_with_content(bars):
|
|
||||||
'''this method is looking for non-empty bars in the tracks bars
|
|
||||||
the empty bar consist of only rest notes.
|
|
||||||
returns: a set of bars with notes
|
|
||||||
'''
|
|
||||||
bars_with_content = []
|
|
||||||
for bar in bars:
|
|
||||||
if bar_has_content(bar):
|
|
||||||
bars_with_content.append(bar)
|
|
||||||
|
|
||||||
return bars_with_content
|
|
||||||
|
|
||||||
|
|
||||||
# In[5]:
|
|
||||||
|
|
||||||
|
|
||||||
def get_common_bars(track_x,track_y):
|
|
||||||
'''return common bars, for two tracks is song
|
|
||||||
return X_train, y_train list of
|
|
||||||
'''
|
|
||||||
bars_x = track_x.stream_to_bars()
|
|
||||||
bars_y = track_y.stream_to_bars()
|
|
||||||
bwc_x = get_bar_indexes_with_content(bars_x)
|
|
||||||
bwc_y = get_bar_indexes_with_content(bars_y)
|
|
||||||
|
|
||||||
common_bars = bwc_x.intersection(bwc_y)
|
|
||||||
common_bars_x = [bars_x[i] for i in common_bars]
|
|
||||||
common_bars_y = [bars_y[i] for i in common_bars]
|
|
||||||
return common_bars_x, common_bars_y
|
|
||||||
|
|
||||||
|
|
||||||
# In[6]:
|
|
||||||
|
|
||||||
|
|
||||||
def get_bar_len(bar):
|
|
||||||
"""calculate a lenth of a bar
|
|
||||||
parameters:
|
|
||||||
bar : list
|
|
||||||
list of "notes", tuples like (pitches, len)
|
|
||||||
"""
|
|
||||||
time = 0
|
|
||||||
for note in bar:
|
|
||||||
time += note[1]
|
|
||||||
return time
|
|
||||||
|
|
||||||
|
|
||||||
# In[7]:
|
|
||||||
|
|
||||||
|
|
||||||
def bar_has_content(bar):
|
|
||||||
'''check if bar has any musical information, more accurate
|
|
||||||
it checks if in a bar is any non-rest event like note, or chord
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
-----------
|
|
||||||
bar: list
|
|
||||||
list of notes
|
|
||||||
|
|
||||||
return:
|
|
||||||
-------
|
|
||||||
bool:
|
|
||||||
True if bas has concent and False of doesn't
|
|
||||||
'''
|
|
||||||
bar_notes = len(bar)
|
|
||||||
count_rest = 0
|
|
||||||
for note in bar:
|
|
||||||
if note[0] == (-1,):
|
|
||||||
count_rest += 1
|
|
||||||
if count_rest == bar_notes:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# In[8]:
|
|
||||||
|
|
||||||
|
|
||||||
def round_to_sixteenth_note(x, base=0.25):
|
|
||||||
'''round value to closest multiplication by base
|
|
||||||
in default to 0.25 witch is sisteenth note accuracy
|
|
||||||
'''
|
|
||||||
|
|
||||||
return base * round(x/base)
|
|
||||||
|
|
||||||
|
|
||||||
# In[9]:
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pretty_midi_instrument(instrument, resolution, time_to_tick, key_offset):
|
|
||||||
''' arguments: a prettyMidi instrument object
|
|
||||||
return: a custom SingleTrack object
|
|
||||||
'''
|
|
||||||
|
|
||||||
first_tick = None
|
|
||||||
prev_tick = 0
|
|
||||||
prev_note_lenth = 0
|
|
||||||
max_rest_len = 4.0
|
|
||||||
|
|
||||||
notes = defaultdict(lambda:[set(), set()])
|
|
||||||
for note in instrument.notes:
|
|
||||||
if first_tick == None:
|
|
||||||
# first_tick = round_to_sixteenth_note(time_to_tick(note.start)/resolution)
|
|
||||||
first_tick = 0
|
|
||||||
|
|
||||||
tick = round_to_sixteenth_note(time_to_tick(note.start)/resolution)
|
|
||||||
# add rest if needed
|
|
||||||
if prev_tick != None:
|
|
||||||
act_tick = prev_tick + prev_note_lenth
|
|
||||||
if act_tick < tick:
|
|
||||||
rest_lenth = tick - act_tick
|
|
||||||
while rest_lenth > max_rest_len:
|
|
||||||
notes[act_tick] = [{-1},{max_rest_len}]
|
|
||||||
act_tick += max_rest_len
|
|
||||||
rest_lenth -= max_rest_len
|
|
||||||
notes[act_tick] = [{-1},{rest_lenth}]
|
|
||||||
|
|
||||||
note_lenth = round_to_sixteenth_note(time_to_tick(note.end-note.start)/resolution)
|
|
||||||
|
|
||||||
if -1 in notes[tick][0]:
|
|
||||||
notes[tick] = [set(), set()]
|
|
||||||
|
|
||||||
if instrument.is_drum:
|
|
||||||
notes[tick][0].add(note.pitch)
|
|
||||||
else:
|
|
||||||
notes[tick][0].add(note.pitch+key_offset)
|
|
||||||
notes[tick][1].add(note_lenth)
|
|
||||||
|
|
||||||
prev_tick = tick
|
|
||||||
prev_note_lenth = note_lenth
|
|
||||||
|
|
||||||
notes = [(tuple(e[0]), max(e[1])) for e in notes.values()]
|
|
||||||
|
|
||||||
name = 'Drums' if instrument.is_drum else pm.program_to_instrument_class(instrument.program)
|
|
||||||
return SingleTrack(name, instrument.program, instrument.is_drum, Stream(first_tick,notes) )
|
|
||||||
|
|
||||||
|
|
||||||
# In[10]:
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicated_sequences(xy_tuple):
|
|
||||||
x = xy_tuple[0]
|
|
||||||
y = xy_tuple[1]
|
|
||||||
x_freeze = [tuple(seq) for seq in x]
|
|
||||||
y_freeze = [tuple(seq) for seq in y]
|
|
||||||
unique_data = list(set(zip(x_freeze,y_freeze)))
|
|
||||||
x_unique = [seq[0] for seq in unique_data]
|
|
||||||
y_unique = [seq[1] for seq in unique_data]
|
|
||||||
return x_unique, y_unique
|
|
||||||
|
|
||||||
|
|
||||||
# In[11]:
|
|
||||||
|
|
||||||
|
|
||||||
class Stream():
|
class Stream():
|
||||||
|
|
||||||
def __init__ (self, first_tick, notes):
|
def __init__ (self, first_tick, notes):
|
||||||
@ -328,10 +20,6 @@ class Stream():
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Stream object with {} musical events>'.format(len(self.notes))
|
return '<Stream object with {} musical events>'.format(len(self.notes))
|
||||||
|
|
||||||
|
|
||||||
# In[12]:
|
|
||||||
|
|
||||||
|
|
||||||
class SingleTrack():
|
class SingleTrack():
|
||||||
'''class of single track in midi file encoded from pretty midi library
|
'''class of single track in midi file encoded from pretty midi library
|
||||||
|
|
||||||
@ -352,6 +40,7 @@ class SingleTrack():
|
|||||||
self.program = program
|
self.program = program
|
||||||
self.is_drum = is_drum
|
self.is_drum = is_drum
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.is_melody = self.check_if_melody()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<SingleTrack object. Name:{}, Program:{}, is_drum:{}>".format(self.name, self.program, self.is_drum)
|
return "<SingleTrack object. Name:{}, Program:{}, is_drum:{}>".format(self.name, self.program, self.is_drum)
|
||||||
@ -387,13 +76,21 @@ class SingleTrack():
|
|||||||
def stream_to_bars(self, beat_per_bar=4):
|
def stream_to_bars(self, beat_per_bar=4):
|
||||||
'''it takes notes and split it into equaly time distibuted sequances
|
'''it takes notes and split it into equaly time distibuted sequances
|
||||||
if note is between bars, the note is splited into two notes, with time sum equal to the note between bars.
|
if note is between bars, the note is splited into two notes, with time sum equal to the note between bars.
|
||||||
|
|
||||||
arguments:
|
arguments:
|
||||||
|
----------
|
||||||
stream: list of "notes"
|
stream: list of "notes"
|
||||||
|
|
||||||
return:
|
return:
|
||||||
|
-------
|
||||||
bars: list: list of lists of notes, every list has equal time. in musical context it returns bars
|
bars: list: list of lists of notes, every list has equal time. in musical context it returns bars
|
||||||
'''
|
'''
|
||||||
# TODO: if last bar of sequance has less notes to has time equal given bar lenth it is left shorter
|
# TODO: if last bar of sequance has less notes to has time equal given bar lenth it is left shorter
|
||||||
# fill the rest of bar with rests
|
# fill the rest of bar with rests
|
||||||
|
|
||||||
|
# FIXME: there is a problem, where note is longer that bar and negative time occured
|
||||||
|
# split note to max_rest_note, the problem occured when note is longer then 2 bars
|
||||||
|
|
||||||
notes = self.stream.notes
|
notes = self.stream.notes
|
||||||
bars = []
|
bars = []
|
||||||
time = 0
|
time = 0
|
||||||
@ -408,8 +105,15 @@ class SingleTrack():
|
|||||||
bars.append([])
|
bars.append([])
|
||||||
|
|
||||||
if add_tail:
|
if add_tail:
|
||||||
bars[bar_index].append(tail_note)
|
tail_pitch = note_pitch(tail_note)
|
||||||
time += note_len(tail_note)
|
while tail_note_len > beat_per_bar:
|
||||||
|
bars[bar_index].append((tail_pitch, beat_per_bar))
|
||||||
|
tail_note_len -= beat_per_bar
|
||||||
|
bar_index += 1
|
||||||
|
bars.append([])
|
||||||
|
|
||||||
|
bars[bar_index].append((tail_pitch, tail_note_len))
|
||||||
|
time += tail_note_len
|
||||||
add_tail = False
|
add_tail = False
|
||||||
|
|
||||||
time += note_len(note)
|
time += note_len(note)
|
||||||
@ -436,8 +140,41 @@ class SingleTrack():
|
|||||||
|
|
||||||
return bars
|
return bars
|
||||||
|
|
||||||
|
def check_if_melody(self):
|
||||||
|
'''checks if Track object could be a melody
|
||||||
|
|
||||||
# In[99]:
|
it checks if percentage of single notes in Track.stream.notes is higher than treshold of 90%
|
||||||
|
TODO: and there is at least 3 notes in bar per average
|
||||||
|
|
||||||
|
'''
|
||||||
|
events = None
|
||||||
|
single_notes = None
|
||||||
|
content_lenth = None
|
||||||
|
|
||||||
|
for note in self.stream.notes:
|
||||||
|
if self.name not in ['Bass','Drums']:
|
||||||
|
events = 0
|
||||||
|
content_lenth = 0
|
||||||
|
single_notes = 0
|
||||||
|
if note[0][0] != -1: # if note is not a rest
|
||||||
|
events += 1
|
||||||
|
content_lenth += note[1]
|
||||||
|
if len(note[0]) == 1: # if note is a single note, not a chord
|
||||||
|
single_notes += 1
|
||||||
|
|
||||||
|
if events != None:
|
||||||
|
if events == 0 or content_lenth == 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
single_notes_rate = single_notes/events
|
||||||
|
density_rate = events/content_lenth
|
||||||
|
if single_notes_rate >= 0.9 and density_rate < 2:
|
||||||
|
self.name = 'Melody'
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MultiTrack():
|
class MultiTrack():
|
||||||
@ -452,13 +189,14 @@ class MultiTrack():
|
|||||||
|
|
||||||
def __init__(self, path=None, tempo=100):
|
def __init__(self, path=None, tempo=100):
|
||||||
self.tempo = tempo
|
self.tempo = tempo
|
||||||
self.pm_obj = pm.PrettyMIDI(path, initial_tempo=self.tempo)
|
self.pm_obj = pm.PrettyMIDI(path, initial_tempo=self.tempo) # changename to self.PrettyMIDI
|
||||||
self.res = self.pm_obj.resolution
|
self.res = self.pm_obj.resolution
|
||||||
self.time_to_tick = self.pm_obj.time_to_tick
|
self.time_to_tick = self.pm_obj.time_to_tick
|
||||||
self.name = path
|
self.name = path
|
||||||
self.tracks = [parse_pretty_midi_instrument(instrument, self.res, self.time_to_tick, self.get_pitch_offset_to_C() ) for instrument in self.pm_obj.instruments]
|
self.tracks = [parse_pretty_midi_instrument(instrument, self.res, self.time_to_tick, self.get_pitch_offset_to_C() ) for instrument in self.pm_obj.instruments]
|
||||||
self.tracks_by_instrument = self.get_track_by_instrument()
|
self.tracks_by_instrument = self.get_track_by_instrument()
|
||||||
|
|
||||||
|
# TODO: this function is deprecated <- remove from code
|
||||||
def get_multiseq(self):
|
def get_multiseq(self):
|
||||||
'''tracks: list of SingleTrack objects
|
'''tracks: list of SingleTrack objects
|
||||||
reaturn a dictionary of sequences for every sequence in SingleTrack
|
reaturn a dictionary of sequences for every sequence in SingleTrack
|
||||||
@ -476,6 +214,14 @@ class MultiTrack():
|
|||||||
|
|
||||||
return multiseq
|
return multiseq
|
||||||
|
|
||||||
|
def get_programs(self, instrument):
|
||||||
|
program_list = []
|
||||||
|
for track in self.tracks:
|
||||||
|
if track.name == instrument:
|
||||||
|
program_list.append(track.program)
|
||||||
|
|
||||||
|
return program_list
|
||||||
|
|
||||||
def get_pitch_offset_to_C(self):
|
def get_pitch_offset_to_C(self):
|
||||||
'''to get better train resoult without augmenting midis to all posible keys
|
'''to get better train resoult without augmenting midis to all posible keys
|
||||||
we assumed that most frequent note is the rootnote of song then calculate
|
we assumed that most frequent note is the rootnote of song then calculate
|
||||||
@ -526,7 +272,10 @@ class MultiTrack():
|
|||||||
return x_bars, y_bars
|
return x_bars, y_bars
|
||||||
|
|
||||||
def get_data_seq2seq_arrangment(self, x_instrument, y_instrument, bars_in_seq=4):
|
def get_data_seq2seq_arrangment(self, x_instrument, y_instrument, bars_in_seq=4):
|
||||||
'''this method is returning a sequances of given lenth by rolling this lists of x and y for arrangemt generation'''
|
'''this method is returning a sequances of given lenth by rolling this lists of x and y for arrangemt generation
|
||||||
|
x and y has the same bar lenth, and represent the same musical phrase playd my difrent instruments (tracks)
|
||||||
|
|
||||||
|
'''
|
||||||
x_seq = []
|
x_seq = []
|
||||||
y_seq = []
|
y_seq = []
|
||||||
x_bars, y_bars = self.get_common_bars_for_every_possible_pair(x_instrument, y_instrument)
|
x_bars, y_bars = self.get_common_bars_for_every_possible_pair(x_instrument, y_instrument)
|
||||||
@ -540,7 +289,12 @@ class MultiTrack():
|
|||||||
return x_seq, y_seq
|
return x_seq, y_seq
|
||||||
|
|
||||||
def get_data_seq2seq_melody(self,instrument_class, x_seq_len=4):
|
def get_data_seq2seq_melody(self,instrument_class, x_seq_len=4):
|
||||||
'''return a list of bars with content for every track with given instrument class for melody generaiton'''
|
'''return a list of bars with content for every track with given instrument class for melody generaiton
|
||||||
|
x_seq_len and y_seq_len
|
||||||
|
|
||||||
|
x previous sentence, y next sentence of the same melody line
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
instrument_tracks = self.tracks_by_instrument[instrument_class]
|
instrument_tracks = self.tracks_by_instrument[instrument_class]
|
||||||
|
|
||||||
@ -573,13 +327,7 @@ class MultiTrack():
|
|||||||
'''
|
'''
|
||||||
x_indexes = self.tracks_by_instrument[instrument_x]
|
x_indexes = self.tracks_by_instrument[instrument_x]
|
||||||
y_indexes = self.tracks_by_instrument[instrument_y]
|
y_indexes = self.tracks_by_instrument[instrument_y]
|
||||||
# pairs = []
|
|
||||||
pairs = [(x,y) for x in x_indexes for y in y_indexes]
|
pairs = [(x,y) for x in x_indexes for y in y_indexes]
|
||||||
|
|
||||||
# for x in x_indexes:
|
|
||||||
# for y in y_indexes:
|
|
||||||
# pairs.append((x,y))
|
|
||||||
|
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
def show_map(self):
|
def show_map(self):
|
||||||
@ -597,11 +345,215 @@ class MultiTrack():
|
|||||||
print(track.name[:4],':', track_str)
|
print(track.name[:4],':', track_str)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_to_bars(notes, beat_per_bar=4):
|
||||||
|
'''it takes notes and split it into equaly time distibuted sequances
|
||||||
|
if note is between bars, the note is splited into two notes, with time sum equal to the note between bars.
|
||||||
|
arguments:
|
||||||
|
stream: list of "notes"
|
||||||
|
return:
|
||||||
|
bars: list: list of lists of notes, every list has equal time. in musical context it returns bars
|
||||||
|
'''
|
||||||
|
# TODO: if last bar of sequance has less notes to has time equal given bar lenth it is left shorter
|
||||||
|
# fill the rest of bar with rests
|
||||||
|
|
||||||
# In[104]:
|
# FIXME: there is a problem, where note is longer that bar and negative time occured
|
||||||
|
# split note to max_rest_note, the problem occured when note is longer then 2 bars - FIXED
|
||||||
|
|
||||||
|
bars = []
|
||||||
|
time = 0
|
||||||
|
bar_index = 0
|
||||||
|
add_tail = False
|
||||||
|
note_pitch = lambda note: note[0]
|
||||||
|
note_len = lambda note: note[1]
|
||||||
|
for note in notes:
|
||||||
|
try:
|
||||||
|
temp = bars[bar_index]
|
||||||
|
except IndexError:
|
||||||
|
bars.append([])
|
||||||
|
|
||||||
|
if add_tail:
|
||||||
|
tail_pitch = note_pitch(tail_note)
|
||||||
|
while tail_note_len > beat_per_bar:
|
||||||
|
bars[bar_index].append((tail_pitch, beat_per_bar))
|
||||||
|
tail_note_len -= beat_per_bar
|
||||||
|
bar_index += 1
|
||||||
|
|
||||||
|
bars[bar_index].append((tail_pitch, tail_note_len))
|
||||||
|
time += tail_note_len
|
||||||
|
add_tail = False
|
||||||
|
time += note_len(note)
|
||||||
|
|
||||||
|
if time == beat_per_bar:
|
||||||
|
bars[bar_index].append(note)
|
||||||
|
time = 0
|
||||||
|
bar_index += 1
|
||||||
|
|
||||||
|
elif time > beat_per_bar: # if note is between bars
|
||||||
|
between_bars_note_len = note_len(note)
|
||||||
|
tail_note_len = time - beat_per_bar
|
||||||
|
leading_note_len = between_bars_note_len - tail_note_len
|
||||||
|
leading_note = (note_pitch(note), leading_note_len)
|
||||||
|
bars[bar_index].append(leading_note)
|
||||||
|
tail_note = (note_pitch(note), tail_note_len)
|
||||||
|
|
||||||
|
add_tail = True
|
||||||
|
time = 0
|
||||||
|
bar_index += 1
|
||||||
|
else:
|
||||||
|
bars[bar_index].append(note)
|
||||||
|
|
||||||
|
return bars
|
||||||
|
|
||||||
|
def get_bar_len(bar):
|
||||||
|
"""calculate a lenth of a bar
|
||||||
|
parameters:
|
||||||
|
bar : list
|
||||||
|
list of "notes", tuples like (pitches, len)
|
||||||
|
"""
|
||||||
|
time = 0
|
||||||
|
for note in bar:
|
||||||
|
time += note[1]
|
||||||
|
return time
|
||||||
|
|
||||||
|
def get_common_bars(track_x,track_y):
|
||||||
|
'''return common bars, for two tracks is song
|
||||||
|
return X_train, y_train list of
|
||||||
|
'''
|
||||||
|
bars_x = track_x.stream_to_bars()
|
||||||
|
bars_y = track_y.stream_to_bars()
|
||||||
|
bwc_x = get_bar_indexes_with_content(bars_x)
|
||||||
|
bwc_y = get_bar_indexes_with_content(bars_y)
|
||||||
|
|
||||||
|
common_bars = bwc_x.intersection(bwc_y)
|
||||||
|
common_bars_x = [bars_x[i] for i in common_bars]
|
||||||
|
common_bars_y = [bars_y[i] for i in common_bars]
|
||||||
|
return common_bars_x, common_bars_y
|
||||||
|
|
||||||
|
def get_bar_indexes_with_content(bars):
|
||||||
|
'''this method is looking for non-empty bars in the tracks bars
|
||||||
|
the empty bar consist of only rest notes.
|
||||||
|
returns: a set of bars indexes with notes
|
||||||
|
'''
|
||||||
|
bars_indexes_with_content = set()
|
||||||
|
for i, bar in enumerate(bars):
|
||||||
|
if bar_has_content(bar):
|
||||||
|
bars_indexes_with_content.add(i)
|
||||||
|
|
||||||
|
return bars_indexes_with_content
|
||||||
|
|
||||||
|
def get_bars_with_content(bars):
|
||||||
|
'''this method is looking for non-empty bars in the tracks bars
|
||||||
|
the empty bar consist of only rest notes.
|
||||||
|
returns: a set of bars with notes
|
||||||
|
'''
|
||||||
|
bars_with_content = []
|
||||||
|
for bar in bars:
|
||||||
|
if bar_has_content(bar):
|
||||||
|
bars_with_content.append(bar)
|
||||||
|
|
||||||
|
return bars_with_content
|
||||||
|
|
||||||
|
|
||||||
def extract_data(midi_folder_path=None, how=None, instrument=None, remove_duplicates=True):
|
def bar_has_content(bar):
|
||||||
|
'''check if bar has any musical information, more accurate
|
||||||
|
it checks if in a bar is any non-rest event like note, or chord
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
-----------
|
||||||
|
bar: list
|
||||||
|
list of notes
|
||||||
|
|
||||||
|
return:
|
||||||
|
-------
|
||||||
|
bool:
|
||||||
|
True if bas has concent and False of doesn't
|
||||||
|
'''
|
||||||
|
bar_notes = len(bar)
|
||||||
|
count_rest = 0
|
||||||
|
for note in bar:
|
||||||
|
if note[0] == (-1,):
|
||||||
|
count_rest += 1
|
||||||
|
if count_rest == bar_notes:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def round_to_sixteenth_note(x, base=0.25):
|
||||||
|
'''round value to closest multiplication by base
|
||||||
|
in default to 0.25 witch is sisteenth note accuracy
|
||||||
|
'''
|
||||||
|
|
||||||
|
return base * round(x/base)
|
||||||
|
|
||||||
|
def parse_pretty_midi_instrument(instrument, resolution, time_to_tick, key_offset):
|
||||||
|
''' arguments: a prettyMidi instrument object
|
||||||
|
return: a custom SingleTrack object
|
||||||
|
'''
|
||||||
|
|
||||||
|
first_tick = None
|
||||||
|
prev_tick = 0
|
||||||
|
prev_note_lenth = 0
|
||||||
|
max_rest_len = 4.0
|
||||||
|
|
||||||
|
notes = defaultdict(lambda:[set(), set()])
|
||||||
|
for note in instrument.notes:
|
||||||
|
if first_tick == None:
|
||||||
|
first_tick = 0
|
||||||
|
|
||||||
|
tick = round_to_sixteenth_note(time_to_tick(note.start)/resolution)
|
||||||
|
if prev_tick != None:
|
||||||
|
act_tick = prev_tick + prev_note_lenth
|
||||||
|
if act_tick < tick:
|
||||||
|
rest_lenth = tick - act_tick
|
||||||
|
while rest_lenth > max_rest_len:
|
||||||
|
notes[act_tick] = [{-1},{max_rest_len}]
|
||||||
|
act_tick += max_rest_len
|
||||||
|
rest_lenth -= max_rest_len
|
||||||
|
notes[act_tick] = [{-1},{rest_lenth}]
|
||||||
|
|
||||||
|
note_lenth = round_to_sixteenth_note(time_to_tick(note.end-note.start)/resolution)
|
||||||
|
|
||||||
|
if -1 in notes[tick][0]:
|
||||||
|
notes[tick] = [set(), set()]
|
||||||
|
|
||||||
|
if instrument.is_drum:
|
||||||
|
notes[tick][0].add(note.pitch)
|
||||||
|
else:
|
||||||
|
notes[tick][0].add(note.pitch+key_offset)
|
||||||
|
|
||||||
|
notes[tick][1].add(note_lenth)
|
||||||
|
|
||||||
|
prev_tick = tick
|
||||||
|
prev_note_lenth = note_lenth
|
||||||
|
|
||||||
|
notes = [(tuple(e[0]), max(e[1])) for e in notes.values()]
|
||||||
|
|
||||||
|
name = 'Drums' if instrument.is_drum else pm.program_to_instrument_class(instrument.program)
|
||||||
|
return SingleTrack(name, instrument.program, instrument.is_drum, Stream(first_tick,notes) )
|
||||||
|
|
||||||
|
def remove_duplicated_sequences(xy_tuple):
|
||||||
|
''' removes duplicated x,y sequences
|
||||||
|
parameters:
|
||||||
|
-----------
|
||||||
|
xy_tuple: tuple of lists
|
||||||
|
tuple of x,y lists that represens sequances in training set
|
||||||
|
|
||||||
|
return:
|
||||||
|
------
|
||||||
|
x_unique, y_unique: tuple
|
||||||
|
a tuple of cleaned x, y traing set
|
||||||
|
'''
|
||||||
|
x = xy_tuple[0]
|
||||||
|
y = xy_tuple[1]
|
||||||
|
x_freeze = [tuple(seq) for seq in x]
|
||||||
|
y_freeze = [tuple(seq) for seq in y]
|
||||||
|
unique_data = list(set(zip(x_freeze,y_freeze)))
|
||||||
|
x_unique = [seq[0] for seq in unique_data]
|
||||||
|
y_unique = [seq[1] for seq in unique_data]
|
||||||
|
return x_unique, y_unique
|
||||||
|
|
||||||
|
|
||||||
|
def extract_data(midi_folder_path=None, how=None, instrument=None, bar_in_seq=4, remove_duplicates=True):
|
||||||
'''extract musical data from midis in given folder, to x_train, y_train lists on sequences
|
'''extract musical data from midis in given folder, to x_train, y_train lists on sequences
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
@ -628,49 +580,74 @@ def extract_data(midi_folder_path=None, how=None, instrument=None, remove_duplic
|
|||||||
|
|
||||||
notes:
|
notes:
|
||||||
------
|
------
|
||||||
extracted data is transposed to the key od C
|
extracted data is transposed to the key of C
|
||||||
duplicated x,y pairs are removed
|
duplicated x,y pairs are removed
|
||||||
'''
|
'''
|
||||||
if how not in {'melody','arrangment'}:
|
if how not in {'melody','arrangment'}:
|
||||||
raise ValueError('how parameter must by one of {melody,arrangment} ')
|
raise ValueError('how parameter must by one of {melody, arrangment} ')
|
||||||
|
|
||||||
x_train = []
|
x_train = []
|
||||||
y_train = []
|
y_train = []
|
||||||
|
|
||||||
|
programs_for_instrument = []
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
for directory, subdirectories, files in os.walk(midi_folder_path):
|
||||||
|
for midi_file in tqdm(files, desc='Exporting: {}'.format(instrument)):
|
||||||
|
midi_file_path = os.path.join(directory, midi_file)
|
||||||
|
try:
|
||||||
|
mt = MultiTrack(midi_file_path)
|
||||||
|
# get programs
|
||||||
|
mt.get_programs(instrument)
|
||||||
|
|
||||||
|
if how=='melody':
|
||||||
|
x ,y = mt.get_data_seq2seq_melody(instrument, bar_in_seq)
|
||||||
|
programs_for_instrument.extend(mt.get_programs(instrument))
|
||||||
|
if how=='arrangment':
|
||||||
|
x ,y = mt.get_data_seq2seq_arrangment(instrument[0], instrument[1], bar_in_seq)
|
||||||
|
programs_for_instrument.extend(mt.get_programs(instrument[1]))
|
||||||
|
x_train.extend(x)
|
||||||
|
y_train.extend(y)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
most_recent_program = most_recent(programs_for_instrument)
|
||||||
|
|
||||||
|
if remove_duplicates:
|
||||||
|
x_train, y_train = remove_duplicated_sequences((x_train, y_train))
|
||||||
|
|
||||||
|
return x_train , y_train, most_recent_program
|
||||||
|
|
||||||
|
def most_recent(list):
|
||||||
|
occurence_count = Counter(list)
|
||||||
|
return occurence_count.most_common(1)[0][0]
|
||||||
|
|
||||||
|
def analyze_data(midi_folder_path):
|
||||||
|
'''Show usage of instumets in midipack
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
-----------
|
||||||
|
midi_folder_path : string
|
||||||
|
a path to directory where midi files are stored
|
||||||
|
'''
|
||||||
|
|
||||||
|
instrument_count = dict()
|
||||||
|
instrument_programs = dict()
|
||||||
|
|
||||||
for directory, subdirectories, files in os.walk(midi_folder_path):
|
for directory, subdirectories, files in os.walk(midi_folder_path):
|
||||||
for midi_file in tqdm(files):
|
for midi_file in tqdm(files):
|
||||||
midi_file_path = os.path.join(directory, midi_file)
|
midi_file_path = os.path.join(directory, midi_file)
|
||||||
try:
|
try:
|
||||||
mt = MultiTrack(midi_file_path)
|
mt = MultiTrack(midi_file_path)
|
||||||
if how=='melody':
|
for track in mt.tracks:
|
||||||
x ,y = mt.get_data_seq2seq_melody(instrument)
|
try:
|
||||||
if how=='arrangment':
|
instrument_count[track.name] += len(get_bars_with_content(track.stream_to_bars()))
|
||||||
x ,y = mt.get_data_seq2seq_arrangment(instrument[0], instrument[1])
|
except KeyError:
|
||||||
x_train.extend(x)
|
instrument_count[track.name] = 1
|
||||||
y_train.extend(y)
|
except Exception as e:
|
||||||
except:
|
print(e)
|
||||||
continue
|
|
||||||
|
|
||||||
if remove_duplicates:
|
for key, value in sorted(instrument_count.items(), key=lambda x: x[1], reverse=True):
|
||||||
x_train, y_train = remove_duplicated_sequences((x_train, y_train))
|
print(value, 'of', key)
|
||||||
|
|
||||||
return x_train , y_train
|
|
||||||
|
|
||||||
|
|
||||||
# In[109]:
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
'''extract data from midis
|
|
||||||
'''
|
|
||||||
x_train, y_train = extract_data(midi_folder_path='WhiteStripes', how='arrangment', instrument=('Guitar','Bass'))
|
|
||||||
pickle.dump((x_train, y_train), open('Guitar_to_Bass_data.pkl','wb'))
|
|
||||||
return x_train, y_train
|
|
||||||
|
|
||||||
|
|
||||||
# In[107]:
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
main()
|
|
||||||
|
|
370
project/model.py
Normal file
370
project/model.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
from midi_processing import stream_to_bars
|
||||||
|
from keras.models import Model, load_model
|
||||||
|
from keras.layers import Input, LSTM, Dense, LSTM, LSTMCell, TimeDistributed
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Seq2SeqTransformer():
|
||||||
|
''' encoder/transforer
|
||||||
|
params:
|
||||||
|
-------
|
||||||
|
x_train, y_train - list of sequences
|
||||||
|
|
||||||
|
methods:
|
||||||
|
fit
|
||||||
|
transform'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.transform_dict = None
|
||||||
|
self.reverse_dict = None
|
||||||
|
self.vocab_x = None
|
||||||
|
self.vocab_y = None
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(self, x_train, y_train):
|
||||||
|
'''Converts training set do list and add special chars'''
|
||||||
|
|
||||||
|
_x_train = []
|
||||||
|
for i, seq in enumerate(x_train):
|
||||||
|
_x_train.append([])
|
||||||
|
for note in seq:
|
||||||
|
_x_train[i].append(note)
|
||||||
|
|
||||||
|
_y_train = []
|
||||||
|
for i, seq in enumerate(y_train):
|
||||||
|
_y_train.append([])
|
||||||
|
_y_train[i].append('<GO>')
|
||||||
|
for note in seq:
|
||||||
|
_y_train[i].append(note)
|
||||||
|
_y_train[i].append('<EOS>')
|
||||||
|
|
||||||
|
return _x_train, _y_train
|
||||||
|
|
||||||
|
def transform(self, x_train, y_train):
|
||||||
|
|
||||||
|
x_vocab = set([note for seq in x_train for note in seq])
|
||||||
|
y_vocab = set([note for seq in y_train for note in seq])
|
||||||
|
|
||||||
|
self.x_vocab = sorted(list(x_vocab))
|
||||||
|
self.y_vocab = ['<GO>','<EOS>']
|
||||||
|
self.y_vocab.extend(sorted(list(y_vocab)))
|
||||||
|
|
||||||
|
self.x_vocab_size = len(self.x_vocab)
|
||||||
|
self.y_vocab_size = len(self.y_vocab)
|
||||||
|
|
||||||
|
self.x_transform_dict = dict(
|
||||||
|
[(char, i) for i, char in enumerate(self.x_vocab)])
|
||||||
|
self.y_transform_dict = dict(
|
||||||
|
[(char, i) for i, char in enumerate(self.y_vocab)])
|
||||||
|
self.x_reverse_dict = dict(
|
||||||
|
(i, char) for char, i in self.x_transform_dict.items())
|
||||||
|
self.y_reverse_dict = dict(
|
||||||
|
(i, char) for char, i in self.y_transform_dict.items())
|
||||||
|
|
||||||
|
x_train, y_train = self.preprocess(x_train, y_train)
|
||||||
|
|
||||||
|
self.x_max_seq_length = max([len(seq) for seq in x_train])
|
||||||
|
self.y_max_seq_length = max([len(seq) for seq in y_train])
|
||||||
|
|
||||||
|
encoder_input_data = np.zeros(
|
||||||
|
(len(x_train), self.x_max_seq_length, self.x_vocab_size),
|
||||||
|
dtype='float32')
|
||||||
|
decoder_input_data = np.zeros(
|
||||||
|
(len(x_train), self.y_max_seq_length, self.y_vocab_size),
|
||||||
|
dtype='float32')
|
||||||
|
decoder_target_data = np.zeros(
|
||||||
|
(len(x_train), self.y_max_seq_length, self.y_vocab_size),
|
||||||
|
dtype='float32')
|
||||||
|
|
||||||
|
for i, (x_train, y_train) in enumerate(zip(x_train, y_train)):
|
||||||
|
for t, char in enumerate(x_train):
|
||||||
|
encoder_input_data[i, t, self.x_transform_dict[char]] = 1.
|
||||||
|
for t, char in enumerate(y_train):
|
||||||
|
decoder_input_data[i, t, self.y_transform_dict[char]] = 1.
|
||||||
|
if t > 0:
|
||||||
|
decoder_target_data[i, t - 1, self.y_transform_dict[char]] = 1.
|
||||||
|
|
||||||
|
return encoder_input_data, decoder_input_data, decoder_target_data
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqModel():
|
||||||
|
'''NeuralNerwork Seq2Seq model.
|
||||||
|
The network is created based on training data
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, latent_dim, x_train, y_train):
|
||||||
|
self.has_predict_model = False
|
||||||
|
self.has_train_model = False
|
||||||
|
self.x_train = x_train
|
||||||
|
self.y_train = y_train
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.transformer = Seq2SeqTransformer()
|
||||||
|
self.encoder_input_data, self.decoder_input_data, self.decoder_target_data = self.transformer.transform(self.x_train, self.y_train)
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# SEQ 2 SEQ MODEL:
|
||||||
|
# INPUT_1 : encoder_input_data
|
||||||
|
# INPUT_2 : decodet_input_data
|
||||||
|
# OUTPUT : decoder_target_data
|
||||||
|
# ---------------
|
||||||
|
|
||||||
|
# ENCODER MODEL
|
||||||
|
#---------------
|
||||||
|
|
||||||
|
# 1 layer - Input : encoder_input_data
|
||||||
|
self.encoder_inputs = Input(shape=(None, self.transformer.x_vocab_size ))
|
||||||
|
|
||||||
|
# 2 layer - LSTM_1, LSTM
|
||||||
|
self.encoder = LSTM(latent_dim, return_state=True)
|
||||||
|
#self.encoder = LSTM(latent_dim, return_state=True)
|
||||||
|
|
||||||
|
# 2 layer - LSTM_1 : outputs
|
||||||
|
self.encoder_outputs, self.state_h, self.state_c = self.encoder(self.encoder_inputs)
|
||||||
|
self.encoder_states = [self.state_h, self.state_c]
|
||||||
|
|
||||||
|
|
||||||
|
# DECODER MODEL
|
||||||
|
#---------------
|
||||||
|
|
||||||
|
# 1 layer - Input : decoder_input_data
|
||||||
|
self.decoder_inputs = Input(shape=(None, self.transformer.y_vocab_size))
|
||||||
|
|
||||||
|
# 2 layer - LSTM_1, LSTM
|
||||||
|
self.decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
|
||||||
|
#self.decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
|
||||||
|
|
||||||
|
# 2 layer - LSTM_2 : outputs, full sequance as lstm layer
|
||||||
|
self.decoder_outputs, _, _ = self.decoder_lstm(self.decoder_inputs,
|
||||||
|
initial_state=self.encoder_states)
|
||||||
|
|
||||||
|
# 3 layer - Dense
|
||||||
|
self.decoder_dense = Dense(self.transformer.y_vocab_size, activation='softmax')
|
||||||
|
|
||||||
|
# 3 layer - Dense : outputs, full sequance as the array of one-hot-encoded elements
|
||||||
|
self.decoder_outputs = self.decoder_dense(self.decoder_outputs)
|
||||||
|
|
||||||
|
def init_train_model(self):
|
||||||
|
self.train_model = Model([self.encoder_inputs, self.decoder_inputs], self.decoder_outputs)
|
||||||
|
self.train_model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
|
||||||
|
|
||||||
|
def fit(self, batch_size, epochs, callbacks):
|
||||||
|
|
||||||
|
if not self.has_train_model:
|
||||||
|
self.init_train_model()
|
||||||
|
self.has_train_model = True
|
||||||
|
|
||||||
|
|
||||||
|
history = self.train_model.fit([self.encoder_input_data, self.decoder_input_data], self.decoder_target_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
validation_split=0.2)
|
||||||
|
return history
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
self.train_model.save(path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
self.train_model = load_model(path)
|
||||||
|
self.has_train_model = True
|
||||||
|
|
||||||
|
self.encoder_inputs = self.train_model.layers[0].input
|
||||||
|
self.encoder = self.train_model.layers[2]
|
||||||
|
self.encoder_outputs, self.state_h, self.state_c = self.train_model.layers[2].output
|
||||||
|
self.encoder_states = [self.state_h, self.state_c]
|
||||||
|
self.decoder_inputs = self.train_model.layers[1].input
|
||||||
|
self.decoder_lstm = self.train_model.layers[3]
|
||||||
|
self.decoder_outputs, _, _ = self.train_model.layers[3].output
|
||||||
|
self.decoder_dense = self.train_model.layers[4]
|
||||||
|
self.decoder_outputs = self.train_model.layers[4].output
|
||||||
|
|
||||||
|
def init_predict_model(self):
|
||||||
|
|
||||||
|
# ENCODER MODEL <- note used in develop music process
|
||||||
|
# from encoder_input to encoder_states
|
||||||
|
# to give a context to decoder model
|
||||||
|
#---------------------------------
|
||||||
|
|
||||||
|
self.encoder_model = Model(self.encoder_inputs, self.encoder_states)
|
||||||
|
|
||||||
|
|
||||||
|
# DECODER MODEL
|
||||||
|
# From states (context) to sequance by generating firts element from context vector
|
||||||
|
# and starting element <GO>. Then adding predicted element as input to next cell, with
|
||||||
|
# updated states (context) by prevously generated element.
|
||||||
|
#
|
||||||
|
# INPUT_1 : state_h
|
||||||
|
# INPUT_2 : state_c
|
||||||
|
# INPUT_3 : y_train sized layer, that will be recursivly generated starting from <GO> sign
|
||||||
|
#
|
||||||
|
# INPUT -> LSTM -> DENSE
|
||||||
|
#
|
||||||
|
# OUTPUT : one-hot-encoded element of sequance
|
||||||
|
# OUTPUT : state_h (updated)
|
||||||
|
# OUTPUT : state_c (updated)
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
# 1 layer: TWO INPUTS: decoder_state_h, decoder_state_c
|
||||||
|
self.decoder_state_input_h = Input(shape=(self.latent_dim,))
|
||||||
|
self.decoder_state_input_c = Input(shape=(self.latent_dim,))
|
||||||
|
self.decoder_states_inputs = [self.decoder_state_input_h, self.decoder_state_input_c]
|
||||||
|
|
||||||
|
|
||||||
|
# 2 layer: LSTM_1 output: element of sequance, lstm cell states
|
||||||
|
self.decoder_outputs, self.state_h, self.state_c = self.decoder_lstm(
|
||||||
|
self.decoder_inputs,
|
||||||
|
initial_state = self.decoder_states_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder_states = [self.state_h, self.state_c]
|
||||||
|
|
||||||
|
# 3 layer: Dense output: one-hot-encoded representation of element of sequance
|
||||||
|
self.decoder_outputs = self.decoder_dense(self.decoder_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
self.decoder_model = Model(
|
||||||
|
[self.decoder_inputs] + self.decoder_states_inputs,
|
||||||
|
[self.decoder_outputs] + self.decoder_states)
|
||||||
|
self.has_predict_model = True
|
||||||
|
|
||||||
|
def predict(self, input_seq=None, mode=None):
|
||||||
|
|
||||||
|
if not self.has_predict_model:
|
||||||
|
self.init_predict_model()
|
||||||
|
self.has_predict_model = True
|
||||||
|
|
||||||
|
if mode == 'generate':
|
||||||
|
# create a random context as starting point
|
||||||
|
h = np.random.rand(1,self.latent_dim)*2 - 1
|
||||||
|
c = np.random.rand(1,self.latent_dim)*2 - 1
|
||||||
|
states_value = [h, c]
|
||||||
|
else:
|
||||||
|
# get context from input sequance
|
||||||
|
states_value = self.encoder_model.predict(input_seq)
|
||||||
|
|
||||||
|
# make the empty decoder_input_data
|
||||||
|
# and create the starting <GO> element of decoder_input_data
|
||||||
|
target_seq = np.zeros((1, 1, self.transformer.y_vocab_size))
|
||||||
|
target_seq[0, 0, self.transformer.y_transform_dict['<GO>']] = 1.
|
||||||
|
|
||||||
|
# sequance generation loop of decoder model
|
||||||
|
stop_condition = False
|
||||||
|
decoded_sentence = []
|
||||||
|
# time = 0
|
||||||
|
while not stop_condition:
|
||||||
|
|
||||||
|
# INPUT_1 : target_seq : started from empty array with start <GO> char
|
||||||
|
# and recursivly updated by predicted elements
|
||||||
|
|
||||||
|
# INPUT_2 : states_value :context from encoder model or randomly generated in develop mode
|
||||||
|
# this can give as a 2 * latent_dim parameters to play with in manual generation
|
||||||
|
|
||||||
|
# OUTPUT_1 : output_tokens : one hot encoded predicted element of sequance
|
||||||
|
# OUTPUT_2,3 : h, c : context updated by predicted element
|
||||||
|
output_tokens, h, c = self.decoder_model.predict(
|
||||||
|
[target_seq] + states_value)
|
||||||
|
|
||||||
|
# get most likly element index
|
||||||
|
# translate from index to final (in normal form) preidcted element
|
||||||
|
# append it to output list
|
||||||
|
sampled_token_index = np.argmax(output_tokens[0, -1, :])
|
||||||
|
sampled_char = self.transformer.y_reverse_dict[sampled_token_index]
|
||||||
|
decoded_sentence.append(sampled_char)
|
||||||
|
|
||||||
|
# time += sampled_char[1]
|
||||||
|
# or time>=16
|
||||||
|
if (sampled_char == '<EOS>' or len(decoded_sentence) > self.transformer.y_max_seq_length ):
|
||||||
|
stop_condition = True
|
||||||
|
|
||||||
|
target_seq = np.zeros((1, 1, self.transformer.y_vocab_size))
|
||||||
|
target_seq[0, 0, sampled_token_index] = 1.
|
||||||
|
|
||||||
|
states_value = [h, c]
|
||||||
|
|
||||||
|
return decoded_sentence
|
||||||
|
|
||||||
|
def develop(self, mode='from_seq'):
|
||||||
|
|
||||||
|
# music generation for seq2seq for melody
|
||||||
|
input_seq_start = random_seed_generator(16,
|
||||||
|
self.transformer.x_max_seq_length,
|
||||||
|
self.transformer.x_vocab_size,
|
||||||
|
self.transformer.x_transform_dict,
|
||||||
|
self.transformer.x_reverse_dict)
|
||||||
|
|
||||||
|
input_data = seq_to_numpy(input_seq_start,
|
||||||
|
self.transformer.x_max_seq_length,
|
||||||
|
self.transformer.x_vocab_size,
|
||||||
|
self.transformer.x_transform_dict)
|
||||||
|
|
||||||
|
# generate sequnce iterativly for melody
|
||||||
|
input_seq = input_seq_start.copy()
|
||||||
|
melody = []
|
||||||
|
for i in range(4):
|
||||||
|
if mode == 'from_seq':
|
||||||
|
decoded_sentence = self.predict(input_data)[:-1]
|
||||||
|
elif mode == 'from_state':
|
||||||
|
decoded_sentence = self.predict(mode='generate')[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError('mode must be in {from_seq, from_state}')
|
||||||
|
melody.append(decoded_sentence)
|
||||||
|
input_seq.extend(decoded_sentence)
|
||||||
|
input_bars = stream_to_bars(input_seq, 4)
|
||||||
|
input_bars = input_bars[1:5]
|
||||||
|
input_seq = [note for bar in input_bars for note in bar]
|
||||||
|
input_data = seq_to_numpy(input_seq,
|
||||||
|
self.transformer.x_max_seq_length,
|
||||||
|
self.transformer.x_vocab_size,
|
||||||
|
self.transformer.x_transform_dict)
|
||||||
|
|
||||||
|
melody = [note for bar in melody for note in bar]
|
||||||
|
return melody
|
||||||
|
|
||||||
|
def random_seed_generator(time_of_seq, max_encoder_seq_length, num_encoder_tokens, input_token_index, reverse_input_char_index):
|
||||||
|
time = 0
|
||||||
|
random_seq = []
|
||||||
|
items = 0
|
||||||
|
stop_sign = False
|
||||||
|
while (time < time_of_seq):
|
||||||
|
seed = np.random.randint(0,num_encoder_tokens-1)
|
||||||
|
note = reverse_input_char_index[seed]
|
||||||
|
time += note[1]
|
||||||
|
if time > time_of_seq:
|
||||||
|
note_time = note[1] - (time-time_of_seq)
|
||||||
|
trimmed_note = (note[0],note_time)
|
||||||
|
try:
|
||||||
|
seed = input_token_index[trimmed_note]
|
||||||
|
random_seq.append(trimmed_note)
|
||||||
|
items += 1
|
||||||
|
except KeyError:
|
||||||
|
time -= note[1]
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
random_seq.append(note)
|
||||||
|
items += 1
|
||||||
|
|
||||||
|
if items > max_encoder_seq_length:
|
||||||
|
time = 0
|
||||||
|
random_seq = []
|
||||||
|
items = 0
|
||||||
|
stop_sign = False
|
||||||
|
|
||||||
|
|
||||||
|
return random_seq
|
||||||
|
|
||||||
|
# seq to numpy array:
|
||||||
|
def seq_to_numpy(seq, max_encoder_seq_length, num_encoder_tokens, input_token_index):
|
||||||
|
input_data = np.zeros(
|
||||||
|
(1, max_encoder_seq_length, num_encoder_tokens),
|
||||||
|
dtype='float32')
|
||||||
|
|
||||||
|
for t, char in enumerate(seq):
|
||||||
|
try:
|
||||||
|
input_data[0, t, input_token_index[char]] = 1.
|
||||||
|
except KeyError:
|
||||||
|
char_time = char[1]
|
||||||
|
_char = ((-1,), char_time)
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
|
||||||
|
return input_data
|
@ -1,26 +1,12 @@
|
|||||||
## MUSIC GENERATION USING DEEP LEARNING
|
## MUSIC GENERATION USING DEEP LEARNING
|
||||||
## AUTHOR: CEZARY PUKOWNIK
|
## AUTHOR: CEZARY PUKOWNIK
|
||||||
|
|
||||||
### Files:
|
|
||||||
- midi.py - code for data extraction, and midi convertion
|
|
||||||
- train.py - code for model definition, and training session
|
|
||||||
- generate.py - code for model loading, predicting ang saving to midi_dir
|
|
||||||
- settings.py - file where deafult settings are stored
|
|
||||||
- readme.md - this file
|
|
||||||
|
|
||||||
### Directories:
|
|
||||||
- data/midi - directory where input midi are stored
|
|
||||||
- data/models - directory where trained models are stored
|
|
||||||
- data/output - directory where generated music is stored
|
|
||||||
- data/samples - directory where extracted data from midi is stored
|
|
||||||
- data/samples.npz - deprecated
|
|
||||||
|
|
||||||
## How to use:
|
## How to use:
|
||||||
1. Use midi.py to export data from midi files
|
1. In folder ./midi_packs make folder with midi files you want train on
|
||||||
> ./midi.py [midi_folder_path] [output_path]
|
2. Use extract.py to export data from midis
|
||||||
|
> ./extract.py [str: midi_pack_name] [str: name_of_session] --b [int: seq_len] -a [analize data first]
|
||||||
|
3. Use train.py to train model
|
||||||
|
> ./train.py [str: name_of_session] --b [int: batch_size] --l [int: latent_space] --e [int: epochs] --i [str: instrument] -r [reset]
|
||||||
|
4. Use generate.py to generate music from models
|
||||||
|
> ./generate.py [str: name_of_session] --n [number of generations] --m [mode {'from_seq','from_state'}]
|
||||||
|
|
||||||
2. Use train.py to train a model (this can take a while)
|
|
||||||
> ./train.py [input_training_data] [model_save_path] [epochs]
|
|
||||||
|
|
||||||
3. Use generate.py to generate music from trained models
|
|
||||||
> ./generate.py [trained_model_path] [output_path] [treshold]
|
|
||||||
|
@ -1,313 +0,0 @@
|
|||||||
# paths
|
|
||||||
midi_dir = 'data/midi'
|
|
||||||
samples_dir = 'data/samples'
|
|
||||||
samples_path = 'data/samples.npz'
|
|
||||||
sample_preview_path = 'data/samples_preview.png'
|
|
||||||
model_path = 'data/autoencoder_model.h5'
|
|
||||||
generated_sample_path = 'data/output/generated_bar.npz'
|
|
||||||
generated_midi_path = 'data/output/generated_midi.mid'
|
|
||||||
generated_pianoroll_path = 'data/output/pianoroll.png'
|
|
||||||
|
|
||||||
# export_settings
|
|
||||||
midi_resolution = 96
|
|
||||||
beat_resolution = 24
|
|
||||||
beats_per_sample = 1
|
|
||||||
ignore_note_lenght = False
|
|
||||||
|
|
||||||
#train_settings
|
|
||||||
epochs = 1
|
|
||||||
|
|
||||||
#extras
|
|
||||||
midi_program = {
|
|
||||||
# Piano
|
|
||||||
1 : 'Acoustic Grand Piano',
|
|
||||||
2 : 'Bright Acoustic Piano',
|
|
||||||
3 : 'Electric Grand Piano',
|
|
||||||
4 : 'Honky-tonk Piano',
|
|
||||||
5 : 'Electric Piano 1',
|
|
||||||
6 : 'Electric Piano 2',
|
|
||||||
7 : 'Harpsichord',
|
|
||||||
8 : 'Clavi',
|
|
||||||
# Chromatic Percussion
|
|
||||||
9 : 'Celesta',
|
|
||||||
10 : 'Glockenspiel',
|
|
||||||
11 : 'Music Box',
|
|
||||||
12 : 'Vibraphone',
|
|
||||||
13 : 'Marimba',
|
|
||||||
14 : 'Xylophone',
|
|
||||||
15 : 'Tubular Bells',
|
|
||||||
16 : 'Dulcimer',
|
|
||||||
# Organ
|
|
||||||
17 : 'Drawbar Organ',
|
|
||||||
18 : 'Percussive Organ',
|
|
||||||
19 : 'Rock Organ',
|
|
||||||
20 : 'Church Organ',
|
|
||||||
21 : 'Reed Organ',
|
|
||||||
22 : 'Accordion',
|
|
||||||
23 : 'Harmonica',
|
|
||||||
24 : 'Tango Accordion',
|
|
||||||
# Guitar
|
|
||||||
25 : 'Acoustic Guitar (nylon)',
|
|
||||||
26 : 'Acoustic Guitar (steel)',
|
|
||||||
27 : 'Electric Guitar (jazz)',
|
|
||||||
28 : 'Electric Guitar (clean)',
|
|
||||||
29 : 'Electric Guitar (muted)',
|
|
||||||
30 : 'Overdriven Guitar',
|
|
||||||
31 : 'Distortion Guitar',
|
|
||||||
32 : 'Guitar harmonics',
|
|
||||||
# Bass
|
|
||||||
33 : 'Acoustic Bass',
|
|
||||||
34 : 'Electric Bass (finger)',
|
|
||||||
35 : 'Electric Bass (pick)',
|
|
||||||
36 : 'Fretless Bass',
|
|
||||||
37 : 'Slap Bass 1',
|
|
||||||
38 : 'Slap Bass 2',
|
|
||||||
39 : 'Synth Bass 1',
|
|
||||||
40 : 'Synth Bass 2',
|
|
||||||
# Strings
|
|
||||||
41 : 'Violin',
|
|
||||||
42 : 'Viola',
|
|
||||||
43 : 'Cello',
|
|
||||||
44 : 'Contrabass',
|
|
||||||
45 : 'Tremolo Strings',
|
|
||||||
46 : 'Pizzicato Strings',
|
|
||||||
47 : 'Orchestral Harp',
|
|
||||||
48 : 'Timpani',
|
|
||||||
# Ensemble
|
|
||||||
49 : 'String Ensemble 1',
|
|
||||||
50 : 'String Ensemble 2',
|
|
||||||
51 : 'SynthStrings 1',
|
|
||||||
52 : 'SynthStrings 2',
|
|
||||||
53 : 'Choir Aahs',
|
|
||||||
54 : 'Voice Oohs',
|
|
||||||
55 : 'Synth Voice',
|
|
||||||
56 : 'Orchestra Hit',
|
|
||||||
# Brass
|
|
||||||
57 : 'Trumpet',
|
|
||||||
58 : 'Trombone',
|
|
||||||
59 : 'Tuba',
|
|
||||||
60 : 'Muted Trumpet',
|
|
||||||
61 : 'French Horn',
|
|
||||||
62 : 'Brass Section',
|
|
||||||
63 : 'SynthBrass 1',
|
|
||||||
64 : 'SynthBrass 2',
|
|
||||||
# Reed
|
|
||||||
65 : 'Soprano Sax',
|
|
||||||
66 : 'Alto Sax',
|
|
||||||
67 : 'Tenor Sax',
|
|
||||||
68 : 'Baritone Sax',
|
|
||||||
69 : 'Oboe',
|
|
||||||
70 : 'English Horn',
|
|
||||||
71 : 'Bassoon',
|
|
||||||
72 : 'Clarinet',
|
|
||||||
# Pipe
|
|
||||||
73 : 'Piccolo',
|
|
||||||
74 : 'Flute',
|
|
||||||
75 : 'Recorder',
|
|
||||||
76 : 'Pan Flute',
|
|
||||||
77 : 'Blown Bottle',
|
|
||||||
78 : 'Shakuhachi',
|
|
||||||
79 : 'Whistle',
|
|
||||||
80 : 'Ocarina',
|
|
||||||
# Synth Lead
|
|
||||||
81 : 'Lead 1 (square)',
|
|
||||||
82 : 'Lead 2 (sawtooth)',
|
|
||||||
83 : 'Lead 3 (calliope)',
|
|
||||||
84 : 'Lead 4 (chiff)',
|
|
||||||
85 : 'Lead 5 (charang)',
|
|
||||||
86 : 'Lead 6 (voice)',
|
|
||||||
87 : 'Lead 7 (fifths)',
|
|
||||||
88 : 'Lead 8 (bass + lead)',
|
|
||||||
# Synth Pad
|
|
||||||
89 : 'Pad 1 (new age)',
|
|
||||||
90 : 'Pad 2 (warm)',
|
|
||||||
91 : 'Pad 3 (polysynth)',
|
|
||||||
92 : 'Pad 4 (choir)',
|
|
||||||
93 : 'Pad 5 (bowed)',
|
|
||||||
94 : 'Pad 6 (metallic)',
|
|
||||||
95 : 'Pad 7 (halo)',
|
|
||||||
96 : 'Pad 8 (sweep)',
|
|
||||||
# Synth Effects
|
|
||||||
97 : 'FX 1 (rain)',
|
|
||||||
98 : 'FX 2 (soundtrack)',
|
|
||||||
99 : 'FX 3 (crystal)',
|
|
||||||
100 : 'FX 4 (atmosphere)',
|
|
||||||
101 : 'FX 5 (brightness)',
|
|
||||||
102 : 'FX 6 (goblins)',
|
|
||||||
103 : 'FX 7 (echoes)',
|
|
||||||
104 : 'FX 8 (sci-fi)',
|
|
||||||
# Ethnic
|
|
||||||
105 : 'Sitar',
|
|
||||||
106 : 'Banjo',
|
|
||||||
107 : 'Shamisen',
|
|
||||||
108 : 'Koto',
|
|
||||||
109 : 'Kalimba',
|
|
||||||
110 : 'Bag pipe',
|
|
||||||
111 : 'Fiddle',
|
|
||||||
112 : 'Shanai',
|
|
||||||
# Percussive
|
|
||||||
113 : 'Tinkle Bell',
|
|
||||||
114 : 'Agogo',
|
|
||||||
115 : 'Steel Drums',
|
|
||||||
116 : 'Woodblock',
|
|
||||||
117 : 'Taiko Drum',
|
|
||||||
118 : 'Melodic Tom',
|
|
||||||
119 : 'Synth Drum',
|
|
||||||
120 : 'Reverse Cymbal',
|
|
||||||
# Sound Effects
|
|
||||||
121 : 'Guitar Fret Noise',
|
|
||||||
122 : 'Breath Noise',
|
|
||||||
123 : 'Seashore',
|
|
||||||
124 : 'Bird Tweet',
|
|
||||||
125 : 'Telephone Ring',
|
|
||||||
126 : 'Helicopter',
|
|
||||||
127 : 'Applause',
|
|
||||||
128 : 'Gunshot'
|
|
||||||
}
|
|
||||||
|
|
||||||
midi_group = {
|
|
||||||
# Piano
|
|
||||||
1 : 'Piano',
|
|
||||||
2 : 'Piano',
|
|
||||||
3 : 'Piano',
|
|
||||||
4 : 'Piano',
|
|
||||||
5 : 'Piano',
|
|
||||||
6 : 'Piano',
|
|
||||||
7 : 'Piano',
|
|
||||||
8 : 'Piano',
|
|
||||||
# Chromatic Percussion
|
|
||||||
9 : 'Chromatic_Percussion',
|
|
||||||
10 : 'Chromatic_Percussion',
|
|
||||||
11 : 'Chromatic_Percussion',
|
|
||||||
12 : 'Chromatic_Percussion',
|
|
||||||
13 : 'Chromatic_Percussion',
|
|
||||||
14 : 'Chromatic_Percussion',
|
|
||||||
15 : 'Chromatic_Percussion',
|
|
||||||
16 : 'Chromatic_Percussion',
|
|
||||||
# Organ
|
|
||||||
17 : 'Organ',
|
|
||||||
18 : 'Organ',
|
|
||||||
19 : 'Organ',
|
|
||||||
20 : 'Organ',
|
|
||||||
21 : 'Organ',
|
|
||||||
22 : 'Organ',
|
|
||||||
23 : 'Organ',
|
|
||||||
24 : 'Organ',
|
|
||||||
# Guitar
|
|
||||||
25 : 'Guitar',
|
|
||||||
26 : 'Guitar',
|
|
||||||
27 : 'Guitar',
|
|
||||||
28 : 'Guitar',
|
|
||||||
29 : 'Guitar',
|
|
||||||
30 : 'Guitar',
|
|
||||||
31 : 'Guitar',
|
|
||||||
32 : 'Guitar',
|
|
||||||
# Bass
|
|
||||||
33 : 'Bass',
|
|
||||||
34 : 'Bass',
|
|
||||||
35 : 'Bass',
|
|
||||||
36 : 'Bass',
|
|
||||||
37 : 'Bass',
|
|
||||||
38 : 'Bass',
|
|
||||||
39 : 'Bass',
|
|
||||||
40 : 'Bass',
|
|
||||||
# Strings
|
|
||||||
41 : 'Strings',
|
|
||||||
42 : 'Strings',
|
|
||||||
43 : 'Strings',
|
|
||||||
44 : 'Strings',
|
|
||||||
45 : 'Strings',
|
|
||||||
46 : 'Strings',
|
|
||||||
47 : 'Strings',
|
|
||||||
48 : 'Strings',
|
|
||||||
# Ensemble
|
|
||||||
49 : 'Ensemble',
|
|
||||||
50 : 'Ensemble',
|
|
||||||
51 : 'Ensemble',
|
|
||||||
52 : 'Ensemble',
|
|
||||||
53 : 'Ensemble',
|
|
||||||
54 : 'Ensemble',
|
|
||||||
55 : 'Ensemblee',
|
|
||||||
56 : 'Ensemble',
|
|
||||||
# Brass
|
|
||||||
57 : 'Brass',
|
|
||||||
58 : 'Brass',
|
|
||||||
59 : 'Brass',
|
|
||||||
60 : 'Brass',
|
|
||||||
61 : 'Brass',
|
|
||||||
62 : 'Brass',
|
|
||||||
63 : 'Brass',
|
|
||||||
64 : 'Brass',
|
|
||||||
# Reed
|
|
||||||
65 : 'Reed',
|
|
||||||
66 : 'Reed',
|
|
||||||
67 : 'Reed',
|
|
||||||
68 : 'Reed',
|
|
||||||
69 : 'Reed',
|
|
||||||
70 : 'Reed',
|
|
||||||
71 : 'Reed',
|
|
||||||
72 : 'Reed',
|
|
||||||
# Pipe
|
|
||||||
73 : 'Pipe',
|
|
||||||
74 : 'Pipe',
|
|
||||||
75 : 'Pipe',
|
|
||||||
76 : 'Pipe',
|
|
||||||
77 : 'Pipe',
|
|
||||||
78 : 'Pipe',
|
|
||||||
79 : 'Pipe',
|
|
||||||
80 : 'Pipe',
|
|
||||||
# Synth Lead
|
|
||||||
81 : 'Synth_Lead',
|
|
||||||
82 : 'Synth_Lead',
|
|
||||||
83 : 'Synth_Lead',
|
|
||||||
84 : 'Synth_Lead',
|
|
||||||
85 : 'Synth_Lead',
|
|
||||||
86 : 'Synth_Lead',
|
|
||||||
87 : 'Synth_Lead',
|
|
||||||
88 : 'Synth_Lead',
|
|
||||||
# Synth Pad
|
|
||||||
89 : 'Synth_Pad',
|
|
||||||
90 : 'Synth_Pad',
|
|
||||||
91 : 'Synth_Pad',
|
|
||||||
92 : 'Synth_Pad',
|
|
||||||
93 : 'Synth_Pad',
|
|
||||||
94 : 'Synth_Pad',
|
|
||||||
95 : 'Synth_Pad',
|
|
||||||
96 : 'Synth_Pad',
|
|
||||||
# Synth Effects
|
|
||||||
97 : 'Synth_Effects',
|
|
||||||
98 : 'Synth_Effects',
|
|
||||||
99 : 'Synth_Effects',
|
|
||||||
100 : 'Synth_Effects',
|
|
||||||
101 : 'Synth_Effects',
|
|
||||||
102 : 'Synth_Effects',
|
|
||||||
103 : 'Synth_Effects',
|
|
||||||
104 : 'Synth_Effects',
|
|
||||||
# Ethnic
|
|
||||||
105 : 'Ethnic',
|
|
||||||
106 : 'Ethnic',
|
|
||||||
107 : 'Ethnic',
|
|
||||||
108 : 'Ethnic',
|
|
||||||
109 : 'Ethnic',
|
|
||||||
110 : 'Ethnic',
|
|
||||||
111 : 'Ethnic',
|
|
||||||
112 : 'Ethnic',
|
|
||||||
# Percussive
|
|
||||||
113 : 'Percussive',
|
|
||||||
114 : 'Percussive',
|
|
||||||
115 : 'Percussive',
|
|
||||||
116 : 'Percussive',
|
|
||||||
117 : 'Percussive',
|
|
||||||
118 : 'Percussive',
|
|
||||||
119 : 'Percussive',
|
|
||||||
120 : 'Percussive',
|
|
||||||
# Sound Effects
|
|
||||||
121 : 'Sound_Effects',
|
|
||||||
122 : 'Sound_Effects',
|
|
||||||
123 : 'Sound_Effects',
|
|
||||||
124 : 'Sound_Effects',
|
|
||||||
125 : 'Sound_Effects',
|
|
||||||
126 : 'Sound_Effects',
|
|
||||||
127 : 'Sound_Effects',
|
|
||||||
128 : 'Sound_Effects'
|
|
||||||
}
|
|
Binary file not shown.
119
project/train.py
119
project/train.py
@ -1,68 +1,79 @@
|
|||||||
#!python3
|
import os
|
||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
import sys
|
||||||
import pickle
|
import pickle
|
||||||
import settings
|
import keras
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
from model import Seq2SeqModel
|
||||||
|
from extract import make_folder_if_not_exist
|
||||||
|
|
||||||
import numpy as np
|
# TODO:
|
||||||
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector, Activation, Bidirectional, Reshape
|
# FIXME:
|
||||||
from keras.models import Model, Sequential
|
|
||||||
from keras.utils.np_utils import to_categorical
|
|
||||||
|
|
||||||
|
def parse_argv():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('n', help='name for experiment', type=str)
|
||||||
|
parser.add_argument('--b', help='batch_size', type=int)
|
||||||
|
parser.add_argument('--l', help='latent_dim', type=int)
|
||||||
|
parser.add_argument('--e', help='epochs', type=int)
|
||||||
|
parser.add_argument('--i', help='refrance to instrument to train, if you want to train only one instument')
|
||||||
|
parser.add_argument('-r', help='reset, use when you want to reset waights and train from scratch', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
def load_data(samples_path):
|
def load_workflow():
|
||||||
print('Loading... {}'.format(train_data_path))
|
workflow_path = os.path.join('training_sets', EXPERIMENT_NAME, 'workflow.pkl')
|
||||||
train_X = np.load(train_data_path, allow_pickle=True)['arr_0']
|
if os.path.isfile(workflow_path):
|
||||||
train_y = np.load(train_data_path, allow_pickle=True)['arr_1']
|
model_workflow = pickle.load(open(workflow_path,'rb'))
|
||||||
return train_X, train_y
|
else:
|
||||||
|
raise FileNotFoundError(f'There is no workflow.pkl file in trainig_sets/{EXPERIMENT_NAME}/ folder')
|
||||||
|
return model_workflow
|
||||||
|
|
||||||
# TODO: make transformer class with fit, transform and reverse definitions
|
def train_models(model_workflow):
|
||||||
def preprocess_samples(train_X, train_y):
|
|
||||||
vocab_X = np.unique(train_X)
|
|
||||||
vocab_y = np.unique(train_y)
|
|
||||||
vocab = np.concatenate([vocab_X, vocab_y])
|
|
||||||
n_vocab = vocab.shape[0]
|
|
||||||
note_to_int = dict((note, number) for number, note in enumerate(vocab))
|
|
||||||
int_to_note = dict((number, note) for number, note in enumerate(vocab))
|
|
||||||
_train_X = []
|
|
||||||
_train_y = []
|
|
||||||
for sample in train_X:
|
|
||||||
# TODO: add normalizasion
|
|
||||||
_train_X.append([note_to_int[note] for note in sample])
|
|
||||||
|
|
||||||
train_X = np.array(_train_X).reshape(train_X.shape[0], train_X.shape[1], 1)
|
instruments = [instrument if how == 'melody' else instrument[1] for key, (instrument, how) in model_workflow.items()]
|
||||||
train_y = np.array([note_to_int[note] for note in train_y]).reshape(-1,1)
|
# make_folder_if_not_exist(os.mkdir(os.path.join('models',EXPERIMENT_NAME)))
|
||||||
train_y = to_categorical(train_y)
|
|
||||||
|
|
||||||
return train_X, train_y, n_vocab, int_to_note
|
found = False
|
||||||
|
for instrument in instruments:
|
||||||
|
|
||||||
train_data_path = sys.argv[1]
|
if INSTRUMENT == None or INSTRUMENT == instrument:
|
||||||
|
data_path = os.path.join('training_sets', EXPERIMENT_NAME, instrument.lower() + '_data.pkl')
|
||||||
|
model_path = os.path.join('models', EXPERIMENT_NAME, f'{instrument.lower()}_model.h5')
|
||||||
|
|
||||||
train_X, train_y = load_data(train_data_path)
|
x_train, y_train, _ = pickle.load(open(data_path,'rb'))
|
||||||
train_X, train_y, n_vocab, int_to_note = preprocess_samples(train_X, train_y)
|
model = Seq2SeqModel(LATENT_DIM, x_train, y_train)
|
||||||
|
if os.path.isfile(model_path) and not RESET:
|
||||||
|
model.load(model_path)
|
||||||
|
|
||||||
save_model_path = sys.argv[2]
|
print(f'Training: {instrument}')
|
||||||
epochs = int(sys.argv[3])
|
model.fit(BATCH_SIZE, EPOCHS, callbacks=[])
|
||||||
|
model.save(model_path)
|
||||||
|
found = True
|
||||||
|
|
||||||
model = Sequential()
|
if not found:
|
||||||
model.add(LSTM(512, input_shape=(train_X.shape[1], train_X.shape[2]), return_sequences=True))
|
raise ValueError(f'Instrument not found. Use one of the {instruments}')
|
||||||
model.add(Dropout(0.3))
|
|
||||||
model.add(LSTM(512, return_sequences=True))
|
|
||||||
model.add(Dropout(0.3))
|
|
||||||
model.add(LSTM(512))
|
|
||||||
model.add(Dense(256))
|
|
||||||
model.add(Dropout(0.3))
|
|
||||||
model.add(Dense(n_vocab))
|
|
||||||
model.add(Activation('softmax'))
|
|
||||||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
|
||||||
|
|
||||||
# This code will train our model, with given by parameter number of epochs
|
if __name__ == '__main__':
|
||||||
print('Training...')
|
|
||||||
model.fit(train_X, train_y, epochs=epochs, batch_size=64)
|
|
||||||
|
|
||||||
# it saves model, and additional informations of model
|
warnings.filterwarnings("ignore")
|
||||||
# that is needed to generate music from it
|
args = parse_argv()
|
||||||
pickle.dump(model, open(save_model_path,'wb'))
|
|
||||||
pickle.dump((int_to_note, n_vocab, train_X.shape[1]), open('{}_dict'.format(save_model_path),'wb'))
|
EXPERIMENT_NAME = args.n
|
||||||
print('Done!')
|
BATCH_SIZE = args.b
|
||||||
print("Model saved to: {}".format(save_model_path))
|
LATENT_DIM = args.l
|
||||||
|
EPOCHS = args.e
|
||||||
|
RESET = args.r
|
||||||
|
INSTRUMENT = args.i
|
||||||
|
|
||||||
|
# default settings if not args passed
|
||||||
|
if not BATCH_SIZE:
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
if not LATENT_DIM:
|
||||||
|
LATENT_DIM = 256
|
||||||
|
if not EPOCHS:
|
||||||
|
EPOCHS = 1
|
||||||
|
if not RESET:
|
||||||
|
RESET = False
|
||||||
|
|
||||||
|
train_models(load_workflow())
|
Loading…
Reference in New Issue
Block a user