import torch import pytorch_lightning as pl import torch.nn as nn from torch.optim import SGD, Adam, lr_scheduler import torch.nn.functional as F from torch.utils.data import DataLoader from watersandtreegrass import WaterSandTreeGrass from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS class NeuralNetwork(pl.LightningModule): def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4): super(NeuralNetwork, self).__init__() self.conv1 = nn.Conv2d(numChannels, 24, (3, 3), padding=1) self.relu1 = nn.ReLU() self.maxpool1 = nn.MaxPool2d((2, 2), stride=2) self.conv2 = nn.Conv2d(24, 48, (3, 3), padding=1) self.relu2 = nn.ReLU() self.fc1 = nn.Linear(48*18*18, 4) self.relu3 = nn.ReLU() self.fc2 = nn.Linear(500, num_classes) self.logSoftmax = nn.LogSoftmax(dim=1) self.batch_size = batch_size self.learning_rate = learning_rate def forward(self, x): x = self.conv1(x) x = self.relu1(x) x = self.maxpool1(x) x = self.conv2(x) x = self.relu2(x) x = x.reshape(x.shape[0], -1) x = self.fc1(x) x = self.logSoftmax(x) return x def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.learning_rate) return optimizer def training_step(self, batch, batch_idx): x, y = batch scores = self(x) loss = F.nll_loss(scores, y) return loss def validation_step(self, batch, batch_idx): x, y = batch scores = self(x) val_loss = F.nll_loss(scores, y) self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True) def test_step(self, batch, batch_idx): x, y = batch scores = self(x) test_loss = F.nll_loss(scores, y) self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)