Read tree from file

This commit is contained in:
Mateusz Kantorski 2023-05-30 19:09:24 +02:00
parent 3f65be482c
commit bdea00c5b2
3 changed files with 20 additions and 27 deletions

Binary file not shown.

View File

@ -1,4 +1,5 @@
import graphviz import graphviz
import joblib
import pandas as pd import pandas as pd
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz from sklearn.tree import export_graphviz
@ -23,36 +24,26 @@ def make_tree():
x = pd.read_csv('DecisionTree/training_data.txt', delimiter=';', x = pd.read_csv('DecisionTree/training_data.txt', delimiter=';',
names=['wielkosc', 'waga,', 'priorytet', 'ksztalt', 'kruchosc', 'dolna', 'gorna', 'g > d']) names=['wielkosc', 'waga,', 'priorytet', 'ksztalt', 'kruchosc', 'dolna', 'gorna', 'g > d'])
y = pd.read_csv('DecisionTree/decisions.txt', names=['polka']) y = pd.read_csv('DecisionTree/decisions.txt', names=['polka'])
# X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1) # 70% treningowe and 30% testowe
# Tworzenie instancji klasyfikatora ID3 # Tworzenie instancji klasyfikatora ID3
clf = DecisionTreeClassifier(criterion='entropy') clf = DecisionTreeClassifier(criterion='entropy')
# Trenowanie klasyfikatora # Trenowanie klasyfikatora
clf.fit(x.values, y.values) clf.fit(x.values, y.values)
# clf.fit(X_train, y_train)
# Zapis drzewa do pliku
joblib.dump(clf, 'DecisionTree/wyuczone_drzewo.pkl')
return clf return clf
# # Predykcja na nowych danych
# new_data = [[2, 2, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]
# predictions = clf.predict(new_data)
# # y_pred = clf.predict(X_test)
def stworz_plik_drzewa_w_pdf(clf, feature_names, class_names):
# Wygenerowanie pliku .dot reprezentującego drzewo
dot_data = export_graphviz(clf, out_file=None, feature_names=feature_names, class_names=class_names, filled=True,
rounded=True)
# Tworzenie obiektu graphviz z pliku .dot
graph = graphviz.Source(dot_data)
# print(predictions) # Wyświetlanie drzewa
# # print("Accuracy:", clf.score(new_data, predictions)) graph.view()
# # print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
# Wygenerowanie pliku .dot reprezentującego drzewo
# dot_data = export_graphviz(clf, out_file=None, feature_names=list(x.columns), class_names=['0', '1'], filled=True,
# rounded=True)
# # Tworzenie obiektu graphviz z pliku .dot
# graph = graphviz.Source(dot_data)
# # Wyświetlanie drzewa
# graph.view()
# z = pd.concat([x, y], axis=1)
# z.to_csv('dane.csv', index=False)

10
main.py
View File

@ -1,13 +1,13 @@
import sys import sys
import joblib
import pygame import pygame
from paczka import Paczka from paczka import Paczka
from wozek import Wozek from wozek import Wozek
import wyszukiwanie import wyszukiwanie
import ekran import ekran
from grid import GridCellType, SearchGrid from grid import GridCellType, SearchGrid
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import drzewo_decyzyjne
from plansza import a_pix, b_pix from plansza import a_pix, b_pix
pygame.init() pygame.init()
@ -19,7 +19,9 @@ def main():
p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any) p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any)
ekran.dodaj_paczki_na_rampe(p1, p2) ekran.dodaj_paczki_na_rampe(p1, p2)
grid_points = SearchGrid() grid_points = SearchGrid()
drzewo = drzewo_decyzyjne.make_tree()
# Odczyt drzewa z pliku
drzewo = joblib.load('DecisionTree/wyuczone_drzewo.pkl')
while True: while True:
for event in pygame.event.get(): for event in pygame.event.get():