AI-Project/survival/ai/decision_tree/decision_tree_data.py
2021-06-19 18:04:59 +02:00

125 lines
5.1 KiB
Python

import random
from typing import Dict
from survival.ai.decision_tree.decision_tree import DecisionTree
from survival.generators.resource_type import ResourceType
class TreeDataGenerator:
INV_RANGE = (1, 100)
VISIBLE = (True, False)
DISTANCE_RANGE = (3, 7)
DISTANCE_FACTOR = 0.2
COUNT = (1, 2, 3)
def generate(self, count=1000):
full_data = []
self.process(count, full_data)
self.write_data_to_file(full_data)
return full_data
def process(self, count, full_data):
for i in range(count):
# if i % 10000 == 0:
# print(i)
package = {}
# Create resource data for each resource type.
for resource in ResourceType:
package[resource] = self.create_resource_data()
# Get the resource with highest result among all generated resource types.
best_resource = self.get_best_resource(package)
# Unpack packaged resources.
(food, water, wood) = (
package[ResourceType.FOOD], package[ResourceType.WATER], package[ResourceType.WOOD])
# Create dictionary filled with data.
data = {"food_inv": food[0], 'food_visible': str(food[1]), 'food_distance': food[2],
'food_count': food[3], 'food_result': food[4],
'water_inv': water[0], 'water_visible': str(water[1]), 'water_distance': water[2],
'water_count': water[3], 'water_result': water[4],
'wood_inv': wood[0], 'wood_visible': str(wood[1]), 'wood_distance': wood[2],
'wood_count': wood[3], 'wood_result': wood[4],
'result': best_resource.name.lower()}
full_data.append(data)
@staticmethod
def write_data_to_file(full_data):
print("Writing to file...")
# Open the target file to which the data will be saved and write all the data to it.
with open('tree_data.json', 'w') as f:
for data in full_data:
data_str = str(data).replace("'", '"').replace('"False"', 'false').replace('"True"', 'true')
f.write(data_str)
f.write('\n')
print("Success!")
def create_resource_data(self):
is_visible = random.choice(self.VISIBLE)
inventory = random.randint(min(self.INV_RANGE), max(self.INV_RANGE))
if is_visible:
cnt = random.choice(self.COUNT)
distance = random.randint(min(self.DISTANCE_RANGE), max(self.DISTANCE_RANGE))
else:
cnt = 0
distance = 0
# Equation determining the results processed by decision tree.
result = (self.INV_RANGE[1] / inventory) * (1 * cnt if is_visible else 0.9) + (
max(self.DISTANCE_RANGE) / distance if is_visible else 0.5) * self.DISTANCE_FACTOR
return [inventory, is_visible, distance, cnt, result]
@staticmethod
def get_best_resource(package: Dict) -> ResourceType:
best_resource = None
for resource, data in package.items():
if best_resource is None or data[:-1] < package[best_resource][:-1]:
best_resource = resource
return best_resource
@staticmethod
def print_data(full_data):
for data in full_data:
print(TreeDataGenerator.format_words(["Data", "Apple", "Water", "Wood"]))
print(TreeDataGenerator.format_words(["Inventory", data["food_inv"], data["water_inv"], data["wood_inv"]]))
print(TreeDataGenerator.format_words(
["Visible", data["food_visible"], data["water_visible"], data["wood_visible"]]))
print(TreeDataGenerator.format_words(
["Distance", data["food_distance"], data["water_distance"], data["wood_distance"]]))
print(
TreeDataGenerator.format_words(["Count", data["food_count"], data["water_count"], data["wood_count"]]))
print(TreeDataGenerator.format_words(
["Result", round(data["food_result"], 3), round(data["water_result"], 3),
round(data["wood_result"], 3)]))
print(f'Best resource: {data["result"]}')
print('--------------------------------------------------------------')
@staticmethod
def format_words(words):
return '{:>12} {:>12} {:>12} {:>12}'.format(words[0], words[1], words[2], words[3])
# Train tree
generator = TreeDataGenerator()
data = generator.generate(50000)
generator.print_data(data)
tree = DecisionTree()
tree.build(1000)
tree.plot_tree()
tree.save_model('classifier.joblib', 'vectorizer.joblib')
# ----------------------------------------------------------- #
# Use trained tree
# tree = DecisionTree()
# tree.load_model('classifier.joblib', 'vectorizer.joblib')
#
# answ = tree.predict_answer({'food_inv': 40, 'water_inv': 10, 'wood_inv': 20,
# 'food_distance': 2, 'water_distance': -1, 'wood_distance': 4,
# 'food_visible': True, 'water_visible': False, 'wood_visible': True,
# 'food_count': 1, 'water_count': 1, 'wood_count': 1})
# print(answ)