iumKC/model.py

113 lines
3.7 KiB
Python

import lightning as L
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.middle = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x1 = self.encoder(x)
x2 = self.middle(x1)
x3 = self.decoder(x2)
return x3
class UNetLightning(L.LightningModule):
def __init__(self, in_channels=3, out_channels=3, learning_rate=1e-3):
super(UNetLightning, self).__init__()
self.model = UNet(in_channels, out_channels)
self.learning_rate = learning_rate
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
# print(max(x.flatten()), min(x.flatten()),
# max(y.flatten()), min(y.flatten()))
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("val_loss", loss)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
y_hat = self(x)
return y_hat
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
self.log("test_loss", loss)
# visualize
# if batch_idx == 0:
# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision.transforms as transforms
# x = transforms.ToPILImage()(x[0])
# y = transforms.ToPILImage()(y[0])
# y_hat = transforms.ToPILImage()(y_hat[0])
# plt.figure(figsize=(15, 15))
# plt.subplot(131)
# plt.imshow(np.array(x))
# plt.title("Input")
# plt.axis("off")
# plt.subplot(132)
# plt.imshow(np.array(y))
# plt.title("Ground Truth")
# plt.axis("off")
# plt.subplot(133)
# plt.imshow(np.array(y_hat))
# plt.title("Prediction")
# plt.axis("off")
# plt.show()
# def on_test_epoch_end(self):
# all_preds, all_labels = [], []
# for output in self.trainer.predictions:
# # predicted values
# probs = list(output['logits'].cpu().detach().numpy())
# labels = list(output['labels'].flatten().cpu().detach().numpy())
# all_preds.extend(probs)
# all_labels.extend(labels)
# # save predictions and labels
# import numpy as np
# np.save('predictions.npy', all_preds)
# np.save('labels.npy', all_labels)