diff --git a/src/machine_learning/data_set.py b/src/machine_learning/data_set.py index a343751..a1ce059 100644 --- a/src/machine_learning/data_set.py +++ b/src/machine_learning/data_set.py @@ -1,3 +1,4 @@ +from typing import List from itertools import product visibility = ('bad', 'medium', 'good') @@ -9,7 +10,7 @@ pressure_gt_two = (True, False) attributes = ['visibility', 'stability', 'ground', 'mine_type', 'armed', 'pressure_gt_two'] -def generate_data_set(): +def generate_data_set() -> List[dict]: data_list = list(product(visibility, stability, ground, mine_type, armed, pressure_gt_two)) data = [ { diff --git a/src/machine_learning/decision_tree.py b/src/machine_learning/decision_tree.py index b8ba973..f50dd4c 100644 --- a/src/machine_learning/decision_tree.py +++ b/src/machine_learning/decision_tree.py @@ -60,7 +60,7 @@ def tree_learn(examples: List[Dict], attributes: List[str], default_class: str) return root -def get_decision(data, root): +def get_decision(data: dict, root: Node) -> str: while root.children: for children in root.children: if data[root.name] == children.edge: diff --git a/src/tile.py b/src/tile.py index 1c726f1..3c8d3e3 100644 --- a/src/tile.py +++ b/src/tile.py @@ -2,13 +2,20 @@ from typing import Union from ap_mine import APMine from at_mine import ATMine +from adm_mine import ADMMine class Tile: - def __init__(self, number: int, weight: int, visibility, stability, ground, mine: Union[None, APMine, ATMine] = None): - self.number = number - self.mine = mine - self.weight = weight - self.visibility = visibility - self.stability = stability - self.ground = ground + def __init__(self, + number: int, + weight: int, + visibility: str, + stability: str, + ground: str, + mine: Union[None, APMine, ATMine, ADMMine] = None): + self.number = number + self.mine = mine + self.weight = weight + self.visibility = visibility + self.stability = stability + self.ground = ground