forked from s464965/WMICraft
126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
import torch
|
|
import common.helpers
|
|
from algorithms.neural_network.neural_network import NeuralNetwork
|
|
from algorithms.neural_network.watersandtreegrass import WaterSandTreeGrass
|
|
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.io import read_image, ImageReadMode
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
import matplotlib.pyplot as plt
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
|
|
def train(model):
|
|
model = model.to(DEVICE)
|
|
model.train()
|
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
|
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
for epoch in range(NUM_EPOCHS):
|
|
for batch_idx, (data, targets) in enumerate(train_loader):
|
|
data = data.to(device=DEVICE)
|
|
targets = targets.to(device=DEVICE)
|
|
|
|
scores = model(data)
|
|
loss = criterion(scores, targets)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
if batch_idx % 4 == 0:
|
|
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
|
|
|
print("FINISHED TRAINING!")
|
|
torch.save(model.state_dict(), "./learnednetwork.pth")
|
|
|
|
print("Checking accuracy for the train set.")
|
|
check_accuracy(train_loader)
|
|
print("Checking accuracy for the test set.")
|
|
check_accuracy(test_loader)
|
|
print("Checking accuracy for the tiles.")
|
|
check_accuracy_tiles()
|
|
|
|
|
|
def check_accuracy_tiles():
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass_with_tree.jpg", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass2.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass2.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass3.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass3.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass4.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass4.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/grass1.png') == 'grass':
|
|
answer = answer + 1
|
|
print("Accuracy(%) grass1.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/water.png') == 'water':
|
|
answer = answer + 1
|
|
print("Accuracy(%) water.png", answer)
|
|
|
|
answer = 0
|
|
for i in range(100):
|
|
if what_is_it('../../resources/textures/sand.png') == 'sand':
|
|
answer = answer + 1
|
|
print("Accuracy(%) sand.png", answer)
|
|
|
|
|
|
def what_is_it(img_path, show_img=False):
|
|
image = read_image(img_path, mode=ImageReadMode.RGB)
|
|
if show_img:
|
|
plt.imshow(plt.imread(img_path))
|
|
plt.show()
|
|
image = SETUP_PHOTOS(image).unsqueeze(0)
|
|
model = NeuralNetwork.load_from_checkpoint('D:/DEV/UAM/WMICraft/algorithms/neural_network/lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
idx = int(model(image).argmax(dim=1))
|
|
return ID_TO_CLASS[idx]
|
|
|
|
|
|
# CNN = NeuralNetwork()
|
|
# common.helpers.createCSV()
|
|
|
|
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
|
# trainer = pl.Trainer(accelerator='cpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
|
#
|
|
# trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
|
# testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
|
# train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
|
# test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
|
#
|
|
# trainer.fit(CNN, train_loader, test_loader)
|
|
#trainer.tune(CNN, train_loader, test_loader)
|
|
#check_accuracy_tiles()
|
|
#print(what_is_it('../../resources/textures/sand.png', True))
|