decisionTree update
This commit is contained in:
parent
f7b6da279d
commit
77625d79e4
@ -3,7 +3,6 @@ import random
|
||||
from data.ClientParams import ClientParams
|
||||
from data.enum.CompanySize import CompanySize
|
||||
|
||||
|
||||
class ClientParamsFactory:
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
from AgentBase import AgentBase
|
||||
from data.Direction import Direction
|
||||
from data.enum.Direction import Direction
|
||||
from decision.ActionType import ActionType
|
||||
|
||||
|
||||
|
51
data/TEST/importedData.csv
Normal file
51
data/TEST/importedData.csv
Normal file
@ -0,0 +1,51 @@
|
||||
DELAY,PAYED,NET-WORTH,INFLUENCE,SKARBOWKA,MEMBER,HAT,SIZE,PRIORITY
|
||||
11,FALSE,41,97,TRUE,FALSE,TRUE,CompanySize.HUGE,LOW
|
||||
7,FALSE,22,80,TRUE,FALSE,FALSE,CompanySize.NORMAL,LOW
|
||||
3,FALSE,58,0,TRUE,TRUE,TRUE,CompanySize.BIG,LOW
|
||||
11,FALSE,3,15,FALSE,TRUE,FALSE,CompanySize.NO,LOW
|
||||
5,FALSE,42,18,TRUE,FALSE,FALSE,CompanySize.SMALL,LOW
|
||||
4,TRUE,51,54,TRUE,FALSE,TRUE,CompanySize.NORMAL,HIGH
|
||||
7,TRUE,18,47,FALSE,TRUE,TRUE,CompanySize.SMALL,LOW
|
||||
3,TRUE,96,61,FALSE,TRUE,TRUE,CompanySize.SMALL,MEDIUM
|
||||
13,FALSE,42,44,TRUE,TRUE,TRUE,CompanySize.NORMAL,LOW
|
||||
4,TRUE,6,3,FALSE,FALSE,FALSE,CompanySize.NORMAL,LOW
|
||||
10,TRUE,91,36,FALSE,FALSE,TRUE,CompanySize.NO,MEDIUM
|
||||
1,TRUE,80,11,TRUE,FALSE,TRUE,CompanySize.GIGANTISHE,HIGH
|
||||
6,FALSE,91,82,FALSE,TRUE,TRUE,CompanySize.NO,LOW
|
||||
4,FALSE,1,93,TRUE,TRUE,FALSE,CompanySize.BIG,LOW
|
||||
14,FALSE,67,13,TRUE,TRUE,TRUE,CompanySize.SMALL,LOW
|
||||
0,FALSE,7,58,FALSE,FALSE,FALSE,CompanySize.NORMAL,LOW
|
||||
8,TRUE,74,67,TRUE,TRUE,TRUE,CompanySize.NORMAL,HIGH
|
||||
4,TRUE,33,43,FALSE,TRUE,FALSE,CompanySize.BIG,MEDIUM
|
||||
8,TRUE,74,44,TRUE,FALSE,TRUE,CompanySize.HUGE,HIGH
|
||||
14,FALSE,59,33,TRUE,FALSE,FALSE,CompanySize.NORMAL,LOW
|
||||
6,FALSE,87,80,TRUE,TRUE,FALSE,CompanySize.GIGANTISHE,LOW
|
||||
10,FALSE,2,45,FALSE,FALSE,FALSE,CompanySize.BIG,LOW
|
||||
7,FALSE,74,17,TRUE,FALSE,FALSE,CompanySize.SMALL,LOW
|
||||
14,FALSE,14,80,FALSE,TRUE,FALSE,CompanySize.NO,LOW
|
||||
1,FALSE,74,82,TRUE,TRUE,FALSE,CompanySize.NO,LOW
|
||||
13,FALSE,66,50,FALSE,TRUE,TRUE,CompanySize.HUGE,LOW
|
||||
12,TRUE,55,82,TRUE,TRUE,TRUE,CompanySize.NO,HIGH
|
||||
0,TRUE,63,1,TRUE,TRUE,TRUE,CompanySize.NO,HIGH
|
||||
0,FALSE,39,70,FALSE,FALSE,TRUE,CompanySize.NORMAL,LOW
|
||||
1,FALSE,14,66,FALSE,FALSE,FALSE,CompanySize.BIG,LOW
|
||||
7,FALSE,48,86,TRUE,TRUE,TRUE,CompanySize.BIG,LOW
|
||||
7,FALSE,39,41,FALSE,TRUE,FALSE,CompanySize.HUGE,LOW
|
||||
6,TRUE,29,90,FALSE,FALSE,FALSE,CompanySize.GIGANTISHE,HIGH
|
||||
8,TRUE,79,49,FALSE,TRUE,FALSE,CompanySize.BIG,HIGH
|
||||
14,TRUE,51,51,FALSE,TRUE,FALSE,CompanySize.NO,LOW
|
||||
1,FALSE,92,97,FALSE,TRUE,TRUE,CompanySize.HUGE,LOW
|
||||
6,FALSE,92,90,TRUE,FALSE,FALSE,CompanySize.GIGANTISHE,LOW
|
||||
9,FALSE,89,34,FALSE,TRUE,TRUE,CompanySize.NO,LOW
|
||||
14,FALSE,85,8,FALSE,FALSE,TRUE,CompanySize.HUGE,LOW
|
||||
14,TRUE,86,30,FALSE,TRUE,FALSE,CompanySize.GIGANTISHE,MEDIUM
|
||||
3,TRUE,82,57,FALSE,TRUE,FALSE,CompanySize.BIG,MEDIUM
|
||||
8,TRUE,18,44,FALSE,TRUE,FALSE,CompanySize.HUGE,LOW
|
||||
0,FALSE,87,32,FALSE,FALSE,FALSE,CompanySize.NO,LOW
|
||||
10,TRUE,97,26,FALSE,TRUE,TRUE,CompanySize.HUGE,HIGH
|
||||
0,FALSE,88,98,FALSE,TRUE,FALSE,CompanySize.NO,LOW
|
||||
10,TRUE,27,82,FALSE,FALSE,FALSE,CompanySize.HUGE,MEDIUM
|
||||
8,TRUE,28,36,TRUE,FALSE,FALSE,CompanySize.HUGE,HIGH
|
||||
14,FALSE,48,94,TRUE,FALSE,TRUE,CompanySize.HUGE,LOW
|
||||
7,FALSE,40,63,TRUE,TRUE,TRUE,CompanySize.BIG,LOW
|
||||
8,FALSE,90,20,TRUE,TRUE,FALSE,CompanySize.NO,LOW
|
|
92
main.py
92
main.py
@ -1,14 +1,21 @@
|
||||
import csv
|
||||
import random
|
||||
|
||||
import pandas
|
||||
|
||||
from mesa.visualization.ModularVisualization import ModularServer
|
||||
from mesa.visualization.modules import CanvasGrid
|
||||
from sklearn import metrics, preprocessing
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
from ClientParamsFactory import ClientParamsFactory
|
||||
from ForkliftAgent import ForkliftAgent
|
||||
from GameModel import GameModel
|
||||
from PatchAgent import PatchAgent
|
||||
from PatchType import PatchType
|
||||
from data.enum.CompanySize import CompanySize
|
||||
from data.enum.Direction import Direction
|
||||
from util.PathDefinitions import GridWithWeights
|
||||
from data.enum.Priority import Priority
|
||||
|
||||
colors = [
|
||||
'blue', 'cyan', 'orange', 'yellow', 'magenta', 'purple', '#103d3e', '#9fc86c',
|
||||
@ -52,23 +59,74 @@ def agent_portrayal(agent):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base = 512
|
||||
gridWidth = 10
|
||||
gridHeight = 10
|
||||
scale = base / gridWidth
|
||||
test = ClientParamsFactory()
|
||||
|
||||
diagram4 = GridWithWeights(gridWidth, gridHeight)
|
||||
diagram4.walls = [(6, 5), (6, 6), (6, 7), (6, 8), (2, 3), (2, 4), (3, 4), (4, 4), (6, 4)]
|
||||
header = ['DELAY',
|
||||
'PAYED',
|
||||
'NET-WORTH',
|
||||
'INFLUENCE',
|
||||
'SKARBOWKA',
|
||||
'MEMBER',
|
||||
'HAT',
|
||||
'SIZE']
|
||||
|
||||
diagram5 = GridWithWeights(gridWidth, gridHeight)
|
||||
diagram5.puddles = [(2, 2), (2, 5), (2, 6), (5, 4)]
|
||||
with open("data/TEST/generatedData.csv", 'w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
grid = CanvasGrid(agent_portrayal, gridWidth, gridHeight, scale * gridWidth, scale * gridHeight)
|
||||
writer.writerow(header)
|
||||
|
||||
server = ModularServer(GameModel,
|
||||
[grid],
|
||||
"Automatyczny Wózek Widłowy",
|
||||
{"width": gridHeight, "height": gridWidth, "graph": diagram4, "graph2": diagram5},)
|
||||
for i in range(50):
|
||||
data = test.get_client_params()
|
||||
|
||||
server.port = 8888
|
||||
server.launch()
|
||||
writer.writerow([data.payment_delay,
|
||||
data.payed,
|
||||
data.net_worth,
|
||||
data.infuence_rate,
|
||||
data.is_skarbowka,
|
||||
data.membership,
|
||||
data.is_hat,
|
||||
data.company_size])
|
||||
|
||||
file.close()
|
||||
|
||||
data_input = pandas.read_csv('data/TEST/importedData.csv', delimiter=",")
|
||||
|
||||
X = data_input[['DELAY','PAYED','NET-WORTH','INFLUENCE','SKARBOWKA','MEMBER','HAT','SIZE']].values
|
||||
Y = data_input["PRIORITY"]
|
||||
|
||||
label_BP = preprocessing.LabelEncoder()
|
||||
label_BP.fit(['CompanySize.NO', 'CompanySize.SMALL', 'CompanySize.NORMAL', 'CompanySize.BIG', 'CompanySize.HUGE', 'CompanySize.GIGANTISHE'])
|
||||
X[:, 7] = label_BP.transform(X[:, 7])
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1, train_size=0.9)
|
||||
|
||||
drugTree = DecisionTreeClassifier(criterion="entropy", max_depth=4)
|
||||
|
||||
drugTree.fit(X_train, y_train)
|
||||
predicted = drugTree.predict(X_test)
|
||||
|
||||
print(X_test)
|
||||
print(predicted)
|
||||
|
||||
print("\nDecisionTrees's Accuracy: ", metrics.accuracy_score(y_test, predicted))
|
||||
|
||||
# base = 512
|
||||
# gridWidth = 10
|
||||
# gridHeight = 10
|
||||
# scale = base / gridWidth
|
||||
#
|
||||
# diagram4 = GridWithWeights(gridWidth, gridHeight)
|
||||
# diagram4.walls = [(6, 5), (6, 6), (6, 7), (6, 8), (2, 3), (2, 4), (3, 4), (4, 4), (6, 4)]
|
||||
#
|
||||
# diagram5 = GridWithWeights(gridWidth, gridHeight)
|
||||
# diagram5.puddles = [(2, 2), (2, 5), (2, 6), (5, 4)]
|
||||
#
|
||||
# grid = CanvasGrid(agent_portrayal, gridWidth, gridHeight, scale * gridWidth, scale * gridHeight)
|
||||
#
|
||||
# server = ModularServer(GameModel,
|
||||
# [grid],
|
||||
# "Automatyczny Wózek Widłowy",
|
||||
# {"width": gridHeight, "height": gridWidth, "graph": diagram4, "graph2": diagram5},)
|
||||
#
|
||||
# server.port = 8888
|
||||
# server.launch()
|
||||
|
@ -1,8 +1,9 @@
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
from data.Direction import Direction
|
||||
|
||||
from data.GameConstants import GameConstants
|
||||
from data.enum.Direction import Direction
|
||||
from decision.ActionType import ActionType
|
||||
from pathfinding.PathFinderState import PathFinderState
|
||||
from pathfinding.PrioritizedItem import PrioritizedItem
|
||||
|
@ -1,4 +1,5 @@
|
||||
jupyter
|
||||
matplotlib
|
||||
mesa
|
||||
numpy
|
||||
numpy
|
||||
sklearn
|
Loading…
Reference in New Issue
Block a user