SZI-Smieciarka/uczenie_adamO.py
2020-05-19 21:14:55 +02:00

152 lines
5.0 KiB
Python

import h5py
import numpy as np
import os
import glob
import cv2
import warnings
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
import mahotas
import random
from math import ceil
from io import StringIO
from sklearn.tree import export_graphviz
import pydotplus
warnings.filterwarnings('ignore')
rozmiar_zbioru_testowego = 0.20
katalog_uczacy = "resources\\smieci_stare"
katalog_testujacy = "resources\\smieci w kontenerach"
h5_parametry = 'parametry_zdjec.h5'
h5_etykiety = 'etykiety.h5'
rozmiar_zdj = tuple((500, 500))
def wyznaczHuMomenty(zdj):
zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY)
momenty = cv2.HuMoments(cv2.moments(zdj)).flatten()
return momenty
def wyznaczHistogram(zdj, mask=None):
zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist([zdj], [0, 1, 2], mask, [8, 8, 8],
[0, 256, 0, 256, 0, 256])
cv2.normalize(hist, hist)
return hist.flatten()
def wyznaczHaralick(zdj):
szare_zdj = cv2.cvtColor(zdj, cv2.COLOR_BGR2GRAY)
haralick = mahotas.features.haralick(szare_zdj).mean(axis=0)
return haralick
def rozpocznijUczenie():
klasy = os.listdir(katalog_uczacy)
klasy.sort()
h5f_parametry = h5py.File(h5_parametry, 'r')
h5f_etykiety = h5py.File(h5_etykiety, 'r')
dane = h5f_parametry['dataset_1']
etykiety = h5f_etykiety['dataset_1']
dane = np.array(dane)
etykiety = np.array(etykiety)
h5f_parametry.close()
h5f_etykiety.close()
(uczenieDane, testowanieDane, uczenieEtykiety, testowanieEtykiety) = train_test_split(np.array(dane),
np.array(
etykiety),
test_size=rozmiar_zbioru_testowego)
rfc = RandomForestClassifier(max_depth=15, n_jobs=4, random_state=1)
rfc.fit(uczenieDane, uczenieEtykiety)
print("uzyskana skutecznosc: ", rfc.score(
testowanieDane, testowanieEtykiety))
# tworzenie grafu
# dot_data = StringIO()
# print(rfc.estimators_)
# estimator = rfc.estimators_[5]
# export_graphviz(estimator, out_file=dot_data,
# feature_names=dane[1],
# rounded=True, proportion=False,
# precision=2, filled=True,
# special_characters=True, class_names=['glass', 'metal', 'paper', 'plastic'])
# graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# graph.write_png('graph.png')
return rfc
def przewidz(zdjecie, rfc):
klasy = os.listdir(katalog_uczacy)
klasy.sort()
zdj = cv2.imread(zdjecie)
zdj = cv2.resize(zdj, rozmiar_zdj)
# wyznaczanie parametrow zdjecia
momenty = wyznaczHuMomenty(zdj)
haralick = wyznaczHaralick(zdj)
histogram = wyznaczHistogram(zdj)
# ustaw poziomo, jeden za drugim
wiersz = np.hstack([momenty, histogram, haralick])
wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1
przewidywany_typ = rfc.predict(wiersz)[0] # zwraca wartosc 0,1,2,3
return klasy[przewidywany_typ] # zwraca glass,metal,paper,plastic
def wyswietlZdjecia(rfc):
if rfc is None:
return
path = os.getcwd()
klasy = os.listdir(katalog_testujacy)
klasy.sort()
wszystkie_pliki = []
for dir in os.listdir(katalog_testujacy):
os.chdir(path + "\\" + katalog_testujacy + "\\" + dir)
pliki = glob.glob('*.jpg')
for i in range(len(pliki)):
pliki[i] = dir + "\\" + pliki[i]
wszystkie_pliki.append(pliki)
os.chdir(path)
wszystkie_pliki = sum(wszystkie_pliki, [])
rozmiar = ceil(0.25 * len(wszystkie_pliki))
wybrane_zdjecia = random.sample(wszystkie_pliki, k=rozmiar)
print("ilosc wybranych zdjec: ", len(wybrane_zdjecia))
for i in wybrane_zdjecia:
zdjecie = cv2.imread(path + "\\" + katalog_testujacy + "\\" + i)
zdjecie = cv2.resize(zdjecie, rozmiar_zdj)
# wyznaczanie parametrow zdjecia
momenty = wyznaczHuMomenty(zdjecie)
haralick = wyznaczHaralick(zdjecie)
histogram = wyznaczHistogram(zdjecie)
# ustaw poziomo, jeden za drugim
wiersz = np.hstack([momenty, histogram, haralick])
wiersz = wiersz.reshape(1, -1) # zmniejsz wymiar z 2 do 1
przewidywany = rfc.predict(wiersz)[0]
prawdopodobienstwo = rfc.predict_proba(wiersz)
cv2.putText(zdjecie, klasy[przewidywany], (3, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), thickness=3)
cv2.putText(zdjecie, str(prawdopodobienstwo), (3, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), thickness=1)
cv2.putText(zdjecie, i, (3, 460), cv2.FONT_HERSHEY_SIMPLEX,
1, (255, 0, 0), thickness=2)
plt.imshow(cv2.cvtColor(zdjecie, cv2.COLOR_BGR2RGB))
plt.show()