Update 'train.py'

This commit is contained in:
Tomasz Koszarek 2023-09-30 01:02:27 +02:00
parent 5cfd930e4f
commit d4942a6d30

View File

@ -19,38 +19,35 @@ exint.observers.append(FileStorageObserver('z-s487174-training/master'))
@exint.config
def my_config():
EPOCHS = int(os.environ['EPOCHS'])
@exint.main
def main(EPOCHS, _run):
if EPOCHS == 0:
EPOCHS = 500
train_data_x = pd.read_csv('./train_data.csv')
train_data_x_tmp = train_data_x.copy()
price_train = train_data_x.copy()
price_predict = train_data_x_tmp.pop('Index')
def my_main(EPOCHS, _run):
_run.info["epochs"] = EPOCHS,
normalize = layers.Normalization()
normalize.adapt(price_train)
price_model = tensorflow.keras.Sequential([
train_data_x = pandas.read_csv('./data_train.csv')
_run.info["dataset"] = train_data_x
dat_all = train_data_x.copy()
dat_predict = train_data_x.pop('Price')
normalize.adapt(dat_all)
norm_model = tensorflow.keras.Sequential([
normalize,
layers.Dense(64),
layers.Dense(1)
])
price_model.compile(
norm_model.compile(
loss=tensorflow.keras.losses.MeanSquaredError(),
optimizer=tensorflow.keras.optimizers.Adam())
price_model.fit(price_train, price_predict, epochs=EPOCHS)
norm_model.fit(dat_all, dat_predict, epochs=EPOCHS)
counter = 0
while counter < 20:
counter+=1
value = counter * random.randint(5, 5000)
_run.log_scalar("training.accuracy", value * 2)
norm_model.save('test')
price_model.save('model')
if __name__ == "__main__":
EPOCHS = int(os.environ['EPOCHS'])
main(EPOCHS)
exint.run()
exint.add_artifact('saved_model.pb')
exint.add_artifact('test/saved_model.pb')
exint.add_source_file('./ium_z487174/train.py')