Compare commits

..

24 Commits

Author SHA1 Message Date
edf1e43d33 Merge pull request 'image_recognition' (#5) from image_recognition into master
Reviewed-on: #5
2023-06-12 16:25:59 +02:00
0b33ff1803 fix index out of range error 2023-06-05 15:49:35 +02:00
86be72ba33 add neural network implementation to main 2023-06-05 05:25:13 +02:00
da73e223e3 add more code flexibility 2023-06-05 04:48:04 +02:00
e27acbacaf make code more flexible 2023-06-05 04:42:53 +02:00
28bf53c037 cleanup 2023-06-05 04:10:27 +02:00
931e40d88f cleanup 2023-06-05 04:10:16 +02:00
d4e382a7f0 begin image recognition neural network implementation 2023-06-05 03:35:16 +02:00
Aliaksei Brown
b9fba20676 scratch for nn 2023-06-04 10:35:05 +02:00
d16267826d remove bandaid stones, TODO: fix astar 2023-05-15 16:58:23 +02:00
Aliaksei Brown
d3ffe50c91 inor fix 2023-05-15 15:42:44 +02:00
Aliaksei Brown
0c3b174078 changes: added a visible mark where the chicken goes to 2023-05-08 21:26:55 +02:00
Aliaksei Brown
0c025a857d changes: added a visible mark where the chicken goes to 2023-05-08 21:15:37 +02:00
5608eb2729 Merge pull request 'astar implementation' (#3) from astar into master
Reviewed-on: #3
2023-05-08 16:29:30 +02:00
2437125100 add heuristic; fix cost method 2023-05-08 16:23:21 +02:00
e6dd642006 implement astar. TODO: fix the stones bandaids 2023-05-08 08:23:04 +02:00
08b318d1e6 purge magic numbers from graph_search, TODO: do the same in tractor.py 2023-04-17 16:53:54 +02:00
ab37f8af99 improve succ() switch tree 2023-04-17 16:35:22 +02:00
e60d18d3f6 Merge pull request 'WIP succ_limitfunction' (#2) from succ_limitfunction into master
Reviewed-on: #2
2023-04-17 16:25:13 +02:00
eaf7ed46fe improve limits of succ(), comments cleanup 2023-04-17 16:24:08 +02:00
2cb44dcb01 partial cleanup of graph_search.py 2023-04-17 16:00:22 +02:00
c5d86faade add: display target field in console 2023-04-17 15:46:17 +02:00
24894482e4 add tractor event; tractor moves using graph_search 2023-04-17 10:20:05 +02:00
425a9bf3e2 provisionary limit to the succ function 2023-04-17 04:05:42 +02:00
15 changed files with 533 additions and 45 deletions

115
astar_search.py Normal file
View File

@ -0,0 +1,115 @@
class Node:
def __init__(self, state, parent='', action='', distance=0):
self.state = state
self.parent = parent
self.action = action
self.distance = distance
class Search:
def __init__(self, cell_size, cell_number):
self.cell_size = cell_size
self.cell_number = cell_number
def succ(self, state):
x = state[0]
y = state[1]
angle = state[2]
match(angle):
case 'UP':
possible = [['left', x, y, 'LEFT'], ['right', x, y, 'RIGHT']]
if y != 0: possible.append(['move', x, y - self.cell_size, 'UP'])
return possible
case 'RIGHT':
possible = [['left', x, y, 'UP'], ['right', x, y, 'DOWN']]
if x != self.cell_size*(self.cell_number-1): possible.append(['move', x + self.cell_size, y, 'RIGHT'])
return possible
case 'DOWN':
possible = [['left', x, y, 'RIGHT'], ['right', x, y, 'LEFT']]
if y != self.cell_size*(self.cell_number-1): possible.append(['move', x, y + self.cell_size, 'DOWN'])
return possible
case 'LEFT':
possible = [['left', x, y, 'DOWN'], ['right', x, y, 'UP']]
if x != 0: possible.append(['move', x - self.cell_size, y, 'LEFT'])
return possible
def cost(self, node, stones, goal, flowers):
# cost = node.distance
cost = 0
# cost += 10 if stones[node.state[0], node.state[1]] == 1 else 1
cost += 1000 if (node.state[0], node.state[1]) in stones else 1
cost += 300 if ((node.state[0]), (node.state[1])) in flowers else 1
if node.parent:
node = node.parent
cost += node.distance # should return only elem.action in prod
return cost
def heuristic(self, node, goal):
return abs(node.state[0] - goal[0]) + abs(node.state[1] - goal[1])
#bandaid to know about stones
def astarsearch(self, istate, goaltest, cStones, cFlowers):
#to be expanded
def cost_old(x, y):
if (x, y) in stones:
return 10
else:
return 1
x = istate[0]
y = istate[1]
angle = istate[2]
stones = [(x*50, y*50) for (x, y) in cStones]
flowers = [(x*50, y*50) for (x, y) in cFlowers]
print(stones)
# fringe = [(Node([x, y, angle]), cost_old(x, y))] # queue (moves/states to check)
fringe = [(Node([x, y, angle]))] # queue (moves/states to check)
fringe[0].distance = self.cost(fringe[0], stones, goaltest, flowers)
fringe.append((Node([x, y, angle]), self.cost(fringe[0], stones, goaltest, flowers)))
fringe.pop(0)
explored = []
while True:
if len(fringe) == 0:
return False
fringe.sort(key=lambda x: x[1])
elem = fringe.pop(0)[0]
# if goal_test(elem.state):
# return
# print(elem.state[0], elem.state[1], elem.state[2])
if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point
steps = []
while elem.parent:
steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod
elem = elem.parent
steps.reverse()
print(steps) # only for dev
return steps
explored.append(elem.state)
for (action, state_x, state_y, state_angle) in self.succ(elem.state):
x = Node([state_x, state_y, state_angle], elem, action)
x.parent = elem
priority = self.cost(elem, stones, goaltest, flowers) + self.heuristic(elem, goaltest)
elem.distance = priority
# priority = cost_old(x, y) + self.heuristic(elem, goaltest)
fringe_states = [node.state for (node, p) in fringe]
if x.state not in fringe_states and x.state not in explored:
fringe.append((x, priority))
elif x.state in fringe_states:
for i in range(len(fringe)):
if fringe[i][0].state == x.state:
if fringe[i][1] > priority:
fringe[i] = (x, priority)

View File

@ -5,6 +5,7 @@ import soil
class Blocks: class Blocks:
def __init__(self, parent_screen,cell_size): def __init__(self, parent_screen,cell_size):
self.parent_screen = parent_screen self.parent_screen = parent_screen
self.flower_image = pygame.image.load(r'resources/flower.png').convert_alpha() self.flower_image = pygame.image.load(r'resources/flower.png').convert_alpha()
@ -25,9 +26,12 @@ class Blocks:
self.fawn_wheat_image = pygame.image.load(r'resources/fawn_wheat.png').convert_alpha() self.fawn_wheat_image = pygame.image.load(r'resources/fawn_wheat.png').convert_alpha()
self.fawn_wheat_image = pygame.transform.scale(self.fawn_wheat_image, (cell_size, cell_size)) self.fawn_wheat_image = pygame.transform.scale(self.fawn_wheat_image, (cell_size, cell_size))
self.red_image = pygame.image.load(r'resources/redBush.png').convert_alpha()
self.red_image = pygame.transform.scale(self.red_image, (cell_size, cell_size))
self.soil = soil.Soil() self.soil = soil.Soil()
def locate_blocks(self, blocks_number, cell_number, body): def locate_blocks(self, blocks_number, cell_number, body):
for i in range(blocks_number): for i in range(blocks_number):
self.x = random.randint(0, cell_number-1) self.x = random.randint(0, cell_number-1)
@ -53,6 +57,8 @@ class Blocks:
self.parent_screen.blit(self.fawn_seed_image, (x, y)) self.parent_screen.blit(self.fawn_seed_image, (x, y))
if color == 'fawn_wheat': if color == 'fawn_wheat':
self.parent_screen.blit(self.fawn_wheat_image, (x, y)) self.parent_screen.blit(self.fawn_wheat_image, (x, y))
if color == 'red':
self.parent_screen.blit(self.red_image, (x, y))

View File

@ -6,22 +6,31 @@ class Node:
class Search: class Search:
def __init__(self, cell_size): def __init__(self, cell_size, cell_number):
self.cell_size = cell_size self.cell_size = cell_size
self.cell_number = cell_number
# WARNING! IT EXCEEDS THE PLANE!!! def succ(self, state):
def succ(self, state): # successor function
x = state[0] x = state[0]
y = state[1] y = state[1]
angle = state[2] angle = state[2]
if angle == 0: match(angle):
return [['move', x, y - self.cell_size, 0], ['left', x, y, 270], ['right', x, y, 90]] case 'UP':
if angle == 90: possible = [['left', x, y, 'LEFT'], ['right', x, y, 'RIGHT']]
return [['move', x + self.cell_size, y, 90], ['left', x, y, 0], ['right', x, y, 180]] if y != 0: possible.append(['move', x, y - self.cell_size, 'UP'])
if angle == 180: return possible
return [['move', x, y + self.cell_size, 180], ['left', x, y, 90], ['right', x, y, 270]] case 'RIGHT':
if angle == 270: possible = [['left', x, y, 'UP'], ['right', x, y, 'DOWN']]
return [['move', x - self.cell_size, y, 270], ['left', x, y, 180], ['right', x, y, 0]] if x != self.cell_size*(self.cell_number-1): possible.append(['move', x + self.cell_size, y, 'RIGHT'])
return possible
case 'DOWN':
possible = [['left', x, y, 'RIGHT'], ['right', x, y, 'LEFT']]
if y != self.cell_size*(self.cell_number-1): possible.append(['move', x, y + self.cell_size, 'DOWN'])
return possible
case 'LEFT':
possible = [['left', x, y, 'DOWN'], ['right', x, y, 'UP']]
if x != 0: possible.append(['move', x - self.cell_size, y, 'LEFT'])
return possible
def graphsearch(self, istate, goaltest): def graphsearch(self, istate, goaltest):
x = istate[0] x = istate[0]
@ -44,7 +53,7 @@ class Search:
# print(elem.state[0], elem.state[1], elem.state[2]) # print(elem.state[0], elem.state[1], elem.state[2])
if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point
steps = [] steps = []
while elem.parent != '': while elem.parent:
steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod
elem = elem.parent elem = elem.parent
@ -55,8 +64,6 @@ class Search:
explored.append(elem.state) explored.append(elem.state)
for (action, state_x, state_y, state_angle) in self.succ(elem.state): for (action, state_x, state_y, state_angle) in self.succ(elem.state):
if state_x < 0 or state_y < 0: # check if any of the values are negative
continue
if [state_x, state_y, state_angle] not in fringe_state and \ if [state_x, state_y, state_angle] not in fringe_state and \
[state_x, state_y, state_angle] not in explored: [state_x, state_y, state_angle] not in explored:
x = Node([state_x, state_y, state_angle]) x = Node([state_x, state_y, state_angle])
@ -64,7 +71,3 @@ class Search:
x.action = action x.action = action
fringe.append(x) fringe.append(x)
fringe_state.append(x.state) fringe_state.append(x.state)
se = Search(50)
se.graphsearch(istate=[50, 50, 0], goaltest=[150, 250])

51
learn_tree.py Normal file
View File

@ -0,0 +1,51 @@
from collections import Counter
def tree_learn(examples, attributes, default_class):
if len(examples) == 0:
return default_class
if all(examples[0][-1] == example[-1] for example in examples):
return examples[0][-1]
if len(attributes) == 0:
class_counts = Counter(example[-1] for example in examples)
majority_class = class_counts.most_common(1)[0][0]
return majority_class
# Choose the attribute A as the root of the decision tree
A = select_attribute(attributes, examples)
tree = {A: {}}
new_attributes = [attr for attr in attributes if attr != A]
new_default_class = Counter(example[-1] for example in examples).most_common(1)[0][0]
for value in get_attribute_values(A):
new_examples = [example for example in examples if example[attributes.index(A)] == value]
subtree = tree_learn(new_examples, new_attributes, new_default_class)
tree[A][value] = subtree
return tree
# Helper function: Select the best attribute based on a certain criterion (e.g., information gain)
def select_attribute(attributes, examples):
# Implement your attribute selection criterion here
pass
# Helper function: Get the possible values of an attribute from the examples
def get_attribute_values(attribute):
# Implement your code to retrieve the attribute values from the examples here
pass
# Example usage with coordinates
examples = [
[1, 2, 'A'],
[3, 4, 'A'],
[5, 6, 'B'],
[7, 8, 'B']
]
attributes = ['x', 'y']
default_class = 'unknown'
decision_tree = tree_learn(examples, attributes, default_class)
print(decision_tree)

58
main.py
View File

@ -1,12 +1,12 @@
import os import os
import pygame import pygame
import random import random
import land import land
import tractor import tractor
import blocks import blocks
import astar_search
import neural_network.inference
from pygame.locals import * from pygame.locals import *
from datetime import datetime
examples = [ examples = [
['piasek', 'sucha', 'jalowa', 'żółty'], ['piasek', 'sucha', 'jalowa', 'żółty'],
@ -93,7 +93,7 @@ class Node:
class Game: class Game:
cell_size = 50 cell_size = 50
cell_number = 15 # horizontally cell_number = 15 # horizontally
blocks_number = 15 blocks_number = 20
def __init__(self): def __init__(self):
@ -103,6 +103,7 @@ class Game:
self.flower_body = [] self.flower_body = []
self.dead_grass_body = [] self.dead_grass_body = []
self.grass_body = [] self.grass_body = []
self.red_block = [] #aim block
self.fawn_seed_body = [] self.fawn_seed_body = []
self.fawn_wheat_body = [] self.fawn_wheat_body = []
@ -135,6 +136,8 @@ class Game:
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.stone_body) self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.stone_body)
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.flower_body) self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.flower_body)
#self.blocks.locate_blocks(1, self.cell_number, self.red_block)
# self.potato = blocks.Blocks(self.surface, self.cell_size) # self.potato = blocks.Blocks(self.surface, self.cell_size)
# self.potato.locate_soil('black earth', 6, 1, []) # self.potato.locate_soil('black earth', 6, 1, [])
@ -147,12 +150,17 @@ class Game:
# print(self.potato.get_soil_info().get_irrigation()) # print(self.potato.get_soil_info().get_irrigation())
running = True running = True
clock = pygame.time.Clock() clock = pygame.time.Clock()
# last_time = datetime.now()
move_tractor_event = pygame.USEREVENT + 1
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
tractor_next_moves = []
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
veggies = dict()
veggies_debug = dict()
while running: while running:
clock.tick(60) # manual fps control not to overwork the computer clock.tick(60) # manual fps control not to overwork the computer
# time_now = datetime.now()
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == KEYDOWN: if event.type == KEYDOWN:
if pygame.key.get_pressed()[K_ESCAPE]: if pygame.key.get_pressed()[K_ESCAPE]:
@ -173,29 +181,57 @@ class Game:
if pygame.key.get_pressed()[K_q]: if pygame.key.get_pressed()[K_q]:
self.tractor.harvest(self.fawn_seed_body, self.fawn_wheat_body, self.cell_size) self.tractor.harvest(self.fawn_seed_body, self.fawn_wheat_body, self.cell_size)
self.tractor.put_seed(self.fawn_soil_body, self.fawn_seed_body, self.cell_size) self.tractor.put_seed(self.fawn_soil_body, self.fawn_seed_body, self.cell_size)
if event.type == move_tractor_event:
if len(tractor_next_moves) == 0:
random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
print("Generated target: ",random_x, random_y)
if self.red_block:
self.red_block.pop()
self.red_block.append([random_x/50, random_y/50])
# below line should be later moved into tractor.py
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
#bandaid to know about stones
tractor_next_moves = astar_search_object.astarsearch(
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1])-1)]
if(current_veggie in veggies_debug):
veggies_debug[current_veggie]+=1
else:
veggies_debug[current_veggie] = 1
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2])-1)]
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
if predicted_veggie in veggies:
veggies[predicted_veggie]+=1
else:
veggies[predicted_veggie] = 1
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
else:
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
elif event.type == QUIT: elif event.type == QUIT:
running = False running = False
self.surface.fill((123, 56, 51)) # background color self.surface.fill((123, 56, 51)) # background color
self.grass.set_and_place_block_of_grass('good') self.grass.set_and_place_block_of_grass('good')
self.black_earth.place_soil(self.black_earth_body, 'black_earth') self.black_earth.place_soil(self.black_earth_body, 'black_earth')
self.green_earth.place_soil(self.green_earth_body, 'green_earth') self.green_earth.place_soil(self.green_earth_body, 'green_earth')
self.fawn_soil.place_soil(self.fawn_soil_body, 'fawn_soil') self.fawn_soil.place_soil(self.fawn_soil_body, 'fawn_soil')
self.fen_soil.place_soil(self.fen_soil_body, 'fen_soil') self.fen_soil.place_soil(self.fen_soil_body, 'fen_soil')
#plants examples # plants examples
self.blocks.place_blocks(self.surface, self.cell_size, self.dead_leaf_body, 'leaf') self.blocks.place_blocks(self.surface, self.cell_size, self.dead_leaf_body, 'leaf')
self.blocks.place_blocks(self.surface, self.cell_size, self.green_leaf_body, 'alive') self.blocks.place_blocks(self.surface, self.cell_size, self.green_leaf_body, 'alive')
self.blocks.place_blocks(self.surface, self.cell_size, self.stone_body, 'stone') self.blocks.place_blocks(self.surface, self.cell_size, self.stone_body, 'stone')
self.blocks.place_blocks(self.surface, self.cell_size, self.flower_body, 'flower') self.blocks.place_blocks(self.surface, self.cell_size, self.flower_body, 'flower')
#seeds self.blocks.place_blocks(self.surface, self.cell_size, self.red_block, 'red')
# seeds
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_seed_body, 'fawn_seed') self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_seed_body, 'fawn_seed')
#wheat # wheat
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_wheat_body, 'fawn_wheat') self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_wheat_body, 'fawn_wheat')
self.tractor.draw() self.tractor.draw()

