ium_470607/lab5/train/train.py

37 lines
991 B
Python
Raw Normal View History

2021-05-02 22:01:32 +02:00
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
def onezero(label):
return 0 if label == 'unstable' else 1
X_train = pd.read_csv('train.csv')
X_test = pd.read_csv('test.csv')
Y_train = X_train.pop('stabf')
Y_test = X_test.pop('stabf')
Y_train_one_zero = [onezero(x) for x in Y_train]
Y_train_onehot = np.eye(2)[Y_train_one_zero]
Y_test_one_zero = [onezero(x) for x in Y_test]
Y_test_onehot = np.eye(2)[Y_test_one_zero]
model = tf.keras.Sequential([
layers.Input(shape=(12,)),
layers.Dense(32),
layers.Dense(16),
layers.Dense(2, activation='softmax')])
model.compile(
loss=tf.losses.BinaryCrossentropy(),
optimizer=tf.optimizers.Adam(),
metrics=[tf.keras.metrics.BinaryAccuracy()])
history = model.fit(tf.convert_to_tensor(X_train, np.float32),
Y_train_onehot, epochs=5)
model.save('grid_stability.h5')