wko-on-cloud-n/benchmark/cloud_model.py
2022-02-15 16:42:28 +01:00

197 lines
6.6 KiB
Python

from typing import Optional, List
import pandas as pd
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
try:
from cloud_dataset import CloudDataset
from losses import intersection_over_union
except ImportError:
from benchmark_src.cloud_dataset import CloudDataset
from benchmark_src.losses import intersection_over_union
class CloudModel(pl.LightningModule):
def __init__(
self,
bands: List[str],
x_train: Optional[pd.DataFrame] = None,
y_train: Optional[pd.DataFrame] = None,
x_val: Optional[pd.DataFrame] = None,
y_val: Optional[pd.DataFrame] = None,
hparams: dict = {},
):
"""
Instantiate the CloudModel class based on the pl.LightningModule
(https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html).
Args:
bands (list[str]): Names of the bands provided for each chip
x_train (pd.DataFrame, optional): a dataframe of the training features with a row for each chip.
There must be a column for chip_id, and a column with the path to the TIF for each of bands.
Required for model training
y_train (pd.DataFrame, optional): a dataframe of the training labels with a for each chip
and columns for chip_id and the path to the label TIF with ground truth cloud cover.
Required for model training
x_val (pd.DataFrame, optional): a dataframe of the validation features with a row for each chip.
There must be a column for chip_id, and a column with the path to the TIF for each of bands.
Required for model training
y_val (pd.DataFrame, optional): a dataframe of the validation labels with a for each chip
and columns for chip_id and the path to the label TIF with ground truth cloud cover.
Required for model training
hparams (dict, optional): Dictionary of additional modeling parameters.
"""
super().__init__()
self.hparams.update(hparams)
self.save_hyperparameters()
# required
self.bands = bands
# optional modeling params
self.backbone = self.hparams.get("backbone", "resnet34")
self.weights = self.hparams.get("weights", "imagenet")
self.learning_rate = self.hparams.get("lr", 1e-3)
self.patience = self.hparams.get("patience", 4)
self.num_workers = self.hparams.get("num_workers", 2)
self.batch_size = self.hparams.get("batch_size", 32)
self.gpu = self.hparams.get("gpu", False)
self.transform = None
# Instantiate datasets, model, and trainer params if provided
self.train_dataset = CloudDataset(
x_paths=x_train,
bands=self.bands,
y_paths=y_train,
transforms=self.transform,
)
self.val_dataset = CloudDataset(
x_paths=x_val,
bands=self.bands,
y_paths=y_val,
transforms=None,
)
self.model = self._prepare_model()
## Required LightningModule methods ##
def forward(self, image: torch.Tensor):
# Forward pass
return self.model(image)
def training_step(self, batch: dict, batch_idx: int):
"""
Training step.
Args:
batch (dict): dictionary of items from CloudDataset of the form
{'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]}
batch_idx (int): batch number
"""
if self.train_dataset.data is None:
raise ValueError(
"x_train and y_train must be specified when CloudModel is instantiated to run training"
)
# Switch on training mode
self.model.train()
torch.set_grad_enabled(True)
# Load images and labels
x = batch["chip"]
y = batch["label"].long()
if self.gpu:
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
# Forward pass
preds = self.forward(x)
# Log batch loss
loss = torch.nn.CrossEntropyLoss(reduction="none")(preds, y).mean()
self.log(
"loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch: dict, batch_idx: int):
"""
Validation step.
Args:
batch (dict): dictionary of items from CloudDataset of the form
{'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]}
batch_idx (int): batch number
"""
if self.val_dataset.data is None:
raise ValueError(
"x_val and y_val must be specified when CloudModel is instantiated to run validation"
)
# Switch on validation mode
self.model.eval()
torch.set_grad_enabled(False)
# Load images and labels
x = batch["chip"]
y = batch["label"].long()
if self.gpu:
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
# Forward pass & softmax
preds = self.forward(x)
preds = torch.softmax(preds, dim=1)[:, 1]
preds = (preds > 0.5) * 1 # convert to int
# Log batch IOU
batch_iou = intersection_over_union(preds, y)
self.log(
"iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
return batch_iou
def train_dataloader(self):
# DataLoader class for training
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self):
# DataLoader class for validation
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=0,
shuffle=False,
pin_memory=True,
)
def configure_optimizers(self):
opt = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]
## Convenience Methods ##
def _prepare_model(self):
# Instantiate U-Net model
unet_model = smp.Unet(
encoder_name=self.backbone,
encoder_weights=self.weights,
in_channels=4,
classes=2,
)
if self.gpu:
unet_model.cuda()
return unet_model