forked from s464965/WMICraft
pytorch lighning addition
This commit is contained in:
parent
4e3e68d4c3
commit
088e90ec5b
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
@ -1,22 +1,48 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.optim import SGD, Adam, lr_scheduler
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from watersandtreegrass import WaterSandTreeGrass
|
||||||
|
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
||||||
|
|
||||||
|
|
||||||
class NeuralNetwork(nn.Module):
|
class NeuralNetwork(pl.LightningModule):
|
||||||
def __init__(self, num_classes=4):
|
def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4):
|
||||||
super(NeuralNetwork, self).__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
self.layer = nn.Sequential(
|
||||||
self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
|
nn.Linear(36*36*3, 300),
|
||||||
self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
nn.ReLU(),
|
||||||
self.fc1 = nn.Linear(20*9*9, num_classes)
|
nn.Linear(300, 4),
|
||||||
|
nn.LogSoftmax(dim=-1)
|
||||||
|
)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.relu(self.conv1(x))
|
|
||||||
x = self.pool(x)
|
|
||||||
x = F.relu(self.conv2(x))
|
|
||||||
x = self.pool(x)
|
|
||||||
x = x.reshape(x.shape[0], -1)
|
x = x.reshape(x.shape[0], -1)
|
||||||
x = self.fc1(x)
|
x = self.layer(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = SGD(self.parameters(), lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
scores = self(x)
|
||||||
|
loss = F.nll_loss(scores, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
scores = self(x)
|
||||||
|
val_loss = F.nll_loss(scores, y)
|
||||||
|
self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
scores = self(x)
|
||||||
|
test_loss = F.nll_loss(scores, y)
|
||||||
|
self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from common.constants import device, batch_size, num_epochs, learning_rate, setup_photos, id_to_class
|
import common.helpers
|
||||||
|
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
|
||||||
from watersandtreegrass import WaterSandTreeGrass
|
from watersandtreegrass import WaterSandTreeGrass
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from neural_network import NeuralNetwork
|
from neural_network import NeuralNetwork
|
||||||
@ -7,24 +8,25 @@ from torchvision.io import read_image, ImageReadMode
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import pytorch_lightning as pl
|
||||||
CNN = NeuralNetwork().to(device)
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
|
|
||||||
|
|
||||||
def train(model):
|
def train(model):
|
||||||
|
model = model.to(DEVICE)
|
||||||
model.train()
|
model.train()
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=setup_photos)
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=setup_photos)
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
|
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(NUM_EPOCHS):
|
||||||
for batch_idx, (data, targets) in enumerate(train_loader):
|
for batch_idx, (data, targets) in enumerate(train_loader):
|
||||||
data = data.to(device=device)
|
data = data.to(device=DEVICE)
|
||||||
targets = targets.to(device=device)
|
targets = targets.to(device=DEVICE)
|
||||||
|
|
||||||
scores = model(data)
|
scores = model(data)
|
||||||
loss = criterion(scores, targets)
|
loss = criterion(scores, targets)
|
||||||
@ -34,39 +36,62 @@ def train(model):
|
|||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if epoch % 2 == 0:
|
if batch_idx % 4 == 0:
|
||||||
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
print("epoch: %d loss: %.4f" % (epoch, loss.item()))
|
||||||
|
|
||||||
print("FINISHED TRAINING!")
|
print("FINISHED TRAINING!")
|
||||||
|
torch.save(model.state_dict(), "./learnednetwork.pth")
|
||||||
|
|
||||||
print("Checking accuracy for the train set.")
|
print("Checking accuracy for the train set.")
|
||||||
check_accuracy(train_loader)
|
check_accuracy(train_loader)
|
||||||
print("Checking accuracy for the test set.")
|
print("Checking accuracy for the test set.")
|
||||||
check_accuracy(test_loader)
|
check_accuracy(test_loader)
|
||||||
|
print("Checking accuracy for the tiles.")
|
||||||
torch.save(model.state_dict(), "./learnedNetwork.pt")
|
check_accuracy_tiles()
|
||||||
|
|
||||||
|
|
||||||
def check_accuracy(loader):
|
def check_accuracy_tiles():
|
||||||
num_correct = 0
|
answer = 0
|
||||||
num_samples = 0
|
for i in range(100):
|
||||||
model = NeuralNetwork()
|
if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) grass_with_tree.jpg", answer)
|
||||||
|
|
||||||
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
answer = 0
|
||||||
model = model.to(device)
|
for i in range(100):
|
||||||
|
if what_is_it('../../resources/textures/grass2.png') == 'grass':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) grass2.png", answer)
|
||||||
|
|
||||||
with torch.no_grad():
|
answer = 0
|
||||||
model.eval()
|
for i in range(100):
|
||||||
for x, y in loader:
|
if what_is_it('../../resources/textures/grass3.png') == 'grass':
|
||||||
x = x.to(device=device)
|
answer = answer + 1
|
||||||
y = y.to(device=device)
|
print("Accuracy(%) grass3.png", answer)
|
||||||
|
|
||||||
scores = model(x)
|
answer = 0
|
||||||
|
for i in range(100):
|
||||||
|
if what_is_it('../../resources/textures/grass4.png') == 'grass':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) grass4.png", answer)
|
||||||
|
|
||||||
_, predictions = scores.max(1)
|
answer = 0
|
||||||
num_correct += (predictions == y).sum()
|
for i in range(100):
|
||||||
num_samples += predictions.size(0)
|
if what_is_it('../../resources/textures/grass1.png') == 'grass':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) grass1.png", answer)
|
||||||
|
|
||||||
print(f"Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")
|
answer = 0
|
||||||
|
for i in range(100):
|
||||||
|
if what_is_it('../../resources/textures/water.png') == 'water':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) water.png", answer)
|
||||||
|
|
||||||
|
answer = 0
|
||||||
|
for i in range(100):
|
||||||
|
if what_is_it('../../resources/textures/sand.png') == 'sand':
|
||||||
|
answer = answer + 1
|
||||||
|
print("Accuracy(%) sand.png", answer)
|
||||||
|
|
||||||
|
|
||||||
def what_is_it(img_path, show_img=False):
|
def what_is_it(img_path, show_img=False):
|
||||||
@ -74,17 +99,27 @@ def what_is_it(img_path, show_img=False):
|
|||||||
if show_img:
|
if show_img:
|
||||||
plt.imshow(plt.imread(img_path))
|
plt.imshow(plt.imread(img_path))
|
||||||
plt.show()
|
plt.show()
|
||||||
image = setup_photos(image).unsqueeze(0)
|
image = SETUP_PHOTOS(image).unsqueeze(0)
|
||||||
model = NeuralNetwork()
|
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')
|
||||||
|
|
||||||
model.load_state_dict(torch.load("./learnedNetwork.pt", map_location=device))
|
|
||||||
model = model.to(device)
|
|
||||||
image = image.to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
idx = int(model(image).argmax(dim=1))
|
idx = int(model(image).argmax(dim=1))
|
||||||
return id_to_class[idx]
|
return ID_TO_CLASS[idx]
|
||||||
|
|
||||||
|
|
||||||
train(CNN)
|
CNN = NeuralNetwork()
|
||||||
|
|
||||||
|
|
||||||
|
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_scale_batch_size=True, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
||||||
|
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
||||||
|
|
||||||
|
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
|
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
|
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
|
#trainer.fit(CNN, train_loader, test_loader)
|
||||||
|
#trainer.tune(CNN, train_loader, test_loader)
|
||||||
|
check_accuracy_tiles()
|
||||||
|
print(what_is_it('../../resources/textures/sand.png', True))
|
||||||
|
@ -72,14 +72,15 @@ BAR_HEIGHT_MULTIPLIER = 0.1
|
|||||||
|
|
||||||
|
|
||||||
#NEURAL_NETWORK
|
#NEURAL_NETWORK
|
||||||
learning_rate = 0.001
|
LEARNING_RATE = 0.13182567385564073
|
||||||
batch_size = 7
|
BATCH_SIZE = 64
|
||||||
num_epochs = 100
|
NUM_EPOCHS = 50
|
||||||
|
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
classes = ['grass', 'sand', 'tree', 'water']
|
print("Using ", DEVICE)
|
||||||
|
CLASSES = ['grass', 'sand', 'tree', 'water']
|
||||||
|
|
||||||
setup_photos = transforms.Compose([
|
SETUP_PHOTOS = transforms.Compose([
|
||||||
transforms.Resize(36),
|
transforms.Resize(36),
|
||||||
transforms.CenterCrop(36),
|
transforms.CenterCrop(36),
|
||||||
transforms.ToPILImage(),
|
transforms.ToPILImage(),
|
||||||
@ -87,5 +88,5 @@ setup_photos = transforms.Compose([
|
|||||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||||
])
|
])
|
||||||
|
|
||||||
id_to_class = {i: j for i, j in enumerate(classes)}
|
ID_TO_CLASS = {i: j for i, j in enumerate(CLASSES)}
|
||||||
class_to_id = {value: key for key, value in id_to_class.items()}
|
CLASS_TO_ID = {value: key for key, value in ID_TO_CLASS.items()}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import pygame
|
import pygame
|
||||||
from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE, COLUMNS, ROWS, classes, class_to_id
|
from common.constants import GRID_CELL_PADDING, GRID_CELL_SIZE, COLUMNS, ROWS, CLASSES, CLASS_TO_ID
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -24,29 +24,29 @@ def createCSV():
|
|||||||
writer = csv.writer(train_csvfile)
|
writer = csv.writer(train_csvfile)
|
||||||
writer.writerow(["filepath", "type"])
|
writer.writerow(["filepath", "type"])
|
||||||
|
|
||||||
for class_name in classes:
|
for class_name in CLASSES:
|
||||||
class_dir = train_data_path + "/" + class_name
|
class_dir = train_data_path + "/" + class_name
|
||||||
for filename in os.listdir(class_dir):
|
for filename in os.listdir(class_dir):
|
||||||
f = os.path.join(class_dir, filename)
|
f = os.path.join(class_dir, filename)
|
||||||
if os.path.isfile(f):
|
if os.path.isfile(f):
|
||||||
writer.writerow([f, class_to_id[class_name]])
|
writer.writerow([f, CLASS_TO_ID[class_name]])
|
||||||
|
|
||||||
train_csvfile.close()
|
train_csvfile.close()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Brak plików do uczenia")
|
print("Brak plików do uczenia")
|
||||||
|
|
||||||
if os.path.exists(train_data_path):
|
if os.path.exists(test_data_path):
|
||||||
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
|
test_csvfile = open('./data/test_csv_file.csv', 'w', newline="")
|
||||||
writer = csv.writer(test_csvfile)
|
writer = csv.writer(test_csvfile)
|
||||||
writer.writerow(["filepath", "type"])
|
writer.writerow(["filepath", "type"])
|
||||||
|
|
||||||
for class_name in classes:
|
for class_name in CLASSES:
|
||||||
class_dir = test_data_path + "/" + class_name
|
class_dir = test_data_path + "/" + class_name
|
||||||
for filename in os.listdir(class_dir):
|
for filename in os.listdir(class_dir):
|
||||||
f = os.path.join(class_dir, filename)
|
f = os.path.join(class_dir, filename)
|
||||||
if os.path.isfile(f):
|
if os.path.isfile(f):
|
||||||
writer.writerow([f, class_to_id[class_name]])
|
writer.writerow([f, CLASS_TO_ID[class_name]])
|
||||||
|
|
||||||
test_csvfile.close()
|
test_csvfile.close()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user