cnn #28

Merged
s464869 merged 10 commits from cnn into master 2022-05-25 19:57:09 +02:00
31 changed files with 284 additions and 21 deletions

View File

@ -51,7 +51,7 @@ def graphsearch(initial_state: State, map, goal_list, fringe: List[Node] = None,
explored_states = set() explored_states = set()
fringe_states = set() fringe_states = set()
# root Node # train Node
fringe.append(Node(initial_state)) fringe.append(Node(initial_state))
fringe_states.add((initial_state.row, initial_state.column, initial_state.direction)) fringe_states.add((initial_state.row, initial_state.column, initial_state.direction))
@ -71,7 +71,7 @@ def graphsearch(initial_state: State, map, goal_list, fringe: List[Node] = None,
parent = element.parent parent = element.parent
while parent is not None: while parent is not None:
# root's action will be None, don't add it # train's action will be None, don't add it
if parent.action is not None: if parent.action is not None:
actions_sequence.append(parent.action) actions_sequence.append(parent.action)
parent = parent.parent parent = parent.parent

Binary file not shown.

After

Width:  |  Height:  |  Size: 814 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 820 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 789 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 760 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 725 B

View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1,48 @@
import torch
import pytorch_lightning as pl
import torch.nn as nn
from torch.optim import SGD, Adam, lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader
from watersandtreegrass import WaterSandTreeGrass
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
class NeuralNetwork(pl.LightningModule):
def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(36*36*3, 300),
nn.ReLU(),
nn.Linear(300, 4),
nn.LogSoftmax(dim=-1)
)
self.batch_size = batch_size
self.learning_rate = learning_rate
def forward(self, x):
x = x.reshape(x.shape[0], -1)
x = self.layer(x)
return x
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
loss = F.nll_loss(scores, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
val_loss = F.nll_loss(scores, y)
self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)
def test_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
test_loss = F.nll_loss(scores, y)
self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)

View File

