Update 'simple_regression_lab7.py'

This commit is contained in:
Kacper Dudzic 2022-05-05 22:55:45 +02:00
parent debf8a3d66
commit f3f9e656e4
1 changed files with 3 additions and 5 deletions

View File

@ -3,8 +3,7 @@ from keras import layers
from keras.models import save_model from keras.models import save_model
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as pltsave_git_info=False
import sys
from sacred import Experiment from sacred import Experiment
from sacred.observers import FileStorageObserver from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver from sacred.observers import MongoObserver
@ -23,13 +22,12 @@ ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2
def config(): def config():
units = 1 units = 1
learning_rate = 0.1 learning_rate = 0.1
epochs = 100
# Reszta kodu wrzucona do udekorowanej funkcji train do wywołania przez Sacred, żeby coś było capture'owane # Reszta kodu wrzucona do udekorowanej funkcji train do wywołania przez Sacred, żeby coś było capture'owane
@ex.capture @ex.capture
def train(units, learning_rate, _run): def train(units, learning_rate, _run):
# Pobranie przykładowego argumentu trenowania z poziomu Jenkinsa
EPOCHS_NUM = int(sys.argv[1])
# Wczytanie danych # Wczytanie danych
data_train = pd.read_csv('lego_sets_clean_train.csv') data_train = pd.read_csv('lego_sets_clean_train.csv')
@ -61,7 +59,7 @@ def train(units, learning_rate, _run):
history = model.fit( history = model.fit(
train_piece_counts, train_piece_counts,
train_prices, train_prices,
epochs=EPOCHS_NUM, epochs=epochs,
verbose=0, verbose=0,
validation_split=0.2 validation_split=0.2
) )