This commit is contained in:
Kacper 2022-06-10 03:16:43 +02:00
parent d7c2e446a3
commit 2908d62fb3
7 changed files with 51 additions and 35 deletions

View File

@ -12,7 +12,7 @@ def getPrediction(img_path, network_name):
img = Image.open(img_path)
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
classes = ['glass', 'metal', 'paper', 'plastic']
neural_net.load_state_dict(torch.load(PATH + network_name))
neural_net.load_state_dict(torch.load(PATH + network_name, map_location='cpu'))
neural_net.eval()
outputs = neural_net(transform_tensor)

View File

@ -4,46 +4,46 @@
| | | |--- class: 0
| | |--- feature_0 > 1.50
| | | |--- feature_3 <= 3.50
| | | | |--- feature_2 <= 2.50
| | | | |--- feature_4 <= 2.50
| | | | | |--- class: 1
| | | | |--- feature_2 > 2.50
| | | | | |--- feature_4 <= 2.50
| | | | |--- feature_4 > 2.50
| | | | | |--- feature_2 <= 2.50
| | | | | | |--- class: 1
| | | | | |--- feature_4 > 2.50
| | | | | |--- feature_2 > 2.50
| | | | | | |--- class: 0
| | | |--- feature_3 > 3.50
| | | | |--- feature_3 <= 4.50
| | | | | |--- feature_1 <= 2.50
| | | | | | |--- feature_0 <= 2.50
| | | | | | | |--- feature_1 <= 1.50
| | | | | | | | |--- feature_4 <= 2.50
| | | | | | | | |--- feature_2 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.50
| | | | | | | | | |--- feature_2 <= 2.00
| | | | | | | | |--- feature_2 > 2.50
| | | | | | | | | |--- feature_4 <= 2.00
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- feature_2 > 2.00
| | | | | | | | | |--- feature_4 > 2.00
| | | | | | | | | | |--- class: 0
| | | | | | | |--- feature_1 > 1.50
| | | | | | | | |--- class: 0
| | | | | | |--- feature_0 > 2.50
| | | | | | | |--- feature_2 <= 2.50
| | | | | | | |--- feature_4 <= 2.50
| | | | | | | | |--- class: 1
| | | | | | | |--- feature_2 > 2.50
| | | | | | | | |--- feature_4 <= 2.50
| | | | | | | |--- feature_4 > 2.50
| | | | | | | | |--- feature_2 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.50
| | | | | | | | |--- feature_2 > 2.50
| | | | | | | | | |--- class: 0
| | | | | |--- feature_1 > 2.50
| | | | | | |--- feature_1 <= 3.50
| | | | | | | |--- feature_0 <= 3.50
| | | | | | | | |--- class: 0
| | | | | | | |--- feature_0 > 3.50
| | | | | | | | |--- feature_4 <= 2.50
| | | | | | | | |--- feature_2 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.50
| | | | | | | | | |--- feature_2 <= 2.00
| | | | | | | | |--- feature_2 > 2.50
| | | | | | | | | |--- feature_4 <= 2.00
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- feature_2 > 2.00
| | | | | | | | | |--- feature_4 > 2.00
| | | | | | | | | | |--- class: 0
| | | | | | |--- feature_1 > 3.50
| | | | | | | |--- class: 0
@ -76,13 +76,13 @@
| |--- feature_4 <= 1.50
| | |--- feature_1 <= 1.50
| | | |--- feature_2 <= 4.50
| | | | |--- feature_3 <= 4.50
| | | | | |--- feature_0 <= 1.50
| | | | | | |--- class: 0
| | | | | |--- feature_0 > 1.50
| | | | | | |--- class: 1
| | | | |--- feature_3 > 4.50
| | | | |--- feature_0 <= 1.50
| | | | | |--- class: 0
| | | | |--- feature_0 > 1.50
| | | | | |--- feature_3 <= 4.50
| | | | | | |--- class: 1
| | | | | |--- feature_3 > 4.50
| | | | | | |--- class: 0
| | | |--- feature_2 > 4.50
| | | | |--- class: 0
| | |--- feature_1 > 1.50

Binary file not shown.

Binary file not shown.

34
main.py
View File

@ -101,10 +101,10 @@ class Game():
atrrs_container = i.get_attributes()
x, y = i.get_coords()
dec = decisionTree.decision(getTree(), *atrrs_container)
if dec[0] == 1:
self.positive_decision.append(i)
else:
self.negative_decision.append(i)
# if dec[0] == 1:
self.positive_decision.append(i) # zmiana po to by losowało wszystkie smietniki a nie poprawne tylko, zeby ladniej bylo widac algorytm genetyczny
# else:
# self.negative_decision.append(i)
print('positive actions')
print(len(self.positive_decision))
@ -114,12 +114,26 @@ class Game():
# print(i)
# print('----')
def decsion_tree_move(self):
for i in self.positive_decision:
for i in range(0,len(self.positive_decision)):
# print(i.get_coords())
print('action')
trash_x, trash_y = i.get_coords()
# trash_x, trash_y = i.get_coords()
# for ii in self.tsp_list:
temp_tsp = str(self.tsp_list[i])
temp_tsp = temp_tsp.strip("()")
temp_tsp = temp_tsp.split(",")
trash_x = int(temp_tsp[0])
trash_y = int(temp_tsp[1])
print(trash_x, trash_y)
action = a_star_controller.get_actions_for_target_coords(trash_x, trash_y, self)
print(action)
self.t.startAiController(action)
@ -127,7 +141,7 @@ class Game():
print('--rozpoczecie sortowania smietnika--')
dir = "./resources/trash_dataset/test/all"
files = os.listdir(dir)
for i in range(0, 10):
for j in range(0, 10):
random = randint(0, 48)
file = files[random]
result = prediction.getPrediction(dir + '/' + file, 'trained_nn_20.pth')
@ -155,8 +169,8 @@ class Game():
# dist = a_star.get_cost
tsp_list = TSP.geneticAlgorithmPlot(population=city_list, popSize=100, eliteSize=20, mutationRate=0.01, generations=200)
print(tsp_list)
self.tsp_list = TSP.geneticAlgorithmPlot(population=city_list, popSize=100, eliteSize=20, mutationRate=0.01, generations=200)
print(self.tsp_list)
def load_data(self):
game_folder = os.path.dirname(__file__)

View File

@ -21,7 +21,7 @@ def generate_map():
map[y][x] = 1
# generowanie smietnikow
for i in range(0, 30):
for i in range(0, 20):
x = random.randint(0, MAP_WIDTH-1)
y = random.randint(0, MAP_HEIGHT-1)
map[y][x] = 2

View File

@ -84,7 +84,9 @@ def get_rotate_change(rotationA: Rotation, rotationB: Rotation) -> int:
return int(rotationA) - int(rotationB)
# get new rotation for target_node as neighbour of start_node
def get_needed_rotation(start_node: Node, target_node: Node) -> Rotation:
def get_needed_rotation(start_node: Node or bool, target_node: Node) -> Rotation:
if(start_node == False):
return target_node.rotation
if (start_node.x - target_node.x > 0):
return Rotation.LEFT
if (start_node.x - target_node.x < 0):