80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
import lightning as L
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(UNet, self).__init__()
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
self.middle = nn.Sequential(
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.encoder(x)
|
|
x2 = self.middle(x1)
|
|
x3 = self.decoder(x2)
|
|
return x3
|
|
|
|
|
|
class UNetLightning(L.LightningModule):
|
|
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
|
|
super(UNetLightning, self).__init__()
|
|
self.model = UNet(in_channels, out_channels)
|
|
self.learning_rate = learning_rate
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("train_loss", loss)
|
|
return loss
|
|
|
|
def log_hyperparameters(self):
|
|
for key, value in self.hparams.items():
|
|
self.logger.experiment.log_param(key, value)
|
|
|
|
def configure_optimizers(self):
|
|
self.log_hyperparameters()
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
return optimizer
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("val_loss", loss)
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
return y_hat
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("test_loss", loss)
|