iumKC/model.py
2024-05-13 18:51:43 +02:00

80 lines
2.5 KiB
Python

import lightning as L
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
return x3
class UNetLightning(L.LightningModule):
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
super(UNetLightning, self).__init__()
self.model = UNet(in_channels, out_channels)
self.learning_rate = learning_rate
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def log_hyperparameters(self):
for key, value in self.hparams.items():
self.logger.experiment.log_param(key, value)
def configure_optimizers(self):
self.log_hyperparameters()
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("val_loss", loss)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
y_hat = self(x)
return y_hat
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("test_loss", loss)