@ -0,0 +1,125 @@
import torch
import common.helpers
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
from watersandtreegrass import WaterSandTreeGrass
from torch.utils.data import DataLoader
from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
def train(model):
model = model.to(DEVICE)
model.train()
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
for epoch in range(NUM_EPOCHS):
for batch_idx, (data, targets) in enumerate(train_loader):
data = data.to(device=DEVICE)
targets = targets.to(device=DEVICE)
scores = model(data)
loss = criterion(scores, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 4 == 0:
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
print("FINISHED TRAINING!")
torch.save(model.state_dict(), "./learnednetwork.pth")
print("Checking accuracy for the train set.")
check_accuracy(train_loader)
print("Checking accuracy for the test set.")
check_accuracy(test_loader)
print("Checking accuracy for the tiles.")
check_accuracy_tiles()
def check_accuracy_tiles():
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
answer = answer + 1
print("Accuracy(%) grass_with_tree.jpg", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass2.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass2.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass3.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass3.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass4.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass4.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass1.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass1.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/water.png') == 'water':
answer = answer + 1
print("Accuracy(%) water.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/sand.png') == 'sand':
answer = answer + 1
print("Accuracy(%) sand.png", answer)
def what_is_it(img_path, show_img=False):
image = read_image(img_path, mode=ImageReadMode.RGB)
if show_img:
plt.imshow(plt.imread(img_path))
plt.show()
image = SETUP_PHOTOS(image).unsqueeze(0)
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
with torch.no_grad():
model.eval()
idx = int(model(image).argmax(dim=1))
return ID_TO_CLASS[idx]
CNN = NeuralNetwork()
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_scale_batch_size=True, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
#trainer.fit(CNN, train_loader, test_loader)
#trainer.tune(CNN, train_loader, test_loader)
check_accuracy_tiles()
print(what_is_it('../../resources/textures/sand.png', True))

View File

@ -0,0 +1,25 @@
import torch
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_image, ImageReadMode
from common.helpers import createCSV
class WaterSandTreeGrass(Dataset):
def __init__(self, annotations_file, transform=None):
createCSV()
self.img_labels = pd.read_csv(annotations_file)
self.transform = transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
if self.transform:
image = self.transform(image)
return image, label

View File

@ -1,4 +1,6 @@
from enum import Enum from enum import Enum
import torchvision.transforms as transforms
import torch
GAME_TITLE = 'WMICraft' GAME_TITLE = 'WMICraft'
WINDOW_HEIGHT = 800 WINDOW_HEIGHT = 800
@ -63,12 +65,34 @@ ACTION = {
"go": 0, "go": 0,
} }
LEFT = 'LEFT'
RIGHT = 'RIGHT'
UP = 'UP'
DOWN = 'DOWN'
# HEALTH_BAR # HEALTH_BAR
BAR_ANIMATION_SPEED = 1 BAR_ANIMATION_SPEED = 1
BAR_WIDTH_MULTIPLIER = 0.9 # (0;1> BAR_WIDTH_MULTIPLIER = 0.9 # (0;1>
BAR_HEIGHT_MULTIPLIER = 0.1 BAR_HEIGHT_MULTIPLIER = 0.1
LEFT = 'LEFT'
RIGHT = 'RIGHT' #NEURAL_NETWORK
UP = 'UP' LEARNING_RATE = 0.13182567385564073
DOWN = 'DOWN' BATCH_SIZE = 64
NUM_EPOCHS = 50
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using ", DEVICE)
CLASSES = ['grass', 'sand', 'tree', 'water']
SETUP_PHOTOS = transforms.Compose([
transforms.Resize(36),
transforms.CenterCrop(36),
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
ID_TO_CLASS = {i: j for i, j in enumerate(CLASSES)}
CLASS_TO_ID = {value: key for key, value in ID_TO_CLASS.items()}

View File

@ -1,6 +1,9 @@
from typing import Tuple, List from typing import Tuple, List
import pygame import pygame
from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE, COLUMNS, ROWS, CLASSES, CLASS_TO_ID
import csv
import os
from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE
from common.constants import ROWS, COLUMNS, LEFT, RIGHT, UP, DOWN from common.constants import ROWS, COLUMNS, LEFT, RIGHT, UP, DOWN
@ -24,6 +27,44 @@ def draw_text(text, color, surface, x, y, text_size=30, is_bold=False):
surface.blit(textobj, textrect) surface.blit(textobj, textrect)
def createCSV():
train_data_path = './data/train'
test_data_path = './data/test'
if os.path.exists(train_data_path):
train_csvfile = open('./data/train_csv_file.csv', 'w', newline="")
writer = csv.writer(train_csvfile)
writer.writerow(["filepath", "type"])
for class_name in CLASSES:
class_dir = train_data_path + "/" + class_name
for filename in os.listdir(class_dir):
f = os.path.join(class_dir, filename)
if os.path.isfile(f):
writer.writerow([f, CLASS_TO_ID[class_name]])
train_csvfile.close()
else:
print("Brak plików do uczenia")
if os.path.exists(test_data_path):
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
writer = csv.writer(test_csvfile)
writer.writerow(["filepath", "type"])
for class_name in CLASSES:
class_dir = test_data_path + "/" + class_name
for filename in os.listdir(class_dir):
f = os.path.join(class_dir, filename)
if os.path.isfile(f):
writer.writerow([f, CLASS_TO_ID[class_name]])
test_csvfile.close()
else:
print("Brak plików do testowania")
def print_numbers(): def print_numbers():
display_surface = pygame.display.get_surface() display_surface = pygame.display.get_surface()
font = pygame.font.SysFont('Arial', 16) font = pygame.font.SysFont('Arial', 16)

View File

@ -46,7 +46,7 @@ class HealthBar:
def heal(self, amount): def heal(self, amount):
if self.current_hp + amount < self.max_hp: if self.current_hp + amount < self.max_hp:
self.current_hp += amount self.current_hp += amount
elif self.current_hp + amount > self.max_hp: elif self.current_hp + amount >= self.max_hp:
self.current_hp = self.max_hp self.current_hp = self.max_hp
def show(self): def show(self):

View File

@ -155,19 +155,6 @@ class Level:
self.logs.enqueue_log(f'AI {current_knight.team}: Ruch w lewo.') self.logs.enqueue_log(f'AI {current_knight.team}: Ruch w lewo.')
self.map[knight_pos_y][knight_pos_x - 1] = current_knight.team_alias() self.map[knight_pos_y][knight_pos_x - 1] = current_knight.team_alias()
def update_health_bars(self):
for knight in self.list_knights_blue:
knight.health_bar.update()
for knight in self.list_knights_red:
knight.health_bar.update()
for monster in self.list_monsters:
monster.health_bar.update()
for castle in self.list_castles:
castle.health_bar.update()
def update(self): def update(self):
bg_width = (GRID_CELL_PADDING + GRID_CELL_SIZE) * COLUMNS + BORDER_WIDTH bg_width = (GRID_CELL_PADDING + GRID_CELL_SIZE) * COLUMNS + BORDER_WIDTH
bg_height = (GRID_CELL_PADDING + GRID_CELL_SIZE) * ROWS + BORDER_WIDTH bg_height = (GRID_CELL_PADDING + GRID_CELL_SIZE) * ROWS + BORDER_WIDTH
@ -175,4 +162,4 @@ class Level:
# update and draw the game # update and draw the game
self.sprites.draw(self.screen) self.sprites.draw(self.screen)
self.update_health_bars() # has to be called last self.sprites.update()

View File

@ -18,3 +18,6 @@ class Castle(pygame.sprite.Sprite):
self.max_hp = 80 self.max_hp = 80
self.current_hp = random.randint(1, self.max_hp) self.current_hp = random.randint(1, self.max_hp)
self.health_bar = HealthBar(screen, self.rect, current_hp=self.current_hp, max_hp=self.max_hp, calculate_xy=True, calculate_size=True) self.health_bar = HealthBar(screen, self.rect, current_hp=self.current_hp, max_hp=self.max_hp, calculate_xy=True, calculate_size=True)
def update(self):
self.health_bar.update()

View File

@ -43,6 +43,9 @@ class Knight(pygame.sprite.Sprite):
self.direction = self.direction.left() self.direction = self.direction.left()
self.image = self.states[self.direction.value] self.image = self.states[self.direction.value]
def update(self):
self.health_bar.update()
def rotate_right(self): def rotate_right(self):
self.direction = self.direction.right() self.direction = self.direction.right()
self.image = self.states[self.direction.value] self.image = self.states[self.direction.value]

View File

@ -43,3 +43,6 @@ class Monster(pygame.sprite.Sprite):
self.max_hp = 7 self.max_hp = 7
self.attack = 2 self.attack = 2
self.points = 2 self.points = 2
def update(self):
self.health_bar.update()

Binary file not shown.