lstm - drop this branch, looking for other way to generate music
This commit is contained in:
parent
1997ff96ef
commit
35c19e1e80
@ -1,26 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import midi
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.layers import Input, Dense, Conv2D
|
from keras.layers import Input, Dense, Conv2D
|
||||||
from keras.models import Model
|
from keras.models import Model
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector
|
||||||
|
from keras.models import Model, Sequential
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import settings
|
import settings
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
|
||||||
#model
|
trained_model_path = sys.argv[1]
|
||||||
input_shape = settings.midi_resolution*128
|
output_path = sys.argv[2]
|
||||||
input_img = tf.keras.layers.Input(shape=(input_shape,))
|
# treshold = float(sys.argv[3])
|
||||||
encoded = tf.keras.layers.Dense(160, activation='relu')(input_img)
|
|
||||||
decoded = tf.keras.layers.Dense(input_shape, activation='sigmoid')(encoded)
|
|
||||||
autoencoder = tf.keras.models.Model(input_img, decoded)
|
|
||||||
|
|
||||||
autoencoder.compile(optimizer='adadelta',
|
#random seed
|
||||||
loss='categorical_crossentropy',
|
generate_seed = np.random.rand(12288).reshape(1,96,128)
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
# load weights into new model
|
# load and predict
|
||||||
autoencoder.load_weights(settings.model_path)
|
model = pickle.load(open(trained_model_path, 'rb'))
|
||||||
print("Loaded model from {}".format(settings.model_path))
|
|
||||||
|
|
||||||
# generate_seed = np.random.rand(12288).reshape(1,12288)
|
generated_music = np.empty((0,128))
|
||||||
generate_seed = np.load(settings.samples_path)['arr_0'][15].reshape(1,12288)
|
|
||||||
|
|
||||||
generated_sample = autoencoder.predict(generate_seed)
|
for note in range(100):
|
||||||
np.savez_compressed(settings.generated_sample_path, generated_sample)
|
generated_vector = model.predict(generate_seed).reshape(1,4,128)
|
||||||
|
generated_notes = np.zeros((4,128))
|
||||||
|
for i, col in enumerate(generated_vector[0]):
|
||||||
|
best_note = np.argmax(col)
|
||||||
|
generated_notes[i][best_note] = 1
|
||||||
|
|
||||||
|
generate_seed = np.concatenate([generated_notes, generate_seed[0][:-4]]).reshape(1,96,128)
|
||||||
|
generated_music = np.concatenate([generated_music, generated_notes])
|
||||||
|
|
||||||
|
# generated_sample = generated_sample.reshape(96,128)
|
||||||
|
generated_sample = generated_music
|
||||||
|
# print(generated_music)
|
||||||
|
# binarize generated music
|
||||||
|
# generated_sample = generated_sample > 0 * generated_sample.max()
|
||||||
|
|
||||||
|
#save to midi
|
||||||
|
midi.to_midi(generated_sample, output_path='{}.mid'.format(output_path) )
|
||||||
|
|
||||||
|
#save piano roll to png
|
||||||
|
plt.imshow(generated_sample, cmap = plt.get_cmap('gray'))
|
||||||
|
plt.savefig('{}.png'.format(output_path))
|
||||||
|
148
project/midi.py
148
project/midi.py
@ -8,36 +8,86 @@ import os
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from math import floor
|
from math import floor
|
||||||
import sys
|
import sys
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
from tqdm import trange
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import bz2
|
||||||
|
import pickle
|
||||||
|
|
||||||
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
|
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
|
||||||
|
|
||||||
# this function export a samples from midi file:
|
# TODO: add transpositions of every sample to every possible key transposition
|
||||||
# and for every track in midi file chopped pianoroll
|
# np.roll(sample, pitch_interval, axis=1) for transposition
|
||||||
# for a samples of given beat_lenth (midi_res)
|
# np.roll(sample, time_steps, axis=0) for time shifting
|
||||||
# every track is single line
|
all_X_samples = []
|
||||||
|
all_y_samples = []
|
||||||
|
for track in roll.Multitrack(midi_file_path).tracks:
|
||||||
|
if not track.is_drum:
|
||||||
|
# TODO: this makes rollable samples and dataset of y_train for prdiction
|
||||||
|
# the idea is to predict next N timesteps from prevous M timesteps
|
||||||
|
m_timesteps = 96
|
||||||
|
n_next_notes = 4
|
||||||
|
|
||||||
print('Exporting samples from: {}'.format(midi_file_path))
|
track_timesteps = track.pianoroll.shape[0] - (m_timesteps + n_next_notes)
|
||||||
|
|
||||||
all_beats = np.empty((0, settings.midi_resolution, 128))
|
X_track_samples = []
|
||||||
|
y_track_samples = []
|
||||||
|
for i in range(track_timesteps):
|
||||||
|
X = track.pianoroll[i : i + m_timesteps].reshape(96,128)
|
||||||
|
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(4,128)
|
||||||
|
X_track_samples.append(X)
|
||||||
|
y_track_samples.append(y)
|
||||||
|
|
||||||
|
all_X_samples.extend(X_track_samples)
|
||||||
|
all_y_samples.extend(y_track_samples)
|
||||||
|
else:
|
||||||
|
# TODO: add code for drums samples
|
||||||
|
pass
|
||||||
|
return all_X_samples, all_y_samples
|
||||||
|
|
||||||
|
def to_samples_by_instrument(midi_file_path, midi_res=settings.midi_resolution):
|
||||||
|
|
||||||
|
# add transpositions of every sample to every possible key transposition
|
||||||
|
# np.roll(sample, pitch_interval, axis=1) for transposition
|
||||||
|
# np.roll(sample, time_steps, axis=0) for time shifting
|
||||||
|
|
||||||
|
# TODO: make rollable samples with train_Y set
|
||||||
|
|
||||||
|
fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
|
||||||
|
samples_by_instrument = defaultdict(fill_empty_array)
|
||||||
|
all_beats = np.empty((0, 96, 128))
|
||||||
|
|
||||||
for track in roll.Multitrack(midi_file_path).tracks:
|
for track in roll.Multitrack(midi_file_path).tracks:
|
||||||
print('Track: {}'.format(track.name))
|
|
||||||
if not track.is_drum:
|
if not track.is_drum:
|
||||||
number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
|
key = track.program + 1
|
||||||
track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
|
# TODO: this makes pack of samples of N x 96 x 128 shape
|
||||||
track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
|
# number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
|
||||||
all_beats = np.concatenate([track_beats, all_beats], axis=0)
|
# track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
|
||||||
|
# track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
|
||||||
|
|
||||||
print('Exported {} samples of {}'.format(number_of_beats, settings.midi_program[track.program]))
|
# TODO: this makes rollable samples and dataset of y_train for prdiction
|
||||||
|
# the idea is to predict next n notes from prevous m timesteps
|
||||||
|
m_timesteps = 96
|
||||||
|
n_next_notes = 4
|
||||||
|
for i, value in tqdm(enumerate(track.pianoroll[:-(m_timesteps + n_next_notes)])):
|
||||||
|
X = track.pianoroll[i : i + m_timesteps].reshape(1,96,128)
|
||||||
|
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(1,1,128)
|
||||||
|
|
||||||
|
samples_by_instrument[key][0] = np.concatenate([X, samples_by_instrument[ key ][0]], axis=0)
|
||||||
|
samples_by_instrument[key][1] = np.concatenate([y, samples_by_instrument[ key ][1]], axis=0)
|
||||||
|
|
||||||
|
# samples_by_instrument[track.program + 1][0] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
|
||||||
else:
|
else:
|
||||||
# add code for drums samples
|
# TODO: add code for drums samples
|
||||||
pass
|
pass
|
||||||
return all_beats
|
return samples_by_instrument
|
||||||
|
|
||||||
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
|
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
|
||||||
tracks = [roll.Track(samples, program=program)]
|
tracks = [roll.Track(samples, program=program)]
|
||||||
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
|
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
|
||||||
roll.write(return_midi, settings.generated_midi_path)
|
roll.write(return_midi, output_path)
|
||||||
|
|
||||||
# todo: this function is running too slow.
|
# todo: this function is running too slow.
|
||||||
def delete_empty_samples(sample_pack):
|
def delete_empty_samples(sample_pack):
|
||||||
@ -52,31 +102,61 @@ def delete_empty_samples(sample_pack):
|
|||||||
return temp_sample_pack
|
return temp_sample_pack
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
print('Exporting...')
|
||||||
|
|
||||||
if sys.argv[1]=='export':
|
# from collections import defaultdict
|
||||||
print('Exporting started...')
|
# fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
|
||||||
|
# samples_pack_by_instrument = defaultdict(fill_empty_array)
|
||||||
|
|
||||||
sample_pack = np.empty((0,settings.midi_resolution,128))
|
# sample_pack = np.empty((0,settings.midi_resolution,128))
|
||||||
|
X_train = []
|
||||||
|
y_train = []
|
||||||
|
|
||||||
for midi_file in os.listdir(settings.midi_dir):
|
for midi_file in tqdm(os.listdir(settings.midi_dir)):
|
||||||
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
print(midi_file)
|
||||||
midi_samples = to_samples(midi_file_path)
|
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
||||||
if midi_samples is None:
|
X, y = to_samples(midi_file_path)
|
||||||
continue
|
# if midi_samples is None:
|
||||||
sample_pack = np.concatenate((midi_samples, sample_pack), axis=0)
|
# continue
|
||||||
|
X_train.extend(X)
|
||||||
|
y_train.extend(y)
|
||||||
|
# this is for intrument separation
|
||||||
|
# for key, value in midi_samples.items():
|
||||||
|
# samples_pack_by_instrument[key][0] = np.concatenate((samples_pack_by_instrument[key][0], value[0]), axis=0)
|
||||||
|
# samples_pack_by_instrument[key][1] = np.concatenate((samples_pack_by_instrument[key][1], value[1]), axis=0)
|
||||||
|
|
||||||
# I commented out this line, because it was too slow
|
# TODO: Delete empty samples
|
||||||
# sample_pack = delete_empty_samples(sample_pack)
|
# sample_pack = delete_empty_samples(sample_pack)
|
||||||
|
|
||||||
np.savez_compressed(settings.samples_dir, sample_pack)
|
# save as compressed pickle (sample-dictionary)
|
||||||
print('Exported {} samples'.format(sample_pack.shape[0]))
|
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
|
||||||
|
# pickle.dump(dict(samples_pack_by_instrument), sfile)
|
||||||
|
|
||||||
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
# this is for intrument separation
|
||||||
for idx, ax in enumerate(axes.ravel()):
|
# print('Saving...')
|
||||||
n = np.random.randint(0, sample_pack.shape[0])
|
# for key, value in tqdm(samples_pack_by_instrument.items()):
|
||||||
sample = sample_pack[n]
|
# np.savez_compressed('data/samples/X_{}.npz'.format(settings.midi_program[key][0]), value)
|
||||||
ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
# np.savez_compressed('data/samples/y_{}.npz'.format(settings.midi_program[key][1]), value)
|
||||||
plt.savefig(settings.sample_preview_path)
|
|
||||||
|
# this if for one big list
|
||||||
|
print('Saving...')
|
||||||
|
|
||||||
|
np_X_train = np.array(X_train)
|
||||||
|
np_y_train = np.array(y_train)
|
||||||
|
print(np_X_train.shape, np_y_train.shape)
|
||||||
|
np.savez_compressed('data/samples/X_{}.npz'.format(1), np_X_train)
|
||||||
|
np.savez_compressed('data/samples/y_{}.npz'.format(1), np_y_train)
|
||||||
|
|
||||||
|
# Give a preview of what samples looks like
|
||||||
|
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
||||||
|
# for idx, ax in enumerate(axes.ravel()):
|
||||||
|
# n = np.random.randint(0, value[0].shape[0])
|
||||||
|
# sample = value[n]
|
||||||
|
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
||||||
|
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
|
||||||
|
|
||||||
|
print('Exported {} samples'.format(np_X_train.shape[0]))
|
||||||
|
print('Done!')
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -1,16 +1,28 @@
|
|||||||
## MUSIC GENERATION USING DEEP LEARNING ##
|
## MUSIC GENERATION USING DEEP LEARNING ##
|
||||||
## AUTHOR: CEZARY PUKOWNIK
|
## AUTHOR: CEZARY PUKOWNIK
|
||||||
|
|
||||||
|
files:
|
||||||
|
- midi.py - code for data extraction, and midi convertion
|
||||||
|
- train.py - code for model definition, and training session
|
||||||
|
- generate.py - code for model loading, predicting ang saving to midi_dir
|
||||||
|
- settings.py - file where deafult settings are stored
|
||||||
|
- readme - this file
|
||||||
|
- data/midi - directory where input midi are stored
|
||||||
|
- data/models - directory where trained models are stored
|
||||||
|
- data/output - directory where generated music is stored
|
||||||
|
- data/samples - directory where extracted data from midi is stored
|
||||||
|
- data/samples.npz - deprecated
|
||||||
|
|
||||||
How to use:
|
How to use:
|
||||||
|
|
||||||
1. Use midi.py to export data from midi files
|
1. Use midi.py to export data from midi files
|
||||||
|
|
||||||
./midi.py export <midi_folder_path> <output_path>
|
./midi.py <midi_folder_path> <output_path>
|
||||||
|
|
||||||
2. Use train.py to train a model (this can take a while)
|
2. Use train.py to train a model (this can take a while)
|
||||||
|
|
||||||
./train.py <input_training_data> <model_save_path>
|
./train.py <input_training_data> <model_save_path> <epochs>
|
||||||
|
|
||||||
3. Use generate.py to generate music from trained models
|
3. Use generate.py to generate music from trained models
|
||||||
|
|
||||||
./generate.py <model_weights_path> <output_path>
|
./generate.py <trained_model_path> <output_path> <treshold>
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
import pypianoroll as roll
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import settings
|
|
||||||
|
|
||||||
instruments = np.load(settings.generated_sample_path)['arr_0'][0]
|
|
||||||
|
|
||||||
instruments = instruments.reshape(96,128)
|
|
||||||
# instruments = instruments>0.5
|
|
||||||
instruments = instruments*255
|
|
||||||
|
|
||||||
i = roll.Track(instruments, program=0)
|
|
||||||
generated_midi = roll.Multitrack(tracks=[i], tempo=120.0, downbeat=[0, 96, 192, 288], beat_resolution=24)
|
|
||||||
roll.write(generated_midi, settings.generated_midi_path)
|
|
||||||
|
|
||||||
plt.imshow(instruments.T, cmap='gray')
|
|
||||||
plt.savefig(settings.generated_pianoroll_path)
|
|
@ -15,11 +15,11 @@ beats_per_sample = 1
|
|||||||
ignore_note_lenght = False
|
ignore_note_lenght = False
|
||||||
|
|
||||||
#train_settings
|
#train_settings
|
||||||
epochs = 1000
|
epochs = 1
|
||||||
|
|
||||||
#extras
|
#extras
|
||||||
midi_program = {
|
midi_program = {
|
||||||
0 : 'Perc',
|
# Piano
|
||||||
1 : 'Acoustic Grand Piano',
|
1 : 'Acoustic Grand Piano',
|
||||||
2 : 'Bright Acoustic Piano',
|
2 : 'Bright Acoustic Piano',
|
||||||
3 : 'Electric Grand Piano',
|
3 : 'Electric Grand Piano',
|
||||||
@ -28,6 +28,7 @@ midi_program = {
|
|||||||
6 : 'Electric Piano 2',
|
6 : 'Electric Piano 2',
|
||||||
7 : 'Harpsichord',
|
7 : 'Harpsichord',
|
||||||
8 : 'Clavi',
|
8 : 'Clavi',
|
||||||
|
# Chromatic Percussion
|
||||||
9 : 'Celesta',
|
9 : 'Celesta',
|
||||||
10 : 'Glockenspiel',
|
10 : 'Glockenspiel',
|
||||||
11 : 'Music Box',
|
11 : 'Music Box',
|
||||||
@ -36,6 +37,7 @@ midi_program = {
|
|||||||
14 : 'Xylophone',
|
14 : 'Xylophone',
|
||||||
15 : 'Tubular Bells',
|
15 : 'Tubular Bells',
|
||||||
16 : 'Dulcimer',
|
16 : 'Dulcimer',
|
||||||
|
# Organ
|
||||||
17 : 'Drawbar Organ',
|
17 : 'Drawbar Organ',
|
||||||
18 : 'Percussive Organ',
|
18 : 'Percussive Organ',
|
||||||
19 : 'Rock Organ',
|
19 : 'Rock Organ',
|
||||||
@ -44,6 +46,7 @@ midi_program = {
|
|||||||
22 : 'Accordion',
|
22 : 'Accordion',
|
||||||
23 : 'Harmonica',
|
23 : 'Harmonica',
|
||||||
24 : 'Tango Accordion',
|
24 : 'Tango Accordion',
|
||||||
|
# Guitar
|
||||||
25 : 'Acoustic Guitar (nylon)',
|
25 : 'Acoustic Guitar (nylon)',
|
||||||
26 : 'Acoustic Guitar (steel)',
|
26 : 'Acoustic Guitar (steel)',
|
||||||
27 : 'Electric Guitar (jazz)',
|
27 : 'Electric Guitar (jazz)',
|
||||||
@ -52,6 +55,7 @@ midi_program = {
|
|||||||
30 : 'Overdriven Guitar',
|
30 : 'Overdriven Guitar',
|
||||||
31 : 'Distortion Guitar',
|
31 : 'Distortion Guitar',
|
||||||
32 : 'Guitar harmonics',
|
32 : 'Guitar harmonics',
|
||||||
|
# Bass
|
||||||
33 : 'Acoustic Bass',
|
33 : 'Acoustic Bass',
|
||||||
34 : 'Electric Bass (finger)',
|
34 : 'Electric Bass (finger)',
|
||||||
35 : 'Electric Bass (pick)',
|
35 : 'Electric Bass (pick)',
|
||||||
@ -60,6 +64,7 @@ midi_program = {
|
|||||||
38 : 'Slap Bass 2',
|
38 : 'Slap Bass 2',
|
||||||
39 : 'Synth Bass 1',
|
39 : 'Synth Bass 1',
|
||||||
40 : 'Synth Bass 2',
|
40 : 'Synth Bass 2',
|
||||||
|
# Strings
|
||||||
41 : 'Violin',
|
41 : 'Violin',
|
||||||
42 : 'Viola',
|
42 : 'Viola',
|
||||||
43 : 'Cello',
|
43 : 'Cello',
|
||||||
@ -68,6 +73,7 @@ midi_program = {
|
|||||||
46 : 'Pizzicato Strings',
|
46 : 'Pizzicato Strings',
|
||||||
47 : 'Orchestral Harp',
|
47 : 'Orchestral Harp',
|
||||||
48 : 'Timpani',
|
48 : 'Timpani',
|
||||||
|
# Ensemble
|
||||||
49 : 'String Ensemble 1',
|
49 : 'String Ensemble 1',
|
||||||
50 : 'String Ensemble 2',
|
50 : 'String Ensemble 2',
|
||||||
51 : 'SynthStrings 1',
|
51 : 'SynthStrings 1',
|
||||||
@ -76,6 +82,7 @@ midi_program = {
|
|||||||
54 : 'Voice Oohs',
|
54 : 'Voice Oohs',
|
||||||
55 : 'Synth Voice',
|
55 : 'Synth Voice',
|
||||||
56 : 'Orchestra Hit',
|
56 : 'Orchestra Hit',
|
||||||
|
# Brass
|
||||||
57 : 'Trumpet',
|
57 : 'Trumpet',
|
||||||
58 : 'Trombone',
|
58 : 'Trombone',
|
||||||
59 : 'Tuba',
|
59 : 'Tuba',
|
||||||
@ -84,6 +91,7 @@ midi_program = {
|
|||||||
62 : 'Brass Section',
|
62 : 'Brass Section',
|
||||||
63 : 'SynthBrass 1',
|
63 : 'SynthBrass 1',
|
||||||
64 : 'SynthBrass 2',
|
64 : 'SynthBrass 2',
|
||||||
|
# Reed
|
||||||
65 : 'Soprano Sax',
|
65 : 'Soprano Sax',
|
||||||
66 : 'Alto Sax',
|
66 : 'Alto Sax',
|
||||||
67 : 'Tenor Sax',
|
67 : 'Tenor Sax',
|
||||||
@ -92,6 +100,7 @@ midi_program = {
|
|||||||
70 : 'English Horn',
|
70 : 'English Horn',
|
||||||
71 : 'Bassoon',
|
71 : 'Bassoon',
|
||||||
72 : 'Clarinet',
|
72 : 'Clarinet',
|
||||||
|
# Pipe
|
||||||
73 : 'Piccolo',
|
73 : 'Piccolo',
|
||||||
74 : 'Flute',
|
74 : 'Flute',
|
||||||
75 : 'Recorder',
|
75 : 'Recorder',
|
||||||
@ -100,6 +109,7 @@ midi_program = {
|
|||||||
78 : 'Shakuhachi',
|
78 : 'Shakuhachi',
|
||||||
79 : 'Whistle',
|
79 : 'Whistle',
|
||||||
80 : 'Ocarina',
|
80 : 'Ocarina',
|
||||||
|
# Synth Lead
|
||||||
81 : 'Lead 1 (square)',
|
81 : 'Lead 1 (square)',
|
||||||
82 : 'Lead 2 (sawtooth)',
|
82 : 'Lead 2 (sawtooth)',
|
||||||
83 : 'Lead 3 (calliope)',
|
83 : 'Lead 3 (calliope)',
|
||||||
@ -108,6 +118,7 @@ midi_program = {
|
|||||||
86 : 'Lead 6 (voice)',
|
86 : 'Lead 6 (voice)',
|
||||||
87 : 'Lead 7 (fifths)',
|
87 : 'Lead 7 (fifths)',
|
||||||
88 : 'Lead 8 (bass + lead)',
|
88 : 'Lead 8 (bass + lead)',
|
||||||
|
# Synth Pad
|
||||||
89 : 'Pad 1 (new age)',
|
89 : 'Pad 1 (new age)',
|
||||||
90 : 'Pad 2 (warm)',
|
90 : 'Pad 2 (warm)',
|
||||||
91 : 'Pad 3 (polysynth)',
|
91 : 'Pad 3 (polysynth)',
|
||||||
@ -116,6 +127,7 @@ midi_program = {
|
|||||||
94 : 'Pad 6 (metallic)',
|
94 : 'Pad 6 (metallic)',
|
||||||
95 : 'Pad 7 (halo)',
|
95 : 'Pad 7 (halo)',
|
||||||
96 : 'Pad 8 (sweep)',
|
96 : 'Pad 8 (sweep)',
|
||||||
|
# Synth Effects
|
||||||
97 : 'FX 1 (rain)',
|
97 : 'FX 1 (rain)',
|
||||||
98 : 'FX 2 (soundtrack)',
|
98 : 'FX 2 (soundtrack)',
|
||||||
99 : 'FX 3 (crystal)',
|
99 : 'FX 3 (crystal)',
|
||||||
@ -124,6 +136,7 @@ midi_program = {
|
|||||||
102 : 'FX 6 (goblins)',
|
102 : 'FX 6 (goblins)',
|
||||||
103 : 'FX 7 (echoes)',
|
103 : 'FX 7 (echoes)',
|
||||||
104 : 'FX 8 (sci-fi)',
|
104 : 'FX 8 (sci-fi)',
|
||||||
|
# Ethnic
|
||||||
105 : 'Sitar',
|
105 : 'Sitar',
|
||||||
106 : 'Banjo',
|
106 : 'Banjo',
|
||||||
107 : 'Shamisen',
|
107 : 'Shamisen',
|
||||||
@ -132,6 +145,7 @@ midi_program = {
|
|||||||
110 : 'Bag pipe',
|
110 : 'Bag pipe',
|
||||||
111 : 'Fiddle',
|
111 : 'Fiddle',
|
||||||
112 : 'Shanai',
|
112 : 'Shanai',
|
||||||
|
# Percussive
|
||||||
113 : 'Tinkle Bell',
|
113 : 'Tinkle Bell',
|
||||||
114 : 'Agogo',
|
114 : 'Agogo',
|
||||||
115 : 'Steel Drums',
|
115 : 'Steel Drums',
|
||||||
@ -140,6 +154,7 @@ midi_program = {
|
|||||||
118 : 'Melodic Tom',
|
118 : 'Melodic Tom',
|
||||||
119 : 'Synth Drum',
|
119 : 'Synth Drum',
|
||||||
120 : 'Reverse Cymbal',
|
120 : 'Reverse Cymbal',
|
||||||
|
# Sound Effects
|
||||||
121 : 'Guitar Fret Noise',
|
121 : 'Guitar Fret Noise',
|
||||||
122 : 'Breath Noise',
|
122 : 'Breath Noise',
|
||||||
123 : 'Seashore',
|
123 : 'Seashore',
|
||||||
|
@ -3,31 +3,50 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import settings
|
import settings
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from keras.layers import Input, Dense, Conv2D, Flatten
|
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector, Activation, Bidirectional, Reshape
|
||||||
from keras.models import Model, Sequential
|
from keras.models import Model, Sequential
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sys import exit
|
import sys
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
print('Reading samples from: {}'.format(settings.samples_path))
|
train_data_path_X = sys.argv[1]
|
||||||
|
train_data_path_y = sys.argv[2]
|
||||||
|
save_model_path = sys.argv[3]
|
||||||
|
epochs = int(sys.argv[4])
|
||||||
|
|
||||||
train_X = np.load(settings.samples_path)['arr_0']
|
# model architecture
|
||||||
|
# model = Sequential()
|
||||||
|
# model.add(LSTM(128, activation='relu', input_shape=(96,128)))
|
||||||
|
# model.add(RepeatVector(96))
|
||||||
|
# model.add(LSTM(128, activation='softmax', return_sequences=True))
|
||||||
|
# model.add(TimeDistributed(Dense(128)))
|
||||||
|
#
|
||||||
|
# model.compile(optimizer='adam',
|
||||||
|
# loss='categorical_crossentropy',
|
||||||
|
# metrics=['accuracy'])
|
||||||
|
|
||||||
n_samples = train_X.shape[0]
|
model = Sequential()
|
||||||
input_shape = settings.midi_resolution*128
|
model.add(LSTM(128,input_shape=(96, 128),return_sequences=True))
|
||||||
train_X = train_X.reshape(n_samples, input_shape)
|
model.add(Dropout(0.3))
|
||||||
|
model.add(LSTM(512, return_sequences=True))
|
||||||
|
model.add(Dropout(0.3))
|
||||||
|
model.add(LSTM(512))
|
||||||
|
model.add(Dense(512))
|
||||||
|
# model.add(Dropout(0.3))
|
||||||
|
# model.add(Dense(128))
|
||||||
|
model.add(Activation('softmax'))
|
||||||
|
model.add(Reshape((4, 128)))
|
||||||
|
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||||||
|
|
||||||
# encoder model
|
# load training data
|
||||||
input_img = tf.keras.layers.Input(shape=(input_shape,))
|
print('Reading samples from: {}'.format(train_data_path_X))
|
||||||
encoded = tf.keras.layers.Dense(160, activation='relu')(input_img)
|
train_X = np.load(train_data_path_X)['arr_0']
|
||||||
decoded = tf.keras.layers.Dense(input_shape, activation='sigmoid')(encoded)
|
train_y = np.load(train_data_path_y)['arr_0']
|
||||||
autoencoder = tf.keras.models.Model(input_img, decoded)
|
|
||||||
|
|
||||||
autoencoder.compile(optimizer='adam',
|
# model training
|
||||||
loss='binary_crossentropy',
|
model.fit(train_X, train_y, epochs=epochs, batch_size=32)
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
autoencoder.fit(train_X, train_X, epochs=settings.epochs, batch_size=32)
|
# save trained model
|
||||||
|
pickle_path = '{}.pickle'.format(save_model_path)
|
||||||
autoencoder.save_weights(settings.model_path)
|
pickle.dump(model, open(pickle_path,'wb'))
|
||||||
print("Model save to {}".format(settings.model_path))
|
print("Model save to {}".format(pickle_path))
|
||||||
|
Loading…
Reference in New Issue
Block a user