1.1 MiB
Wykorzystanie indeksu wegetacji NDVI do klasyfikacji gatunków drzew
Rafał Borkowski
Pochodzenie danych
Spis treści:
- Pochodzenie źródła danych
- Etap 0 - preprocesing i przygotowanie
- Obliczanie indeksu NDVI
- Przetwarzanie próbek na dane
- Przykładowe dane
- Import bibliotek i funkcji
- Wartości histogramu
Pochodzenie źródła danych
Źródłem danych są lotnicze zdjęcia spektralne wykonane w barwach widzialnych (RGB) oraz bliskiej podczerwieni (NIR) dostępne na Geoportalu krajowym.
Obszarem zainteresowań objęty jest zwarty obszar leśny Nadleśnictwa Doświadczalnego Zielonka 52º31' N, 17º4' E.
Etap 0 - preprocesing i przygotowanie
Po pobraniu zdjęcia lotnicze w formacie .tif zostały wczytane do programu QGIS będącym otwartym oprogramowaniem geoinformacyjnym. Korzystając z danych dotyczących drzewostanów, dostępnych na stronie Banku Danych o Lasach wytypowano zwarte, jednogatunkowe drzewostany sosny zwyczajnej (_Pinus sylvestris) oraz dębu (Quercus spp). Wybrano drzewostany, których wiek jest zbliżony do średniego wieku dla danego gatunku wewnątrz kompleksu leśnego.
Dla dębu:
- 153 lata,
- 158 lat,
- 159 lat.
Dla sosny:
- 54 lata,
- 55 lat,
- 58 lat.
Kolejnym krokiem było wykorzystanie wtyczki QRectangleCreator do generowania próbek drzew. Wtyczka ta umożliwia definiowanie predefiniowanych kształtów poligonów, w naszym przypadku kwadratów o wymiarach 2x2 metrów. Na podstawie oceny wizualnej, wyznaczono lokalizacje tych poligonów, które odpowiadały koronom drzew na zdjęciach spektralnych. Ze względu na rozmiar korony w przypadku dębu możliwe było wyznaczenie więcej, niż jednego poligonu z jednej korony drzewa. Ważne było pozyskanie próbki indeksu wegetacji NDVI.
Rycina 1 prezentuje koncept pozyskiwania danych.
from IPython.display import display
from PIL import Image
def show_image(image_path):
image = Image.open(image_path)
display(image)
image_path = "C:/python_zaliczenie/jpg/fig1.jpg"
show_image(image_path)
Na powyższej ilustracji widzimy fragment zdjęcia spektralnego w kompozycji barw umownych (CIR). Często wykorzystywaną techniką jest CIR (_Color Infrared), gdzie roślinność przedstawiona jest w czerwieni. Dzięki zastosowaniu podczerwieni widoczne jest zróżnicowanie gatunkowe. Niebieskim prostokątem oznaczona została przykładowa korona dębu, zielonym prostokątem przykładowe korona sosny. Różowe prostokąty to poligony próbkujące, w tym przypadku fragmenty koron dębów.
Obliczanie indeksu NDVI
Założeniem projektu była hipoteza twierdząca, że na podstawie indeksu wegetacji NDVI możliwe jest rozróznienie gatunków drzew. NDVI (_Normalized Difference Vegetation Index), czyli znormalizowany indeks wegetacji to wskaźnik stosowany do oceny aktywności wegetacyjnej roślin na podstawie różnicy wchłaniania światła w obszarach podczerwieni i bliskiej podczerwieni. Jego wartości skorelowane są z zawartością chlorofilu.
NDVI przyjmuje postać:
NDVI = (NIR - RED) / (NIR + RED), gdzie
NIR to wartość odbicia promieniowania w bliskiej podczerieni, RED to wartość odbicia promieniowania w zakrsie czerwonym światła widzialnego.
Wskaźnik NDVI przyjmuje wartości z zakresu -1 do 1.
Przetwarzanie próbek na dane
Przykład kodu, do pozyskiwania statystyk z przygotowanych próbek. Podajemy folder z próbkami poligonów NDVI oraz folder wynikowoy dla pliku CSV.
Przeliczane są następujące wartości:
- liczba pixeli w poligonie,
- wartość minimalna,
- wartość maksymalna,
- wartość średnia,
- odchylenie standardowe.
"""
import rasterio
import pandas as pd
import glob
import numpy as np
import os
import time # Dodaj ten import!
import tkinter as tk
from tkinter import filedialog
def calculate_raster_statistics(input_folder, output_csv):
# Lista rastrów do przetworzenia
raster_list = glob.glob(os.path.join(input_folder, '*.tif'))
# Tworzenie pustego dataframe'u do przechowywania wyników
results_df = pd.DataFrame()
# Pętla przez rastry
start_time = time.time() # Początkowy czas
for raster_path in raster_list:
with rasterio.open(raster_path) as src:
raster = src.read(1)
# Obliczanie statystyk
min_val = np.min(raster)
max_val = np.max(raster)
mean_val = np.mean(raster)
std_val = np.std(raster)
count_val = np.count_nonzero(~np.isnan(raster))
# Tworzenie dataframe'u ze statystykami
df = pd.DataFrame({
'Raster': [raster_path],
'Count': [count_val],
'Min': [min_val],
'Max': [max_val],
'Mean': [mean_val],
'Std': [std_val]
})
# Łączenie wyniku z dataframe'em wyników
results_df = pd.concat([results_df, df])
# Zapisanie wyników do pliku CSV
results_df.to_csv(output_csv, index=False)
elapsed_time = time.time() - start_time # Czas trwania obliczeń
print(f"Wyniki zostały zapisane do: {output_csv}")
print(f"Czas trwania obliczeń: {elapsed_time:.2f} sekundy")
# Tworzymy GUI z pomocą tkinter
root = tk.Tk()
root.withdraw() # Ukrywamy główne okno, ponieważ nie potrzebujemy pełnej aplikacji GUI
# Wybierz folder z rastrami
input_folder = filedialog.askdirectory(title="Wybierz folder z rastrami")
# Sprawdź, czy użytkownik wybrał folder
if not input_folder:
print("Anulowano wybór folderu z rastrami.")
else:
# Wybierz ścieżkę do zapisu pliku CSV
output_csv = filedialog.asksaveasfilename(
title="Wybierz miejsce zapisu pliku CSV",
defaultextension=".csv",
filetypes=[("CSV Files", "*.csv")]
)
# Sprawdź, czy użytkownik wybrał miejsce zapisu
if not output_csv:
print("Anulowano wybór miejsca zapisu pliku CSV.")
else:
# Wywołaj funkcję przetwarzania z wybranymi ścieżkami
calculate_raster_statistics(input_folder, output_csv
"""
Przykładowe dane
Wyjaśnienie danych:
- id - numer id każdej próbki
- min - minimalna wartość indeksu NDVI
- max - maksymalna wartość indeksu NDVI
- mean - średnia wartość indeksu NDVI
- std - odchylenie standardowe wartość indeksu NDVI
- count - liczba pixeli w poligonie
- species - gatunek drzewa
- age - wiek drzewostanu
- tsl - typ siedliskowy lasu; cecha lasu świadcząca o żyzności siedliska
import pandas as pd
df = pd.read_csv(r"C:\python_zaliczenie\CSV\trees.csv", delimiter=";")
print(df)
id tree_id min max mean std count species \ 0 5640 0 0.238095 0.489933 0.352131 0.054493 156 DB 1 5641 1 0.203791 0.537415 0.421799 0.072868 156 DB 2 5642 10 0.210762 0.513889 0.427296 0.053479 144 DB 3 5643 100 0.362162 0.609023 0.531492 0.045743 169 DB 4 5644 1000 0.283019 0.570470 0.441266 0.069439 144 DB ... ... ... ... ... ... ... ... ... 5633 11273 95 0.300000 0.611940 0.460505 0.076445 156 DB 5634 11274 96 0.227273 0.594595 0.432737 0.073367 156 DB 5635 11275 97 0.252336 0.572414 0.436175 0.070488 144 DB 5636 11276 98 0.227907 0.594406 0.407233 0.073816 156 DB 5637 11277 99 0.285714 0.664122 0.482923 0.079463 156 DB age tsl 0 153 LSW 1 153 LSW 2 153 LSW 3 153 LSW 4 153 LSW ... ... ... 5633 158 LSW 5634 158 LSW 5635 158 LSW 5636 158 LSW 5637 158 LSW [5638 rows x 10 columns]
Główne statystyki zbioru
Liczba obserwacji = 5638 Liczba kolumn = 10
Liczba obserwacji w podziale na gatunek i wiek
species | age | liczba obserwacji |
---|---|---|
DB | 153 | 1081 |
DB | 158 | 893 |
DB | 159 | 1026 |
SO | 54 | 333 |
SO | 55 | 1329 |
SO | 58 | 976 |
Średnia wartość parametru NDVI dla zbioru = 0.30291
Średnia wartość parametru NDVI w podziele na gatunki:
DB = 0,424301 SO = 0,164860
0 976
Średnia wartośćc parametru NDVI w podziale na gatunki i wiek
species | age | 54 | 55 | 58 | 153 | 158 | 159 |
---|---|---|---|---|---|---|---|
DB | NaN | NaN | NaN | 0.429681 | 0.416908 | 0.425066 | |
SO | 0.164 | 0.149983 | 0.185314 | NaN | NaN | NaN |
Analiza danych
Import bibliotek i funkcji
#import bibliotek
import pandas as pd
import numpy as np
from pandas import Series, DataFrame
import matplotlib.pyplot as plt
from scipy.stats import shapiro
from scipy.stats import mannwhitneyu
from statsmodels.graphics.gofplots import qqplot
import seaborn as sns
### import funkcji ###
# wczytywanie zboru danych do data frame'u
def read_csv_pandas(path_to_file):
df = pd.read_csv(path_to_file,sep=';')
return(df)
# Tworzenie histogramu danych
def plot_histogram_all(df, colors):
grouped_df = df.groupby('species')
for (species, group), color in zip(grouped_df, colors):
plt.hist(group['mean'], label=species, alpha=0.5, color=color, edgecolor='black')
mean_value = group['mean'].mean()
plt.axvline(mean_value, color='red', linestyle='dashed', linewidth=2, label=f'Mean ({species}): {mean_value:.2f}')
plt.title('Histogram wartości średnich według gatunku')
plt.xlabel('Wartości średnie NDVI')
plt.ylabel('Częstotliwość')
plt.legend()
plt.show()
# Histogram danych z podziałem na gatunek
def plot_histogram(df, species, color):
df_species = df[df['species'] == species]
# Tworzymy histogram z kolorem ramki wokół prostokątów
plt.hist(df_species['mean'], bins=40, label=species, alpha=0.5, edgecolor='black', color=color)
# Dodajemy linię reprezentującą średnią wartość
mean_value = df_species['mean'].mean()
plt.axvline(mean_value, color='red', linestyle='dashed', linewidth=2, label=f'Mean: {mean_value:.2f}')
# Ustawiamy etykiety osi
plt.xlabel('Mean Value')
plt.ylabel('Frequency')
# Ustawiamy zakres osi x od minimum do maksimum danych
plt.xlim(df_species['mean'].min(), df_species['mean'].max())
# Dodajemy legendę
plt.legend()
# Dodajemy tytuł
plt.title(f'Histogram dla species=\'{species}\'')
# Wyświetlamy histogram
plt.show()
# Statystyki opisowe
def descriptive_statistics(data, species):
# Filtrujemy dane dla danego gatunku
data_species = data[data['species'] == species]['mean']
# Miary pozycyjne
mean_value = data_species.mean()
median_value = data_species.median()
mode_value = data_species.mode().iloc[0] # Mode może mieć więcej niż jedną wartość, dlatego wybieramy pierwszą
# Miary przeciętne
arithmetic_mean = data_species.mean()
harmonic_mean = 1 / (1 / data_species).mean()
geometric_mean = data_species.prod() ** (1 / len(data_species))
# Miary zmienności klasyczne
variance_value = data_species.var()
std_deviation_value = data_species.std()
# Miary pozycyjne dla zmiennych ilościowych
q1 = data_species.quantile(0.25)
q3 = data_species.quantile(0.75)
interquartile_range = q3 - q1
# Wyświetlanie wyników
print(f"Statystyki opisowe dla gatunku '{species}':")
print("\nMiary pozycyjne:")
print(f"Średnia arytmetyczna: {mean_value}")
print(f"Mediana: {median_value}")
print(f"Moda: {mode_value}")
print("\nMiary przeciętne:")
print(f"Średnia arytmetyczna: {arithmetic_mean}")
print(f"Średnia harmoniczna: {harmonic_mean}")
print(f"Średnia geometryczna: {geometric_mean}")
print("\nMiary zmienności klasyczne:")
print(f"Wariancja: {variance_value}")
print(f"Odchylenie standardowe: {std_deviation_value}")
print("\nMiary pozycyjne dla zmiennych ilościowych:")
print(f"Kwartyl 1 (Q1): {q1}")
print(f"Kwartyl 3 (Q3): {q3}")
print(f"Rozstęp międzykwartylowy: {interquartile_range}")
# Test normalności rozkładu
def assess_normality(data, species):
# Filtrujemy dane dla danego gatunku
data_species = data[data['species'] == species]['mean']
# Wykres QQ
qqplot(data_species, line='s')
plt.title(f'Q-Q Plot dla gatunku {species}')
plt.show()
# Histogram
plt.hist(data_species, bins=20, density=True, alpha=0.5, color='grey', edgecolor='black')
plt.title(f'Histogram dla gatunku {species}')
plt.xlabel('Wartości średnie')
plt.ylabel('Częstotliwość')
plt.show()
# Test Shapiro-Wilka
stat, p_value = shapiro(data_species)
print(f'Test Shapiro-Wilka dla gatunku {species}:')
print(f'Statystyka testowa: {stat}')
print(f'Wartość p: {p_value}')
if p_value > 0.05:
print('Nie ma podstaw do odrzucenia hipotezy zerowej - dane mogą pochodzić z rozkładu normalnego.')
else:
print('Hipoteza zerowa (rozkład normalny) jest odrzucana.')
#
def compare_groups(data, group1_name, group2_name):
# Wyodrębnienie grup
group1 = data[data['species'] == group1_name]['mean']
group2 = data[data['species'] == group2_name]['mean']
# Przeprowadzenie testu U Manna-Whitneya
u_stat, p_value = mannwhitneyu(group1, group2)
print(f"Test U Manna-Whitneya: statystyka U = {u_stat}, p-wartość = {p_value}")
# Wykres pudełkowy dla obu grup
sns.boxplot(x='species', y='mean', data=data)
plt.title(f'Porównanie grup {group1_name} i {group2_name}')
plt.xlabel('Species')
plt.ylabel('Mean Value')
plt.show()
def compare_wilcoxon(data, group1_name, group2_name, alpha=0.05):
# Wyodrębnienie grup
group1 = data[data['species'] == group1_name]['mean']
group2 = data[data['species'] == group2_name]['mean']
# Test Wilcoxona dla dwóch grup
stat, p_value = wilcoxon(group1, group2)
# Sprawdzenie istotności statystycznej i wypisanie wyniku
if p_value < alpha:
print("Istnieją istotne różnice między grupami.")
else:
print("Brak istotnych różnic między grupami.")
return stat, p_value
# podstawowe statystyki zbioru
df = read_csv_pandas(r'C:\python_zaliczenie\CSV\trees.csv')
print("Liczba obserwacji, liczba kolumn")
print(df.shape)
print("\n")
#Ilość obserwacji w przeiczeniu na gatuenk i wiek
counts = df.groupby(['species', 'age']).size()
print("Liczba obeserwacji: ")
print(counts)
Liczba obserwacji, liczba kolumn (5638, 10) Liczba obeserwacji: species age DB 153 1081 158 893 159 1026 SO 54 333 55 1329 58 976 dtype: int64
#Średnia wartość parametru NDVI
print(df['mean'].mean())
0.30290945529230223
#ŚRednia wartość parametru NDVI w rozbiciu na gatunki
print(df.groupby('species')['mean'].mean())
species DB 0.424301 SO 0.164860 Name: mean, dtype: float64
#ŚRednia wartość parametru NDVI w rozbiciu na gatunki i wiek
print(df.pivot_table(index='species', columns='age',values=('mean')))
age 54 55 58 153 158 159 species DB NaN NaN NaN 0.429681 0.416908 0.425066 SO 0.16429 0.149983 0.185314 NaN NaN NaN
Wartości histogramu
colors = ['orange', 'blue']
plot_histogram_all(df, colors)
Parametry dla Sosny
Statystyki opisowe dla gatunku 'SO':
Miary pozycyjne:
- Średnia arytmetyczna: 0.16486043726990143
- Mediana: 0.16289376500000002
- Moda: 0.04902203
Miary przeciętne:
- Średnia arytmetyczna: 0.16486043726990143
- Średnia harmoniczna: 0.15346692899481953
- Średnia geometryczna: 0.0
Miary zmienności klasyczne:
- Wariancja: 0.0017045115313393045
- Odchylenie standardowe: 0.04128573035976601
Miary pozycyjne dla zmiennych ilościowych:
- Kwartyl 1 (Q1): 0.1361851875
- Kwartyl 3 (Q3): 0.19286027
- Rozstęp międzykwartylowy: 0.0566750825 ylowy: 0.0566750825
plot_histogram(df, species='SO', color='orange')
descriptive_statistics(df,species='SO')
Statystyki opisowe dla gatunku 'SO': Miary pozycyjne: Średnia arytmetyczna: 0.16486043726990143 Mediana: 0.16289376500000002 Moda: 0.04902203 Miary przeciętne: Średnia arytmetyczna: 0.16486043726990143 Średnia harmoniczna: 0.15346692899481953 Średnia geometryczna: 0.0 Miary zmienności klasyczne: Wariancja: 0.0017045115313393045 Odchylenie standardowe: 0.04128573035976601 Miary pozycyjne dla zmiennych ilościowych: Kwartyl 1 (Q1): 0.1361851875 Kwartyl 3 (Q3): 0.19286027 Rozstęp międzykwartylowy: 0.0566750825
Parametry dla dębu
Statystyki opisowe dla gatunku DB
Miary pozycyjne:
- Średnia arytmetyczna: 0.4243005584733333
- Mediana: 0.42180143999999997
- Moda: 0.3600402
Miary przeciętne:
- Średnia arytmetyczna: 0.4243005584733333
- Średnia harmoniczna: 0.42035434595669174
- Średnia geometryczna: 0.0
Miary zmienności klasyczne:
- Wariancja: 0.0016796518236727545
- Odchylenie standardowe: 0.04098355552746436
Miary pozycyjne dla zmiennych ilościowych:
- Kwartyl 1 (Q1): 0.39590554499999997
- Kwartyl 3 (Q3): 0.45235550999999996
- Rozstęp międzykwartylowy: 0.05644996499999999
plot_histogram(df, species='DB', color='blue')
descriptive_statistics(df,species='DB')
Statystyki opisowe dla gatunku 'DB': Miary pozycyjne: Średnia arytmetyczna: 0.4243005584733333 Mediana: 0.42180143999999997 Moda: 0.3600402 Miary przeciętne: Średnia arytmetyczna: 0.4243005584733333 Średnia harmoniczna: 0.42035434595669174 Średnia geometryczna: 0.0 Miary zmienności klasyczne: Wariancja: 0.0016796518236727545 Odchylenie standardowe: 0.04098355552746436 Miary pozycyjne dla zmiennych ilościowych: Kwartyl 1 (Q1): 0.39590554499999997 Kwartyl 3 (Q3): 0.45235550999999996 Rozstęp międzykwartylowy: 0.05644996499999999
Test Shapiro-Wilka dla SO
assess_normality(df, species='SO')
Test Shapiro-Wilka dla gatunku SO: Statystyka testowa: 0.9979913830757141 Wartość p: 0.0020150956697762012 Hipoteza zerowa (rozkład normalny) jest odrzucana.
Test Shapiro-Wilka dla DB
assess_normality(df, species='DB')
Test Shapiro-Wilka dla gatunku DB: Statystyka testowa: 0.9962971806526184 Wartość p: 9.341939062323945e-07 Hipoteza zerowa (rozkład normalny) jest odrzucana.
Test U Manna-Whitneya
Test U Manna-Whitneya:
statystyka U = 42.0, p-wartość = 0.0 ** Istnieją istotne różnice między grupa**mi
compare_groups(df, 'SO', 'DB')
Test U Manna-Whitneya: statystyka U = 42.0, p-wartość = 0.0
Trenowanie modelu
Wyniki oceny modelu są następujące:
Dokładność (Accuracy): Wynosi 100%, co oznacza, że model poprawnie sklasyfikował wszystkie próbki w zbiorze testowym.
Macierz pomyłek (Confusion Matrix): Pokazuje, że:
- 615 próbek dla gatunku DB zostało poprawnie sklasyfikowanych jako DB (True Positives).
- 510 próbek dla gatunku SO zostało poprawnie sklasyfikowanych jako SO (True Positives).
- 3 próbki dla gatunku SO zostały błędnie sklasyfikowane jako DB (False Negatives).
- Nie ma błędnych sklasyfikowań dla gatunku DB.
Raport klasyfikacji (Classification Report): Zapewnia szczegółowe informacje na temat wyników klasyfikacji dla każdej klasy, w tym precision, recall, f1-score i support (liczba próbek w danej klasie).
- Precision (precyzja) dla obu klas wynosi 100%, co oznacza, że wszystkie pozytywne predykcje dla danej klasy są prawidłowe.
- Recall (czułość) dla klasy DB wynosi 100%, a dla klasy SO wynosi 99%, co oznacza, że większość próbek dla każdej klasy została poprawnie zidentyfikowana.
- F1-score dla obu klas wynosi 100%, co oznacza dobrą równowagę między precyzją a czułością.
- Support określa liczbę próbek w każdej klasie.
Rozmiar zbioru treningowego i testowego: W zbiorze treningowym znajduje się 4510 próbek, a w zbiorze testowym 1128 próbek.
#Impor bibliotek
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
import pandas as pd
import joblib
#Import zbioru danych
def read_csv_pandas(path_to_file):
df = pd.read_csv(path_to_file, sep=';')
return df
df = read_csv_pandas(r'C:\python_zaliczenie\CSV\trees.csv')
# Trenowanie modelu
# Podziel dane na zbiór treningowy i testowy
X_train, X_test, y_train, y_test = train_test_split(df['mean'].values.reshape(-1, 1), df['species'], test_size=0.2, random_state=42)
# Wytrenuj prosty model regresji logistycznej
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
# Określ ścieżkę, pod którą zapiszesz model
model_path = "C:\python_zaliczenie\model\so_db_model.joblib"
# Zapisz model do zmiennej 'model_path'
joblib.dump(model, model_path)
# Przewiduj na zbiorze testowym
y_pred = model.predict(X_test)
# Ocen jakość modelu
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
classification_rep = classification_report(y_test, y_pred)
# Wyświetl wyniki
print(f'Accuracy: {accuracy:.2f}')
print(f'Confusion Matrix:\n{conf_matrix}')
print(f'Classification Report:\n{classification_rep}')
print("Rozmiar zbioru treningowego:", len(X_train))
print("Rozmiar zbioru testowego:", len(X_test))
Accuracy: 1.00 Confusion Matrix: [[615 0] [ 3 510]] Classification Report: precision recall f1-score support DB 1.00 1.00 1.00 615 SO 1.00 0.99 1.00 513 accuracy 1.00 1128 macro avg 1.00 1.00 1.00 1128 weighted avg 1.00 1.00 1.00 1128 Rozmiar zbioru treningowego: 4510 Rozmiar zbioru testowego: 1128
Wyniki dla danych testowych (nie znanych przez model)
Ocena modelu
Dokładność (Accuracy): Wynosi 98%, co oznacza, że model poprawnie sklasyfikował 98% wszystkich próbek w zbiorze testowym.
Macierz pomyłek (Confusion Matrix): Pokazuje, że:
- 486 próbek dla gatunku DB zostało poprawnie sklasyfikowanych jako DB (True Positives), natomiast 18 próbek zostało błędnie sklasyfikowanych jako SO (False Negatives).
- Wszystkie 600 próbek dla gatunku SO zostało poprawnie sklasyfikowanych jako SO (True Positives), bez błędów.
Raport klasyfikacji (Classification Report): Zapewnia szczegółowe informacje na temat wyników klasyfikacji dla każdej klasy, w tym precision, recall, f1-score i support (liczba próbek w danej klasie).
- Precyzja (Precision) dla klasy DB wynosi 100%, co oznacza, że wszystkie pozytywne predykcje dla klasy DB są prawidłowe. Dla klasy SO precyzja wynosi 97%, co oznacza, że większość pozytywnych predykcji dla klasy SO jest prawidłowa.
- Czułość (Recall) dla klasy DB wynosi 96%, co oznacza, że model zidentyfikował 96% wszystkich rzeczywistych próbek klasy DB. Dla klasy SO czułość wynosi 100%, co oznacza, że model zidentyfikował wszystkie rzeczywiste próbki klasy SO.
- F1-score dla klasy DB wynosi 98%, a dla klasy SO wynosi 99%, co oznacza, że model osiągnął dobrą równowagę między precyzją a czułością dla obu klas.
- Support określa liczbę próbek w każdej klasie.
Rozmiar zbioru testowego: W zbiorze testowym znajduje się 1104 próbek.
#import biblitek
import joblib
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
# Wczytaj model (upewnij się, że wcześniej model został zapisany)
model_path = "C:\python_zaliczenie\model\so_db_model.joblib"
model = joblib.load(model_path)
# Wczytaj dane testowe z pliku CSV
test_data = pd.read_csv(r'C:\python_zaliczenie\CSV\trees_results_for_testing_model.csv', sep=';')
# Dodaj numer ID do danych wynikowych
test_data['id'] = test_data['id'].astype(str) # Jeśli numer ID jest liczbą, zamień na tekst
test_data['id'] = test_data['id'] + '_predicted' # Dodaj '_predicted' do numeru ID
# Przewiduj gatunki dla danych testowych
predicted_classes = model.predict(test_data['mean'].values.reshape(-1, 1))
predicted_probabilities = model.predict_proba(test_data['mean'].values.reshape(-1, 1))[:, 1]
# Dodaj przewidziane gatunki i prawdopodobieństwa do danych wynikowych
test_data['predicted_species'] = predicted_classes
test_data['predicted_probabilities'] = predicted_probabilities
# Sprawdź dokładność modelu
accuracy = accuracy_score(test_data['species'], predicted_classes)
conf_matrix = confusion_matrix(test_data['species'], predicted_classes)
classification_rep = classification_report(test_data['species'], predicted_classes)
# Wyświetl wyniki dokładności
print(f'Accuracy: {accuracy:.2f}')
print(f'Confusion Matrix:\n{conf_matrix}')
print(f'Classification Report:\n{classification_rep}')
Accuracy: 0.98 Confusion Matrix: [[486 18] [ 0 600]] Classification Report: precision recall f1-score support DB 1.00 0.96 0.98 504 SO 0.97 1.00 0.99 600 accuracy 0.98 1104 macro avg 0.99 0.98 0.98 1104 weighted avg 0.98 0.98 0.98 1104
# Oblicz krzywą ROC
fpr, tpr, thresholds = roc_curve(test_data['species'], predicted_probabilities, pos_label='SO')
roc_auc = auc(fpr, tpr)
# Wykres krzywej ROC
plt.figure(figsize=(8, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()
# Zapisz wyniki do nowego pliku CSV
test_data.to_csv(r'C:\python_zaliczenie\CSV\test_results.csv', index=False)
# Wyświetl wyniki
print(test_data[['id', 'mean', 'species', 'predicted_species', 'predicted_probabilities']])
id mean species predicted_species \ 0 11278_predicted 0.348585 DB DB 1 11279_predicted 0.415473 DB DB 2 11280_predicted 0.370652 DB DB 3 11281_predicted 0.399752 DB DB 4 11282_predicted 0.319520 DB DB ... ... ... ... ... 1099 12377_predicted 0.144427 SO SO 1100 12378_predicted 0.168115 SO SO 1101 12379_predicted 0.139620 SO SO 1102 12380_predicted 0.147586 SO SO 1103 12381_predicted 0.184157 SO SO predicted_probabilities 0 0.199945 1 0.045746 2 0.126598 3 0.066005 4 0.338688 ... ... 1099 0.974745 1100 0.955569 1101 0.977507 1102 0.972752 1103 0.935376 [1104 rows x 5 columns]