WMICraft/algorithms/neural_network/neural_network.py

49 lines
1.6 KiB
Python
Raw Normal View History

import torch
2022-05-25 19:47:08 +02:00
import pytorch_lightning as pl
import torch.nn as nn
2022-05-25 19:47:08 +02:00
from torch.optim import SGD, Adam, lr_scheduler
import torch.nn.functional as F
2022-05-25 19:47:08 +02:00
from torch.utils.data import DataLoader
from watersandtreegrass import WaterSandTreeGrass
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
2022-05-25 19:47:08 +02:00
class NeuralNetwork(pl.LightningModule):
def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(36*36*3, 300),
nn.ReLU(),
nn.Linear(300, 4),
nn.LogSoftmax(dim=-1)
)
self.batch_size = batch_size
self.learning_rate = learning_rate
def forward(self, x):
x = x.reshape(x.shape[0], -1)
2022-05-25 19:47:08 +02:00
x = self.layer(x)
return x
2022-05-25 19:47:08 +02:00
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
loss = F.nll_loss(scores, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
val_loss = F.nll_loss(scores, y)
self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)
def test_step(self, batch, batch_idx):
x, y = batch
scores = self(x)
test_loss = F.nll_loss(scores, y)
self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)