WMICraft/algorithms/neural_network/neural_network_interface.py

126 lines
4.2 KiB
Python
Raw Normal View History

import torch
2022-05-25 19:47:08 +02:00
import common.helpers
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
from watersandtreegrass import WaterSandTreeGrass
from torch.utils.data import DataLoader
from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode
import torch.nn as nn
from torch.optim import Adam
2022-05-18 12:18:59 +02:00
import matplotlib.pyplot as plt
2022-05-25 19:47:08 +02:00
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
2022-05-27 01:38:20 +02:00
import torchvision.transforms.functional as F
from PIL import Image
2022-05-31 09:25:36 +02:00
2022-05-25 19:47:08 +02:00
def check_accuracy_tiles():
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
answer = answer + 1
print("Accuracy(%) grass_with_tree.jpg", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass2.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass2.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass3.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass3.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass4.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass4.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/grass1.png') == 'grass':
answer = answer + 1
print("Accuracy(%) grass1.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/water.png') == 'water':
answer = answer + 1
print("Accuracy(%) water.png", answer)
answer = 0
for i in range(100):
if what_is_it('../../resources/textures/sand.png') == 'sand':
answer = answer + 1
print("Accuracy(%) sand.png", answer)
2022-05-18 12:18:59 +02:00
def what_is_it(img_path, show_img=False):
2022-05-27 01:38:20 +02:00
image = Image.open(img_path).convert('RGB')
2022-05-18 12:18:59 +02:00
if show_img:
2022-05-27 01:38:20 +02:00
plt.imshow(image)
2022-05-18 12:18:59 +02:00
plt.show()
2022-05-27 01:38:20 +02:00
2022-05-25 19:47:08 +02:00
image = SETUP_PHOTOS(image).unsqueeze(0)
2022-05-27 01:38:20 +02:00
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_20/checkpoints/epoch=3-step=324.ckpt')
with torch.no_grad():
model.eval()
idx = int(model(image).argmax(dim=1))
2022-05-25 19:47:08 +02:00
return ID_TO_CLASS[idx]
2022-05-31 09:25:36 +02:00
def check_accuracy(tset):
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_23/checkpoints/epoch=3-step=324.ckpt')
num_correct = 0
num_samples = 0
model = model.to(DEVICE)
model.eval()
with torch.no_grad():
for photo, label in tset:
photo = photo.to(DEVICE)
label = label.to(DEVICE)
scores = model(photo)
predictions = scores.argmax(dim=1)
num_correct += (predictions == label).sum()
num_samples += predictions.size(0)
print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')
def check_accuracy_data():
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
print("Accuracy of train_set:")
check_accuracy(train_loader)
print("Accuracy of test_set:")
check_accuracy(test_loader)
2022-05-27 01:38:20 +02:00
#CNN = NeuralNetwork()
#common.helpers.createCSV()
2022-05-25 19:47:08 +02:00
2022-05-27 01:38:20 +02:00
#trainer = pl.Trainer(accelerator='gpu', callbacks=EarlyStopping('val_loss'), devices=1, max_epochs=NUM_EPOCHS)
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
2022-05-27 01:38:20 +02:00
#trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
#trainer.fit(CNN, train_loader, test_loader)
2022-05-25 19:47:08 +02:00
#trainer.tune(CNN, train_loader, test_loader)
2022-05-31 09:25:36 +02:00
2022-05-27 01:38:20 +02:00
#print(what_is_it('../../resources/textures/grass2.png', True))
2022-05-31 09:25:36 +02:00
#check_accuracy_data()
#check_accuracy_tiles()