forked from s464965/WMICraft
57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import torch
|
|
import pytorch_lightning as pl
|
|
import torch.nn as nn
|
|
from torch.optim import Adam
|
|
import torch.nn.functional as F
|
|
from common.constants import BATCH_SIZE, LEARNING_RATE
|
|
|
|
|
|
class NeuralNetwork(pl.LightningModule):
|
|
def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4):
|
|
super(NeuralNetwork, self).__init__()
|
|
self.conv1 = nn.Conv2d(numChannels, 24, (3, 3), padding=1)
|
|
self.relu1 = nn.ReLU()
|
|
self.maxpool1 = nn.MaxPool2d((2, 2), stride=2)
|
|
self.conv2 = nn.Conv2d(24, 48, (3, 3), padding=1)
|
|
self.relu2 = nn.ReLU()
|
|
self.fc1 = nn.Linear(48*18*18, 4)
|
|
self.relu3 = nn.ReLU()
|
|
self.fc2 = nn.Linear(500, num_classes)
|
|
self.logSoftmax = nn.LogSoftmax(dim=1)
|
|
|
|
self.batch_size = batch_size
|
|
self.learning_rate = learning_rate
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.relu1(x)
|
|
x = self.maxpool1(x)
|
|
x = self.conv2(x)
|
|
x = self.relu2(x)
|
|
x = x.reshape(x.shape[0], -1)
|
|
x = self.fc1(x)
|
|
x = self.logSoftmax(x)
|
|
return x
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
|
return optimizer
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
scores = self(x)
|
|
loss = F.nll_loss(scores, y)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
scores = self(x)
|
|
val_loss = F.nll_loss(scores, y)
|
|
self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
scores = self(x)
|
|
test_loss = F.nll_loss(scores, y)
|
|
self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)
|