import sys import pandas as pd import numpy as np import tensorflow as tf import os.path from sacred import Experiment from sacred.observers import FileStorageObserver, MongoObserver from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.layers.experimental import preprocessing exp = Experiment("s434704", interactive=False, save_git_info=False) exp.observers.append(FileStorageObserver("sacred_file")) exp.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017', db_name="sacred")) @exp.config def my_config(): verbose = 0 epochs = 100 @exp.capture def training(verbose, epochs, _log, _run): pd.set_option("display.max_columns", None) # Wczytanie danych train_data = pd.read_csv("./MoviesOnStreamingPlatforms_updated.train") # Stworzenie modelu columns_to_use = ['Year', 'Runtime', 'Netflix'] train_X = tf.convert_to_tensor(train_data[columns_to_use]) train_Y = tf.convert_to_tensor(train_data[["IMDb"]]) normalizer = preprocessing.Normalization(input_shape=[3,]) normalizer.adapt(train_X) model = keras.Sequential([ keras.Input(shape=(len(columns_to_use),)), normalizer, layers.Dense(30, activation='relu'), layers.Dense(10, activation='relu'), layers.Dense(25, activation='relu'), layers.Dense(1) ]) model.compile(loss='mean_absolute_error', optimizer=tf.keras.optimizers.Adam(0.001), metrics=[tf.keras.metrics.RootMeanSquaredError()]) params = f"Verbose: {verbose}, Epochs: {epochs}" _log.info(params) model.fit(train_X, train_Y, verbose=verbose, epochs=epochs) model.save('linear_regression.h5') # Evaluation test_data = pd.read_csv("./MoviesOnStreamingPlatforms_updated.test") columns_to_use = ['Year', 'Runtime', 'Netflix'] test_X = tf.convert_to_tensor(test_data[columns_to_use]) test_Y = tf.convert_to_tensor(test_data[["IMDb"]]) scores = model.evaluate(x=test_X, y=test_Y) _run.log_scalar("training.RMSE", scores[1]) @exp.automain def run(verbose, epochs): training() runner = exp.run() exp.add_source_file("./training.py") exp.add_artifact("linear_regression.h5")