add parameter fix
This commit is contained in:
parent
303bfb0193
commit
21fc9d3677
@ -23,4 +23,4 @@ Y_test.to_csv('Y_test.csv', index=False)
|
|||||||
Y_train.to_csv('Y_train.csv', index=False)
|
Y_train.to_csv('Y_train.csv', index=False)
|
||||||
Y_dev.to_csv('Y_dev.csv', index=False)
|
Y_dev.to_csv('Y_dev.csv', index=False)
|
||||||
|
|
||||||
train.main()
|
train.main(0)
|
||||||
|
8
train.py
8
train.py
@ -5,8 +5,9 @@ import tensorflow
|
|||||||
from keras.applications.densenet import layers
|
from keras.applications.densenet import layers
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(EPOCHS):
|
||||||
EPOCHS = int(os.environ['EPOCHS'])
|
if EPOCHS == 0:
|
||||||
|
EPOCHS = 500
|
||||||
train_data_x = pd.read_csv('./X_train.csv')
|
train_data_x = pd.read_csv('./X_train.csv')
|
||||||
|
|
||||||
adults_train = train_data_x.copy()
|
adults_train = train_data_x.copy()
|
||||||
@ -30,4 +31,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
EPOCHS = int(os.environ['EPOCHS'])
|
||||||
|
main(EPOCHS)
|
||||||
|
Loading…
Reference in New Issue
Block a user