Update 'train.py'
This commit is contained in:
parent
5cfd930e4f
commit
d4942a6d30
39
train.py
39
train.py
@ -19,38 +19,35 @@ exint.observers.append(FileStorageObserver('z-s487174-training/master'))
|
||||
@exint.config
|
||||
def my_config():
|
||||
EPOCHS = int(os.environ['EPOCHS'])
|
||||
|
||||
|
||||
@exint.main
|
||||
def main(EPOCHS, _run):
|
||||
if EPOCHS == 0:
|
||||
EPOCHS = 500
|
||||
train_data_x = pd.read_csv('./train_data.csv')
|
||||
train_data_x_tmp = train_data_x.copy()
|
||||
|
||||
price_train = train_data_x.copy()
|
||||
price_predict = train_data_x_tmp.pop('Index')
|
||||
def my_main(EPOCHS, _run):
|
||||
_run.info["epochs"] = EPOCHS,
|
||||
normalize = layers.Normalization()
|
||||
normalize.adapt(price_train)
|
||||
|
||||
price_model = tensorflow.keras.Sequential([
|
||||
train_data_x = pandas.read_csv('./data_train.csv')
|
||||
_run.info["dataset"] = train_data_x
|
||||
dat_all = train_data_x.copy()
|
||||
dat_predict = train_data_x.pop('Price')
|
||||
normalize.adapt(dat_all)
|
||||
norm_model = tensorflow.keras.Sequential([
|
||||
normalize,
|
||||
layers.Dense(64),
|
||||
layers.Dense(1)
|
||||
])
|
||||
|
||||
price_model.compile(
|
||||
norm_model.compile(
|
||||
loss=tensorflow.keras.losses.MeanSquaredError(),
|
||||
optimizer=tensorflow.keras.optimizers.Adam())
|
||||
|
||||
price_model.fit(price_train, price_predict, epochs=EPOCHS)
|
||||
norm_model.fit(dat_all, dat_predict, epochs=EPOCHS)
|
||||
counter = 0
|
||||
while counter < 20:
|
||||
counter+=1
|
||||
value = counter * random.randint(5, 5000)
|
||||
_run.log_scalar("training.accuracy", value * 2)
|
||||
norm_model.save('test')
|
||||
|
||||
price_model.save('model')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
EPOCHS = int(os.environ['EPOCHS'])
|
||||
main(EPOCHS)
|
||||
|
||||
exint.run()
|
||||
exint.add_artifact('saved_model.pb')
|
||||
exint.add_artifact('test/saved_model.pb')
|
||||
exint.add_source_file('./ium_z487174/train.py')
|
Loading…
Reference in New Issue
Block a user