Compare commits

..

No commits in common. "master" and "Drzewa-decyzyjne" have entirely different histories.

45 changed files with 46 additions and 423 deletions

View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.9" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (pythonProject)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
</project>

View File

@ -1,10 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.10 (pythonProject)" jdkType="Python SDK" />
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -3,22 +3,11 @@ import pygame
from abc import abstractmethod
class Animal:
def choose_picture(self, name):
ran = random.randint(0, 1)
if ran == 0:
path = f'images/{name}.png'
return path
else:
path = f'images/{name}2.png'
return path
def __init__(self, x, y,name, image_path, food_image, food, environment, activity, ill=False, adult=False,):
def __init__(self, x, y,name, image, food_image, food, environment, activity, ill=False, adult=False,):
self.x = x - 1
self.y = y - 1
self.name = name
self.image_path = image_path
self.image = pygame.image.load(image_path)
self.image = image
self.adult = adult
self.food = food
self.food_image = food_image
@ -74,13 +63,6 @@ class Animal:
illness_image = pygame.transform.scale(illness_image, (int(grid_size * scale), int(grid_size * scale)))
screen.blit(illness_image, (x_blit, y * grid_size))
def draw_snack(self, screen, grid_size, x, y):
exclamation_image = pygame.image.load(self.food_image)
exclamation_image = pygame.transform.scale(exclamation_image, (int(grid_size * 0.45), int(grid_size * 0.45)))
screen.blit(exclamation_image, (x * grid_size, y * grid_size))
pygame.display.update()
pygame.time.wait(700)
@abstractmethod
def getting_hungry(self):
pass

View File

