import h5py import numpy as np import os import glob import cv2 import warnings from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier import matplotlib.pyplot as plt import mahotas import random from math import ceil from io import StringIO from sklearn.tree import export_graphviz import pydotplus warnings.filterwarnings('ignore') rozmiar_zbioru_testowego = 0.20 katalog_uczacy = "resources\\smieci_stare" katalog_testujacy = "resources\\smieci w kontenerach" h5_parametry = 'parametry_zdjec.h5' h5_etykiety = 'etykiety.h5' rozmiar_zdj = tuple((500, 500)) def wyznaczHuMomenty(zdj): zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY) momenty = cv2.HuMoments(cv2.moments(zdj)).flatten() return momenty def wyznaczHistogram(zdj, mask=None): zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2HSV) hist = cv2.calcHist([zdj], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256]) cv2.normalize(hist, hist) return hist.flatten() def wyznaczHaralick(zdj): szare_zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY) haralick = mahotas.features.haralick(szare_zdj).mean(axis=0) return haralick def rozpocznijUczenie(): klasy = os.listdir(katalog_uczacy) klasy.sort() h5f_parametry = h5py.File(h5_parametry, 'r') h5f_etykiety = h5py.File(h5_etykiety, 'r') dane = h5f_parametry['dataset_1'] etykiety = h5f_etykiety['dataset_1'] dane = np.array(dane) etykiety = np.array(etykiety) h5f_parametry.close() h5f_etykiety.close() (uczenieDane, testowanieDane, uczenieEtykiety, testowanieEtykiety) = train_test_split(np.array(dane), np.array(etykiety), test_size=rozmiar_zbioru_testowego) rfc = RandomForestClassifier(max_depth=15, n_jobs=4, random_state=1) rfc.fit(uczenieDane, uczenieEtykiety) print("uzyskana skutecznosc: ", rfc.score(testowanieDane, testowanieEtykiety)) # tworzenie grafu # dot_data = StringIO() # print(rfc.estimators_) # estimator = rfc.estimators_[5] # export_graphviz(estimator, out_file=dot_data, # feature_names=dane[1], # rounded=True, proportion=False, # precision=2, filled=True, # special_characters=True, class_names=['glass', 'metal', 'paper', 'plastic']) # graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) # graph.write_png('graph.png') return rfc def przewidz(zdjecie, rfc): klasy = os.listdir(katalog_uczacy) klasy.sort() zdj = cv2.imread(zdjecie) zdj = cv2.resize(zdj, rozmiar_zdj) # wyznaczanie parametrow zdjecia momenty = wyznaczHuMomenty(zdj) haralick = wyznaczHaralick(zdj) histogram = wyznaczHistogram(zdj) wiersz = np.hstack([momenty, histogram, haralick]) # ustaw poziomo, jeden za drugim wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1 przewidywany_typ = rfc.predict(wiersz)[0] # zwraca wartosc 0,1,2,3 return klasy[przewidywany_typ] # zwraca glass,metal,paper,plastic def wyswietlZdjecia(rfc): if rfc is None: return path = os.getcwd() klasy = os.listdir(katalog_testujacy) klasy.sort() wszystkie_pliki = [] for dir in os.listdir(katalog_testujacy): os.chdir(path + "\\" + katalog_testujacy + "\\" + dir) pliki = glob.glob('*.jpg') for i in range(len(pliki)): pliki[i] = dir + "\\" + pliki[i] wszystkie_pliki.append(pliki) os.chdir(path) wszystkie_pliki = sum(wszystkie_pliki, []) rozmiar = ceil(0.25 * len(wszystkie_pliki)) wybrane_zdjecia = random.sample(wszystkie_pliki, k=rozmiar) print("ilosc wybranych zdjec: ",len(wybrane_zdjecia)) for i in wybrane_zdjecia: zdjecie = cv2.imread(path + "\\" + katalog_testujacy + "\\" + i) zdjecie = cv2.resize(zdjecie, rozmiar_zdj) # wyznaczanie parametrow zdjecia momenty = wyznaczHuMomenty(zdjecie) haralick = wyznaczHaralick(zdjecie) histogram = wyznaczHistogram(zdjecie) wiersz = np.hstack([momenty, histogram, haralick]) # ustaw poziomo, jeden za drugim wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1 przewidywany = rfc.predict(wiersz)[0] prawdopodobienstwo = rfc.predict_proba(wiersz) cv2.putText(zdjecie, klasy[przewidywany], (3, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), thickness=3) cv2.putText(zdjecie, str(prawdopodobienstwo), (3, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), thickness=1) cv2.putText(zdjecie, i, (3, 460), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2) plt.imshow(cv2.cvtColor(zdjecie, cv2.COLOR_BGR2RGB)) plt.show()