View File

@ -0,0 +1,42 @@
import torchvision
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize((224, 224)), #validate that all images are 224x244
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
transforms.RandomRotation(degrees=(30, 70)), #random effects are applied to prevent overfitting
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
train_dataset = torchvision.datasets.ImageFolder(root='./images/train', transform=train_transform)
validation_dataset = torchvision.datasets.ImageFolder(root='./images/validation', transform=valid_transform)
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
)
valid_loader = DataLoader(
validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True
)

View File

@ -0,0 +1,59 @@
import torch
import cv2
import torchvision.transforms as transforms
import argparse
from neural_network.model import CNNModel
# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input',
default='',
help='path to the input image')
args = vars(parser.parse_args())
def main(path):
# the computation device
device = ('cuda' if torch.cuda.is_available() else 'cpu')
# list containing all the class labels
labels = [
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
]
# initialize the model and load the trained weights
model = CNNModel().to(device)
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# define preprocess transforms
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
# read and preprocess the image
image = cv2.imread(path)
# get the ground truth class
gt_class = path.split('/')[-2]
orig_image = image.copy()
# convert to RGB format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
# add batch dimension
image = torch.unsqueeze(image, 0)
with torch.no_grad():
outputs = model(image.to(device))
output_label = torch.topk(outputs, 1)
pred_class = labels[int(output_label.indices)]
return pred_class
if __name__ == "__main__":
main(args['input'])

