2022-05-08 20:36:33 +02:00
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
from sklearn.preprocessing import LabelEncoder
|
|
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
|
|
|
|
|
|
from common.helpers import castle_neighbors, find_neighbours
|
|
|
|
from models.castle import Castle
|
|
|
|
from models.knight import Knight
|
|
|
|
from models.monster import Monster
|
|
|
|
|
|
|
|
|
|
|
|
def manhattan_distance(p1: Tuple[int, int], p2: Tuple[int, int]) -> int:
|
|
|
|
x1, y1 = p1
|
|
|
|
x2, y2 = p2
|
|
|
|
return abs(x1 - x2) + abs(y1 - y2)
|
|
|
|
|
|
|
|
|
|
|
|
def parse_hp(hp: int) -> int:
|
|
|
|
return max(0, hp)
|
|
|
|
|
|
|
|
|
|
|
|
def parse_idx_of_opp_or_monster(s: str) -> int:
|
|
|
|
return int(s[-1]) - 1
|
|
|
|
|
|
|
|
|
|
|
|
class DecisionTree:
|
|
|
|
def __init__(self) -> None:
|
2022-06-09 22:31:02 +02:00
|
|
|
data_frame = pd.read_csv('learning/dataset_tree_1000.csv', delimiter=';')
|
2022-05-08 20:36:33 +02:00
|
|
|
unlabeled_goals = data_frame['goal']
|
|
|
|
self.goals_label_encoder = LabelEncoder()
|
|
|
|
self.goals = self.goals_label_encoder.fit_transform(unlabeled_goals)
|
|
|
|
self.train_set = data_frame.drop('goal', axis='columns')
|
|
|
|
self.model = DecisionTreeClassifier(criterion='entropy')
|
|
|
|
self.model.fit(self.train_set.values, self.goals)
|
|
|
|
|
2022-06-05 10:57:20 +02:00
|
|
|
def predict_move(self, grid: List[List[int]], current_knight: Knight, castle: Castle, monsters: List[Monster],
|
2022-05-08 20:36:33 +02:00
|
|
|
opponents: List[Knight]) -> \
|
|
|
|
List[Tuple[int, int]]:
|
|
|
|
distance_to_castle = manhattan_distance(current_knight.position, castle.position)
|
|
|
|
|
|
|
|
monsters_parsed = []
|
|
|
|
for monster in monsters:
|
|
|
|
monsters_parsed.append((manhattan_distance(current_knight.position, monster.position), parse_hp(
|
2022-06-05 10:57:20 +02:00
|
|
|
monster.health_bar.current_hp)))
|
2022-05-08 20:36:33 +02:00
|
|
|
|
|
|
|
opponents_parsed = []
|
|
|
|
for opponent in opponents:
|
|
|
|
opponents_parsed.append(
|
2022-06-05 10:57:20 +02:00
|
|
|
(manhattan_distance(current_knight.position, opponent.position),
|
|
|
|
parse_hp(opponent.health_bar.current_hp)))
|
2022-05-08 20:36:33 +02:00
|
|
|
|
2022-06-05 10:57:20 +02:00
|
|
|
prediction = self.get_prediction(tower_dist=distance_to_castle, tower_hp=castle.health_bar.current_hp,
|
2022-05-08 20:36:33 +02:00
|
|
|
mob1_dist=monsters_parsed[0][0], mob1_hp=monsters_parsed[0][1],
|
|
|
|
mob2_dist=monsters_parsed[1][0], mob2_hp=monsters_parsed[1][1],
|
|
|
|
opp1_dist=opponents_parsed[0][0], opp1_hp=opponents_parsed[0][1],
|
|
|
|
opp2_dist=opponents_parsed[1][0], opp2_hp=opponents_parsed[1][1],
|
|
|
|
opp3_dist=opponents_parsed[2][0], opp3_hp=opponents_parsed[2][1],
|
|
|
|
opp4_dist=opponents_parsed[3][0], opp4_hp=opponents_parsed[3][1],
|
2022-05-21 14:31:02 +02:00
|
|
|
agent_hp=current_knight.health_bar.current_hp)
|
2022-06-05 10:57:20 +02:00
|
|
|
print(f'Prediction = {prediction}')
|
2022-05-08 20:36:33 +02:00
|
|
|
if prediction == 'tower': # castle...
|
2022-05-09 18:22:01 +02:00
|
|
|
return castle_neighbors(grid, castle_bottom_right_row=castle.position[0],
|
2022-05-08 20:36:33 +02:00
|
|
|
castle_bottom_right_col=castle.position[1])
|
|
|
|
elif prediction.startswith('opp'):
|
|
|
|
idx = parse_idx_of_opp_or_monster(prediction)
|
2022-05-09 18:22:01 +02:00
|
|
|
return find_neighbours(grid, opponents[idx].position[1], opponents[idx].position[0])
|
2022-05-08 20:36:33 +02:00
|
|
|
else:
|
|
|
|
idx = parse_idx_of_opp_or_monster(prediction)
|
2022-05-09 18:22:01 +02:00
|
|
|
return find_neighbours(grid, monsters[idx].position[1], monsters[idx].position[0])
|
2022-05-08 20:36:33 +02:00
|
|
|
|
|
|
|
def get_prediction(self, tower_dist: int, mob1_dist: int, mob2_dist: int, opp1_dist: int, opp2_dist: int,
|
|
|
|
opp3_dist: int, opp4_dist: int, agent_hp: int, tower_hp: int, mob1_hp: int, mob2_hp: int,
|
|
|
|
opp1_hp: int, opp2_hp: int, opp3_hp: int, opp4_hp) -> str:
|
|
|
|
prediction = self.model.predict(
|
|
|
|
[[tower_dist, mob1_dist, mob2_dist, opp1_dist, opp2_dist, opp3_dist, opp4_dist, agent_hp,
|
|
|
|
tower_hp, mob1_hp, mob2_hp, opp1_hp, opp2_hp, opp3_hp, opp4_hp]])
|
|
|
|
return self.goals_label_encoder.inverse_transform(prediction)[0]
|