Madra_smieciarka/classes/Decisiontree.py
Neerka 8c37aaea5d tree complete
Signed-off-by: Neerka <kuba.markil0220@gmail.com>
2024-05-13 15:05:11 +02:00

54 lines
1.6 KiB
Python

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.tree import export_text
from joblib import dump
data = pd.read_csv("data.csv")
labels = {}
for column in data.columns:
if data[column].dtype == 'object':
labels[column] = LabelEncoder()
data[column] = labels[column].fit_transform(data[column])
labels['decyzja'] = LabelEncoder()
data['decyzja'] = labels['decyzja'].fit_transform(data['decyzja'])
x = data.drop('decyzja', axis=1)
y = data['decyzja']
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=3)
treeclf = DecisionTreeClassifier(criterion='entropy')
treeclf.fit(x_train, y_train)
y_pred = treeclf.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)
print("Dokładność:", accuracy)
# Zapisanie modelu do pliku
dump(treeclf, 'drzewo.joblib')
class_names = [str(class_label) for class_label in labels['decyzja'].classes_]
plt.figure(figsize=(25,20))
plot_tree(treeclf, feature_names=x.columns, class_names=class_names, filled=True)
plt.show()
for (column, encoder) in labels.items():
if column == 'decyzja':
continue
print(f"{column}:")
for (i, label) in enumerate(encoder.classes_):
print(f"{i}: {label}")
tree_text = export_text(treeclf, feature_names=list(x.columns))
tree_file_path = "wyuczone_drzewo.txt"
with open(tree_file_path, "w") as tree_file:
tree_file.write(tree_text)