praca-magisterska/project/generate.py

40 lines
1.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
2019-05-28 12:40:26 +02:00
import numpy as np
import midi
2019-05-28 12:40:26 +02:00
import tensorflow as tf
import pypianoroll as roll
2019-05-28 12:40:26 +02:00
from keras.layers import Input, Dense, Conv2D
from keras.models import Model
from tensorflow.keras import layers
from keras.layers import Input, Dense, Conv2D, Flatten, LSTM, Dropout, TimeDistributed, RepeatVector
from keras.models import Model, Sequential
import matplotlib.pyplot as plt
2019-05-28 12:40:26 +02:00
import settings
import pickle
import sys
trained_model_path = sys.argv[1]
output_path = sys.argv[2]
treshold = float(sys.argv[3])
# load and predict
model = pickle.load(open(trained_model_path, 'rb'))
music = np.empty((4,96,128))
for x in range(4):
generate_seed = np.random.randint(0, 127, 12288).reshape(1,96,128)
music[x] = model.predict(generate_seed).reshape(96,128)
generated_sample = music.reshape(4*96,128)
# binarize generated music
generated_sample = generated_sample > treshold * generated_sample.max()
# generated_sample = np.clip(generated_sample,0,1) * 128
2019-05-28 12:40:26 +02:00
# save to midi
generated_midi = midi.to_midi(generated_sample, output_path='{}.mid'.format(output_path), is_drum=True, program=0, )
2019-05-28 12:40:26 +02:00
#save plot for preview
roll.plot(generated_midi, filename='{}.png'.format(output_path))