24
neural_network/model.py Normal file
View File

@ -0,0 +1,24 @@
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module): #model of the CNN type
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5)
self.conv2 = nn.Conv2d(32, 64, 5)
self.conv3 = nn.Conv2d(64, 128, 3)
self.conv4 = nn.Conv2d(128, 256, 5)
self.fc1 = nn.Linear(256, 50)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
bs, _, _, _ = x.shape
x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
x = self.fc1(x)
return x

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

119
neural_network/train.py Normal file
View File

@ -0,0 +1,119 @@
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import time
from tqdm.auto import tqdm
from neural_network.model import CNNModel
from neural_network.datasets import train_loader, valid_loader
from neural_network.utils import save_model, save_plots
# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=20,
help='number of epochs to train our network for')
args = vars(parser.parse_args())
lr = 1e-3
epochs = args['epochs']
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Computation device: {device}\n")
model = CNNModel().to(device)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# loss function
criterion = nn.CrossEntropyLoss()
# training
def train(model, trainloader, optimizer, criterion):
model.train()
print('Training')
train_running_loss = 0.0
train_running_correct = 0
counter = 0
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
counter += 1
image, labels = data
image = image.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# forward pass
outputs = model(image)
# calculate the loss
loss = criterion(outputs, labels)
train_running_loss += loss.item()
# calculate the accuracy
_, preds = torch.max(outputs.data, 1)
train_running_correct += (preds == labels).sum().item()
# backpropagation
loss.backward()
# update the optimizer parameters
optimizer.step()
# loss and accuracy for the complete epoch
epoch_loss = train_running_loss / counter
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
return epoch_loss, epoch_acc
# validation
def validate(model, testloader, criterion):
model.eval()
print('Validation')
valid_running_loss = 0.0
valid_running_correct = 0
counter = 0
with torch.no_grad():
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
counter += 1
image, labels = data
image = image.to(device)
labels = labels.to(device)
# forward pass
outputs = model(image)
# calculate the loss
loss = criterion(outputs, labels)
valid_running_loss += loss.item()
# calculate the accuracy
_, preds = torch.max(outputs.data, 1)
valid_running_correct += (preds == labels).sum().item()
# loss and accuracy for the complete epoch
epoch_loss = valid_running_loss / counter
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
return epoch_loss, epoch_acc
# lists to keep track of losses and accuracies
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
# start the training
for epoch in range(epochs):
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
train_epoch_loss, train_epoch_acc = train(model, train_loader,
optimizer, criterion)
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
criterion)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
train_acc.append(train_epoch_acc)
valid_acc.append(valid_epoch_acc)
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
print('-'*50)
time.sleep(5)
# save the trained model weights
save_model(epochs, model, optimizer, criterion)
# save the loss and accuracy plots
save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')

