113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
import lightning as L
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(UNet, self).__init__()
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
self.middle = nn.Sequential(
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.encoder(x)
|
|
x2 = self.middle(x1)
|
|
x3 = self.decoder(x2)
|
|
return x3
|
|
|
|
|
|
class UNetLightning(L.LightningModule):
|
|
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
|
|
super(UNetLightning, self).__init__()
|
|
self.model = UNet(in_channels, out_channels)
|
|
self.learning_rate = learning_rate
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
# print(max(x.flatten()), min(x.flatten()),
|
|
# max(y.flatten()), min(y.flatten()))
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("train_loss", loss)
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
return optimizer
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("val_loss", loss)
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
return y_hat
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
y_hat = self(x)
|
|
loss = nn.CrossEntropyLoss()(y_hat, y)
|
|
self.log("test_loss", loss)
|
|
# visualize
|
|
# if batch_idx == 0:
|
|
# import matplotlib.pyplot as plt
|
|
# import numpy as np
|
|
# import torchvision.transforms as transforms
|
|
# x = transforms.ToPILImage()(x[0])
|
|
# y = transforms.ToPILImage()(y[0])
|
|
# y_hat = transforms.ToPILImage()(y_hat[0])
|
|
# plt.figure(figsize=(15, 15))
|
|
# plt.subplot(131)
|
|
# plt.imshow(np.array(x))
|
|
# plt.title("Input")
|
|
# plt.axis("off")
|
|
# plt.subplot(132)
|
|
# plt.imshow(np.array(y))
|
|
# plt.title("Ground Truth")
|
|
# plt.axis("off")
|
|
# plt.subplot(133)
|
|
# plt.imshow(np.array(y_hat))
|
|
# plt.title("Prediction")
|
|
# plt.axis("off")
|
|
# plt.show()
|
|
|
|
# def on_test_epoch_end(self):
|
|
# all_preds, all_labels = [], []
|
|
# for output in self.trainer.predictions:
|
|
# # predicted values
|
|
# probs = list(output['logits'].cpu().detach().numpy())
|
|
# labels = list(output['labels'].flatten().cpu().detach().numpy())
|
|
# all_preds.extend(probs)
|
|
# all_labels.extend(labels)
|
|
|
|
# # save predictions and labels
|
|
# import numpy as np
|
|
# np.save('predictions.npy', all_preds)
|
|
# np.save('labels.npy', all_labels)
|