diff --git a/DecisionTree/wyuczone_drzewo.pkl b/DecisionTree/wyuczone_drzewo.pkl new file mode 100644 index 0000000..751d454 Binary files /dev/null and b/DecisionTree/wyuczone_drzewo.pkl differ diff --git a/drzewo_decyzyjne.py b/drzewo_decyzyjne.py index f8efe20..a563a7e 100644 --- a/drzewo_decyzyjne.py +++ b/drzewo_decyzyjne.py @@ -1,4 +1,5 @@ import graphviz +import joblib import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz @@ -23,36 +24,26 @@ def make_tree(): x = pd.read_csv('DecisionTree/training_data.txt', delimiter=';', names=['wielkosc', 'waga,', 'priorytet', 'ksztalt', 'kruchosc', 'dolna', 'gorna', 'g > d']) y = pd.read_csv('DecisionTree/decisions.txt', names=['polka']) - # X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1) # 70% treningowe and 30% testowe + # Tworzenie instancji klasyfikatora ID3 clf = DecisionTreeClassifier(criterion='entropy') # Trenowanie klasyfikatora clf.fit(x.values, y.values) - # clf.fit(X_train, y_train) + + # Zapis drzewa do pliku + joblib.dump(clf, 'DecisionTree/wyuczone_drzewo.pkl') + return clf -# # Predykcja na nowych danych -# new_data = [[2, 2, 1, 0, 1, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]] -# predictions = clf.predict(new_data) -# # y_pred = clf.predict(X_test) +def stworz_plik_drzewa_w_pdf(clf, feature_names, class_names): + # Wygenerowanie pliku .dot reprezentującego drzewo + dot_data = export_graphviz(clf, out_file=None, feature_names=feature_names, class_names=class_names, filled=True, + rounded=True) + # Tworzenie obiektu graphviz z pliku .dot + graph = graphviz.Source(dot_data) -# print(predictions) -# # print("Accuracy:", clf.score(new_data, predictions)) -# # print("Accuracy:", metrics.accuracy_score(y_test, y_pred)) - - -# Wygenerowanie pliku .dot reprezentującego drzewo -# dot_data = export_graphviz(clf, out_file=None, feature_names=list(x.columns), class_names=['0', '1'], filled=True, -# rounded=True) - -# # Tworzenie obiektu graphviz z pliku .dot -# graph = graphviz.Source(dot_data) - -# # Wyświetlanie drzewa -# graph.view() - -# z = pd.concat([x, y], axis=1) -# z.to_csv('dane.csv', index=False) + # Wyświetlanie drzewa + graph.view() diff --git a/main.py b/main.py index 868f509..2282dc5 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,13 @@ import sys + +import joblib import pygame from paczka import Paczka from wozek import Wozek import wyszukiwanie import ekran from grid import GridCellType, SearchGrid -from sklearn.tree import DecisionTreeClassifier -import pandas as pd -import drzewo_decyzyjne + from plansza import a_pix, b_pix pygame.init() @@ -19,7 +19,9 @@ def main(): p2 = Paczka('maly', 1, 'ogród', False, True, False, any, any, any, any, any) ekran.dodaj_paczki_na_rampe(p1, p2) grid_points = SearchGrid() - drzewo = drzewo_decyzyjne.make_tree() + + # Odczyt drzewa z pliku + drzewo = joblib.load('DecisionTree/wyuczone_drzewo.pkl') while True: for event in pygame.event.get():