130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
import h5py
|
|
import numpy as np
|
|
import os
|
|
import glob
|
|
import cv2
|
|
import warnings
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
import matplotlib.pyplot as plt
|
|
import mahotas
|
|
import random
|
|
from math import ceil
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
rozmiar_zbioru_testowego = 0.20
|
|
katalog_uczacy = "resources\\smieci_stare"
|
|
katalog_testujacy = "resources\\smieci w kontenerach"
|
|
h5_parametry = 'parametry_zdjec.h5'
|
|
h5_etykiety = 'etykiety.h5'
|
|
rozmiar_zdj = tuple((500, 500))
|
|
|
|
|
|
def wyznaczHuMomenty(zdj):
|
|
zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY)
|
|
momenty = cv2.HuMoments(cv2.moments(zdj)).flatten()
|
|
return momenty
|
|
|
|
|
|
def wyznaczHistogram(zdj, mask=None):
|
|
zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2HSV)
|
|
hist = cv2.calcHist([zdj], [0, 1, 2], mask, [8, 8, 8], [0, 256, 0, 256, 0, 256])
|
|
cv2.normalize(hist, hist)
|
|
return hist.flatten()
|
|
|
|
|
|
def wyznaczHaralick(zdj):
|
|
szare_zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY)
|
|
haralick = mahotas.features.haralick(szare_zdj).mean(axis=0)
|
|
return haralick
|
|
|
|
|
|
def rozpocznijUczenie():
|
|
klasy = os.listdir(katalog_uczacy)
|
|
klasy.sort()
|
|
|
|
h5f_parametry = h5py.File(h5_parametry, 'r')
|
|
h5f_etykiety = h5py.File(h5_etykiety, 'r')
|
|
|
|
dane = h5f_parametry['dataset_1']
|
|
etykiety = h5f_etykiety['dataset_1']
|
|
|
|
dane = np.array(dane)
|
|
etykiety = np.array(etykiety)
|
|
|
|
h5f_parametry.close()
|
|
h5f_etykiety.close()
|
|
|
|
(uczenieDane, testowanieDane, uczenieEtykiety, testowanieEtykiety) = train_test_split(np.array(dane),
|
|
np.array(etykiety),
|
|
test_size=rozmiar_zbioru_testowego)
|
|
|
|
rfc = RandomForestClassifier(max_depth=15, n_jobs=4, random_state=1)
|
|
rfc.fit(uczenieDane, uczenieEtykiety)
|
|
print("uzyskana skutecznosc: ", rfc.score(testowanieDane, testowanieEtykiety))
|
|
return rfc
|
|
|
|
|
|
def przewidz(zdjecie, rfc):
|
|
klasy = os.listdir(katalog_uczacy)
|
|
klasy.sort()
|
|
zdj = cv2.imread(zdjecie)
|
|
zdj = cv2.resize(zdj, rozmiar_zdj)
|
|
|
|
# wyznaczanie parametrow zdjecia
|
|
momenty = wyznaczHuMomenty(zdj)
|
|
haralick = wyznaczHaralick(zdj)
|
|
histogram = wyznaczHistogram(zdj)
|
|
|
|
wiersz = np.hstack([momenty, histogram, haralick]) # ustaw poziomo, jeden za drugim
|
|
wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1
|
|
przewidywany_typ = rfc.predict(wiersz)[0] # zwraca wartosc 0,1,2,3
|
|
return klasy[przewidywany_typ] # zwraca glass,metal,paper,plastic
|
|
|
|
|
|
def wyswietlZdjecia(rfc):
|
|
if rfc is None:
|
|
return
|
|
|
|
path = os.getcwd()
|
|
|
|
klasy = os.listdir(katalog_testujacy)
|
|
klasy.sort()
|
|
wszystkie_pliki = []
|
|
|
|
for dir in os.listdir(katalog_testujacy):
|
|
os.chdir(path + "\\" + katalog_testujacy + "\\" + dir)
|
|
pliki = glob.glob('*.jpg')
|
|
for i in range(len(pliki)):
|
|
pliki[i] = dir + "\\" + pliki[i]
|
|
wszystkie_pliki.append(pliki)
|
|
os.chdir(path)
|
|
|
|
wszystkie_pliki = sum(wszystkie_pliki, [])
|
|
rozmiar = ceil(0.25 * len(wszystkie_pliki))
|
|
wybrane_zdjecia = random.sample(wszystkie_pliki, k=rozmiar)
|
|
print("ilosc wybranych zdjec: ",len(wybrane_zdjecia))
|
|
|
|
for i in wybrane_zdjecia:
|
|
zdjecie = cv2.imread(path + "\\" + katalog_testujacy + "\\" + i)
|
|
zdjecie = cv2.resize(zdjecie, rozmiar_zdj)
|
|
|
|
# wyznaczanie parametrow zdjecia
|
|
momenty = wyznaczHuMomenty(zdjecie)
|
|
haralick = wyznaczHaralick(zdjecie)
|
|
histogram = wyznaczHistogram(zdjecie)
|
|
|
|
wiersz = np.hstack([momenty, histogram, haralick]) # ustaw poziomo, jeden za drugim
|
|
wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1
|
|
|
|
przewidywany = rfc.predict(wiersz)[0]
|
|
prawdopodobienstwo = rfc.predict_proba(wiersz)
|
|
|
|
cv2.putText(zdjecie, klasy[przewidywany], (3, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), thickness=3)
|
|
cv2.putText(zdjecie, str(prawdopodobienstwo), (3, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), thickness=1)
|
|
cv2.putText(zdjecie, i, (3, 460), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2)
|
|
|
|
plt.imshow(cv2.cvtColor(zdjecie, cv2.COLOR_BGR2RGB))
|
|
plt.show()
|