2024-05-12 16:25:09 +02:00
|
|
|
import pandas as pd
|
2024-05-13 12:49:19 +02:00
|
|
|
from sklearn.tree import DecisionTreeClassifier, plot_tree
|
2024-05-12 16:25:09 +02:00
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
headers = ['adult','active_time','ill','season','guests','hunger','wet_food','dry_food']
|
|
|
|
# Wczytanie danych
|
|
|
|
data = pd.read_csv('dane.csv', header=0)
|
2024-05-12 23:38:23 +02:00
|
|
|
X = data[headers]
|
2024-05-12 16:25:09 +02:00
|
|
|
Y = data['decision']
|
|
|
|
X = pd.get_dummies(data=X, columns=['season'])
|
2024-05-13 12:49:19 +02:00
|
|
|
clf = DecisionTreeClassifier(max_depth=6)
|
|
|
|
X1, X2, Y1, Y2 = train_test_split(X, Y, train_size=0.8)
|
|
|
|
clf = clf.fit(X1, Y1)
|
|
|
|
pred = clf.predict(X2)
|
|
|
|
accuracy = accuracy_score(Y2, pred)
|
2024-05-12 16:25:09 +02:00
|
|
|
print("Dokładność:", accuracy)
|
2024-05-13 12:49:19 +02:00
|
|
|
|
2024-05-12 16:25:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
#zapisanie drzewa do pliku
|
|
|
|
plt.figure(figsize=(50,30))
|
2024-05-13 12:49:19 +02:00
|
|
|
plot_tree(clf, filled=True, feature_names=X.columns.tolist(), class_names=['nie karmi', 'karmi mokrą karmą', 'karmi suchą karmą'])
|
|
|
|
plt.savefig('tree.png')
|
2024-05-12 23:38:23 +02:00
|
|
|
# dane do decyzji
|
2024-05-12 21:07:31 +02:00
|
|
|
def feed_decision(adult,active_time,ill,season,guests,hunger,dry_food,wet_food):
|
2024-05-13 12:49:19 +02:00
|
|
|
|
2024-05-12 16:25:09 +02:00
|
|
|
X_new = pd.DataFrame({
|
|
|
|
'adult': [adult],
|
|
|
|
'active_time': [active_time],
|
|
|
|
'ill': [ill],
|
|
|
|
'season': [season],
|
|
|
|
'guests':[guests],
|
|
|
|
'hunger': [hunger],
|
|
|
|
'wet_food': [wet_food],
|
|
|
|
'dry_food': [dry_food]
|
|
|
|
})
|
2024-05-12 21:07:31 +02:00
|
|
|
X_new = pd.get_dummies(X_new)
|
|
|
|
missing_columns = set(X.columns) - set(X_new)
|
2024-05-12 16:25:09 +02:00
|
|
|
for col in missing_columns:
|
2024-05-12 21:07:31 +02:00
|
|
|
X_new[col] = False
|
|
|
|
X_new = X_new.reindex(columns=X.columns, fill_value=0)
|
2024-05-12 16:25:09 +02:00
|
|
|
print("Atrybuty zwierzęcia:", adult,active_time,ill,season,guests,hunger,wet_food,dry_food)
|
2024-05-12 21:07:31 +02:00
|
|
|
return (clf.predict(X_new))
|
2024-05-12 16:25:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|