forked from s464965/WMICraft
Compare commits
2 Commits
master
...
cnn_modifi
Author | SHA1 | Date | |
---|---|---|---|
|
6aee7bb207 | ||
|
36f20d8895 |
@ -1,11 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import SGD, Adam, lr_scheduler
|
from torch.optim import Adam
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader
|
from common.constants import BATCH_SIZE, LEARNING_RATE
|
||||||
from watersandtreegrass import WaterSandTreeGrass
|
|
||||||
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
|
||||||
|
|
||||||
|
|
||||||
class NeuralNetwork(pl.LightningModule):
|
class NeuralNetwork(pl.LightningModule):
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import common.helpers
|
import common.helpers
|
||||||
|
from algorithms.neural_network.neural_network import NeuralNetwork
|
||||||
|
from algorithms.neural_network.watersandtreegrass import WaterSandTreeGrass
|
||||||
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
||||||
from watersandtreegrass import WaterSandTreeGrass
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from neural_network import NeuralNetwork
|
|
||||||
from torchvision.io import read_image, ImageReadMode
|
from torchvision.io import read_image, ImageReadMode
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
@ -100,7 +100,7 @@ def what_is_it(img_path, show_img=False):
|
|||||||
plt.imshow(plt.imread(img_path))
|
plt.imshow(plt.imread(img_path))
|
||||||
plt.show()
|
plt.show()
|
||||||
image = SETUP_PHOTOS(image).unsqueeze(0)
|
image = SETUP_PHOTOS(image).unsqueeze(0)
|
||||||
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_13/checkpoints/epoch=4-step=405.ckpt')
|
model = NeuralNetwork.load_from_checkpoint('D:/DEV/UAM/WMICraft/algorithms/neural_network/lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -108,18 +108,18 @@ def what_is_it(img_path, show_img=False):
|
|||||||
return ID_TO_CLASS[idx]
|
return ID_TO_CLASS[idx]
|
||||||
|
|
||||||
|
|
||||||
CNN = NeuralNetwork()
|
# CNN = NeuralNetwork()
|
||||||
common.helpers.createCSV()
|
# common.helpers.createCSV()
|
||||||
|
|
||||||
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
||||||
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
# trainer = pl.Trainer(accelerator='cpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
||||||
|
#
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
# trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
# testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
# train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
# test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||||
|
#
|
||||||
trainer.fit(CNN, train_loader, test_loader)
|
# trainer.fit(CNN, train_loader, test_loader)
|
||||||
#trainer.tune(CNN, train_loader, test_loader)
|
#trainer.tune(CNN, train_loader, test_loader)
|
||||||
#check_accuracy_tiles()
|
#check_accuracy_tiles()
|
||||||
#print(what_is_it('../../resources/textures/sand.png', True))
|
#print(what_is_it('../../resources/textures/sand.png', True))
|
||||||
|
@ -6,7 +6,7 @@ GAME_TITLE = 'WMICraft'
|
|||||||
WINDOW_HEIGHT = 800
|
WINDOW_HEIGHT = 800
|
||||||
WINDOW_WIDTH = 1360
|
WINDOW_WIDTH = 1360
|
||||||
FPS_COUNT = 60
|
FPS_COUNT = 60
|
||||||
TURN_INTERVAL = 300
|
TURN_INTERVAL = 500
|
||||||
|
|
||||||
GRID_CELL_PADDING = 5
|
GRID_CELL_PADDING = 5
|
||||||
GRID_CELL_SIZE = 36
|
GRID_CELL_SIZE = 36
|
||||||
|
@ -61,8 +61,8 @@ class Game:
|
|||||||
if event.key == pygame.K_n:
|
if event.key == pygame.K_n:
|
||||||
print_numbers_flag = not print_numbers_flag
|
print_numbers_flag = not print_numbers_flag
|
||||||
|
|
||||||
if event.type == NEXT_TURN: # is called every 'TURN_INTERVAL' milliseconds
|
# if event.type == NEXT_TURN: # is called every 'TURN_INTERVAL' milliseconds
|
||||||
level.handle_turn()
|
# level.handle_turn()
|
||||||
|
|
||||||
stats.update()
|
stats.update()
|
||||||
logs.draw()
|
logs.draw()
|
||||||
|
@ -3,6 +3,7 @@ import random
|
|||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
from algorithms.a_star import a_star, State, TURN_RIGHT, TURN_LEFT, FORWARD
|
from algorithms.a_star import a_star, State, TURN_RIGHT, TURN_LEFT, FORWARD
|
||||||
|
from algorithms.neural_network.neural_network_interface import what_is_it
|
||||||
from common.constants import *
|
from common.constants import *
|
||||||
from learning.decision_tree import DecisionTree
|
from learning.decision_tree import DecisionTree
|
||||||
from logic.knights_queue import KnightsQueue
|
from logic.knights_queue import KnightsQueue
|
||||||
@ -31,10 +32,12 @@ class Level:
|
|||||||
self.knights_queue = None
|
self.knights_queue = None
|
||||||
|
|
||||||
def create_map(self):
|
def create_map(self):
|
||||||
self.generate_map()
|
print("Create map")
|
||||||
self.setup_base_tiles()
|
print(what_is_it('D:/DEV/UAM/WMICraft/resources/textures/t2.jpg'))
|
||||||
self.setup_objects()
|
# self.generate_map()
|
||||||
self.knights_queue = KnightsQueue(self.list_knights_blue, self.list_knights_red)
|
# self.setup_base_tiles()
|
||||||
|
# self.setup_objects()
|
||||||
|
# self.knights_queue = KnightsQueue(self.list_knights_blue, self.list_knights_red)
|
||||||
|
|
||||||
def generate_map(self):
|
def generate_map(self):
|
||||||
spawner = Spawner(self.map)
|
spawner = Spawner(self.map)
|
||||||
|
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Loading…
Reference in New Issue
Block a user