From 32179be076861432970b481c38c46f1d210a31a8 Mon Sep 17 00:00:00 2001 From: Aniela Walczak Date: Thu, 26 May 2022 16:05:08 +0200 Subject: [PATCH] Zaktualizuj 'neural_network.py' poprawka sieci neuronowych --- neural_network.py | 63 ++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/neural_network.py b/neural_network.py index 1247f86..04e4860 100644 --- a/neural_network.py +++ b/neural_network.py @@ -3,44 +3,41 @@ import cv2 import numpy as np import matplotlib.pyplot as plt import tensorflow as tf +import random + +mnist = tf.keras.datasets.mnist +(x_train, y_train), (x_test, y_test) = mnist.load_data() + +x_train = tf.keras.utils.normalize(x_train, axis=1) +x_test = tf.keras.utils.normalize(x_test, axis=1) + +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Flatten(input_shape=(28, 28))) +model.add(tf.keras.layers.Dense(128, activation='relu')) +model.add(tf.keras.layers.Dense(128, activation='relu')) +model.add(tf.keras.layers.Dense(10, activation='softmax')) + +model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + +model.fit(x_train, y_train, epochs=3) +model.save('handwritten.model') + +model = tf.keras.models.load_model('handwritten.model') def recognition(): - mnist = tf.keras.datasets.mnist - (x_train, y_train), (x_test, y_test) = mnist.load_data() - - x_train = tf.keras.utils.normalize(x_train, axis=1) - x_test = tf.keras.utils.normalize(x_test, axis=1) - - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Flatten(input_shape=(28,28))) - model.add(tf.keras.layers.Dense(128, activation='relu')) - model.add(tf.keras.layers.Dense(128, activation='relu')) - model.add(tf.keras.layers.Dense(10, activation='softmax')) - - model.compile(optimizer ='adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy']) - - model.fit(x_train, y_train, epochs = 3) - model.save('handwritten.model') - - model = tf.keras.models.load_model('handwritten.model') - - image_number = 1 - while os.path.isfile(f"digits/digit{image_number}.png"): - try: - img = cv2.imread(f"digits/digit{image_number}.png")[:,:,0] - img = np.invert(np.array([img])) - prediction = model.predict(img) - print(f"This digit is probably a {np.argmax(prediction)}") - plt.imshow(img[0], cmap = plt.cm.binary) - plt.show() - except: - print("Error!") - finally: - image_number +=1 + image_number = random.randint(1, 9) + try: + img = cv2.imread(f"digits/digit{image_number}.png")[:,:,0] + img = np.invert(np.array([img])) + prediction = model.predict(img) + print(f"This digit is probably a {np.argmax(prediction)}") + plt.imshow(img[0], cmap = plt.cm.binary) + plt.show() + except: + print("Error!") loss, accuracy = model.evaluate(x_test, y_test) print(loss) print(accuracy) - recognition() \ No newline at end of file