added multiple options for animals' images
This commit is contained in:
parent
af90938328
commit
07270d54d4
@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.9" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (PROJEKT)" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -1,8 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (PROJEKT)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
14
agent.py
14
agent.py
@ -5,7 +5,19 @@ from state_space_search import is_border, is_obstacle
|
||||
from night import draw_night
|
||||
from decision_tree import feed_decision
|
||||
from constants import Constants
|
||||
from classification import AnimalClassifier
|
||||
|
||||
const = Constants()
|
||||
|
||||
classes = [
|
||||
"bat",
|
||||
"bear",
|
||||
"elephant",
|
||||
"giraffe",
|
||||
"owl",
|
||||
"parrot",
|
||||
"penguin"
|
||||
]
|
||||
class Agent:
|
||||
def __init__(self, istate, image_path, grid_size):
|
||||
self.istate = istate
|
||||
@ -68,6 +80,7 @@ class Agent:
|
||||
|
||||
def feed_animal(self, animals, goal,const):
|
||||
goal_x, goal_y = goal
|
||||
guess = AnimalClassifier('./model/best_model.pth', classes)
|
||||
if self.x == goal_x and self.y == goal_y:
|
||||
for animal in animals:
|
||||
if animal.x == goal_x and animal.y == goal_y:
|
||||
@ -76,6 +89,7 @@ def feed_animal(self, animals, goal,const):
|
||||
else:
|
||||
activity_time = False
|
||||
guests = random.randint(1, 15)
|
||||
guess.classify(animal.image)
|
||||
decision = feed_decision(animal.adult, activity_time, animal.ill, const.season, guests, animal._feed, self._dryfood, self._wetfood)
|
||||
if decision != [1]:
|
||||
if decision == [2]:
|
||||
|
48
classification.py
Normal file
48
classification.py
Normal file
@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import PIL.Image as Image
|
||||
|
||||
class AnimalClassifier:
|
||||
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
|
||||
self.classes = classes
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model = torch.load(model_path)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = self.model.eval()
|
||||
self.image_size = image_size
|
||||
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
|
||||
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize((self.image_size, self.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
|
||||
])
|
||||
|
||||
def classify(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
|
||||
image = self.image_transforms(image).float()
|
||||
image = image.unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model(image)
|
||||
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
return self.classes[predicted.item()]
|
||||
|
||||
# Define the classes
|
||||
classes = [
|
||||
"bat",
|
||||
"bear",
|
||||
"elephant",
|
||||
"giraffe",
|
||||
"owl",
|
||||
"parrot",
|
||||
"penguin"
|
||||
]
|
||||
|
||||
|
@ -17,6 +17,10 @@ class Constants:
|
||||
|
||||
self.season = random.choice(["spring", "summer", "autumn", "winter"])
|
||||
|
||||
self.SIZE = 224
|
||||
self.mean = [0.5164, 0.5147, 0.4746]
|
||||
self.std = [0.2180, 0.2126, 0.2172]
|
||||
|
||||
def init_pygame(const):
|
||||
pygame.init()
|
||||
const.screen = pygame.display.set_mode(const.WINDOW_SIZE)
|
Loading…
Reference in New Issue
Block a user