touch ups

This commit is contained in:
Kamila Bobkowska 2020-05-10 16:27:46 +00:00
parent 1aeb9bf1fa
commit e8bc0b6722

View File

@ -6,6 +6,9 @@ import shutil
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Conv2D, Flatten, MaxPooling2D, Dense from keras.layers import Conv2D, Flatten, MaxPooling2D, Dense
from keras.preprocessing import image from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import random
#dataset from https://www.kaggle.com/asdasdasasdas/garbage-classification #dataset from https://www.kaggle.com/asdasdasasdas/garbage-classification
@ -42,7 +45,7 @@ classifier.add(Dense(activation = "softmax", units = 5))
classifier.compile(optimizer = "adam", loss = "binary_crossentropy", metrics = ["accuracy"]) classifier.compile(optimizer = "adam", loss = "binary_crossentropy", metrics = ["accuracy"])
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator( train_datagen = ImageDataGenerator(
rescale=1./255, rescale=1./255,
@ -76,16 +79,16 @@ test_generator = test_datagen.flow_from_directory(
#Teaching the classifier #Teaching the classifier
'''classifier.fit_generator( train_generator, steps_per_epoch = 150, epochs = 25, validation_data = test_generator ) '''classifier.fit_generator( train_generator, steps_per_epoch = 165, epochs = 32, validation_data = test_generator )
classifier.save_weights('model_ver_4.h5')''' classifier.save_weights('model_ver_5.h5')'''
import matplotlib.pyplot as plt
labels = (train_generator.class_indices) labels = (train_generator.class_indices)
labels = dict((value,key) for key,value in labels.items()) labels = dict((value,key) for key,value in labels.items())
classifier.load_weights("model_ver_4.h5") classifier.load_weights("model_ver_5.h5")
import random
def getTrashPhoto(x, type): def getTrashPhoto(x, type):
@ -94,13 +97,13 @@ def getTrashPhoto(x, type):
path = "Garbage classification\\testset\\" + kind path = "Garbage classification\\testset\\" + kind
file = random.choice(os.listdir(path)) file = random.choice(os.listdir(path))
path = "Garbage classification\\testset\\" + kind + "\\" + file path = "Garbage classification\\testset\\" + kind + "\\" + file
gz = image.load_img(path, target_size = (110,110)) var = image.load_img(path, target_size = (110,110))
ti = image.img_to_array(gz) ti = image.img_to_array(var)
ti=np.array(ti)/255.0 ti=np.array(ti)/255.0
ti = np.expand_dims(ti, axis = 0) ti = np.expand_dims(ti, axis = 0)
prediction = classifier.predict(ti) prediction = classifier.predict(ti)
plt.subplot(1, 3, i+1) plt.subplot(1, 3, i+1)
plt.imshow(gz) plt.imshow(var)
plt.title("AI thinks:%s \nReality:\n %s" % (labels[np.argmax(prediction)], file)) plt.title("AI thinks:%s \nReality:\n %s" % (labels[np.argmax(prediction)], file))
plt.show() plt.show()