WMICraft/learning/decision_tree.py

79 lines
3.8 KiB
Python
Raw Permalink Normal View History

2022-05-08 20:36:33 +02:00
from typing import List, Tuple
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from common.helpers import castle_neighbors, find_neighbours
from models.castle import Castle
from models.knight import Knight
from models.monster import Monster
def manhattan_distance(p1: Tuple[int, int], p2: Tuple[int, int]) -> int:
x1, y1 = p1
x2, y2 = p2
return abs(x1 - x2) + abs(y1 - y2)
def parse_hp(hp: int) -> int:
return max(0, hp)
def parse_idx_of_opp_or_monster(s: str) -> int:
return int(s[-1]) - 1
class DecisionTree:
def __init__(self) -> None:
2022-06-09 22:31:02 +02:00
data_frame = pd.read_csv('learning/dataset_tree_1000.csv', delimiter=';')
2022-05-08 20:36:33 +02:00
unlabeled_goals = data_frame['goal']
self.goals_label_encoder = LabelEncoder()
self.goals = self.goals_label_encoder.fit_transform(unlabeled_goals)
self.train_set = data_frame.drop('goal', axis='columns')
self.model = DecisionTreeClassifier(criterion='entropy')
self.model.fit(self.train_set.values, self.goals)
2022-06-05 10:57:20 +02:00
def predict_move(self, grid: List[List[int]], current_knight: Knight, castle: Castle, monsters: List[Monster],
2022-05-08 20:36:33 +02:00
opponents: List[Knight]) -> \
List[Tuple[int, int]]:
distance_to_castle = manhattan_distance(current_knight.position, castle.position)
monsters_parsed = []
for monster in monsters:
monsters_parsed.append((manhattan_distance(current_knight.position, monster.position), parse_hp(
2022-06-05 10:57:20 +02:00
monster.health_bar.current_hp)))
2022-05-08 20:36:33 +02:00
opponents_parsed = []
for opponent in opponents:
opponents_parsed.append(
2022-06-05 10:57:20 +02:00
(manhattan_distance(current_knight.position, opponent.position),
parse_hp(opponent.health_bar.current_hp)))
2022-05-08 20:36:33 +02:00
2022-06-05 10:57:20 +02:00
prediction = self.get_prediction(tower_dist=distance_to_castle, tower_hp=castle.health_bar.current_hp,
2022-05-08 20:36:33 +02:00
mob1_dist=monsters_parsed[0][0], mob1_hp=monsters_parsed[0][1],
mob2_dist=monsters_parsed[1][0], mob2_hp=monsters_parsed[1][1],
opp1_dist=opponents_parsed[0][0], opp1_hp=opponents_parsed[0][1],
opp2_dist=opponents_parsed[1][0], opp2_hp=opponents_parsed[1][1],
opp3_dist=opponents_parsed[2][0], opp3_hp=opponents_parsed[2][1],
opp4_dist=opponents_parsed[3][0], opp4_hp=opponents_parsed[3][1],
2022-05-21 14:31:02 +02:00
agent_hp=current_knight.health_bar.current_hp)
2022-06-05 10:57:20 +02:00
print(f'Prediction = {prediction}')
2022-05-08 20:36:33 +02:00
if prediction == 'tower': # castle...
2022-05-09 18:22:01 +02:00
return castle_neighbors(grid, castle_bottom_right_row=castle.position[0],
2022-05-08 20:36:33 +02:00
castle_bottom_right_col=castle.position[1])
elif prediction.startswith('opp'):
idx = parse_idx_of_opp_or_monster(prediction)
2022-05-09 18:22:01 +02:00
return find_neighbours(grid, opponents[idx].position[1], opponents[idx].position[0])
2022-05-08 20:36:33 +02:00
else:
idx = parse_idx_of_opp_or_monster(prediction)
2022-05-09 18:22:01 +02:00
return find_neighbours(grid, monsters[idx].position[1], monsters[idx].position[0])
2022-05-08 20:36:33 +02:00
def get_prediction(self, tower_dist: int, mob1_dist: int, mob2_dist: int, opp1_dist: int, opp2_dist: int,
opp3_dist: int, opp4_dist: int, agent_hp: int, tower_hp: int, mob1_hp: int, mob2_hp: int,
opp1_hp: int, opp2_hp: int, opp3_hp: int, opp4_hp) -> str:
prediction = self.model.predict(
[[tower_dist, mob1_dist, mob2_dist, opp1_dist, opp2_dist, opp3_dist, opp4_dist, agent_hp,
tower_hp, mob1_hp, mob2_hp, opp1_hp, opp2_hp, opp3_hp, opp4_hp]])
return self.goals_label_encoder.inverse_transform(prediction)[0]