@ -4,13 +4,13 @@ from datetime import datetime
class Bat(Animal):
def __init__(self, x, y, adult=False):
Bat_image = pygame.image.load('images/bat.png')
name = 'bat'
image_path = self.choose_picture(name)
environment = "medium"
food_image = 'images/grains.png'
parrot_food = 'grains'
activity = 'nocturnal'
super().__init__(x, y,name, image_path, food_image,parrot_food, environment, adult)
super().__init__(x, y,name, Bat_image, food_image,parrot_food, environment, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,14 +4,14 @@ from datetime import datetime
class Bear(Animal):
def __init__(self, x, y, adult=False):
Bear_image = pygame.image.load('images/bear.png')
name = 'bear'
image_path = self.choose_picture(name)
environment = "cold"
activity = 'nocturnal'
ill = self.is_ill()
bear_food = 'meat'
food_image = 'images/meat.png'
super().__init__(x, y,name, image_path, food_image,bear_food,environment, activity, ill, adult)
super().__init__(x, y,name, Bear_image, food_image,bear_food,environment, activity, ill, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,8 +4,8 @@ from datetime import datetime
class Elephant(Animal):
def __init__(self, x, y, adult=False):
Elephant_image = pygame.image.load('images/elephant.png')
name = 'elephant'
image_path = self.choose_picture(name)
environment = "hot"
activity = 'diurnal'
ill = self.is_ill()
@ -16,7 +16,7 @@ class Elephant(Animal):
elephant_food = 'milk'
food_image = 'images/milk.png'
super().__init__(x, y,name, image_path, food_image,elephant_food, environment, activity, ill, adult)
super().__init__(x, y,name, Elephant_image, food_image,elephant_food, environment, activity, ill, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,14 +4,14 @@ from datetime import datetime
class Giraffe(Animal):
def __init__(self, x, y, adult=False):
Giraffe_image = pygame.image.load('images/giraffe.png')
name = 'giraffe'
image_path = self.choose_picture(name)
environment = "hot"
activity = 'diurnal'
ill = self.is_ill()
food_image = 'images/leaves.png'
giraffe_food = 'leaves'
super().__init__(x, y, name, image_path, food_image,giraffe_food, environment, activity, ill, adult)
super().__init__(x, y, name, Giraffe_image, food_image,giraffe_food, environment, activity, ill, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,13 +4,13 @@ from datetime import datetime
class Owl(Animal):
def __init__(self, x, y, adult=False):
Owl_image = pygame.image.load('images/owl.png')
name = 'owl'
image_path = self.choose_picture(name)
environment = "medium"
food_image = 'images/grains.png'
parrot_food = 'grains'
activity = 'nocturnal'
super().__init__(x, y,name, image_path, food_image,parrot_food, environment, adult)
super().__init__(x, y,name, Owl_image, food_image,parrot_food, environment, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,14 +4,14 @@ from datetime import datetime
class Parrot(Animal):
def __init__(self, x, y, adult=False):
Parrot_image = pygame.image.load('images/parrot.png')
name = 'parrot'
image_path = self.choose_picture(name)
environment = "medium"
activity = 'diurnal'
ill = self.is_ill()
food_image = 'images/grains.png'
parrot_food = 'grains'
super().__init__(x, y, name, image_path, food_image, parrot_food, environment, activity, ill, adult)
super().__init__(x, y, name, Parrot_image, food_image, parrot_food, environment, activity, ill, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -4,14 +4,14 @@ from datetime import datetime
class Penguin(Animal):
def __init__(self, x, y, adult=False):
Penguin_image = pygame.image.load('images/penguin.png')
name = 'penguin'
image_path = self.choose_picture(name)
environment = "cold"
activity = 'diurnal'
ill = self.is_ill()
food_image = 'images/fish.png'
penguin_food = 'fish'
super().__init__(x, y, name, image_path, food_image, penguin_food, environment, activity, ill, adult)
super().__init__(x, y, name, Penguin_image, food_image, penguin_food, environment, activity, ill, adult)
self._starttime = datetime.now()
def getting_hungry(self, const):

View File

@ -5,19 +5,7 @@ from state_space_search import is_border, is_obstacle
from night import draw_night
from decision_tree import feed_decision
from constants import Constants
from classification import AnimalClassifier
const = Constants()
classes = [
"bat",
"bear",
"elephant",
"giraffe",
"owl",
"parrot",
"penguin"
]
class Agent:
def __init__(self, istate, image_path, grid_size):
self.istate = istate
@ -78,9 +66,8 @@ class Agent:
feed_animal(self, animals, goal,const)
take_food(self)
def feed_animal(self, animals, goal,const):
def feed_animal(self, animals, goal,const):
goal_x, goal_y = goal
neuron = AnimalClassifier('./model/best_model.pth', classes)
if self.x == goal_x and self.y == goal_y:
for animal in animals:
if animal.x == goal_x and animal.y == goal_y:
@ -89,12 +76,6 @@ def feed_animal(self, animals, goal,const):
else:
activity_time = False
guests = random.randint(1, 15)
guess = neuron.classify(animal.image_path)
if guess == animal.name:
print(f"I'm sure this is {guess} and i give it {animal.food} as a snack")
animal.draw_snack(const.screen, const.GRID_SIZE, animal.x, animal.y)
else:
print(f"I was wrong, this is not a {guess} but a {animal.name}")
decision = feed_decision(animal.adult, activity_time, animal.ill, const.season, guests, animal._feed, self._dryfood, self._wetfood)
if decision != [1]:
if decision == [2]:

View File

@ -1,47 +0,0 @@
import torch
import torchvision.transforms as transforms
import PIL.Image as Image
class AnimalClassifier:
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
self.classes = classes
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = torch.load(model_path, map_location=torch.device('cpu'))
self.model = self.model.to(self.device)
self.model = self.model.eval()
self.image_size = image_size
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
self.image_transforms = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
])
def classify(self, image_path):
image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
image = self.image_transforms(image).float()
image = image.unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(image)
_, predicted = torch.max(output.data, 1)
return self.classes[predicted.item()]
classes = [
"bat",
"bear",
"elephant",
"giraffe",
"owl",
"parrot",
"penguin"
]

View File

@ -6,7 +6,7 @@ class Constants:
def __init__(self):
self.BLACK = (0, 0, 0)
self.RED = (255, 0, 0)
self.GRID_SIZE = 65
self.GRID_SIZE = 50
self.GRID_WIDTH = 30
self.GRID_HEIGHT = 15
self.WINDOW_SIZE = (self.GRID_WIDTH * self.GRID_SIZE, self.GRID_HEIGHT * self.GRID_SIZE)
@ -17,10 +17,6 @@ class Constants:
self.season = random.choice(["spring", "summer", "autumn", "winter"])
self.SIZE = 224
self.mean = [0.5164, 0.5147, 0.4746]
self.std = [0.2180, 0.2126, 0.2172]
def init_pygame(const):
pygame.init()
const.screen = pygame.display.set_mode(const.WINDOW_SIZE)

View File

@ -1,148 +0,0 @@
from state_space_search import graphsearch, generate_cost_map
import random
# Parametry algorytmu genetycznego
POPULATION_SIZE = 700
MUTATION_RATE = 0.01
NUM_GENERATIONS = 600
# Generowanie początkowej populacji
def generate_individual(animals):
return random.sample(animals, len(animals))
def generate_population(animals, size):
return [generate_individual(animals) for _ in range(size)]
# Obliczanie odległości między zwierzetami
def calculate_distance(animal1, animal2):
x1, y1 = animal1
x2, y2 = animal2
return abs(x1 - x2) + abs(y1 - y2) # Odległość Manhattana
def calculate_total_distance(animals):
total_distance = 0
for i in range(len(animals) - 1):
total_distance += calculate_distance(animals[i], animals[i+1])
total_distance += calculate_distance(animals[-1], animals[0]) # Zamknięcie cyklu
return total_distance
# Selekcja rodziców za pomocą metody ruletki
def select_parents(population, num_parents):
fitness_scores = [1 / calculate_total_distance(individual) for individual in population]
total_fitness = sum(fitness_scores)
selection_probs = [fitness / total_fitness for fitness in fitness_scores]
parents = random.choices(population, weights=selection_probs, k=num_parents)
return parents
# Krzyżowanie rodziców (OX,Davis)
def crossover(parent1, parent2):
child1 = [None] * len(parent1)
child2 = [None] * len(parent1)
start_index = random.randint(0, len(parent1) - 1)
end_index = random.randint(start_index, len(parent1) - 1)
child1[start_index:end_index+1] = parent1[start_index:end_index+1]
child2[start_index:end_index+1] = parent2[start_index:end_index+1]
# Uzupełnienie brakujących zwierząt z drugiego rodzica
for i in range(len(parent1)):
if parent2[i] not in child1:
for j in range(len(parent2)):
if child1[j] is None:
child1[j] = parent2[i]
break
for i in range(len(parent1)):
if parent1[i] not in child2:
for j in range(len(parent1)):
if child2[j] is None:
child2[j] = parent1[i]
break
return child1, child2
# Mutacja: zamiana dwóch losowych zwierząt z prawdopodobieństwem MUTATION_RATE
def mutate(individual):
if random.random() < MUTATION_RATE:
index1, index2 = random.sample(range(len(individual)), 2)
individual[index1], individual[index2] = individual[index2], individual[index1]
# Algorytm genetyczny
def genetic_algorithm(animals):
population = generate_population(animals, POPULATION_SIZE)
for generation in range(NUM_GENERATIONS):
# Selekcja rodziców
parents = select_parents(population, POPULATION_SIZE // 2)
# Krzyżowanie i tworzenie nowej populacji
next_generation = []
for i in range(0, len(parents), 2):
parent1 = parents[i]
if i + 1 < len(parents):
parent2 = parents[i + 1]
else:
parent2 = parents[0]
child1, child2 = crossover(parent1, parent2)
next_generation.extend([child1, child2])
# Mutacja nowej populacji
for individual in next_generation:
mutate(individual)
# Zastąpienie starej populacji nową
population = next_generation
# Znalezienie najlepszego osobnika
best_individual = min(population, key=calculate_total_distance)
return best_individual
# def calculate_distance(start, goal, max_x, max_y, obstacles, cost_map):
# istate = (start[0], start[1], 'N') # Zakładamy, że zaczynamy od kierunku północnego
# actions, cost = graphsearch(istate, goal, max_x, max_y, obstacles, cost_map)
# return cost
# def calculate_total_distance(animals, max_x, max_y, obstacles, cost_map):
# total_distance = 0
# for i in range(len(animals) - 1):
# total_distance += calculate_distance(animals[i], animals[i+1], max_x, max_y, obstacles, cost_map)
# total_distance += calculate_distance(animals[-1], animals[0], max_x, max_y, obstacles, cost_map) # Zamknięcie cyklu
# return total_distance
# # Selekcja rodziców za pomocą metody ruletki
# def select_parents(population, num_parents, max_x, max_y, obstacles, cost_map):
# fitness_scores = [1 / calculate_total_distance(individual, max_x, max_y, obstacles, cost_map) for individual in population]
# total_fitness = sum(fitness_scores)
# selection_probs = [fitness / total_fitness for fitness in fitness_scores]
# parents = random.choices(population, weights=selection_probs, k=num_parents)
# return parents
# def genetic_algorithm(animals, max_x, max_y, obstacles, cost_map):
# population = generate_population(animals, POPULATION_SIZE)
# for generation in range(NUM_GENERATIONS):
# # Selekcja rodziców
# parents = select_parents(population, POPULATION_SIZE // 2, max_x, max_y, obstacles, cost_map)
# # Krzyżowanie i tworzenie nowej populacji
# next_generation = []
# for i in range(0, len(parents), 2):
# parent1 = parents[i]
# parent2 = parents[i + 1]
# child1, child2 = crossover(parent1, parent2)
# next_generation.extend([child1, child2])
# # Mutacja nowej populacji
# for individual in next_generation:
# mutate(individual)
# # Zastąpienie starej populacji nową
# population = next_generation
# # Znalezienie najlepszego osobnika
# best_individual = min(population, key=lambda individual: calculate_total_distance(individual, max_x, max_y, obstacles, cost_map))
# return best_individual

Binary file not shown.

Before

Width:  |  Height:  |  Size: 458 KiB

After

Width:  |  Height:  |  Size: 740 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 438 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 366 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 373 KiB

After

Width:  |  Height:  |  Size: 642 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 294 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 161 KiB

After

Width:  |  Height:  |  Size: 444 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

After

Width:  |  Height:  |  Size: 286 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 268 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 176 KiB

53
main.py
View File

@ -13,7 +13,6 @@ from constants import Constants, init_pygame
from draw import draw_goal, draw_grid, draw_house
from season import draw_background
from night import change_time
from genetics import genetic_algorithm
const = Constants()
init_pygame(const)
@ -78,13 +77,12 @@ def main():
actions = []
clock = pygame.time.Clock()
spawned = False
route = False
# # Lista zawierająca klatki do odwiedzenia
# enclosures_to_visit = Enclosures.copy()
# current_enclosure_index = -1 # Indeks bieżącej klatki
# actions_to_compare_list = [] # Lista zawierająca ścieżki do porównania
# goals_to_compare_list = list() # Lista zawierająca cele do porównania
# Lista zawierająca klatki do odwiedzenia
enclosures_to_visit = Enclosures.copy()
current_enclosure_index = -1 # Indeks bieżącej klatki
actions_to_compare_list = [] # Lista zawierająca ścieżki do porównania
goals_to_compare_list = list() # Lista zawierająca cele do porównania
while True:
for event in pygame.event.get():
@ -95,6 +93,7 @@ def main():
change_time(const)
draw_background(const)
draw_grid(const)
draw_enclosures(Enclosures, const)
draw_gates(Enclosures, const)
draw_house(const)
@ -107,11 +106,6 @@ def main():
# animal._feed = 0
animal._feed = random.randint(0, 10)
spawned = True
if not route:
animals = [(animal.x, animal.y) for animal in Animals]
best_route = genetic_algorithm(animals)
route = True
draw_Animals(Animals, const)
draw_Terrain_Obstacles(Terrain_Obstacles, const)
@ -125,34 +119,31 @@ def main():
pygame.time.wait(200)
else:
if agent._dryfood > 1 and agent._wetfood > 1 :
# if not goals_to_compare_list:
# current_enclosure_index = (current_enclosure_index + 1) % len(enclosures_to_visit)
# current_enclosure = enclosures_to_visit[current_enclosure_index]
if not goals_to_compare_list:
current_enclosure_index = (current_enclosure_index + 1) % len(enclosures_to_visit)
current_enclosure = enclosures_to_visit[current_enclosure_index]
# for animal in current_enclosure.animals:
# goal = (animal.x, animal.y)
# goals_to_compare_list.append(goal)
for animal in current_enclosure.animals:
goal = (animal.x, animal.y)
goals_to_compare_list.append(goal)
# actions_to_compare = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
# actions_to_compare_list.append((actions_to_compare, goal))
actions_to_compare = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
actions_to_compare_list.append((actions_to_compare, goal))
# chosen_path_and_goal = min(actions_to_compare_list, key=lambda x: len(x[0]))
# goal = chosen_path_and_goal[1]
# draw_goal(const, goal)
# # Usuń wybrany element z listy
# actions_to_compare_list.remove(chosen_path_and_goal)
# goals_to_compare_list.remove(goal)
goal = best_route.pop(0)
best_route.append(goal)
chosen_path_and_goal = min(actions_to_compare_list, key=lambda x: len(x[0]))
goal = chosen_path_and_goal[1]
draw_goal(const, goal)
actions, cost = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
# Usuń wybrany element z listy
actions_to_compare_list.remove(chosen_path_and_goal)
goals_to_compare_list.remove(goal)
actions = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
else:
goal = (3,1)
draw_goal(const, goal)
actions, cost = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
actions = graphsearch(agent.istate, goal, const.GRID_WIDTH, const.GRID_HEIGHT, obstacles, cost_map)
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 173 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 173 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 158 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

View File

@ -1,129 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
def set_device():
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
return torch.device(device)
train_dataset_path = './data/train'
test_dataset_path = './data/val'
number_of_classes = 7
SIZE = 224
mean = [0.5164, 0.5147, 0.4746]
std = [0.2180, 0.2126, 0.2172]
train_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
test_transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
resnet18_model = models.resnet18(weights=None)
num_ftrs = resnet18_model.fc.in_features
resnet18_model.fc = nn.Linear(num_ftrs, number_of_classes)
device = set_device()
resnet18_model = resnet18_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.003)
def save_checkpoint(model, epoch, optimizer, best_acc):
state = {
'epoch': epoch + 1,
'model': model.state_dict(),
'best accuracy': best_acc,
'optimizer': optimizer.state_dict()
}
torch.save(state, 'model_best_checkpoint.pth.tar')
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
device = set_device()
best_acc = 0
for epoch in range(n_epochs):
print("Epoch number %d " % (epoch + 1))
model.train()
running_loss = 0.0
running_correct = 0.0
total = 0
for data in train_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
optimizer.zero_grad()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_correct += (labels == predicted).sum().item()
epoch_loss = running_loss/len(train_loader)
epoch_acc = 100 * running_correct / total
print(f"Training dataset. Got {running_correct} out of {total} images correctly ({epoch_acc}). Epoch loss: {epoch_loss}")
test_data_acc = evaluate_model_on_test_set(model, test_loader)
if test_data_acc > best_acc:
best_acc = test_data_acc
save_checkpoint(model, epoch, optimizer, best_acc)
print("Finished")
return model
def evaluate_model_on_test_set(model, test_loader):
model.eval()
predicted_correctly_on_epoch = 0
total = 0
device = set_device()
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
total += labels.size(0)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predicted_correctly_on_epoch += (predicted == labels).sum().item()
epoch_acc = 100 * predicted_correctly_on_epoch / total
print(f"Testing dataset. Got {predicted_correctly_on_epoch} out of {total} images correctly ({epoch_acc})")
return epoch_acc
train_nn(resnet18_model, train_loader, test_loader, loss_fn, optimizer, n_epochs=30)
checkpoint = torch.load('model_best_checkpoint.pth.tar')
resnet18_model.load_state_dict(checkpoint['model'])
torch.save(resnet18_model, 'best_model.pth')

Binary file not shown.

View File

@ -40,7 +40,7 @@ def graphsearch(istate, goal, max_x, max_y, obstacles, cost_map):
state, _, _ = node
if goaltest(state, goal):
return build_action_sequence(node), current_cost(node, cost_map)
return build_action_sequence(node)
explored.add(state)
@ -61,7 +61,7 @@ def graphsearch(istate, goal, max_x, max_y, obstacles, cost_map):
else:
break
return False, float('inf')
return False
def is_state_in_queue(state, queue):
for _, (s, _, _) in queue.queue:
@ -124,5 +124,4 @@ def generate_cost_map(Animals, Terrain_Obstacles, cost_map={}):
else:
cost_map[(terrain_obstacle.x , terrain_obstacle.y )] = bush_cost
return cost_map
return cost_map

BIN
tree.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 616 KiB

After

Width:  |  Height:  |  Size: 613 KiB