Compare commits
24 Commits
c0c5d3adcd
...
edf1e43d33
Author | SHA1 | Date | |
---|---|---|---|
edf1e43d33 | |||
0b33ff1803 | |||
86be72ba33 | |||
da73e223e3 | |||
e27acbacaf | |||
28bf53c037 | |||
931e40d88f | |||
d4e382a7f0 | |||
|
b9fba20676 | ||
d16267826d | |||
|
d3ffe50c91 | ||
|
0c3b174078 | ||
|
0c025a857d | ||
5608eb2729 | |||
2437125100 | |||
e6dd642006 | |||
08b318d1e6 | |||
ab37f8af99 | |||
e60d18d3f6 | |||
eaf7ed46fe | |||
2cb44dcb01 | |||
c5d86faade | |||
24894482e4 | |||
425a9bf3e2 |
115
astar_search.py
Normal file
115
astar_search.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
class Node:
|
||||||
|
def __init__(self, state, parent='', action='', distance=0):
|
||||||
|
self.state = state
|
||||||
|
self.parent = parent
|
||||||
|
self.action = action
|
||||||
|
self.distance = distance
|
||||||
|
|
||||||
|
class Search:
|
||||||
|
def __init__(self, cell_size, cell_number):
|
||||||
|
self.cell_size = cell_size
|
||||||
|
self.cell_number = cell_number
|
||||||
|
|
||||||
|
def succ(self, state):
|
||||||
|
x = state[0]
|
||||||
|
y = state[1]
|
||||||
|
angle = state[2]
|
||||||
|
match(angle):
|
||||||
|
case 'UP':
|
||||||
|
possible = [['left', x, y, 'LEFT'], ['right', x, y, 'RIGHT']]
|
||||||
|
if y != 0: possible.append(['move', x, y - self.cell_size, 'UP'])
|
||||||
|
return possible
|
||||||
|
case 'RIGHT':
|
||||||
|
possible = [['left', x, y, 'UP'], ['right', x, y, 'DOWN']]
|
||||||
|
if x != self.cell_size*(self.cell_number-1): possible.append(['move', x + self.cell_size, y, 'RIGHT'])
|
||||||
|
return possible
|
||||||
|
case 'DOWN':
|
||||||
|
possible = [['left', x, y, 'RIGHT'], ['right', x, y, 'LEFT']]
|
||||||
|
if y != self.cell_size*(self.cell_number-1): possible.append(['move', x, y + self.cell_size, 'DOWN'])
|
||||||
|
return possible
|
||||||
|
case 'LEFT':
|
||||||
|
possible = [['left', x, y, 'DOWN'], ['right', x, y, 'UP']]
|
||||||
|
if x != 0: possible.append(['move', x - self.cell_size, y, 'LEFT'])
|
||||||
|
return possible
|
||||||
|
|
||||||
|
def cost(self, node, stones, goal, flowers):
|
||||||
|
# cost = node.distance
|
||||||
|
cost = 0
|
||||||
|
# cost += 10 if stones[node.state[0], node.state[1]] == 1 else 1
|
||||||
|
cost += 1000 if (node.state[0], node.state[1]) in stones else 1
|
||||||
|
cost += 300 if ((node.state[0]), (node.state[1])) in flowers else 1
|
||||||
|
|
||||||
|
if node.parent:
|
||||||
|
node = node.parent
|
||||||
|
cost += node.distance # should return only elem.action in prod
|
||||||
|
return cost
|
||||||
|
|
||||||
|
def heuristic(self, node, goal):
|
||||||
|
return abs(node.state[0] - goal[0]) + abs(node.state[1] - goal[1])
|
||||||
|
|
||||||
|
#bandaid to know about stones
|
||||||
|
def astarsearch(self, istate, goaltest, cStones, cFlowers):
|
||||||
|
|
||||||
|
#to be expanded
|
||||||
|
def cost_old(x, y):
|
||||||
|
if (x, y) in stones:
|
||||||
|
return 10
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
x = istate[0]
|
||||||
|
y = istate[1]
|
||||||
|
angle = istate[2]
|
||||||
|
|
||||||
|
stones = [(x*50, y*50) for (x, y) in cStones]
|
||||||
|
flowers = [(x*50, y*50) for (x, y) in cFlowers]
|
||||||
|
|
||||||
|
print(stones)
|
||||||
|
|
||||||
|
# fringe = [(Node([x, y, angle]), cost_old(x, y))] # queue (moves/states to check)
|
||||||
|
fringe = [(Node([x, y, angle]))] # queue (moves/states to check)
|
||||||
|
fringe[0].distance = self.cost(fringe[0], stones, goaltest, flowers)
|
||||||
|
fringe.append((Node([x, y, angle]), self.cost(fringe[0], stones, goaltest, flowers)))
|
||||||
|
fringe.pop(0)
|
||||||
|
|
||||||
|
explored = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if len(fringe) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
fringe.sort(key=lambda x: x[1])
|
||||||
|
elem = fringe.pop(0)[0]
|
||||||
|
|
||||||
|
# if goal_test(elem.state):
|
||||||
|
# return
|
||||||
|
# print(elem.state[0], elem.state[1], elem.state[2])
|
||||||
|
if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point
|
||||||
|
steps = []
|
||||||
|
while elem.parent:
|
||||||
|
steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod
|
||||||
|
elem = elem.parent
|
||||||
|
|
||||||
|
steps.reverse()
|
||||||
|
print(steps) # only for dev
|
||||||
|
return steps
|
||||||
|
|
||||||
|
explored.append(elem.state)
|
||||||
|
|
||||||
|
for (action, state_x, state_y, state_angle) in self.succ(elem.state):
|
||||||
|
x = Node([state_x, state_y, state_angle], elem, action)
|
||||||
|
x.parent = elem
|
||||||
|
|
||||||
|
priority = self.cost(elem, stones, goaltest, flowers) + self.heuristic(elem, goaltest)
|
||||||
|
elem.distance = priority
|
||||||
|
# priority = cost_old(x, y) + self.heuristic(elem, goaltest)
|
||||||
|
fringe_states = [node.state for (node, p) in fringe]
|
||||||
|
|
||||||
|
if x.state not in fringe_states and x.state not in explored:
|
||||||
|
fringe.append((x, priority))
|
||||||
|
elif x.state in fringe_states:
|
||||||
|
for i in range(len(fringe)):
|
||||||
|
if fringe[i][0].state == x.state:
|
||||||
|
if fringe[i][1] > priority:
|
||||||
|
fringe[i] = (x, priority)
|
@ -5,6 +5,7 @@ import soil
|
|||||||
|
|
||||||
|
|
||||||
class Blocks:
|
class Blocks:
|
||||||
|
|
||||||
def __init__(self, parent_screen,cell_size):
|
def __init__(self, parent_screen,cell_size):
|
||||||
self.parent_screen = parent_screen
|
self.parent_screen = parent_screen
|
||||||
self.flower_image = pygame.image.load(r'resources/flower.png').convert_alpha()
|
self.flower_image = pygame.image.load(r'resources/flower.png').convert_alpha()
|
||||||
@ -25,9 +26,12 @@ class Blocks:
|
|||||||
self.fawn_wheat_image = pygame.image.load(r'resources/fawn_wheat.png').convert_alpha()
|
self.fawn_wheat_image = pygame.image.load(r'resources/fawn_wheat.png').convert_alpha()
|
||||||
self.fawn_wheat_image = pygame.transform.scale(self.fawn_wheat_image, (cell_size, cell_size))
|
self.fawn_wheat_image = pygame.transform.scale(self.fawn_wheat_image, (cell_size, cell_size))
|
||||||
|
|
||||||
|
self.red_image = pygame.image.load(r'resources/redBush.png').convert_alpha()
|
||||||
|
self.red_image = pygame.transform.scale(self.red_image, (cell_size, cell_size))
|
||||||
|
|
||||||
self.soil = soil.Soil()
|
self.soil = soil.Soil()
|
||||||
|
|
||||||
|
|
||||||
def locate_blocks(self, blocks_number, cell_number, body):
|
def locate_blocks(self, blocks_number, cell_number, body):
|
||||||
for i in range(blocks_number):
|
for i in range(blocks_number):
|
||||||
self.x = random.randint(0, cell_number-1)
|
self.x = random.randint(0, cell_number-1)
|
||||||
@ -53,6 +57,8 @@ class Blocks:
|
|||||||
self.parent_screen.blit(self.fawn_seed_image, (x, y))
|
self.parent_screen.blit(self.fawn_seed_image, (x, y))
|
||||||
if color == 'fawn_wheat':
|
if color == 'fawn_wheat':
|
||||||
self.parent_screen.blit(self.fawn_wheat_image, (x, y))
|
self.parent_screen.blit(self.fawn_wheat_image, (x, y))
|
||||||
|
if color == 'red':
|
||||||
|
self.parent_screen.blit(self.red_image, (x, y))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,22 +6,31 @@ class Node:
|
|||||||
|
|
||||||
|
|
||||||
class Search:
|
class Search:
|
||||||
def __init__(self, cell_size):
|
def __init__(self, cell_size, cell_number):
|
||||||
self.cell_size = cell_size
|
self.cell_size = cell_size
|
||||||
|
self.cell_number = cell_number
|
||||||
|
|
||||||
# WARNING! IT EXCEEDS THE PLANE!!!
|
def succ(self, state):
|
||||||
def succ(self, state): # successor function
|
|
||||||
x = state[0]
|
x = state[0]
|
||||||
y = state[1]
|
y = state[1]
|
||||||
angle = state[2]
|
angle = state[2]
|
||||||
if angle == 0:
|
match(angle):
|
||||||
return [['move', x, y - self.cell_size, 0], ['left', x, y, 270], ['right', x, y, 90]]
|
case 'UP':
|
||||||
if angle == 90:
|
possible = [['left', x, y, 'LEFT'], ['right', x, y, 'RIGHT']]
|
||||||
return [['move', x + self.cell_size, y, 90], ['left', x, y, 0], ['right', x, y, 180]]
|
if y != 0: possible.append(['move', x, y - self.cell_size, 'UP'])
|
||||||
if angle == 180:
|
return possible
|
||||||
return [['move', x, y + self.cell_size, 180], ['left', x, y, 90], ['right', x, y, 270]]
|
case 'RIGHT':
|
||||||
if angle == 270:
|
possible = [['left', x, y, 'UP'], ['right', x, y, 'DOWN']]
|
||||||
return [['move', x - self.cell_size, y, 270], ['left', x, y, 180], ['right', x, y, 0]]
|
if x != self.cell_size*(self.cell_number-1): possible.append(['move', x + self.cell_size, y, 'RIGHT'])
|
||||||
|
return possible
|
||||||
|
case 'DOWN':
|
||||||
|
possible = [['left', x, y, 'RIGHT'], ['right', x, y, 'LEFT']]
|
||||||
|
if y != self.cell_size*(self.cell_number-1): possible.append(['move', x, y + self.cell_size, 'DOWN'])
|
||||||
|
return possible
|
||||||
|
case 'LEFT':
|
||||||
|
possible = [['left', x, y, 'DOWN'], ['right', x, y, 'UP']]
|
||||||
|
if x != 0: possible.append(['move', x - self.cell_size, y, 'LEFT'])
|
||||||
|
return possible
|
||||||
|
|
||||||
def graphsearch(self, istate, goaltest):
|
def graphsearch(self, istate, goaltest):
|
||||||
x = istate[0]
|
x = istate[0]
|
||||||
@ -44,7 +53,7 @@ class Search:
|
|||||||
# print(elem.state[0], elem.state[1], elem.state[2])
|
# print(elem.state[0], elem.state[1], elem.state[2])
|
||||||
if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point
|
if elem.state[0] == goaltest[0] and elem.state[1] == goaltest[1]: # checks if we reached the given point
|
||||||
steps = []
|
steps = []
|
||||||
while elem.parent != '':
|
while elem.parent:
|
||||||
steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod
|
steps.append([elem.action, elem.state[0], elem.state[1]]) # should return only elem.action in prod
|
||||||
elem = elem.parent
|
elem = elem.parent
|
||||||
|
|
||||||
@ -55,8 +64,6 @@ class Search:
|
|||||||
explored.append(elem.state)
|
explored.append(elem.state)
|
||||||
|
|
||||||
for (action, state_x, state_y, state_angle) in self.succ(elem.state):
|
for (action, state_x, state_y, state_angle) in self.succ(elem.state):
|
||||||
if state_x < 0 or state_y < 0: # check if any of the values are negative
|
|
||||||
continue
|
|
||||||
if [state_x, state_y, state_angle] not in fringe_state and \
|
if [state_x, state_y, state_angle] not in fringe_state and \
|
||||||
[state_x, state_y, state_angle] not in explored:
|
[state_x, state_y, state_angle] not in explored:
|
||||||
x = Node([state_x, state_y, state_angle])
|
x = Node([state_x, state_y, state_angle])
|
||||||
@ -64,7 +71,3 @@ class Search:
|
|||||||
x.action = action
|
x.action = action
|
||||||
fringe.append(x)
|
fringe.append(x)
|
||||||
fringe_state.append(x.state)
|
fringe_state.append(x.state)
|
||||||
|
|
||||||
|
|
||||||
se = Search(50)
|
|
||||||
se.graphsearch(istate=[50, 50, 0], goaltest=[150, 250])
|
|
||||||
|
51
learn_tree.py
Normal file
51
learn_tree.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
def tree_learn(examples, attributes, default_class):
|
||||||
|
if len(examples) == 0:
|
||||||
|
return default_class
|
||||||
|
|
||||||
|
if all(examples[0][-1] == example[-1] for example in examples):
|
||||||
|
return examples[0][-1]
|
||||||
|
|
||||||
|
if len(attributes) == 0:
|
||||||
|
class_counts = Counter(example[-1] for example in examples)
|
||||||
|
majority_class = class_counts.most_common(1)[0][0]
|
||||||
|
return majority_class
|
||||||
|
|
||||||
|
# Choose the attribute A as the root of the decision tree
|
||||||
|
A = select_attribute(attributes, examples)
|
||||||
|
|
||||||
|
tree = {A: {}}
|
||||||
|
new_attributes = [attr for attr in attributes if attr != A]
|
||||||
|
new_default_class = Counter(example[-1] for example in examples).most_common(1)[0][0]
|
||||||
|
|
||||||
|
for value in get_attribute_values(A):
|
||||||
|
new_examples = [example for example in examples if example[attributes.index(A)] == value]
|
||||||
|
subtree = tree_learn(new_examples, new_attributes, new_default_class)
|
||||||
|
tree[A][value] = subtree
|
||||||
|
|
||||||
|
return tree
|
||||||
|
|
||||||
|
# Helper function: Select the best attribute based on a certain criterion (e.g., information gain)
|
||||||
|
def select_attribute(attributes, examples):
|
||||||
|
# Implement your attribute selection criterion here
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Helper function: Get the possible values of an attribute from the examples
|
||||||
|
def get_attribute_values(attribute):
|
||||||
|
# Implement your code to retrieve the attribute values from the examples here
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Example usage with coordinates
|
||||||
|
examples = [
|
||||||
|
[1, 2, 'A'],
|
||||||
|
[3, 4, 'A'],
|
||||||
|
[5, 6, 'B'],
|
||||||
|
[7, 8, 'B']
|
||||||
|
]
|
||||||
|
|
||||||
|
attributes = ['x', 'y']
|
||||||
|
default_class = 'unknown'
|
||||||
|
|
||||||
|
decision_tree = tree_learn(examples, attributes, default_class)
|
||||||
|
print(decision_tree)
|
58
main.py
58
main.py
@ -1,12 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pygame
|
import pygame
|
||||||
import random
|
import random
|
||||||
import land
|
import land
|
||||||
import tractor
|
import tractor
|
||||||
import blocks
|
import blocks
|
||||||
|
import astar_search
|
||||||
|
import neural_network.inference
|
||||||
from pygame.locals import *
|
from pygame.locals import *
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
['piasek', 'sucha', 'jalowa', 'żółty'],
|
['piasek', 'sucha', 'jalowa', 'żółty'],
|
||||||
@ -93,7 +93,7 @@ class Node:
|
|||||||
class Game:
|
class Game:
|
||||||
cell_size = 50
|
cell_size = 50
|
||||||
cell_number = 15 # horizontally
|
cell_number = 15 # horizontally
|
||||||
blocks_number = 15
|
blocks_number = 20
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
@ -103,6 +103,7 @@ class Game:
|
|||||||
self.flower_body = []
|
self.flower_body = []
|
||||||
self.dead_grass_body = []
|
self.dead_grass_body = []
|
||||||
self.grass_body = []
|
self.grass_body = []
|
||||||
|
self.red_block = [] #aim block
|
||||||
|
|
||||||
self.fawn_seed_body = []
|
self.fawn_seed_body = []
|
||||||
self.fawn_wheat_body = []
|
self.fawn_wheat_body = []
|
||||||
@ -135,6 +136,8 @@ class Game:
|
|||||||
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.stone_body)
|
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.stone_body)
|
||||||
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.flower_body)
|
self.blocks.locate_blocks(self.blocks_number, self.cell_number, self.flower_body)
|
||||||
|
|
||||||
|
#self.blocks.locate_blocks(1, self.cell_number, self.red_block)
|
||||||
|
|
||||||
# self.potato = blocks.Blocks(self.surface, self.cell_size)
|
# self.potato = blocks.Blocks(self.surface, self.cell_size)
|
||||||
# self.potato.locate_soil('black earth', 6, 1, [])
|
# self.potato.locate_soil('black earth', 6, 1, [])
|
||||||
|
|
||||||
@ -147,12 +150,17 @@ class Game:
|
|||||||
# print(self.potato.get_soil_info().get_irrigation())
|
# print(self.potato.get_soil_info().get_irrigation())
|
||||||
running = True
|
running = True
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
# last_time = datetime.now()
|
|
||||||
|
move_tractor_event = pygame.USEREVENT + 1
|
||||||
|
pygame.time.set_timer(move_tractor_event, 500) # tractor moves every 1000 ms
|
||||||
|
tractor_next_moves = []
|
||||||
|
astar_search_object = astar_search.Search(self.cell_size, self.cell_number)
|
||||||
|
|
||||||
|
veggies = dict()
|
||||||
|
veggies_debug = dict()
|
||||||
|
|
||||||
while running:
|
while running:
|
||||||
clock.tick(60) # manual fps control not to overwork the computer
|
clock.tick(60) # manual fps control not to overwork the computer
|
||||||
# time_now = datetime.now()
|
|
||||||
|
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == KEYDOWN:
|
if event.type == KEYDOWN:
|
||||||
if pygame.key.get_pressed()[K_ESCAPE]:
|
if pygame.key.get_pressed()[K_ESCAPE]:
|
||||||
@ -173,29 +181,57 @@ class Game:
|
|||||||
if pygame.key.get_pressed()[K_q]:
|
if pygame.key.get_pressed()[K_q]:
|
||||||
self.tractor.harvest(self.fawn_seed_body, self.fawn_wheat_body, self.cell_size)
|
self.tractor.harvest(self.fawn_seed_body, self.fawn_wheat_body, self.cell_size)
|
||||||
self.tractor.put_seed(self.fawn_soil_body, self.fawn_seed_body, self.cell_size)
|
self.tractor.put_seed(self.fawn_soil_body, self.fawn_seed_body, self.cell_size)
|
||||||
|
if event.type == move_tractor_event:
|
||||||
|
if len(tractor_next_moves) == 0:
|
||||||
|
random_x = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||||
|
random_y = random.randrange(0, self.cell_number * self.cell_size, 50)
|
||||||
|
print("Generated target: ",random_x, random_y)
|
||||||
|
if self.red_block:
|
||||||
|
self.red_block.pop()
|
||||||
|
self.red_block.append([random_x/50, random_y/50])
|
||||||
|
# below line should be later moved into tractor.py
|
||||||
|
angles = {0: 'UP', 90: 'RIGHT', 270: 'LEFT', 180: 'DOWN'}
|
||||||
|
#bandaid to know about stones
|
||||||
|
tractor_next_moves = astar_search_object.astarsearch(
|
||||||
|
[self.tractor.x, self.tractor.y, angles[self.tractor.angle]], [random_x, random_y], self.stone_body, self.flower_body)
|
||||||
|
current_veggie = next(os.walk('./neural_network/images/test'))[1][random.randint(0, len(next(os.walk('./neural_network/images/test'))[1])-1)]
|
||||||
|
if(current_veggie in veggies_debug):
|
||||||
|
veggies_debug[current_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies_debug[current_veggie] = 1
|
||||||
|
|
||||||
|
current_veggie_example = next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2][random.randint(0, len(next(os.walk(f'./neural_network/images/test/{current_veggie}'))[2])-1)]
|
||||||
|
predicted_veggie = neural_network.inference.main(f"./neural_network/images/test/{current_veggie}/{current_veggie_example}")
|
||||||
|
if predicted_veggie in veggies:
|
||||||
|
veggies[predicted_veggie]+=1
|
||||||
|
else:
|
||||||
|
veggies[predicted_veggie] = 1
|
||||||
|
print("Debug veggies: ", veggies_debug, "Predicted veggies: ", veggies)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.tractor.move(tractor_next_moves.pop(0)[0], self.cell_size, self.cell_number)
|
||||||
elif event.type == QUIT:
|
elif event.type == QUIT:
|
||||||
running = False
|
running = False
|
||||||
|
|
||||||
self.surface.fill((123, 56, 51)) # background color
|
self.surface.fill((123, 56, 51)) # background color
|
||||||
|
|
||||||
self.grass.set_and_place_block_of_grass('good')
|
self.grass.set_and_place_block_of_grass('good')
|
||||||
self.black_earth.place_soil(self.black_earth_body, 'black_earth')
|
self.black_earth.place_soil(self.black_earth_body, 'black_earth')
|
||||||
self.green_earth.place_soil(self.green_earth_body, 'green_earth')
|
self.green_earth.place_soil(self.green_earth_body, 'green_earth')
|
||||||
self.fawn_soil.place_soil(self.fawn_soil_body, 'fawn_soil')
|
self.fawn_soil.place_soil(self.fawn_soil_body, 'fawn_soil')
|
||||||
self.fen_soil.place_soil(self.fen_soil_body, 'fen_soil')
|
self.fen_soil.place_soil(self.fen_soil_body, 'fen_soil')
|
||||||
|
|
||||||
#plants examples
|
# plants examples
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.dead_leaf_body, 'leaf')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.dead_leaf_body, 'leaf')
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.green_leaf_body, 'alive')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.green_leaf_body, 'alive')
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.stone_body, 'stone')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.stone_body, 'stone')
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.flower_body, 'flower')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.flower_body, 'flower')
|
||||||
|
|
||||||
#seeds
|
self.blocks.place_blocks(self.surface, self.cell_size, self.red_block, 'red')
|
||||||
|
|
||||||
|
# seeds
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_seed_body, 'fawn_seed')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_seed_body, 'fawn_seed')
|
||||||
|
|
||||||
#wheat
|
# wheat
|
||||||
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_wheat_body, 'fawn_wheat')
|
self.blocks.place_blocks(self.surface, self.cell_size, self.fawn_wheat_body, 'fawn_wheat')
|
||||||
|
|
||||||
self.tractor.draw()
|
self.tractor.draw()
|
||||||
|
42
neural_network/datasets.py
Normal file
42
neural_network/datasets.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torchvision
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), #validate that all images are 224x244
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.RandomVerticalFlip(p=0.5),
|
||||||
|
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
|
||||||
|
transforms.RandomRotation(degrees=(30, 70)), #random effects are applied to prevent overfitting
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
valid_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
train_dataset = torchvision.datasets.ImageFolder(root='./images/train', transform=train_transform)
|
||||||
|
|
||||||
|
validation_dataset = torchvision.datasets.ImageFolder(root='./images/validation', transform=valid_transform)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = DataLoader(
|
||||||
|
validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True
|
||||||
|
)
|
59
neural_network/inference.py
Normal file
59
neural_network/inference.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import argparse
|
||||||
|
from neural_network.model import CNNModel
|
||||||
|
# construct the argument parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-i', '--input',
|
||||||
|
default='',
|
||||||
|
help='path to the input image')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
def main(path):
|
||||||
|
# the computation device
|
||||||
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
# list containing all the class labels
|
||||||
|
labels = [
|
||||||
|
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
|
||||||
|
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
|
||||||
|
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
|
||||||
|
]
|
||||||
|
|
||||||
|
# initialize the model and load the trained weights
|
||||||
|
model = CNNModel().to(device)
|
||||||
|
checkpoint = torch.load('./neural_network/outputs/model.pth', map_location=device)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# define preprocess transforms
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.5, 0.5, 0.5],
|
||||||
|
std=[0.5, 0.5, 0.5]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# read and preprocess the image
|
||||||
|
image = cv2.imread(path)
|
||||||
|
# get the ground truth class
|
||||||
|
gt_class = path.split('/')[-2]
|
||||||
|
orig_image = image.copy()
|
||||||
|
# convert to RGB format
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
image = transform(image)
|
||||||
|
# add batch dimension
|
||||||
|
image = torch.unsqueeze(image, 0)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(image.to(device))
|
||||||
|
output_label = torch.topk(outputs, 1)
|
||||||
|
pred_class = labels[int(output_label.indices)]
|
||||||
|
|
||||||
|
return pred_class
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(args['input'])
|
24
neural_network/model.py
Normal file
24
neural_network/model.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class CNNModel(nn.Module): #model of the CNN type
|
||||||
|
def __init__(self):
|
||||||
|
super(CNNModel, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, 5)
|
||||||
|
self.conv2 = nn.Conv2d(32, 64, 5)
|
||||||
|
self.conv3 = nn.Conv2d(64, 128, 3)
|
||||||
|
self.conv4 = nn.Conv2d(128, 256, 5)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(256, 50)
|
||||||
|
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(F.relu(self.conv1(x)))
|
||||||
|
x = self.pool(F.relu(self.conv2(x)))
|
||||||
|
x = self.pool(F.relu(self.conv3(x)))
|
||||||
|
x = self.pool(F.relu(self.conv4(x)))
|
||||||
|
bs, _, _, _ = x.shape
|
||||||
|
x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
BIN
neural_network/outputs/accuracy.png
Normal file
BIN
neural_network/outputs/accuracy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
BIN
neural_network/outputs/loss.png
Normal file
BIN
neural_network/outputs/loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
BIN
neural_network/outputs/model.pth
Normal file
BIN
neural_network/outputs/model.pth
Normal file
Binary file not shown.
119
neural_network/train.py
Normal file
119
neural_network/train.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import time
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from neural_network.model import CNNModel
|
||||||
|
from neural_network.datasets import train_loader, valid_loader
|
||||||
|
from neural_network.utils import save_model, save_plots
|
||||||
|
|
||||||
|
# construct the argument parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-e', '--epochs', type=int, default=20,
|
||||||
|
help='number of epochs to train our network for')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
|
lr = 1e-3
|
||||||
|
epochs = args['epochs']
|
||||||
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f"Computation device: {device}\n")
|
||||||
|
|
||||||
|
model = CNNModel().to(device)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"{total_params:,} total parameters.")
|
||||||
|
total_trainable_params = sum(
|
||||||
|
p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f"{total_trainable_params:,} training parameters.")
|
||||||
|
# optimizer
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
# loss function
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
# training
|
||||||
|
def train(model, trainloader, optimizer, criterion):
|
||||||
|
model.train()
|
||||||
|
print('Training')
|
||||||
|
train_running_loss = 0.0
|
||||||
|
train_running_correct = 0
|
||||||
|
counter = 0
|
||||||
|
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
|
||||||
|
counter += 1
|
||||||
|
image, labels = data
|
||||||
|
image = image.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# forward pass
|
||||||
|
outputs = model(image)
|
||||||
|
# calculate the loss
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
train_running_loss += loss.item()
|
||||||
|
# calculate the accuracy
|
||||||
|
_, preds = torch.max(outputs.data, 1)
|
||||||
|
train_running_correct += (preds == labels).sum().item()
|
||||||
|
# backpropagation
|
||||||
|
loss.backward()
|
||||||
|
# update the optimizer parameters
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# loss and accuracy for the complete epoch
|
||||||
|
epoch_loss = train_running_loss / counter
|
||||||
|
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
|
||||||
|
# validation
|
||||||
|
def validate(model, testloader, criterion):
|
||||||
|
model.eval()
|
||||||
|
print('Validation')
|
||||||
|
valid_running_loss = 0.0
|
||||||
|
valid_running_correct = 0
|
||||||
|
counter = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
image, labels = data
|
||||||
|
image = image.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
# forward pass
|
||||||
|
outputs = model(image)
|
||||||
|
# calculate the loss
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
valid_running_loss += loss.item()
|
||||||
|
# calculate the accuracy
|
||||||
|
_, preds = torch.max(outputs.data, 1)
|
||||||
|
valid_running_correct += (preds == labels).sum().item()
|
||||||
|
|
||||||
|
# loss and accuracy for the complete epoch
|
||||||
|
epoch_loss = valid_running_loss / counter
|
||||||
|
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
|
||||||
|
return epoch_loss, epoch_acc
|
||||||
|
|
||||||
|
# lists to keep track of losses and accuracies
|
||||||
|
train_loss, valid_loss = [], []
|
||||||
|
train_acc, valid_acc = [], []
|
||||||
|
# start the training
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
|
||||||
|
train_epoch_loss, train_epoch_acc = train(model, train_loader,
|
||||||
|
optimizer, criterion)
|
||||||
|
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
|
||||||
|
criterion)
|
||||||
|
train_loss.append(train_epoch_loss)
|
||||||
|
valid_loss.append(valid_epoch_loss)
|
||||||
|
train_acc.append(train_epoch_acc)
|
||||||
|
valid_acc.append(valid_epoch_acc)
|
||||||
|
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
|
||||||
|
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
|
||||||
|
print('-'*50)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# save the trained model weights
|
||||||
|
save_model(epochs, model, optimizer, criterion)
|
||||||
|
# save the loss and accuracy plots
|
||||||
|
save_plots(train_acc, valid_acc, train_loss, valid_loss)
|
||||||
|
print('TRAINING COMPLETE')
|
49
neural_network/utils.py
Normal file
49
neural_network/utils.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
matplotlib.style.use('ggplot')
|
||||||
|
|
||||||
|
def save_model(epochs, model, optimizer, criterion):
|
||||||
|
"""
|
||||||
|
Function to save the trained model to disk.
|
||||||
|
"""
|
||||||
|
torch.save({
|
||||||
|
'epoch': epochs,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': criterion,
|
||||||
|
}, 'outputs/model.pth')
|
||||||
|
|
||||||
|
def save_plots(train_acc, valid_acc, train_loss, valid_loss):
|
||||||
|
"""
|
||||||
|
Function to save the loss and accuracy plots to disk.
|
||||||
|
"""
|
||||||
|
# accuracy plots
|
||||||
|
plt.figure(figsize=(10, 7))
|
||||||
|
plt.plot(
|
||||||
|
train_acc, color='green', linestyle='-',
|
||||||
|
label='train accuracy'
|
||||||
|
)
|
||||||
|
plt.plot(
|
||||||
|
valid_acc, color='blue', linestyle='-',
|
||||||
|
label='validataion accuracy'
|
||||||
|
)
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Accuracy')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('outputs/accuracy.png')
|
||||||
|
|
||||||
|
# loss plots
|
||||||
|
plt.figure(figsize=(10, 7))
|
||||||
|
plt.plot(
|
||||||
|
train_loss, color='orange', linestyle='-',
|
||||||
|
label='train loss'
|
||||||
|
)
|
||||||
|
plt.plot(
|
||||||
|
valid_loss, color='red', linestyle='-',
|
||||||
|
label='validataion loss'
|
||||||
|
)
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig('outputs/loss.png')
|
BIN
resources/redBush.png
Normal file
BIN
resources/redBush.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 KiB |
16
tractor.py
16
tractor.py
@ -31,22 +31,6 @@ class Tractor:
|
|||||||
|
|
||||||
|
|
||||||
def move(self, direction, cell_size, cell_number):
|
def move(self, direction, cell_size, cell_number):
|
||||||
# if direction == 'up':
|
|
||||||
# if self.y != 0:
|
|
||||||
# self.y -= cell_size
|
|
||||||
# self.image = self.up
|
|
||||||
# if direction == 'down':
|
|
||||||
# if self.y != (cell_number-1)*cell_size:
|
|
||||||
# self.y += cell_size
|
|
||||||
# self.image = self.down
|
|
||||||
# if direction == 'left':
|
|
||||||
# if self.x != 0:
|
|
||||||
# self.x -= cell_size
|
|
||||||
# self.image = self.left
|
|
||||||
# if direction == 'right':
|
|
||||||
# if self.x != (cell_number-1)*cell_size:
|
|
||||||
# self.x += cell_size
|
|
||||||
# self.image = self.right
|
|
||||||
if direction == 'move':
|
if direction == 'move':
|
||||||
if self.angle == 0 and self.y != 0:
|
if self.angle == 0 and self.y != 0:
|
||||||
self.y -= cell_size
|
self.y -= cell_size
|
||||||
|
Loading…
Reference in New Issue
Block a user