133 lines
4.6 KiB
Python
133 lines
4.6 KiB
Python
import csv
|
|
import random
|
|
|
|
import pandas
|
|
|
|
from mesa.visualization.ModularVisualization import ModularServer
|
|
from mesa.visualization.modules import CanvasGrid
|
|
from sklearn import metrics, preprocessing
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
|
|
from ClientParamsFactory import ClientParamsFactory
|
|
from ForkliftAgent import ForkliftAgent
|
|
from PatchAgent import PatchAgent
|
|
from PatchType import PatchType
|
|
from data.enum.CompanySize import CompanySize
|
|
from data.enum.Direction import Direction
|
|
from data.enum.Priority import Priority
|
|
|
|
colors = [
|
|
'blue', 'cyan', 'orange', 'yellow', 'magenta', 'purple', '#103d3e', '#9fc86c',
|
|
'#b4c2ed', '#31767d', '#31a5fa', '#ba96e0', '#fef3e4', '#6237ac', '#f9cacd', '#1e8123'
|
|
]
|
|
|
|
|
|
def agent_portrayal(agent):
|
|
if isinstance(agent, ForkliftAgent):
|
|
shape = ""
|
|
if agent.current_rotation == Direction.top:
|
|
shape = "img/image_top.png"
|
|
elif agent.current_rotation == Direction.right:
|
|
shape = "img/image_right.png"
|
|
elif agent.current_rotation == Direction.down:
|
|
shape = "img/image_down.png"
|
|
elif agent.current_rotation == Direction.left:
|
|
shape = "img/image_left.png"
|
|
|
|
portrayal = {"Shape": shape, "scale": 1.0, "Layer": 0}
|
|
|
|
if isinstance(agent, PatchAgent):
|
|
color = colors[0]
|
|
if agent.patch_type == PatchType.wall:
|
|
portrayal = {"Shape": "img/brick.webp", "scale": 1.0, "Layer": 0}
|
|
elif agent.patch_type == PatchType.dropOff:
|
|
portrayal = {"Shape": "img/truck.png", "scale": 1.0, "Layer": 0}
|
|
elif agent.patch_type == PatchType.pickUp:
|
|
portrayal = {"Shape": "img/okB00mer.png", "scale": 1.0, "Layer": 0}
|
|
elif agent.patch_type == PatchType.diffTerrain:
|
|
portrayal = {"Shape": "img/puddle.png", "scale": 1.0, "Layer": 0}
|
|
else:
|
|
color = colors[random.randrange(13) + 3]
|
|
portrayal = {"Shape": "rect",
|
|
"Filled": "true",
|
|
"Layer": 0,
|
|
"Color": color,
|
|
"w": 1,
|
|
"h": 1}
|
|
return portrayal
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test = ClientParamsFactory()
|
|
|
|
header = ['DELAY',
|
|
'PAYED',
|
|
'NET-WORTH',
|
|
'INFLUENCE',
|
|
'SKARBOWKA',
|
|
'MEMBER',
|
|
'HAT',
|
|
'SIZE']
|
|
|
|
with open("data/TEST/generatedData.csv", 'w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
|
|
writer.writerow(header)
|
|
|
|
for i in range(200):
|
|
data = test.get_client_params()
|
|
|
|
writer.writerow([data.payment_delay,
|
|
data.payed,
|
|
data.net_worth,
|
|
data.infuence_rate,
|
|
data.is_skarbowka,
|
|
data.membership,
|
|
data.is_hat,
|
|
data.company_size])
|
|
|
|
file.close()
|
|
|
|
data_input = pandas.read_csv('data/TEST/importedData.csv', delimiter=",")
|
|
|
|
X = data_input[['DELAY','PAYED','NET-WORTH','INFLUENCE','SKARBOWKA','MEMBER','HAT','SIZE']].values
|
|
Y = data_input["PRIORITY"]
|
|
|
|
label_BP = preprocessing.LabelEncoder()
|
|
label_BP.fit(['CompanySize.NO', 'CompanySize.SMALL', 'CompanySize.NORMAL', 'CompanySize.BIG', 'CompanySize.HUGE', 'CompanySize.GIGANTISHE'])
|
|
X[:, 7] = label_BP.transform(X[:, 7])
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1, train_size=0.9)
|
|
|
|
drugTree = DecisionTreeClassifier(criterion="entropy", max_depth=4)
|
|
|
|
drugTree.fit(X_train, y_train)
|
|
predicted = drugTree.predict(X_test)
|
|
|
|
print(X_test)
|
|
print(predicted)
|
|
|
|
print("\nDecisionTrees's Accuracy: ", metrics.accuracy_score(y_test, predicted))
|
|
|
|
# base = 512
|
|
# gridWidth = 10
|
|
# gridHeight = 10
|
|
# scale = base / gridWidth
|
|
#
|
|
# diagram4 = GridWithWeights(gridWidth, gridHeight)
|
|
# diagram4.walls = [(6, 5), (6, 6), (6, 7), (6, 8), (2, 3), (2, 4), (3, 4), (4, 4), (6, 4)]
|
|
#
|
|
# diagram5 = GridWithWeights(gridWidth, gridHeight)
|
|
# diagram5.puddles = [(2, 2), (2, 5), (2, 6), (5, 4)]
|
|
#
|
|
# grid = CanvasGrid(agent_portrayal, gridWidth, gridHeight, scale * gridWidth, scale * gridHeight)
|
|
#
|
|
# server = ModularServer(GameModel,
|
|
# [grid],
|
|
# "Automatyczny Wózek Widłowy",
|
|
# {"width": gridHeight, "height": gridWidth, "graph": diagram4, "graph2": diagram5},)
|
|
#
|
|
# server.port = 8888
|
|
# server.launch()
|