added multiple options for animals' images
This commit is contained in:
parent
af90938328
commit
07270d54d4
@ -3,5 +3,5 @@
|
|||||||
<component name="Black">
|
<component name="Black">
|
||||||
<option name="sdkName" value="Python 3.9" />
|
<option name="sdkName" value="Python 3.9" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (PROJEKT)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -1,8 +1,10 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$">
|
||||||
<orderEntry type="inheritedJdk" />
|
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.12 (PROJEKT)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
14
agent.py
14
agent.py
@ -5,7 +5,19 @@ from state_space_search import is_border, is_obstacle
|
|||||||
from night import draw_night
|
from night import draw_night
|
||||||
from decision_tree import feed_decision
|
from decision_tree import feed_decision
|
||||||
from constants import Constants
|
from constants import Constants
|
||||||
|
from classification import AnimalClassifier
|
||||||
|
|
||||||
|
const = Constants()
|
||||||
|
|
||||||
|
classes = [
|
||||||
|
"bat",
|
||||||
|
"bear",
|
||||||
|
"elephant",
|
||||||
|
"giraffe",
|
||||||
|
"owl",
|
||||||
|
"parrot",
|
||||||
|
"penguin"
|
||||||
|
]
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, istate, image_path, grid_size):
|
def __init__(self, istate, image_path, grid_size):
|
||||||
self.istate = istate
|
self.istate = istate
|
||||||
@ -68,6 +80,7 @@ class Agent:
|
|||||||
|
|
||||||
def feed_animal(self, animals, goal,const):
|
def feed_animal(self, animals, goal,const):
|
||||||
goal_x, goal_y = goal
|
goal_x, goal_y = goal
|
||||||
|
guess = AnimalClassifier('./model/best_model.pth', classes)
|
||||||
if self.x == goal_x and self.y == goal_y:
|
if self.x == goal_x and self.y == goal_y:
|
||||||
for animal in animals:
|
for animal in animals:
|
||||||
if animal.x == goal_x and animal.y == goal_y:
|
if animal.x == goal_x and animal.y == goal_y:
|
||||||
@ -76,6 +89,7 @@ def feed_animal(self, animals, goal,const):
|
|||||||
else:
|
else:
|
||||||
activity_time = False
|
activity_time = False
|
||||||
guests = random.randint(1, 15)
|
guests = random.randint(1, 15)
|
||||||
|
guess.classify(animal.image)
|
||||||
decision = feed_decision(animal.adult, activity_time, animal.ill, const.season, guests, animal._feed, self._dryfood, self._wetfood)
|
decision = feed_decision(animal.adult, activity_time, animal.ill, const.season, guests, animal._feed, self._dryfood, self._wetfood)
|
||||||
if decision != [1]:
|
if decision != [1]:
|
||||||
if decision == [2]:
|
if decision == [2]:
|
||||||
|
48
classification.py
Normal file
48
classification.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import PIL.Image as Image
|
||||||
|
|
||||||
|
class AnimalClassifier:
|
||||||
|
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
|
||||||
|
self.classes = classes
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.model = torch.load(model_path)
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.model = self.model.eval()
|
||||||
|
self.image_size = image_size
|
||||||
|
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
|
||||||
|
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
|
||||||
|
self.image_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((self.image_size, self.image_size)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
|
||||||
|
])
|
||||||
|
|
||||||
|
def classify(self, image_path):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
image = self.image_transforms(image).float()
|
||||||
|
image = image.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(image)
|
||||||
|
|
||||||
|
_, predicted = torch.max(output.data, 1)
|
||||||
|
|
||||||
|
return self.classes[predicted.item()]
|
||||||
|
|
||||||
|
# Define the classes
|
||||||
|
classes = [
|
||||||
|
"bat",
|
||||||
|
"bear",
|
||||||
|
"elephant",
|
||||||
|
"giraffe",
|
||||||
|
"owl",
|
||||||
|
"parrot",
|
||||||
|
"penguin"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
@ -17,6 +17,10 @@ class Constants:
|
|||||||
|
|
||||||
self.season = random.choice(["spring", "summer", "autumn", "winter"])
|
self.season = random.choice(["spring", "summer", "autumn", "winter"])
|
||||||
|
|
||||||
|
self.SIZE = 224
|
||||||
|
self.mean = [0.5164, 0.5147, 0.4746]
|
||||||
|
self.std = [0.2180, 0.2126, 0.2172]
|
||||||
|
|
||||||
def init_pygame(const):
|
def init_pygame(const):
|
||||||
pygame.init()
|
pygame.init()
|
||||||
const.screen = pygame.display.set_mode(const.WINDOW_SIZE)
|
const.screen = pygame.display.set_mode(const.WINDOW_SIZE)
|
Loading…
Reference in New Issue
Block a user