Zaktualizuj 'sacred.py'
This commit is contained in:
parent
59346b7215
commit
01f7eccb82
@ -34,33 +34,25 @@ def run_experiment():
|
|||||||
feature_cols = ['year', 'mileage', 'vol_engine']
|
feature_cols = ['year', 'mileage', 'vol_engine']
|
||||||
inputs = tf.keras.Input(shape=(len(feature_cols),))
|
inputs = tf.keras.Input(shape=(len(feature_cols),))
|
||||||
|
|
||||||
# Warstwy sieci neuronowej
|
|
||||||
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
|
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
|
||||||
x = tf.keras.layers.Dense(10, activation='relu')(x)
|
x = tf.keras.layers.Dense(10, activation='relu')(x)
|
||||||
outputs = tf.keras.layers.Dense(1, activation='linear')(x)
|
outputs = tf.keras.layers.Dense(1, activation='linear')(x)
|
||||||
|
|
||||||
# Utworzenie modelu
|
|
||||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
# Kompilacja modelu
|
|
||||||
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
||||||
loss='mse', metrics=['mae'])
|
loss='mse', metrics=['mae'])
|
||||||
|
|
||||||
# Trenowanie modelu
|
|
||||||
model.fit(cars_train[feature_cols], cars_train['price'], epochs=100)
|
model.fit(cars_train[feature_cols], cars_train['price'], epochs=100)
|
||||||
|
|
||||||
# Zapis plików wejściowych
|
|
||||||
ex.add_resource('train_data.csv')
|
ex.add_resource('train_data.csv')
|
||||||
ex.add_resource('test_data.csv')
|
ex.add_resource('test_data.csv')
|
||||||
|
|
||||||
# Zapis kodu źródłowego
|
|
||||||
ex.add_artifact(__file__)
|
ex.add_artifact(__file__)
|
||||||
|
|
||||||
# Zapis modelu do pliku
|
|
||||||
model.save('model.h5')
|
model.save('model.h5')
|
||||||
ex.add_artifact('model.h5')
|
ex.add_artifact('model.h5')
|
||||||
|
|
||||||
# Zapisanie metryk
|
|
||||||
metrics = model.evaluate(cars_train[feature_cols], cars_train['price'])
|
metrics = model.evaluate(cars_train[feature_cols], cars_train['price'])
|
||||||
ex.log_scalar('mse', metrics[0])
|
ex.log_scalar('mse', metrics[0])
|
||||||
ex.log_scalar('mae', metrics[1])
|
ex.log_scalar('mae', metrics[1])
|
Loading…
Reference in New Issue
Block a user