import random from typing import Dict from survival.ai.decision_tree.decision_tree import DecisionTree from survival.generators.resource_type import ResourceType class TreeDataGenerator: INV_RANGE = (1, 100) VISIBLE = (True, False) DISTANCE_RANGE = (3, 7) DISTANCE_FACTOR = 0.2 COUNT = (1, 2, 3) def generate(self, count=1000): full_data = [] self.process(count, full_data) self.write_data_to_file(full_data) return full_data def process(self, count, full_data): for i in range(count): # if i % 10000 == 0: # print(i) package = {} # Create resource data for each resource type. for resource in ResourceType: package[resource] = self.create_resource_data() # Get the resource with highest result among all generated resource types. best_resource = self.get_best_resource(package) # Unpack packaged resources. (food, water, wood) = ( package[ResourceType.FOOD], package[ResourceType.WATER], package[ResourceType.WOOD]) # Create dictionary filled with data. data = {"food_inv": food[0], 'food_visible': str(food[1]), 'food_distance': food[2], 'food_count': food[3], 'food_result': food[4], 'water_inv': water[0], 'water_visible': str(water[1]), 'water_distance': water[2], 'water_count': water[3], 'water_result': water[4], 'wood_inv': wood[0], 'wood_visible': str(wood[1]), 'wood_distance': wood[2], 'wood_count': wood[3], 'wood_result': wood[4], 'result': best_resource.name.lower()} full_data.append(data) @staticmethod def write_data_to_file(full_data): print("Writing to file...") # Open the target file to which the data will be saved and write all the data to it. with open('tree_data.json', 'w') as f: for data in full_data: data_str = str(data).replace("'", '"').replace('"False"', 'false').replace('"True"', 'true') f.write(data_str) f.write('\n') print("Success!") def create_resource_data(self): is_visible = random.choice(self.VISIBLE) inventory = random.randint(min(self.INV_RANGE), max(self.INV_RANGE)) if is_visible: cnt = random.choice(self.COUNT) distance = random.randint(min(self.DISTANCE_RANGE), max(self.DISTANCE_RANGE)) else: cnt = 0 distance = 0 # Equation determining the results processed by decision tree. result = (self.INV_RANGE[1] / inventory) * (1 * cnt if is_visible else 0.9) + ( max(self.DISTANCE_RANGE) / distance if is_visible else 0.5) * self.DISTANCE_FACTOR return [inventory, is_visible, distance, cnt, result] @staticmethod def get_best_resource(package: Dict) -> ResourceType: best_resource = None for resource, data in package.items(): if best_resource is None or data[:-1] < package[best_resource][:-1]: best_resource = resource return best_resource @staticmethod def print_data(full_data): for data in full_data: print(TreeDataGenerator.format_words(["Data", "Apple", "Water", "Wood"])) print(TreeDataGenerator.format_words(["Inventory", data["food_inv"], data["water_inv"], data["wood_inv"]])) print(TreeDataGenerator.format_words( ["Visible", data["food_visible"], data["water_visible"], data["wood_visible"]])) print(TreeDataGenerator.format_words( ["Distance", data["food_distance"], data["water_distance"], data["wood_distance"]])) print( TreeDataGenerator.format_words(["Count", data["food_count"], data["water_count"], data["wood_count"]])) print(TreeDataGenerator.format_words( ["Result", round(data["food_result"], 3), round(data["water_result"], 3), round(data["wood_result"], 3)])) print(f'Best resource: {data["result"]}') print('--------------------------------------------------------------') @staticmethod def format_words(words): return '{:>12} {:>12} {:>12} {:>12}'.format(words[0], words[1], words[2], words[3]) # Train tree generator = TreeDataGenerator() data = generator.generate(50000) generator.print_data(data) tree = DecisionTree() tree.build(1000) tree.plot_tree() tree.save_model('classifier.joblib', 'vectorizer.joblib') # ----------------------------------------------------------- # # Use trained tree # tree = DecisionTree() # tree.load_model('classifier.joblib', 'vectorizer.joblib') # # answ = tree.predict_answer({'food_inv': 40, 'water_inv': 10, 'wood_inv': 20, # 'food_distance': 2, 'water_distance': -1, 'wood_distance': 4, # 'food_visible': True, 'water_visible': False, 'wood_visible': True, # 'food_count': 1, 'water_count': 1, 'wood_count': 1}) # print(answ)