2024-05-11 22:20:35 +02:00
|
|
|
import pandas as pd
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from sklearn.preprocessing import LabelEncoder
|
|
|
|
from sklearn.tree import plot_tree
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn.tree import export_text
|
2024-05-13 15:05:11 +02:00
|
|
|
from joblib import dump
|
2024-05-11 22:20:35 +02:00
|
|
|
|
|
|
|
data = pd.read_csv("data.csv")
|
|
|
|
|
|
|
|
labels = {}
|
|
|
|
for column in data.columns:
|
|
|
|
if data[column].dtype == 'object':
|
|
|
|
labels[column] = LabelEncoder()
|
|
|
|
data[column] = labels[column].fit_transform(data[column])
|
|
|
|
|
|
|
|
labels['decyzja'] = LabelEncoder()
|
|
|
|
data['decyzja'] = labels['decyzja'].fit_transform(data['decyzja'])
|
|
|
|
|
|
|
|
x = data.drop('decyzja', axis=1)
|
|
|
|
y = data['decyzja']
|
|
|
|
|
|
|
|
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=3)
|
|
|
|
|
2024-05-12 03:03:41 +02:00
|
|
|
treeclf = DecisionTreeClassifier(criterion='entropy')
|
2024-05-11 22:20:35 +02:00
|
|
|
treeclf.fit(x_train, y_train)
|
|
|
|
|
|
|
|
y_pred = treeclf.predict(x_test)
|
|
|
|
|
|
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
|
|
print("Dokładność:", accuracy)
|
|
|
|
|
2024-05-13 15:05:11 +02:00
|
|
|
# Zapisanie modelu do pliku
|
|
|
|
dump(treeclf, 'drzewo.joblib')
|
2024-05-11 22:20:35 +02:00
|
|
|
|
|
|
|
class_names = [str(class_label) for class_label in labels['decyzja'].classes_]
|
|
|
|
plt.figure(figsize=(25,20))
|
|
|
|
plot_tree(treeclf, feature_names=x.columns, class_names=class_names, filled=True)
|
|
|
|
plt.show()
|
|
|
|
|
2024-05-13 15:05:11 +02:00
|
|
|
for (column, encoder) in labels.items():
|
|
|
|
if column == 'decyzja':
|
|
|
|
continue
|
|
|
|
print(f"{column}:")
|
|
|
|
for (i, label) in enumerate(encoder.classes_):
|
|
|
|
print(f"{i}: {label}")
|
|
|
|
|
2024-05-11 22:20:35 +02:00
|
|
|
tree_text = export_text(treeclf, feature_names=list(x.columns))
|
|
|
|
tree_file_path = "wyuczone_drzewo.txt"
|
|
|
|
with open(tree_file_path, "w") as tree_file:
|
|
|
|
tree_file.write(tree_text)
|