49
neural_network/utils.py Normal file
View File

@ -0,0 +1,49 @@
import torch
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')
def save_model(epochs, model, optimizer, criterion):
"""
Function to save the trained model to disk.
"""
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': criterion,
}, 'outputs/model.pth')
def save_plots(train_acc, valid_acc, train_loss, valid_loss):
"""
Function to save the loss and accuracy plots to disk.
"""
# accuracy plots
plt.figure(figsize=(10, 7))
plt.plot(
train_acc, color='green', linestyle='-',
label='train accuracy'
)
plt.plot(
valid_acc, color='blue', linestyle='-',
label='validataion accuracy'
)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('outputs/accuracy.png')
# loss plots
plt.figure(figsize=(10, 7))
plt.plot(
train_loss, color='orange', linestyle='-',
label='train loss'
)
plt.plot(
valid_loss, color='red', linestyle='-',
label='validataion loss'
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('outputs/loss.png')

BIN
resources/redBush.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -31,22 +31,6 @@ class Tractor:
def move(self, direction, cell_size, cell_number): def move(self, direction, cell_size, cell_number):
# if direction == 'up':
# if self.y != 0:
# self.y -= cell_size
# self.image = self.up
# if direction == 'down':
# if self.y != (cell_number-1)*cell_size:
# self.y += cell_size
# self.image = self.down
# if direction == 'left':
# if self.x != 0:
# self.x -= cell_size
# self.image = self.left
# if direction == 'right':
# if self.x != (cell_number-1)*cell_size:
# self.x += cell_size
# self.image = self.right
if direction == 'move': if direction == 'move':
if self.angle == 0 and self.y != 0: if self.angle == 0 and self.y != 0:
self.y -= cell_size self.y -= cell_size