import lightning as L import torch import torch.nn as nn class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ) self.middle = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2), nn.Sigmoid(), ) def forward(self, x): x1 = self.encoder(x) x2 = self.middle(x1) x3 = self.decoder(x2) return x3 class UNetLightning(L.LightningModule): def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3): super(UNetLightning, self).__init__() self.model = UNet(in_channels, out_channels) self.learning_rate = learning_rate def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("train_loss", loss) return loss def log_hyperparameters(self): for key, value in self.hparams.items(): self.logger.experiment.log_param(key, value) def configure_optimizers(self): self.log_hyperparameters() optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("val_loss", loss) def predict_step(self, batch, batch_idx, dataloader_idx=0): x, y = batch y_hat = self(x) return y_hat def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = nn.CrossEntropyLoss()(y_hat, y) self.log("test_loss", loss)