IUM_07
This commit is contained in:
parent
eda9b367de
commit
d8769d16e1
54
sacred_model.py
Normal file
54
sacred_model.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from sacred import Experiment
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras import Sequential
|
||||||
|
from tensorflow.keras.layers import Dense
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
|
||||||
|
ex = Experiment("car_price_prediction")
|
||||||
|
|
||||||
|
ex.observers.append(FileStorageObserver('my_runs'))
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def my_config():
|
||||||
|
data_path = './data/car_prices_train.csv'
|
||||||
|
epochs = 20
|
||||||
|
batch_size = 32
|
||||||
|
model_path = './car_prices_predict_model.h5'
|
||||||
|
|
||||||
|
|
||||||
|
@ex.main
|
||||||
|
def train_model(data_path, epochs, batch_size, model_path, _run):
|
||||||
|
train_data = pd.read_csv(data_path)
|
||||||
|
train_data.dropna(inplace=True)
|
||||||
|
|
||||||
|
y_train = train_data['sellingprice'].astype(np.float32)
|
||||||
|
X_train = train_data[['year', 'condition', 'transmission']]
|
||||||
|
scaler_x = MinMaxScaler()
|
||||||
|
X_train['condition'] = scaler_x.fit_transform(X_train[['condition']])
|
||||||
|
scaler_y = MinMaxScaler()
|
||||||
|
y_train = scaler_y.fit_transform(y_train.values.reshape(-1, 1))
|
||||||
|
X_train = pd.get_dummies(X_train, columns=['transmission'])
|
||||||
|
|
||||||
|
model = Sequential([
|
||||||
|
Dense(64, activation='relu'),
|
||||||
|
Dense(32, activation='relu'),
|
||||||
|
Dense(1)
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(optimizer='adam', loss='mean_squared_error')
|
||||||
|
model.summary(print_fn=lambda x: _run.info.setdefault('model_summary', []).append(x))
|
||||||
|
history = model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size)
|
||||||
|
|
||||||
|
for epoch, loss in enumerate(history.history['loss']):
|
||||||
|
_run.log_scalar("loss", loss, epoch)
|
||||||
|
|
||||||
|
model.save(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ex.run()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user