forked from s464965/WMICraft
work on what_is_it func;
This commit is contained in:
parent
36f20d8895
commit
6aee7bb207
@ -1,10 +1,9 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD, Adam, lr_scheduler
|
||||
from torch.optim import Adam
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
||||
from common.constants import BATCH_SIZE, LEARNING_RATE
|
||||
|
||||
|
||||
class NeuralNetwork(pl.LightningModule):
|
||||
|
@ -100,7 +100,7 @@ def what_is_it(img_path, show_img=False):
|
||||
plt.imshow(plt.imread(img_path))
|
||||
plt.show()
|
||||
image = SETUP_PHOTOS(image).unsqueeze(0)
|
||||
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_0/checkpoints/epoch=4-step=405.ckpt')
|
||||
model = NeuralNetwork.load_from_checkpoint('D:/DEV/UAM/WMICraft/algorithms/neural_network/lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
@ -108,18 +108,18 @@ def what_is_it(img_path, show_img=False):
|
||||
return ID_TO_CLASS[idx]
|
||||
|
||||
|
||||
CNN = NeuralNetwork()
|
||||
common.helpers.createCSV()
|
||||
# CNN = NeuralNetwork()
|
||||
# common.helpers.createCSV()
|
||||
|
||||
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
||||
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
||||
|
||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||
|
||||
trainer.fit(CNN, train_loader, test_loader)
|
||||
# trainer = pl.Trainer(accelerator='cpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
||||
#
|
||||
# trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
# testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||
# train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
# test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||
#
|
||||
# trainer.fit(CNN, train_loader, test_loader)
|
||||
#trainer.tune(CNN, train_loader, test_loader)
|
||||
#check_accuracy_tiles()
|
||||
#print(what_is_it('../../resources/textures/sand.png', True))
|
||||
|
@ -62,7 +62,7 @@ class Game:
|
||||
print_numbers_flag = not print_numbers_flag
|
||||
|
||||
# if event.type == NEXT_TURN: # is called every 'TURN_INTERVAL' milliseconds
|
||||
# level.handle_turn()
|
||||
# level.handle_turn()
|
||||
|
||||
stats.update()
|
||||
logs.draw()
|
||||
|
@ -33,7 +33,7 @@ class Level:
|
||||
|
||||
def create_map(self):
|
||||
print("Create map")
|
||||
print(what_is_it('../../resources/textures/grass1.png'))
|
||||
print(what_is_it('D:/DEV/UAM/WMICraft/resources/textures/t2.jpg'))
|
||||
# self.generate_map()
|
||||
# self.setup_base_tiles()
|
||||
# self.setup_objects()
|
||||
|
Loading…
Reference in New Issue
Block a user