Male_zoo_Projekt_SI/decision_tree.py

50 lines
1.8 KiB
Python

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
headers = ['adult','active_time','ill','season','guests','hunger','wet_food','dry_food']
# Wczytanie danych
data = pd.read_csv('dane.csv', header=0)
X = data[['adult','active_time','ill','season','guests','hunger','wet_food','dry_food']]
Y = data['decision']
X = pd.get_dummies(data=X, columns=['season'])
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8)
clf = DecisionTreeClassifier(random_state=0, min_samples_leaf = 4, min_samples_split=2)
clf = clf.fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
accuracy = accuracy_score(Y_test, Y_pred)
print("Dokładność:", accuracy)
#zapisanie drzewa do pliku
plt.figure(figsize=(50,30))
plot_tree(clf, filled=True, feature_names=X.columns, class_names=['nie karmi', 'karmi mokrą karmą', 'karmi suchą karmą']) # filled=True koloruje węzły
plt.savefig('tree.png')
# Nowe dane
def add_data(adult,active_time,ill,season,guests,hunger,wet_food,dry_food):
X_new = pd.DataFrame({
'adult': [adult],
'active_time': [active_time],
'ill': [ill],
'season': [season],
'guests':[guests],
'hunger': [hunger],
'wet_food': [wet_food],
'dry_food': [dry_food]
})
X_new_encoded = pd.get_dummies(X_new)
missing_columns = set(X.columns) - set(X_new.columns)
for col in missing_columns:
X_new_encoded[col] = 0
X_new_encoded = X_new_encoded.reindex(columns=X.columns, fill_value=0)
print("Atrybuty zwierzęcia:", adult,active_time,ill,season,guests,hunger,wet_food,dry_food)
return (clf.predict(X_new_encoded))