added multiple options for animals' images

This commit is contained in:
LuminoX 2024-05-26 17:58:48 +02:00
parent af90938328
commit 07270d54d4
6 changed files with 72 additions and 4 deletions

View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.9" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (PROJEKT)" project-jdk-type="Python SDK" />
</project>

View File

@ -1,8 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.12 (PROJEKT)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -5,7 +5,19 @@ from state_space_search import is_border, is_obstacle
from night import draw_night
from decision_tree import feed_decision
from constants import Constants
from classification import AnimalClassifier
const = Constants()
classes = [
"bat",
"bear",
"elephant",
"giraffe",
"owl",
"parrot",
"penguin"
]
class Agent:
def __init__(self, istate, image_path, grid_size):
self.istate = istate
@ -66,8 +78,9 @@ class Agent:
feed_animal(self, animals, goal,const)
take_food(self)
def feed_animal(self, animals, goal,const):
def feed_animal(self, animals, goal,const):
goal_x, goal_y = goal
guess = AnimalClassifier('./model/best_model.pth', classes)
if self.x == goal_x and self.y == goal_y:
for animal in animals:
if animal.x == goal_x and animal.y == goal_y:
@ -76,6 +89,7 @@ def feed_animal(self, animals, goal,const):
else:
activity_time = False
guests = random.randint(1, 15)
guess.classify(animal.image)
decision = feed_decision(animal.adult, activity_time, animal.ill, const.season, guests, animal._feed, self._dryfood, self._wetfood)
if decision != [1]:
if decision == [2]:

48
classification.py Normal file
View File

@ -0,0 +1,48 @@
import torch
import torchvision.transforms as transforms
import PIL.Image as Image
class AnimalClassifier:
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
self.classes = classes
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = torch.load(model_path)
self.model = self.model.to(self.device)
self.model = self.model.eval()
self.image_size = image_size
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
self.image_transforms = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
])
def classify(self, image_path):
image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
image = self.image_transforms(image).float()
image = image.unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(image)
_, predicted = torch.max(output.data, 1)
return self.classes[predicted.item()]
# Define the classes
classes = [
"bat",
"bear",
"elephant",
"giraffe",
"owl",
"parrot",
"penguin"
]

View File

@ -17,6 +17,10 @@ class Constants:
self.season = random.choice(["spring", "summer", "autumn", "winter"])
self.SIZE = 224
self.mean = [0.5164, 0.5147, 0.4746]
self.std = [0.2180, 0.2126, 0.2172]
def init_pygame(const):
pygame.init()
const.screen = pygame.display.set_mode(const.WINDOW_SIZE)

BIN
tree.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 578 KiB

After

Width:  |  Height:  |  Size: 636 KiB