Zaktualizuj 'sacred.py'

This commit is contained in:
Michał Dudziak 2023-05-10 11:35:53 +02:00
parent 59346b7215
commit 01f7eccb82

View File

@ -34,33 +34,25 @@ def run_experiment():
feature_cols = ['year', 'mileage', 'vol_engine']
inputs = tf.keras.Input(shape=(len(feature_cols),))
# Warstwy sieci neuronowej
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
x = tf.keras.layers.Dense(10, activation='relu')(x)
outputs = tf.keras.layers.Dense(1, activation='linear')(x)
# Utworzenie modelu
model = tf.keras.Model(inputs=inputs, outputs=outputs)
# Kompilacja modelu
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mse', metrics=['mae'])
# Trenowanie modelu
model.fit(cars_train[feature_cols], cars_train['price'], epochs=100)
# Zapis plików wejściowych
ex.add_resource('train_data.csv')
ex.add_resource('test_data.csv')
# Zapis kodu źródłowego
ex.add_artifact(__file__)
# Zapis modelu do pliku
model.save('model.h5')
ex.add_artifact('model.h5')
# Zapisanie metryk
metrics = model.evaluate(cars_train[feature_cols], cars_train['price'])
ex.log_scalar('mse', metrics[0])
ex.log_scalar('mae', metrics[1])