Projekt_AI-Automatyczny_saper/Engine/DecisionTree.py

61 lines
2.2 KiB
Python
Raw Normal View History

2021-05-18 00:21:14 +02:00
import pandas as pd
2021-05-18 19:01:17 +02:00
import json
from sklearn import tree
import pydotplus
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import matplotlib.image as pltimg
import pickle
2021-05-18 00:21:14 +02:00
class DecisionTree:
2021-05-18 17:25:03 +02:00
2021-05-18 19:01:17 +02:00
def __init__(self, doCreation):
2021-06-22 16:44:44 +02:00
self.data = pd.read_csv("C:\\Users\\Gabriel\\Projekt_AI-Automatyczny_saper\\out.csv")
2021-05-18 19:01:17 +02:00
if doCreation:
self.mapData()
features = ['bomb_type', 'detonation_duration', 'size', 'detonation_area', 'defusable']
X = self.data[features]
y = self.data['action']
dtree = DecisionTreeClassifier()
dtree = dtree.fit(X, y)
decision_tree_model_pkl = open('tree.pkl', 'wb')
pickle.dump(dtree, decision_tree_model_pkl)
decision_tree_model_pkl.close()
2021-06-22 16:44:44 +02:00
decision_tree_model_pkl = open('C:\\Users\\Gabriel\\Projekt_AI-Automatyczny_saper\\Engine\\tree.pkl', 'rb')
self.dtree = pickle.load(decision_tree_model_pkl)
def getTree(self):
return self.dtree
def mapData(self):
d = {'Atomic Bomb': 0, 'Claymore': 1, 'Land Mine': 2, 'Chemical Bomb': 3, 'Decoy': 4}
self.data['bomb_type'] = self.data['bomb_type'].map(d)
d = {'immediate': 0, 'short': 1, 'long': 2, 'none': 3}
self.data['detonation_duration'] = self.data['detonation_duration'].map(d)
d = {'small': 0, 'medium': 1, 'large': 2}
self.data['size'] = self.data['size'].map(d)
d = {'small': 0, 'large': 1}
self.data['detonation_area'] = self.data['detonation_area'].map(d)
d = {'no': 0, 'yes': 1}
self.data['defusable'] = self.data['defusable'].map(d)
d = {'detonate': 0, 'poligon': 1, 'defuse': 2}
self.data['action'] = self.data['action'].map(d)
def mapAction(self, action):
d = {0 : 'detonate', 1 : 'poligon', 2 : 'defuse'}
return d.get(action)
2021-05-18 00:21:14 +02:00
2021-05-18 19:01:17 +02:00
if __name__ == "__main__":
# data = pd.read_csv("C:\\Users\\kratu\\PycharmProjects\\Projekt_AI-Automatyczny_saper\\out.csv")
DecisionTree(True)
# with open('C:\\Users\\kratu\\PycharmProjects\\Projekt_AI-Automatyczny_saper\\DecisionTree.json', 'w') as fp:
# json.dump(tree, fp)
2021-05-18 19:01:17 +02:00
2021-05-18 00:21:14 +02:00
2021-05-18 17:25:03 +02:00