forked from s464965/WMICraft
126 lines
4.2 KiB
Python
126 lines
4.2 KiB
Python
import torch
|
|
import common.helpers
|
|
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
|
from watersandtreegrass import WaterSandTreeGrass
|
|
from torch.utils.data import DataLoader
|
|
from neural_network import NeuralNetwork
|
|
from torchvision.io import read_image, ImageReadMode
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
import matplotlib.pyplot as plt
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
import torchvision.transforms.functional as F
|
|
from PIL import Image
|
|
|
|
|
|
def check_accuracy_tiles():
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass_with_tree.jpg", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass2.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass2.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass3.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass3.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass4.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass4.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass1.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass1.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/water.png') == 'water':
|
|
answer = answer + 1
|
|
print("Accuracy(%) water.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/sand.png') == 'sand':
|
|
answer = answer + 1
|
|
print("Accuracy(%) sand.png", answer)
|
|
|
|
|
|
def what_is_it(img_path, show_img=False):
|
|
image = Image.open(img_path).convert('RGB')
|
|
if show_img:
|
|
plt.imshow(image)
|
|
plt.show()
|
|
|
|
image = SETUP_PHOTOS(image).unsqueeze(0)
|
|
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_20/checkpoints/epoch=3-step=324.ckpt')
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
idx = int(model(image).argmax(dim=1))
|
|
return ID_TO_CLASS[idx]
|
|
|
|
|
|
def check_accuracy(tset):
|
|
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_23/checkpoints/epoch=3-step=324.ckpt')
|
|
num_correct = 0
|
|
num_samples = 0
|
|
model = model.to(DEVICE)
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for photo, label in tset:
|
|
photo = photo.to(DEVICE)
|
|
label = label.to(DEVICE)
|
|
|
|
scores = model(photo)
|
|
predictions = scores.argmax(dim=1)
|
|
num_correct += (predictions == label).sum()
|
|
num_samples += predictions.size(0)
|
|
|
|
print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')
|
|
|
|
|
|
def check_accuracy_data():
|
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
|
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
|
|
|
print("Accuracy of train_set:")
|
|
check_accuracy(train_loader)
|
|
print("Accuracy of test_set:")
|
|
check_accuracy(test_loader)
|
|
|
|
#CNN = NeuralNetwork()
|
|
#common.helpers.createCSV()
|
|
|
|
#trainer = pl.Trainer(accelerator='gpu', callbacks=EarlyStopping('val_loss'), devices=1, max_epochs=NUM_EPOCHS)
|
|
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
|
|
|
#trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
|
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
|
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
|
#trainer.fit(CNN, train_loader, test_loader)
|
|
#trainer.tune(CNN, train_loader, test_loader)
|
|
|
|
|
|
#print(what_is_it('../../resources/textures/grass2.png', True))
|
|
|
|
#check_accuracy_data()
|
|
|
|
#check_accuracy_tiles()
|