add decision tree
This commit is contained in:
parent
8872992b6b
commit
5a6e5181e8
@ -34,7 +34,7 @@ Change sizes map in config.py
|
||||
VERTICAL_NUM_OF_FIELDS = 3
|
||||
HORIZONTAL_NUM_OF_FIELDS = 3
|
||||
```
|
||||
\
|
||||
|
||||
#### 4.1 Save generated map:
|
||||
```bash
|
||||
python main.py --save-map
|
||||
|
@ -44,6 +44,7 @@ class App:
|
||||
|
||||
if keys[pygame.K_w]:
|
||||
self.__tractor.move()
|
||||
self.__tractor.choose_action()
|
||||
print(self.__tractor)
|
||||
|
||||
if keys[pygame.K_n]:
|
||||
|
@ -66,6 +66,15 @@ class Board:
|
||||
print(f"{j} - {type(self.__fields[i][j]).__name__}", end=" | ")
|
||||
print()
|
||||
|
||||
def convert_fields_to_vectors(self) -> list[list]:
|
||||
list_of_vectors = []
|
||||
for i in range(HORIZONTAL_NUM_OF_FIELDS):
|
||||
list_of_vectors.append([])
|
||||
for j in range(VERTICAL_NUM_OF_FIELDS):
|
||||
list_of_vectors[i].append(self.__fields[i][j].transform())
|
||||
print(list_of_vectors)
|
||||
return list_of_vectors
|
||||
|
||||
def convert_fields_to_list_of_types(self) -> list:
|
||||
data = []
|
||||
for i in range(HORIZONTAL_NUM_OF_FIELDS):
|
||||
|
68
app/decision_tree.py
Normal file
68
app/decision_tree.py
Normal file
@ -0,0 +1,68 @@
|
||||
#!/usr/bin/python3
|
||||
import os
|
||||
from typing import Union
|
||||
import pydotplus
|
||||
import pandas as pd
|
||||
from joblib import dump, load
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.tree import export_graphviz, export_text
|
||||
|
||||
from app.weather import Weather
|
||||
from config import *
|
||||
|
||||
|
||||
class DecisionTree:
|
||||
WEATHER = {W_SUNNY: 0, W_CLOUDY: 1, W_SNOW: 2, W_RAINY: 3}
|
||||
SEASON = {S_AUTUMN: 0, S_WINTER: 1, S_SPRING: 2, S_SUMMER: 3}
|
||||
FEATURES = ['Season', 'Weather', 'Fertilize', 'Hydrate', 'Sow', 'Harvest', 'Action']
|
||||
|
||||
def __int__(self):
|
||||
self.tree = None
|
||||
|
||||
def learn_tree(self) -> None:
|
||||
path = os.path.join(DATA_DIR, MODEL_TREE_FILENAME)
|
||||
if os.path.exists(path):
|
||||
self.tree = load(path)
|
||||
else:
|
||||
# read data
|
||||
training_data = pd.read_csv(os.path.join(DATA_DIR, DATA_TRAINING_FOR_DECISION_TREE))
|
||||
print(training_data.head())
|
||||
|
||||
training_data = self.map_data(training_data)
|
||||
# print(training_data)
|
||||
|
||||
X = training_data[self.FEATURES[:-1]]
|
||||
Y = training_data[self.FEATURES[-1]]
|
||||
|
||||
self.tree = DecisionTreeClassifier()
|
||||
self.tree = self.tree.fit(X, Y)
|
||||
dump(self.tree, path)
|
||||
|
||||
text = export_text(self.tree, feature_names=self.FEATURES[:-1])
|
||||
print(text)
|
||||
|
||||
data = export_graphviz(self.tree, out_file=None, feature_names=self.FEATURES[:-1])
|
||||
graph = pydotplus.graph_from_dot_data(data)
|
||||
graph.write_png(os.path.join(DATA_DIR, IMG_DECISION_TREE))
|
||||
|
||||
def map_data(self, data: Union[pd.Series, pd.DataFrame]) -> Union[pd.Series, pd.DataFrame]:
|
||||
# print(data)
|
||||
data['Season'] = data['Season'].map(DecisionTree.SEASON)
|
||||
data['Weather'] = data['Weather'].map(DecisionTree.WEATHER)
|
||||
return data
|
||||
|
||||
def predict(self, vector: Union[pd.Series, pd.DataFrame]) -> str:
|
||||
print(vector)
|
||||
x = self.map_data(vector)
|
||||
action = self.tree.predict(x)
|
||||
return action
|
||||
|
||||
def make_decision(self, weather: Weather, v: list):
|
||||
s, w = weather.randomize_weather()
|
||||
tree = DecisionTree()
|
||||
tree.learn_tree()
|
||||
final_vector = [s, w] + v
|
||||
print(final_vector)
|
||||
df = pd.DataFrame([final_vector])
|
||||
df.columns = DecisionTree.FEATURES[:-1]
|
||||
return tree.predict(df)
|
@ -27,6 +27,9 @@ class Crops(Field):
|
||||
self.weight = 1.0
|
||||
self._value = VALUE_OF_CROPS
|
||||
|
||||
def transform(self) -> list:
|
||||
return [0, 0, 0, 1]
|
||||
|
||||
|
||||
class Plant(Field):
|
||||
def __init__(self, img_path: str):
|
||||
@ -34,6 +37,9 @@ class Plant(Field):
|
||||
self.is_hydrated = False
|
||||
self._value = VALUE_OF_PLANT
|
||||
|
||||
def transform(self) -> list:
|
||||
return [0, 1, 0, 0]
|
||||
|
||||
|
||||
class Clay(Soil):
|
||||
def __init__(self):
|
||||
@ -41,6 +47,9 @@ class Clay(Soil):
|
||||
self.is_fertilized = False
|
||||
self._value = VALUE_OF_CLAY
|
||||
|
||||
def transform(self) -> list:
|
||||
return [1, 0, 0, 0]
|
||||
|
||||
|
||||
class Sand(Soil):
|
||||
def __init__(self):
|
||||
@ -49,6 +58,12 @@ class Sand(Soil):
|
||||
self.is_hydrated = False
|
||||
self._value = VALUE_OF_SAND
|
||||
|
||||
def transform(self) -> list:
|
||||
if not self.is_hydrated :
|
||||
return [0, 1, 0, 0]
|
||||
else:
|
||||
return [0, 0, 1, 0]
|
||||
|
||||
|
||||
class Grass(Plant):
|
||||
def __init__(self):
|
||||
|
@ -17,6 +17,8 @@ from app.fields import CROPS, PLANTS, Crops, Sand, Clay, Field
|
||||
from config import *
|
||||
|
||||
from app.fields import Plant, Soil, Crops
|
||||
from app.decision_tree import DecisionTree
|
||||
from app.weather import Weather
|
||||
|
||||
|
||||
class Tractor(BaseField):
|
||||
@ -30,6 +32,8 @@ class Tractor(BaseField):
|
||||
self.__harvested_corps = []
|
||||
self.__fuel = 10
|
||||
self.__neural_network = None
|
||||
self.__tree = DecisionTree()
|
||||
self.__weather = Weather()
|
||||
|
||||
def draw(self, screen: pygame.Surface) -> None:
|
||||
self.draw_field(screen, self.__pos_x + FIELD_SIZE / 2, self.__pos_y + FIELD_SIZE / 2,
|
||||
@ -304,3 +308,19 @@ class Tractor(BaseField):
|
||||
time.sleep(1)
|
||||
|
||||
is_running.clear()
|
||||
|
||||
def choose_action(self) -> None:
|
||||
vectors = self.__board.convert_fields_to_vectors()
|
||||
print(vectors)
|
||||
coords = None
|
||||
action = None
|
||||
for i in range(HORIZONTAL_NUM_OF_FIELDS):
|
||||
for j in range(VERTICAL_NUM_OF_FIELDS):
|
||||
action = self.__tree.make_decision(self.__weather, vectors[i][j])
|
||||
if action != A_DO_NOTHING:
|
||||
coords = (i, j)
|
||||
break
|
||||
print(coords, action)
|
||||
if coords is not None:
|
||||
# astar coords
|
||||
pass
|
||||
|
28
app/weather.py
Normal file
28
app/weather.py
Normal file
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/python3
|
||||
import random
|
||||
from config import *
|
||||
|
||||
|
||||
class Weather:
|
||||
def __init__(self):
|
||||
self.months = (S_WINTER, S_WINTER, S_SPRING, S_SPRING, S_SPRING, S_SUMMER,
|
||||
S_SUMMER, S_SUMMER, S_AUTUMN, S_AUTUMN, S_AUTUMN, S_WINTER)
|
||||
self.current_month = 0
|
||||
|
||||
def randomize_weather(self) -> tuple[str, str]:
|
||||
season = self.months[self.current_month]
|
||||
|
||||
if season == S_WINTER:
|
||||
weather = random.choices([W_SNOW, W_CLOUDY])
|
||||
elif season == S_SUMMER:
|
||||
weights = [0.5, 0.3, 0.2]
|
||||
weather = random.choices([W_SUNNY, W_CLOUDY, W_RAINY], weights)
|
||||
elif season == S_SPRING:
|
||||
weights = [0.3, 0.5, 0.2]
|
||||
weather = random.choices([W_SUNNY, W_CLOUDY, W_RAINY], weights)
|
||||
else:
|
||||
weights = [0.2, 0.3, 0.4]
|
||||
weather = random.choices([W_SUNNY, W_CLOUDY, W_RAINY], weights)
|
||||
|
||||
self.current_month = (self.current_month + 1) % len(self.months)
|
||||
return season, weather[0]
|
28
config.py
28
config.py
@ -9,11 +9,14 @@ __all__ = (
|
||||
'SAND', 'CLAY', 'GRASS', 'CORN', 'SUNFLOWER',
|
||||
'FIELD_TYPES', 'TIME_OF_GROWING', 'AMOUNT_OF_CROPS',
|
||||
'M_GO_FORWARD', 'M_ROTATE_LEFT', 'M_ROTATE_RIGHT',
|
||||
'S_AUTUMN', 'S_SPRING', 'S_SUMMER', 'S_WINTER', 'TYPES_OF_SEASON',
|
||||
'W_SUNNY', 'W_CLOUDY', 'W_SNOW', 'W_RAINY', 'TYPES_OF_WEATHER',
|
||||
'A_SOW', 'A_HARVEST', 'A_HYDRATE', 'A_FERTILIZE', 'A_DO_NOTHING',
|
||||
'D_NORTH', 'D_EAST', 'D_SOUTH', 'D_WEST',
|
||||
'TYPES_OF_ACTION', 'D_NORTH', 'D_EAST', 'D_SOUTH', 'D_WEST',
|
||||
'VALUE_OF_CROPS', 'VALUE_OF_PLANT', 'VALUE_OF_SAND', 'VALUE_OF_CLAY',
|
||||
'MAP_FILE_NAME', 'JSON', 'SAVE_MAP', 'LOAD_MAP',
|
||||
'TRAINING_SET_DIR', 'TEST_SET_DIR', 'ADAPTED_IMG_DIR', 'MODEL_DIR'
|
||||
'TRAINING_SET_DIR', 'TEST_SET_DIR', 'ADAPTED_IMG_DIR', 'MODEL_DIR',
|
||||
'DATA_DIR','IMG_DECISION_TREE','MODEL_TREE_FILENAME','DATA_TRAINING_FOR_DECISION_TREE'
|
||||
)
|
||||
|
||||
# Board settings:
|
||||
@ -31,12 +34,17 @@ CAPTION = 'Tractor'
|
||||
BASE_DIR = os.path.dirname(__file__)
|
||||
RESOURCE_DIR = os.path.join(BASE_DIR, 'resources')
|
||||
MAP_DIR = os.path.join(BASE_DIR, 'maps')
|
||||
DATA_DIR = os.path.join(BASE_DIR, 'data')
|
||||
MAP_FILE_NAME = 'map'
|
||||
TRAINING_SET_DIR = os.path.join(RESOURCE_DIR, 'smaller_train')
|
||||
TEST_SET_DIR = os.path.join(RESOURCE_DIR, 'smaller_test')
|
||||
ADAPTED_IMG_DIR = os.path.join(RESOURCE_DIR, "adapted_images")
|
||||
MODEL_DIR = os.path.join(RESOURCE_DIR, 'saved_model')
|
||||
|
||||
MODEL_TREE_FILENAME = 'tree_model.joblib'
|
||||
IMG_DECISION_TREE = 'decision_tree.png'
|
||||
DATA_TRAINING_FOR_DECISION_TREE = 'data_training.csv'
|
||||
|
||||
# Picture format
|
||||
PNG = "png"
|
||||
|
||||
@ -75,6 +83,7 @@ A_HARVEST = "harvest"
|
||||
A_HYDRATE = "hydrate"
|
||||
A_FERTILIZE = "fertilize"
|
||||
A_DO_NOTHING = "do nothing"
|
||||
TYPES_OF_ACTION = [A_SOW, A_HARVEST, A_HYDRATE, A_FERTILIZE, A_DO_NOTHING]
|
||||
|
||||
# Costs fields:
|
||||
VALUE_OF_CROPS = 1
|
||||
@ -82,6 +91,21 @@ VALUE_OF_PLANT = 4
|
||||
VALUE_OF_SAND = 7
|
||||
VALUE_OF_CLAY = 10
|
||||
|
||||
# Weather
|
||||
W_SUNNY = 'Sunny'
|
||||
W_CLOUDY = 'Cloudy'
|
||||
W_SNOW = 'Snow'
|
||||
W_RAINY = 'Rainy'
|
||||
TYPES_OF_WEATHER = [W_SUNNY, W_CLOUDY, W_SNOW, W_RAINY]
|
||||
|
||||
# Seasons
|
||||
S_AUTUMN = 'Autumn'
|
||||
S_WINTER = 'Winter'
|
||||
S_SPRING = 'Spring'
|
||||
S_SUMMER = 'Summer'
|
||||
|
||||
TYPES_OF_SEASON = [S_AUTUMN, S_WINTER, S_SPRING, S_SUMMER]
|
||||
|
||||
# Times
|
||||
TIME_OF_GROWING = 2
|
||||
TIME_OF_MOVING = 2
|
||||
|
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
49
data/data_training.csv
Normal file
49
data/data_training.csv
Normal file
@ -0,0 +1,49 @@
|
||||
Season,Weather,Fertilize,Hydrate,Sow,Harvest,Action
|
||||
Winter,Snow,0,0,0,1,do nothing
|
||||
Winter,Snow,0,0,1,0,do nothing
|
||||
Winter,Snow,0,1,0,0,do nothing
|
||||
Winter,Snow,1,0,0,0,do nothing
|
||||
Winter,Cloudy,0,0,0,1,do nothing
|
||||
Winter,Cloudy,0,0,1,0,do nothing
|
||||
Winter,Cloudy,0,1,0,0,do nothing
|
||||
Winter,Cloudy,1,0,0,0,do nothing
|
||||
Autumn,Cloudy,0,0,0,1,harvest
|
||||
Autumn,Cloudy,0,0,1,0,do nothing
|
||||
Autumn,Cloudy,0,1,0,0,do nothing
|
||||
Autumn,Cloudy,1,0,0,0,fertilize
|
||||
Autumn,Sunny,0,0,0,1,harvest
|
||||
Autumn,Sunny,0,0,1,0,Plant
|
||||
Autumn,Sunny,0,1,0,0,hydrate
|
||||
Autumn,Sunny,1,0,0,0,fertilize
|
||||
Autumn,Rainy,0,0,0,1,harvest
|
||||
Autumn,Rainy,0,0,1,0,do nothing
|
||||
Autumn,Rainy,0,1,0,0,do nothing
|
||||
Autumn,Rainy,1,0,0,0,do nothing
|
||||
Spring,Sunny,0,0,0,1,harvest
|
||||
Spring,Sunny,0,0,1,0,do nothing
|
||||
Spring,Sunny,0,1,0,0,hydrate
|
||||
Spring,Sunny,1,0,0,0,do nothing
|
||||
Spring,Cloudy,0,0,0,1,harvest
|
||||
Spring,Cloudy,0,0,1,0,do nothing
|
||||
Spring,Cloudy,0,1,0,0,hydrate
|
||||
Spring,Cloudy,1,0,0,0,do nothing
|
||||
Spring,Rainy,0,0,0,1,harvest
|
||||
Spring,Rainy,0,0,1,0,do nothing
|
||||
Spring,Rainy,0,1,0,0,do nothing
|
||||
Spring,Rainy,1,0,0,0,do nothing
|
||||
Spring,Rainy,0,0,0,1,harvest
|
||||
Spring,Rainy,0,0,1,0,do nothing
|
||||
Spring,Rainy,0,1,0,0,do nothing
|
||||
Spring,Rainy,1,0,0,0,do nothing
|
||||
Summer,Rainy,0,0,0,1,harvest
|
||||
Summer,Rainy,0,0,1,0,do nothing
|
||||
Summer,Rainy,0,1,0,0,do nothing
|
||||
Summer,Rainy,1,0,0,0,do nothing
|
||||
Summer,Sunny,0,0,0,1,harvest
|
||||
Summer,Sunny,0,0,1,0,do nothing
|
||||
Summer,Sunny,0,1,0,0,hydrate
|
||||
Summer,Sunny,1,0,0,0,fertilize
|
||||
Summer,Cloudy,0,0,0,1,harvest
|
||||
Summer,Cloudy,0,0,1,0,do nothing
|
||||
Summer,Cloudy,0,1,0,0,hydrate
|
||||
Summer,Cloudy,1,0,0,0,fertilize
|
|
BIN
data/decision_tree.png
Normal file
BIN
data/decision_tree.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 121 KiB |
BIN
data/tree_model.joblib
Normal file
BIN
data/tree_model.joblib
Normal file
Binary file not shown.
@ -2,3 +2,7 @@ pygame==2.0.1
|
||||
tensorflow~=2.5.0
|
||||
numpy~=1.19.5
|
||||
pillow~=8.2.0
|
||||
joblib~=1.0.1
|
||||
scikit-learn~=0.24.2
|
||||
pandas~=1.2.5
|
||||
pydotplus~=2.0.2
|
Loading…
Reference in New Issue
Block a user