lstm - drop this branch, looking for other way to generate music
This commit is contained in:
parent
1997ff96ef
commit
35c19e1e80
@ -1,26 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import numpy as np
|
||||
import midi
|
||||
import tensorflow as tf
|
||||
from keras.layers import Input, Dense, Conv2D
|
||||
from keras.models import Model
|
||||
from tensorflow.keras import layers
|
||||
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector
|
||||
from keras.models import Model, Sequential
|
||||
import matplotlib.pyplot as plt
|
||||
import settings
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
#model
|
||||
input_shape = settings.midi_resolution*128
|
||||
input_img = tf.keras.layers.Input(shape=(input_shape,))
|
||||
encoded = tf.keras.layers.Dense(160, activation='relu')(input_img)
|
||||
decoded = tf.keras.layers.Dense(input_shape, activation='sigmoid')(encoded)
|
||||
autoencoder = tf.keras.models.Model(input_img, decoded)
|
||||
trained_model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
# treshold = float(sys.argv[3])
|
||||
|
||||
autoencoder.compile(optimizer='adadelta',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
#random seed
|
||||
generate_seed = np.random.rand(12288).reshape(1,96,128)
|
||||
|
||||
# load weights into new model
|
||||
autoencoder.load_weights(settings.model_path)
|
||||
print("Loaded model from {}".format(settings.model_path))
|
||||
# load and predict
|
||||
model = pickle.load(open(trained_model_path, 'rb'))
|
||||
|
||||
# generate_seed = np.random.rand(12288).reshape(1,12288)
|
||||
generate_seed = np.load(settings.samples_path)['arr_0'][15].reshape(1,12288)
|
||||
generated_music = np.empty((0,128))
|
||||
|
||||
generated_sample = autoencoder.predict(generate_seed)
|
||||
np.savez_compressed(settings.generated_sample_path, generated_sample)
|
||||
for note in range(100):
|
||||
generated_vector = model.predict(generate_seed).reshape(1,4,128)
|
||||
generated_notes = np.zeros((4,128))
|
||||
for i, col in enumerate(generated_vector[0]):
|
||||
best_note = np.argmax(col)
|
||||
generated_notes[i][best_note] = 1
|
||||
|
||||
generate_seed = np.concatenate([generated_notes, generate_seed[0][:-4]]).reshape(1,96,128)
|
||||
generated_music = np.concatenate([generated_music, generated_notes])
|
||||
|
||||
# generated_sample = generated_sample.reshape(96,128)
|
||||
generated_sample = generated_music
|
||||
# print(generated_music)
|
||||
# binarize generated music
|
||||
# generated_sample = generated_sample > 0 * generated_sample.max()
|
||||
|
||||
#save to midi
|
||||
midi.to_midi(generated_sample, output_path='{}.mid'.format(output_path) )
|
||||
|
||||
#save piano roll to png
|
||||
plt.imshow(generated_sample, cmap = plt.get_cmap('gray'))
|
||||
plt.savefig('{}.png'.format(output_path))
|
||||
|
148
project/midi.py
148
project/midi.py
@ -8,36 +8,86 @@ import os
|
||||
from tqdm import tqdm
|
||||
from math import floor
|
||||
import sys
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from tqdm import trange
|
||||
from collections import defaultdict
|
||||
|
||||
import bz2
|
||||
import pickle
|
||||
|
||||
def to_samples(midi_file_path, midi_res=settings.midi_resolution):
|
||||
|
||||
# this function export a samples from midi file:
|
||||
# and for every track in midi file chopped pianoroll
|
||||
# for a samples of given beat_lenth (midi_res)
|
||||
# every track is single line
|
||||
# TODO: add transpositions of every sample to every possible key transposition
|
||||
# np.roll(sample, pitch_interval, axis=1) for transposition
|
||||
# np.roll(sample, time_steps, axis=0) for time shifting
|
||||
all_X_samples = []
|
||||
all_y_samples = []
|
||||
for track in roll.Multitrack(midi_file_path).tracks:
|
||||
if not track.is_drum:
|
||||
# TODO: this makes rollable samples and dataset of y_train for prdiction
|
||||
# the idea is to predict next N timesteps from prevous M timesteps
|
||||
m_timesteps = 96
|
||||
n_next_notes = 4
|
||||
|
||||
print('Exporting samples from: {}'.format(midi_file_path))
|
||||
track_timesteps = track.pianoroll.shape[0] - (m_timesteps + n_next_notes)
|
||||
|
||||
all_beats = np.empty((0, settings.midi_resolution, 128))
|
||||
X_track_samples = []
|
||||
y_track_samples = []
|
||||
for i in range(track_timesteps):
|
||||
X = track.pianoroll[i : i + m_timesteps].reshape(96,128)
|
||||
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(4,128)
|
||||
X_track_samples.append(X)
|
||||
y_track_samples.append(y)
|
||||
|
||||
all_X_samples.extend(X_track_samples)
|
||||
all_y_samples.extend(y_track_samples)
|
||||
else:
|
||||
# TODO: add code for drums samples
|
||||
pass
|
||||
return all_X_samples, all_y_samples
|
||||
|
||||
def to_samples_by_instrument(midi_file_path, midi_res=settings.midi_resolution):
|
||||
|
||||
# add transpositions of every sample to every possible key transposition
|
||||
# np.roll(sample, pitch_interval, axis=1) for transposition
|
||||
# np.roll(sample, time_steps, axis=0) for time shifting
|
||||
|
||||
# TODO: make rollable samples with train_Y set
|
||||
|
||||
fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
|
||||
samples_by_instrument = defaultdict(fill_empty_array)
|
||||
all_beats = np.empty((0, 96, 128))
|
||||
|
||||
for track in roll.Multitrack(midi_file_path).tracks:
|
||||
print('Track: {}'.format(track.name))
|
||||
if not track.is_drum:
|
||||
number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
|
||||
track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
|
||||
track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
|
||||
all_beats = np.concatenate([track_beats, all_beats], axis=0)
|
||||
key = track.program + 1
|
||||
# TODO: this makes pack of samples of N x 96 x 128 shape
|
||||
# number_of_beats = floor(track.pianoroll.shape[0] / midi_res)
|
||||
# track_pianoroll = track.pianoroll[: number_of_beats * midi_res]
|
||||
# track_beats = track_pianoroll.reshape(number_of_beats, midi_res, 128)
|
||||
|
||||
print('Exported {} samples of {}'.format(number_of_beats, settings.midi_program[track.program]))
|
||||
# TODO: this makes rollable samples and dataset of y_train for prdiction
|
||||
# the idea is to predict next n notes from prevous m timesteps
|
||||
m_timesteps = 96
|
||||
n_next_notes = 4
|
||||
for i, value in tqdm(enumerate(track.pianoroll[:-(m_timesteps + n_next_notes)])):
|
||||
X = track.pianoroll[i : i + m_timesteps].reshape(1,96,128)
|
||||
y = track.pianoroll[i + m_timesteps : i + m_timesteps + n_next_notes].reshape(1,1,128)
|
||||
|
||||
samples_by_instrument[key][0] = np.concatenate([X, samples_by_instrument[ key ][0]], axis=0)
|
||||
samples_by_instrument[key][1] = np.concatenate([y, samples_by_instrument[ key ][1]], axis=0)
|
||||
|
||||
# samples_by_instrument[track.program + 1][0] = np.concatenate([track_beats, samples_by_instrument[ track.program + 1]], axis=0)
|
||||
else:
|
||||
# add code for drums samples
|
||||
# TODO: add code for drums samples
|
||||
pass
|
||||
return all_beats
|
||||
return samples_by_instrument
|
||||
|
||||
def to_midi(samples, output_path=settings.generated_midi_path, program=0, tempo=120, beat_resolution=settings.beat_resolution):
|
||||
tracks = [roll.Track(samples, program=program)]
|
||||
return_midi = roll.Multitrack(tracks=tracks, tempo=tempo, downbeat=[0, 96, 192, 288], beat_resolution=beat_resolution)
|
||||
roll.write(return_midi, settings.generated_midi_path)
|
||||
roll.write(return_midi, output_path)
|
||||
|
||||
# todo: this function is running too slow.
|
||||
def delete_empty_samples(sample_pack):
|
||||
@ -52,31 +102,61 @@ def delete_empty_samples(sample_pack):
|
||||
return temp_sample_pack
|
||||
|
||||
def main():
|
||||
print('Exporting...')
|
||||
|
||||
if sys.argv[1]=='export':
|
||||
print('Exporting started...')
|
||||
# from collections import defaultdict
|
||||
# fill_empty_array = lambda : [ np.empty((0, 96, 128)) , np.empty((0, 1, 128)) ]
|
||||
# samples_pack_by_instrument = defaultdict(fill_empty_array)
|
||||
|
||||
sample_pack = np.empty((0,settings.midi_resolution,128))
|
||||
# sample_pack = np.empty((0,settings.midi_resolution,128))
|
||||
X_train = []
|
||||
y_train = []
|
||||
|
||||
for midi_file in os.listdir(settings.midi_dir):
|
||||
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
||||
midi_samples = to_samples(midi_file_path)
|
||||
if midi_samples is None:
|
||||
continue
|
||||
sample_pack = np.concatenate((midi_samples, sample_pack), axis=0)
|
||||
for midi_file in tqdm(os.listdir(settings.midi_dir)):
|
||||
print(midi_file)
|
||||
midi_file_path = '{}/{}'.format(settings.midi_dir, midi_file)
|
||||
X, y = to_samples(midi_file_path)
|
||||
# if midi_samples is None:
|
||||
# continue
|
||||
X_train.extend(X)
|
||||
y_train.extend(y)
|
||||
# this is for intrument separation
|
||||
# for key, value in midi_samples.items():
|
||||
# samples_pack_by_instrument[key][0] = np.concatenate((samples_pack_by_instrument[key][0], value[0]), axis=0)
|
||||
# samples_pack_by_instrument[key][1] = np.concatenate((samples_pack_by_instrument[key][1], value[1]), axis=0)
|
||||
|
||||
# I commented out this line, because it was too slow
|
||||
# sample_pack = delete_empty_samples(sample_pack)
|
||||
# TODO: Delete empty samples
|
||||
# sample_pack = delete_empty_samples(sample_pack)
|
||||
|
||||
np.savez_compressed(settings.samples_dir, sample_pack)
|
||||
print('Exported {} samples'.format(sample_pack.shape[0]))
|
||||
# save as compressed pickle (sample-dictionary)
|
||||
# sfile = bz2.BZ2File('data/samples.pickle', 'w')
|
||||
# pickle.dump(dict(samples_pack_by_instrument), sfile)
|
||||
|
||||
fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
||||
for idx, ax in enumerate(axes.ravel()):
|
||||
n = np.random.randint(0, sample_pack.shape[0])
|
||||
sample = sample_pack[n]
|
||||
ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
||||
plt.savefig(settings.sample_preview_path)
|
||||
# this is for intrument separation
|
||||
# print('Saving...')
|
||||
# for key, value in tqdm(samples_pack_by_instrument.items()):
|
||||
# np.savez_compressed('data/samples/X_{}.npz'.format(settings.midi_program[key][0]), value)
|
||||
# np.savez_compressed('data/samples/y_{}.npz'.format(settings.midi_program[key][1]), value)
|
||||
|
||||
# this if for one big list
|
||||
print('Saving...')
|
||||
|
||||
np_X_train = np.array(X_train)
|
||||
np_y_train = np.array(y_train)
|
||||
print(np_X_train.shape, np_y_train.shape)
|
||||
np.savez_compressed('data/samples/X_{}.npz'.format(1), np_X_train)
|
||||
np.savez_compressed('data/samples/y_{}.npz'.format(1), np_y_train)
|
||||
|
||||
# Give a preview of what samples looks like
|
||||
# fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(20, 20))
|
||||
# for idx, ax in enumerate(axes.ravel()):
|
||||
# n = np.random.randint(0, value[0].shape[0])
|
||||
# sample = value[n]
|
||||
# ax.imshow(sample, cmap = plt.get_cmap('gray'))
|
||||
# plt.savefig('data/samples/{}.png'.format(settings.midi_program[key]))
|
||||
|
||||
print('Exported {} samples'.format(np_X_train.shape[0]))
|
||||
print('Done!')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -1,16 +1,28 @@
|
||||
## MUSIC GENERATION USING DEEP LEARNING ##
|
||||
## AUTHOR: CEZARY PUKOWNIK
|
||||
|
||||
files:
|
||||
- midi.py - code for data extraction, and midi convertion
|
||||
- train.py - code for model definition, and training session
|
||||
- generate.py - code for model loading, predicting ang saving to midi_dir
|
||||
- settings.py - file where deafult settings are stored
|
||||
- readme - this file
|
||||
- data/midi - directory where input midi are stored
|
||||
- data/models - directory where trained models are stored
|
||||
- data/output - directory where generated music is stored
|
||||
- data/samples - directory where extracted data from midi is stored
|
||||
- data/samples.npz - deprecated
|
||||
|
||||
How to use:
|
||||
|
||||
1. Use midi.py to export data from midi files
|
||||
|
||||
./midi.py export <midi_folder_path> <output_path>
|
||||
./midi.py <midi_folder_path> <output_path>
|
||||
|
||||
2. Use train.py to train a model (this can take a while)
|
||||
|
||||
./train.py <input_training_data> <model_save_path>
|
||||
./train.py <input_training_data> <model_save_path> <epochs>
|
||||
|
||||
3. Use generate.py to generate music from trained models
|
||||
|
||||
./generate.py <model_weights_path> <output_path>
|
||||
./generate.py <trained_model_path> <output_path> <treshold>
|
||||
|
@ -1,18 +0,0 @@
|
||||
import pypianoroll as roll
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import settings
|
||||
|
||||
instruments = np.load(settings.generated_sample_path)['arr_0'][0]
|
||||
|
||||
instruments = instruments.reshape(96,128)
|
||||
# instruments = instruments>0.5
|
||||
instruments = instruments*255
|
||||
|
||||
i = roll.Track(instruments, program=0)
|
||||
generated_midi = roll.Multitrack(tracks=[i], tempo=120.0, downbeat=[0, 96, 192, 288], beat_resolution=24)
|
||||
roll.write(generated_midi, settings.generated_midi_path)
|
||||
|
||||
plt.imshow(instruments.T, cmap='gray')
|
||||
plt.savefig(settings.generated_pianoroll_path)
|
@ -15,11 +15,11 @@ beats_per_sample = 1
|
||||
ignore_note_lenght = False
|
||||
|
||||
#train_settings
|
||||
epochs = 1000
|
||||
epochs = 1
|
||||
|
||||
#extras
|
||||
midi_program = {
|
||||
0 : 'Perc',
|
||||
# Piano
|
||||
1 : 'Acoustic Grand Piano',
|
||||
2 : 'Bright Acoustic Piano',
|
||||
3 : 'Electric Grand Piano',
|
||||
@ -28,6 +28,7 @@ midi_program = {
|
||||
6 : 'Electric Piano 2',
|
||||
7 : 'Harpsichord',
|
||||
8 : 'Clavi',
|
||||
# Chromatic Percussion
|
||||
9 : 'Celesta',
|
||||
10 : 'Glockenspiel',
|
||||
11 : 'Music Box',
|
||||
@ -36,6 +37,7 @@ midi_program = {
|
||||
14 : 'Xylophone',
|
||||
15 : 'Tubular Bells',
|
||||
16 : 'Dulcimer',
|
||||
# Organ
|
||||
17 : 'Drawbar Organ',
|
||||
18 : 'Percussive Organ',
|
||||
19 : 'Rock Organ',
|
||||
@ -44,6 +46,7 @@ midi_program = {
|
||||
22 : 'Accordion',
|
||||
23 : 'Harmonica',
|
||||
24 : 'Tango Accordion',
|
||||
# Guitar
|
||||
25 : 'Acoustic Guitar (nylon)',
|
||||
26 : 'Acoustic Guitar (steel)',
|
||||
27 : 'Electric Guitar (jazz)',
|
||||
@ -52,6 +55,7 @@ midi_program = {
|
||||
30 : 'Overdriven Guitar',
|
||||
31 : 'Distortion Guitar',
|
||||
32 : 'Guitar harmonics',
|
||||
# Bass
|
||||
33 : 'Acoustic Bass',
|
||||
34 : 'Electric Bass (finger)',
|
||||
35 : 'Electric Bass (pick)',
|
||||
@ -60,6 +64,7 @@ midi_program = {
|
||||
38 : 'Slap Bass 2',
|
||||
39 : 'Synth Bass 1',
|
||||
40 : 'Synth Bass 2',
|
||||
# Strings
|
||||
41 : 'Violin',
|
||||
42 : 'Viola',
|
||||
43 : 'Cello',
|
||||
@ -68,6 +73,7 @@ midi_program = {
|
||||
46 : 'Pizzicato Strings',
|
||||
47 : 'Orchestral Harp',
|
||||
48 : 'Timpani',
|
||||
# Ensemble
|
||||
49 : 'String Ensemble 1',
|
||||
50 : 'String Ensemble 2',
|
||||
51 : 'SynthStrings 1',
|
||||
@ -76,6 +82,7 @@ midi_program = {
|
||||
54 : 'Voice Oohs',
|
||||
55 : 'Synth Voice',
|
||||
56 : 'Orchestra Hit',
|
||||
# Brass
|
||||
57 : 'Trumpet',
|
||||
58 : 'Trombone',
|
||||
59 : 'Tuba',
|
||||
@ -84,6 +91,7 @@ midi_program = {
|
||||
62 : 'Brass Section',
|
||||
63 : 'SynthBrass 1',
|
||||
64 : 'SynthBrass 2',
|
||||
# Reed
|
||||
65 : 'Soprano Sax',
|
||||
66 : 'Alto Sax',
|
||||
67 : 'Tenor Sax',
|
||||
@ -92,6 +100,7 @@ midi_program = {
|
||||
70 : 'English Horn',
|
||||
71 : 'Bassoon',
|
||||
72 : 'Clarinet',
|
||||
# Pipe
|
||||
73 : 'Piccolo',
|
||||
74 : 'Flute',
|
||||
75 : 'Recorder',
|
||||
@ -100,6 +109,7 @@ midi_program = {
|
||||
78 : 'Shakuhachi',
|
||||
79 : 'Whistle',
|
||||
80 : 'Ocarina',
|
||||
# Synth Lead
|
||||
81 : 'Lead 1 (square)',
|
||||
82 : 'Lead 2 (sawtooth)',
|
||||
83 : 'Lead 3 (calliope)',
|
||||
@ -108,6 +118,7 @@ midi_program = {
|
||||
86 : 'Lead 6 (voice)',
|
||||
87 : 'Lead 7 (fifths)',
|
||||
88 : 'Lead 8 (bass + lead)',
|
||||
# Synth Pad
|
||||
89 : 'Pad 1 (new age)',
|
||||
90 : 'Pad 2 (warm)',
|
||||
91 : 'Pad 3 (polysynth)',
|
||||
@ -116,6 +127,7 @@ midi_program = {
|
||||
94 : 'Pad 6 (metallic)',
|
||||
95 : 'Pad 7 (halo)',
|
||||
96 : 'Pad 8 (sweep)',
|
||||
# Synth Effects
|
||||
97 : 'FX 1 (rain)',
|
||||
98 : 'FX 2 (soundtrack)',
|
||||
99 : 'FX 3 (crystal)',
|
||||
@ -124,6 +136,7 @@ midi_program = {
|
||||
102 : 'FX 6 (goblins)',
|
||||
103 : 'FX 7 (echoes)',
|
||||
104 : 'FX 8 (sci-fi)',
|
||||
# Ethnic
|
||||
105 : 'Sitar',
|
||||
106 : 'Banjo',
|
||||
107 : 'Shamisen',
|
||||
@ -132,6 +145,7 @@ midi_program = {
|
||||
110 : 'Bag pipe',
|
||||
111 : 'Fiddle',
|
||||
112 : 'Shanai',
|
||||
# Percussive
|
||||
113 : 'Tinkle Bell',
|
||||
114 : 'Agogo',
|
||||
115 : 'Steel Drums',
|
||||
@ -140,6 +154,7 @@ midi_program = {
|
||||
118 : 'Melodic Tom',
|
||||
119 : 'Synth Drum',
|
||||
120 : 'Reverse Cymbal',
|
||||
# Sound Effects
|
||||
121 : 'Guitar Fret Noise',
|
||||
122 : 'Breath Noise',
|
||||
123 : 'Seashore',
|
||||
|
@ -3,31 +3,50 @@
|
||||
import tensorflow as tf
|
||||
import settings
|
||||
from tensorflow.keras import layers
|
||||
from keras.layers import Input, Dense, Conv2D, Flatten
|
||||
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector, Activation, Bidirectional, Reshape
|
||||
from keras.models import Model, Sequential
|
||||
import numpy as np
|
||||
from sys import exit
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
print('Reading samples from: {}'.format(settings.samples_path))
|
||||
train_data_path_X = sys.argv[1]
|
||||
train_data_path_y = sys.argv[2]
|
||||
save_model_path = sys.argv[3]
|
||||
epochs = int(sys.argv[4])
|
||||
|
||||
train_X = np.load(settings.samples_path)['arr_0']
|
||||
# model architecture
|
||||
# model = Sequential()
|
||||
# model.add(LSTM(128, activation='relu', input_shape=(96,128)))
|
||||
# model.add(RepeatVector(96))
|
||||
# model.add(LSTM(128, activation='softmax', return_sequences=True))
|
||||
# model.add(TimeDistributed(Dense(128)))
|
||||
#
|
||||
# model.compile(optimizer='adam',
|
||||
# loss='categorical_crossentropy',
|
||||
# metrics=['accuracy'])
|
||||
|
||||
n_samples = train_X.shape[0]
|
||||
input_shape = settings.midi_resolution*128
|
||||
train_X = train_X.reshape(n_samples, input_shape)
|
||||
model = Sequential()
|
||||
model.add(LSTM(128,input_shape=(96, 128),return_sequences=True))
|
||||
model.add(Dropout(0.3))
|
||||
model.add(LSTM(512, return_sequences=True))
|
||||
model.add(Dropout(0.3))
|
||||
model.add(LSTM(512))
|
||||
model.add(Dense(512))
|
||||
# model.add(Dropout(0.3))
|
||||
# model.add(Dense(128))
|
||||
model.add(Activation('softmax'))
|
||||
model.add(Reshape((4, 128)))
|
||||
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
|
||||
|
||||
# encoder model
|
||||
input_img = tf.keras.layers.Input(shape=(input_shape,))
|
||||
encoded = tf.keras.layers.Dense(160, activation='relu')(input_img)
|
||||
decoded = tf.keras.layers.Dense(input_shape, activation='sigmoid')(encoded)
|
||||
autoencoder = tf.keras.models.Model(input_img, decoded)
|
||||
# load training data
|
||||
print('Reading samples from: {}'.format(train_data_path_X))
|
||||
train_X = np.load(train_data_path_X)['arr_0']
|
||||
train_y = np.load(train_data_path_y)['arr_0']
|
||||
|
||||
autoencoder.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
# model training
|
||||
model.fit(train_X, train_y, epochs=epochs, batch_size=32)
|
||||
|
||||
autoencoder.fit(train_X, train_X, epochs=settings.epochs, batch_size=32)
|
||||
|
||||
autoencoder.save_weights(settings.model_path)
|
||||
print("Model save to {}".format(settings.model_path))
|
||||
# save trained model
|
||||
pickle_path = '{}.pickle'.format(save_model_path)
|
||||
pickle.dump(model, open(pickle_path,'wb'))
|
||||
print("Model save to {}".format(pickle_path))
|
||||
|
Loading…
Reference in New Issue
Block a user