Madra_smieciarka/classes/Decisiontree.py

47 lines
1.4 KiB
Python
Raw Normal View History

2024-05-11 22:20:35 +02:00
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.tree import export_text
data = pd.read_csv("data.csv")
labels = {}
for column in data.columns:
if data[column].dtype == 'object':
labels[column] = LabelEncoder()
data[column] = labels[column].fit_transform(data[column])
labels['decyzja'] = LabelEncoder()
data['decyzja'] = labels['decyzja'].fit_transform(data['decyzja'])
x = data.drop('decyzja', axis=1)
y = data['decyzja']
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=3)
2024-05-12 03:03:41 +02:00
treeclf = DecisionTreeClassifier(criterion='entropy')
2024-05-11 22:20:35 +02:00
treeclf.fit(x_train, y_train)
y_pred = treeclf.predict(x_test)
#print("Przewidywane etykiety dla danych testowych:")
#print(y_pred)
accuracy = accuracy_score(y_test, y_pred)
print("Dokładność:", accuracy)
class_names = [str(class_label) for class_label in labels['decyzja'].classes_]
plt.figure(figsize=(25,20))
plot_tree(treeclf, feature_names=x.columns, class_names=class_names, filled=True)
plt.show()
tree_text = export_text(treeclf, feature_names=list(x.columns))
tree_file_path = "wyuczone_drzewo.txt"
with open(tree_file_path, "w") as tree_file:
tree_file.write(tree_text)