125 lines
5.1 KiB
Python
125 lines
5.1 KiB
Python
|
import random
|
||
|
from typing import Dict
|
||
|
|
||
|
from survival.ai.decision_tree.decision_tree import DecisionTree
|
||
|
from survival.generators.resource_type import ResourceType
|
||
|
|
||
|
|
||
|
class TreeDataGenerator:
|
||
|
INV_RANGE = (1, 100)
|
||
|
VISIBLE = (True, False)
|
||
|
DISTANCE_RANGE = (3, 7)
|
||
|
DISTANCE_FACTOR = 0.2
|
||
|
COUNT = (1, 2, 3)
|
||
|
|
||
|
def generate(self, count=1000):
|
||
|
full_data = []
|
||
|
self.process(count, full_data)
|
||
|
|
||
|
self.write_data_to_file(full_data)
|
||
|
return full_data
|
||
|
|
||
|
def process(self, count, full_data):
|
||
|
for i in range(count):
|
||
|
# if i % 10000 == 0:
|
||
|
# print(i)
|
||
|
package = {}
|
||
|
|
||
|
# Create resource data for each resource type.
|
||
|
for resource in ResourceType:
|
||
|
package[resource] = self.create_resource_data()
|
||
|
|
||
|
# Get the resource with highest result among all generated resource types.
|
||
|
best_resource = self.get_best_resource(package)
|
||
|
|
||
|
# Unpack packaged resources.
|
||
|
(food, water, wood) = (
|
||
|
package[ResourceType.FOOD], package[ResourceType.WATER], package[ResourceType.WOOD])
|
||
|
|
||
|
# Create dictionary filled with data.
|
||
|
data = {"food_inv": food[0], 'food_visible': str(food[1]), 'food_distance': food[2],
|
||
|
'food_count': food[3], 'food_result': food[4],
|
||
|
'water_inv': water[0], 'water_visible': str(water[1]), 'water_distance': water[2],
|
||
|
'water_count': water[3], 'water_result': water[4],
|
||
|
'wood_inv': wood[0], 'wood_visible': str(wood[1]), 'wood_distance': wood[2],
|
||
|
'wood_count': wood[3], 'wood_result': wood[4],
|
||
|
'result': best_resource.name.lower()}
|
||
|
full_data.append(data)
|
||
|
|
||
|
@staticmethod
|
||
|
def write_data_to_file(full_data):
|
||
|
print("Writing to file...")
|
||
|
# Open the target file to which the data will be saved and write all the data to it.
|
||
|
with open('tree_data.json', 'w') as f:
|
||
|
for data in full_data:
|
||
|
data_str = str(data).replace("'", '"').replace('"False"', 'false').replace('"True"', 'true')
|
||
|
f.write(data_str)
|
||
|
f.write('\n')
|
||
|
print("Success!")
|
||
|
|
||
|
def create_resource_data(self):
|
||
|
is_visible = random.choice(self.VISIBLE)
|
||
|
inventory = random.randint(min(self.INV_RANGE), max(self.INV_RANGE))
|
||
|
|
||
|
if is_visible:
|
||
|
cnt = random.choice(self.COUNT)
|
||
|
distance = random.randint(min(self.DISTANCE_RANGE), max(self.DISTANCE_RANGE))
|
||
|
else:
|
||
|
cnt = 0
|
||
|
distance = 0
|
||
|
|
||
|
# Equation determining the results processed by decision tree.
|
||
|
result = (self.INV_RANGE[1] / inventory) * (1 * cnt if is_visible else 0.9) + (
|
||
|
max(self.DISTANCE_RANGE) / distance if is_visible else 0.5) * self.DISTANCE_FACTOR
|
||
|
|
||
|
return [inventory, is_visible, distance, cnt, result]
|
||
|
|
||
|
@staticmethod
|
||
|
def get_best_resource(package: Dict) -> ResourceType:
|
||
|
best_resource = None
|
||
|
for resource, data in package.items():
|
||
|
if best_resource is None or data[:-1] < package[best_resource][:-1]:
|
||
|
best_resource = resource
|
||
|
return best_resource
|
||
|
|
||
|
@staticmethod
|
||
|
def print_data(full_data):
|
||
|
for data in full_data:
|
||
|
print(TreeDataGenerator.format_words(["Data", "Apple", "Water", "Wood"]))
|
||
|
print(TreeDataGenerator.format_words(["Inventory", data["food_inv"], data["water_inv"], data["wood_inv"]]))
|
||
|
print(TreeDataGenerator.format_words(
|
||
|
["Visible", data["food_visible"], data["water_visible"], data["wood_visible"]]))
|
||
|
print(TreeDataGenerator.format_words(
|
||
|
["Distance", data["food_distance"], data["water_distance"], data["wood_distance"]]))
|
||
|
print(
|
||
|
TreeDataGenerator.format_words(["Count", data["food_count"], data["water_count"], data["wood_count"]]))
|
||
|
print(TreeDataGenerator.format_words(
|
||
|
["Result", round(data["food_result"], 3), round(data["water_result"], 3),
|
||
|
round(data["wood_result"], 3)]))
|
||
|
print(f'Best resource: {data["result"]}')
|
||
|
print('--------------------------------------------------------------')
|
||
|
|
||
|
@staticmethod
|
||
|
def format_words(words):
|
||
|
return '{:>12} {:>12} {:>12} {:>12}'.format(words[0], words[1], words[2], words[3])
|
||
|
|
||
|
|
||
|
# Train tree
|
||
|
generator = TreeDataGenerator()
|
||
|
data = generator.generate(50000)
|
||
|
generator.print_data(data)
|
||
|
tree = DecisionTree()
|
||
|
tree.build(1000)
|
||
|
tree.plot_tree()
|
||
|
tree.save_model('classifier.joblib', 'vectorizer.joblib')
|
||
|
# ----------------------------------------------------------- #
|
||
|
# Use trained tree
|
||
|
# tree = DecisionTree()
|
||
|
# tree.load_model('classifier.joblib', 'vectorizer.joblib')
|
||
|
#
|
||
|
# answ = tree.predict_answer({'food_inv': 40, 'water_inv': 10, 'wood_inv': 20,
|
||
|
# 'food_distance': 2, 'water_distance': -1, 'wood_distance': 4,
|
||
|
# 'food_visible': True, 'water_visible': False, 'wood_visible': True,
|
||
|
# 'food_count': 1, 'water_count': 1, 'wood_count': 1})
|
||
|
# print(answ)
|