diff --git a/learning/decision_tree.py b/learning/decision_tree.py index 7f92344..57ef4b2 100644 --- a/learning/decision_tree.py +++ b/learning/decision_tree.py @@ -34,7 +34,7 @@ class DecisionTree: self.model = DecisionTreeClassifier(criterion='entropy') self.model.fit(self.train_set.values, self.goals) - def predict_move(self, grid: List[List[str]], current_knight: Knight, castle: Castle, monsters: List[Monster], + def predict_move(self, grid: List[List[int]], current_knight: Knight, castle: Castle, monsters: List[Monster], opponents: List[Knight]) -> \ List[Tuple[int, int]]: distance_to_castle = manhattan_distance(current_knight.position, castle.position) @@ -42,14 +42,15 @@ class DecisionTree: monsters_parsed = [] for monster in monsters: monsters_parsed.append((manhattan_distance(current_knight.position, monster.position), parse_hp( - monster.current_hp))) + monster.health_bar.current_hp))) opponents_parsed = [] for opponent in opponents: opponents_parsed.append( - (manhattan_distance(current_knight.position, opponent.position), parse_hp(opponent.health_bar.current_hp))) + (manhattan_distance(current_knight.position, opponent.position), + parse_hp(opponent.health_bar.current_hp))) - prediction = self.get_prediction(tower_dist=distance_to_castle, tower_hp=castle.current_hp, + prediction = self.get_prediction(tower_dist=distance_to_castle, tower_hp=castle.health_bar.current_hp, mob1_dist=monsters_parsed[0][0], mob1_hp=monsters_parsed[0][1], mob2_dist=monsters_parsed[1][0], mob2_hp=monsters_parsed[1][1], opp1_dist=opponents_parsed[0][0], opp1_hp=opponents_parsed[0][1], @@ -57,7 +58,7 @@ class DecisionTree: opp3_dist=opponents_parsed[2][0], opp3_hp=opponents_parsed[2][1], opp4_dist=opponents_parsed[3][0], opp4_hp=opponents_parsed[3][1], agent_hp=current_knight.health_bar.current_hp) - print(prediction) + print(f'Prediction = {prediction}') if prediction == 'tower': # castle... return castle_neighbors(grid, castle_bottom_right_row=castle.position[0], castle_bottom_right_col=castle.position[1]) diff --git a/logic/level.py b/logic/level.py index 0ea19a8..9f43721 100644 --- a/logic/level.py +++ b/logic/level.py @@ -100,8 +100,8 @@ class Level: goal_list = self.decision_tree.predict_move(grid=self.map, current_knight=current_knight, monsters=self.list_monsters, - opponents=self.list_knights_red - if current_knight.team_alias() == 'k_r' else self.list_knights_blue, + opponents=self.list_knights_blue + if current_knight.team_alias() == 'k_r' else self.list_knights_red, castle=self.list_castles[0]) if len(goal_list) == 0: