Update 'simple_regression_lab7.py'

This commit is contained in:
Kacper Dudzic 2022-05-05 22:55:45 +02:00
parent debf8a3d66
commit f3f9e656e4
1 changed files with 3 additions and 5 deletions

View File

@ -3,8 +3,7 @@ from keras import layers
from keras.models import save_model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import matplotlib.pyplot as pltsave_git_info=False
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
@ -23,13 +22,12 @@ ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2
def config():
units = 1
learning_rate = 0.1
epochs = 100
# Reszta kodu wrzucona do udekorowanej funkcji train do wywołania przez Sacred, żeby coś było capture'owane
@ex.capture
def train(units, learning_rate, _run):
# Pobranie przykładowego argumentu trenowania z poziomu Jenkinsa
EPOCHS_NUM = int(sys.argv[1])
# Wczytanie danych
data_train = pd.read_csv('lego_sets_clean_train.csv')
@ -61,7 +59,7 @@ def train(units, learning_rate, _run):
history = model.fit(
train_piece_counts,
train_prices,
epochs=EPOCHS_NUM,
epochs=epochs,
verbose=0,
validation_split=0.2
)