Compare commits

...

2 Commits

Author SHA1 Message Date
Angelika Iskra
6aee7bb207 work on what_is_it func; 2022-05-27 00:15:32 +02:00
Angelika Iskra
36f20d8895 work on what_is_it func; 2022-05-26 23:34:46 +02:00
6 changed files with 25 additions and 24 deletions

View File

@ -1,11 +1,9 @@
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
from torch.optim import SGD, Adam, lr_scheduler from torch.optim import Adam
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader from common.constants import BATCH_SIZE, LEARNING_RATE
from watersandtreegrass import WaterSandTreeGrass
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
class NeuralNetwork(pl.LightningModule): class NeuralNetwork(pl.LightningModule):

View File

@ -1,9 +1,9 @@
import torch import torch
import common.helpers import common.helpers
from algorithms.neural_network.neural_network import NeuralNetwork
from algorithms.neural_network.watersandtreegrass import WaterSandTreeGrass
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
from watersandtreegrass import WaterSandTreeGrass
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode from torchvision.io import read_image, ImageReadMode
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam from torch.optim import Adam
@ -100,7 +100,7 @@ def what_is_it(img_path, show_img=False):
plt.imshow(plt.imread(img_path)) plt.imshow(plt.imread(img_path))
plt.show() plt.show()
image = SETUP_PHOTOS(image).unsqueeze(0) image = SETUP_PHOTOS(image).unsqueeze(0)
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_13/checkpoints/epoch=4-step=405.ckpt') model = NeuralNetwork.load_from_checkpoint('D:/DEV/UAM/WMICraft/algorithms/neural_network/lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
@ -108,18 +108,18 @@ def what_is_it(img_path, show_img=False):
return ID_TO_CLASS[idx] return ID_TO_CLASS[idx]
CNN = NeuralNetwork() # CNN = NeuralNetwork()
common.helpers.createCSV() # common.helpers.createCSV()
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS) #trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS) # trainer = pl.Trainer(accelerator='cpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
#
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS) # trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS) # testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) # train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE) # test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
#
trainer.fit(CNN, train_loader, test_loader) # trainer.fit(CNN, train_loader, test_loader)
#trainer.tune(CNN, train_loader, test_loader) #trainer.tune(CNN, train_loader, test_loader)
#check_accuracy_tiles() #check_accuracy_tiles()
#print(what_is_it('../../resources/textures/sand.png', True)) #print(what_is_it('../../resources/textures/sand.png', True))

View File

@ -6,7 +6,7 @@ GAME_TITLE = 'WMICraft'
WINDOW_HEIGHT = 800 WINDOW_HEIGHT = 800
WINDOW_WIDTH = 1360 WINDOW_WIDTH = 1360
FPS_COUNT = 60 FPS_COUNT = 60
TURN_INTERVAL = 300 TURN_INTERVAL = 500
GRID_CELL_PADDING = 5 GRID_CELL_PADDING = 5
GRID_CELL_SIZE = 36 GRID_CELL_SIZE = 36

View File

@ -61,8 +61,8 @@ class Game:
if event.key == pygame.K_n: if event.key == pygame.K_n:
print_numbers_flag = not print_numbers_flag print_numbers_flag = not print_numbers_flag
if event.type == NEXT_TURN: # is called every 'TURN_INTERVAL' milliseconds # if event.type == NEXT_TURN: # is called every 'TURN_INTERVAL' milliseconds
level.handle_turn() # level.handle_turn()
stats.update() stats.update()
logs.draw() logs.draw()

View File

@ -3,6 +3,7 @@ import random
import pygame import pygame
from algorithms.a_star import a_star, State, TURN_RIGHT, TURN_LEFT, FORWARD from algorithms.a_star import a_star, State, TURN_RIGHT, TURN_LEFT, FORWARD
from algorithms.neural_network.neural_network_interface import what_is_it
from common.constants import * from common.constants import *
from learning.decision_tree import DecisionTree from learning.decision_tree import DecisionTree
from logic.knights_queue import KnightsQueue from logic.knights_queue import KnightsQueue
@ -31,10 +32,12 @@ class Level:
self.knights_queue = None self.knights_queue = None
def create_map(self): def create_map(self):
self.generate_map() print("Create map")
self.setup_base_tiles() print(what_is_it('D:/DEV/UAM/WMICraft/resources/textures/t2.jpg'))
self.setup_objects() # self.generate_map()
self.knights_queue = KnightsQueue(self.list_knights_blue, self.list_knights_red) # self.setup_base_tiles()
# self.setup_objects()
# self.knights_queue = KnightsQueue(self.list_knights_blue, self.list_knights_red)
def generate_map(self): def generate_map(self):
spawner = Spawner(self.map) spawner = Spawner(self.map)

Binary file not shown.