Zaktualizuj 'sacred.py'
This commit is contained in:
parent
59346b7215
commit
01f7eccb82
@ -34,33 +34,25 @@ def run_experiment():
|
||||
feature_cols = ['year', 'mileage', 'vol_engine']
|
||||
inputs = tf.keras.Input(shape=(len(feature_cols),))
|
||||
|
||||
# Warstwy sieci neuronowej
|
||||
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
|
||||
x = tf.keras.layers.Dense(10, activation='relu')(x)
|
||||
outputs = tf.keras.layers.Dense(1, activation='linear')(x)
|
||||
|
||||
# Utworzenie modelu
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
# Kompilacja modelu
|
||||
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
||||
loss='mse', metrics=['mae'])
|
||||
|
||||
# Trenowanie modelu
|
||||
model.fit(cars_train[feature_cols], cars_train['price'], epochs=100)
|
||||
|
||||
# Zapis plików wejściowych
|
||||
ex.add_resource('train_data.csv')
|
||||
ex.add_resource('test_data.csv')
|
||||
|
||||
# Zapis kodu źródłowego
|
||||
ex.add_artifact(__file__)
|
||||
|
||||
# Zapis modelu do pliku
|
||||
model.save('model.h5')
|
||||
ex.add_artifact('model.h5')
|
||||
|
||||
# Zapisanie metryk
|
||||
metrics = model.evaluate(cars_train[feature_cols], cars_train['price'])
|
||||
ex.log_scalar('mse', metrics[0])
|
||||
ex.log_scalar('mae', metrics[1])
|
Loading…
Reference in New Issue
Block a user