forked from s464965/WMICraft
update
This commit is contained in:
parent
5c1a1605b8
commit
b6ba817d55
Binary file not shown.
Binary file not shown.
@ -0,0 +1 @@
|
|||||||
|
{}
|
@ -16,9 +16,11 @@ class NeuralNetwork(pl.LightningModule):
|
|||||||
self.maxpool1 = nn.MaxPool2d((2, 2), stride=2)
|
self.maxpool1 = nn.MaxPool2d((2, 2), stride=2)
|
||||||
self.conv2 = nn.Conv2d(24, 48, (3, 3), padding=1)
|
self.conv2 = nn.Conv2d(24, 48, (3, 3), padding=1)
|
||||||
self.relu2 = nn.ReLU()
|
self.relu2 = nn.ReLU()
|
||||||
self.fc1 = nn.Linear(48*18*18, 4)
|
self.fc1 = nn.Linear(48*18*18, 800)
|
||||||
self.relu3 = nn.ReLU()
|
self.relu3 = nn.ReLU()
|
||||||
self.fc2 = nn.Linear(500, num_classes)
|
self.fc2 = nn.Linear(800, 400)
|
||||||
|
self.relu4 = nn.ReLU()
|
||||||
|
self.fc3 = nn.Linear(400, 4)
|
||||||
self.logSoftmax = nn.LogSoftmax(dim=1)
|
self.logSoftmax = nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -32,6 +34,10 @@ class NeuralNetwork(pl.LightningModule):
|
|||||||
x = self.relu2(x)
|
x = self.relu2(x)
|
||||||
x = x.reshape(x.shape[0], -1)
|
x = x.reshape(x.shape[0], -1)
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
|
x = self.relu3(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.relu4(x)
|
||||||
|
x = self.fc3(x)
|
||||||
x = self.logSoftmax(x)
|
x = self.logSoftmax(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -10,7 +10,8 @@ from torch.optim import Adam
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
|
import torchvision.transforms.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
def train(model):
|
def train(model):
|
||||||
model = model.to(DEVICE)
|
model = model.to(DEVICE)
|
||||||
@ -95,12 +96,13 @@ def check_accuracy_tiles():
|
|||||||
|
|
||||||
|
|
||||||
def what_is_it(img_path, show_img=False):
|
def what_is_it(img_path, show_img=False):
|
||||||
image = read_image(img_path, mode=ImageReadMode.RGB)
|
image = Image.open(img_path).convert('RGB')
|
||||||
if show_img:
|
if show_img:
|
||||||
plt.imshow(plt.imread(img_path))
|
plt.imshow(image)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
image = SETUP_PHOTOS(image).unsqueeze(0)
|
image = SETUP_PHOTOS(image).unsqueeze(0)
|
||||||
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_13/checkpoints/epoch=4-step=405.ckpt')
|
model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_20/checkpoints/epoch=3-step=324.ckpt')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -108,18 +110,19 @@ def what_is_it(img_path, show_img=False):
|
|||||||
return ID_TO_CLASS[idx]
|
return ID_TO_CLASS[idx]
|
||||||
|
|
||||||
|
|
||||||
CNN = NeuralNetwork()
|
#CNN = NeuralNetwork()
|
||||||
common.helpers.createCSV()
|
#common.helpers.createCSV()
|
||||||
|
|
||||||
#trainer = pl.Trainer(accelerator='gpu', devices=1, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
|
#trainer = pl.Trainer(accelerator='gpu', callbacks=EarlyStopping('val_loss'), devices=1, max_epochs=NUM_EPOCHS)
|
||||||
trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)
|
||||||
|
|
||||||
trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
#trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
#testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
|
||||||
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
#train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
#test_loader = DataLoader(testset, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
trainer.fit(CNN, train_loader, test_loader)
|
#trainer.fit(CNN, train_loader, test_loader)
|
||||||
#trainer.tune(CNN, train_loader, test_loader)
|
#trainer.tune(CNN, train_loader, test_loader)
|
||||||
#check_accuracy_tiles()
|
#check_accuracy_tiles()
|
||||||
#print(what_is_it('../../resources/textures/sand.png', True))
|
|
||||||
|
#print(what_is_it('../../resources/textures/grass2.png', True))
|
||||||
|
@ -3,6 +3,7 @@ from torch.utils.data import Dataset
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torchvision.io import read_image, ImageReadMode
|
from torchvision.io import read_image, ImageReadMode
|
||||||
from common.helpers import createCSV
|
from common.helpers import createCSV
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class WaterSandTreeGrass(Dataset):
|
class WaterSandTreeGrass(Dataset):
|
||||||
@ -15,7 +16,8 @@ class WaterSandTreeGrass(Dataset):
|
|||||||
return len(self.img_labels)
|
return len(self.img_labels)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
image = read_image(self.img_labels.iloc[idx, 0], mode=ImageReadMode.RGB)
|
image = Image.open(self.img_labels.iloc[idx, 0]).convert('RGB')
|
||||||
|
|
||||||
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
label = torch.tensor(int(self.img_labels.iloc[idx, 1]))
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
|
@ -77,16 +77,15 @@ BAR_HEIGHT_MULTIPLIER = 0.1
|
|||||||
|
|
||||||
|
|
||||||
#NEURAL_NETWORK
|
#NEURAL_NETWORK
|
||||||
LEARNING_RATE = 0.00478630092322638
|
LEARNING_RATE = 0.000630957344480193
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
NUM_EPOCHS = 20
|
NUM_EPOCHS = 9
|
||||||
|
|
||||||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
print("Using ", DEVICE)
|
print("Using ", DEVICE)
|
||||||
CLASSES = ['grass', 'sand', 'tree', 'water']
|
CLASSES = ['grass', 'sand', 'tree', 'water']
|
||||||
|
|
||||||
SETUP_PHOTOS = transforms.Compose([
|
SETUP_PHOTOS = transforms.Compose([
|
||||||
transforms.ToPILImage(),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Resize((36, 36)),
|
transforms.Resize((36, 36)),
|
||||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||||
|
Loading…
Reference in New Issue
Block a user