Zaktualizuj 'neural_network.py'

poprawka sieci neuronowych
This commit is contained in:
Aniela Walczak 2022-05-26 16:05:08 +02:00
parent c959093e38
commit 32179be076

View File

@ -3,8 +3,8 @@ import cv2
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tensorflow as tf import tensorflow as tf
import random
def recognition():
mnist = tf.keras.datasets.mnist mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() (x_train, y_train), (x_test, y_test) = mnist.load_data()
@ -24,8 +24,8 @@ def recognition():
model = tf.keras.models.load_model('handwritten.model') model = tf.keras.models.load_model('handwritten.model')
image_number = 1 def recognition():
while os.path.isfile(f"digits/digit{image_number}.png"): image_number = random.randint(1, 9)
try: try:
img = cv2.imread(f"digits/digit{image_number}.png")[:,:,0] img = cv2.imread(f"digits/digit{image_number}.png")[:,:,0]
img = np.invert(np.array([img])) img = np.invert(np.array([img]))
@ -35,12 +35,9 @@ def recognition():
plt.show() plt.show()
except: except:
print("Error!") print("Error!")
finally:
image_number +=1
loss, accuracy = model.evaluate(x_test, y_test) loss, accuracy = model.evaluate(x_test, y_test)
print(loss) print(loss)
print(accuracy) print(accuracy)
recognition() recognition()