from typing import Optional, List import pandas as pd import pytorch_lightning as pl import segmentation_models_pytorch as smp import torch try: from cloud_dataset import CloudDataset from losses import intersection_over_union except ImportError: from benchmark_src.cloud_dataset import CloudDataset from benchmark_src.losses import intersection_over_union class CloudModel(pl.LightningModule): def __init__( self, bands: List[str], x_train: Optional[pd.DataFrame] = None, y_train: Optional[pd.DataFrame] = None, x_val: Optional[pd.DataFrame] = None, y_val: Optional[pd.DataFrame] = None, hparams: dict = {}, ): """ Instantiate the CloudModel class based on the pl.LightningModule ( Args: bands (list[str]): Names of the bands provided for each chip x_train (pd.DataFrame, optional): a dataframe of the training features with a row for each chip. There must be a column for chip_id, and a column with the path to the TIF for each of bands. Required for model training y_train (pd.DataFrame, optional): a dataframe of the training labels with a for each chip and columns for chip_id and the path to the label TIF with ground truth cloud cover. Required for model training x_val (pd.DataFrame, optional): a dataframe of the validation features with a row for each chip. There must be a column for chip_id, and a column with the path to the TIF for each of bands. Required for model training y_val (pd.DataFrame, optional): a dataframe of the validation labels with a for each chip and columns for chip_id and the path to the label TIF with ground truth cloud cover. Required for model training hparams (dict, optional): Dictionary of additional modeling parameters. """ super().__init__() self.hparams.update(hparams) self.save_hyperparameters() # required self.bands = bands # optional modeling params self.backbone = self.hparams.get("backbone", "resnet34") self.weights = self.hparams.get("weights", "imagenet") self.learning_rate = self.hparams.get("lr", 1e-3) self.patience = self.hparams.get("patience", 4) self.num_workers = self.hparams.get("num_workers", 2) self.batch_size = self.hparams.get("batch_size", 32) self.gpu = self.hparams.get("gpu", False) self.transform = None # Instantiate datasets, model, and trainer params if provided self.train_dataset = CloudDataset( x_paths=x_train, bands=self.bands, y_paths=y_train, transforms=self.transform, ) self.val_dataset = CloudDataset( x_paths=x_val, bands=self.bands, y_paths=y_val, transforms=None, ) self.model = self._prepare_model() ## Required LightningModule methods ## def forward(self, image: torch.Tensor): # Forward pass return self.model(image) def training_step(self, batch: dict, batch_idx: int): """ Training step. Args: batch (dict): dictionary of items from CloudDataset of the form {'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]} batch_idx (int): batch number """ if is None: raise ValueError( "x_train and y_train must be specified when CloudModel is instantiated to run training" ) # Switch on training mode self.model.train() torch.set_grad_enabled(True) # Load images and labels x = batch["chip"] y = batch["label"].long() if self.gpu: x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) # Forward pass preds = self.forward(x) # Log batch loss loss = torch.nn.CrossEntropyLoss(reduction="none")(preds, y).mean() self.log( "loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch: dict, batch_idx: int): """ Validation step. Args: batch (dict): dictionary of items from CloudDataset of the form {'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]} batch_idx (int): batch number """ if is None: raise ValueError( "x_val and y_val must be specified when CloudModel is instantiated to run validation" ) # Switch on validation mode self.model.eval() torch.set_grad_enabled(False) # Load images and labels x = batch["chip"] y = batch["label"].long() if self.gpu: x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) # Forward pass & softmax preds = self.forward(x) preds = torch.softmax(preds, dim=1)[:, 1] preds = (preds > 0.5) * 1 # convert to int # Log batch IOU batch_iou = intersection_over_union(preds, y) self.log( "iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True ) return batch_iou def train_dataloader(self): # DataLoader class for training return self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True, ) def val_dataloader(self): # DataLoader class for validation return self.val_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False, pin_memory=True, ) def configure_optimizers(self): opt = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) return [opt], [sch] ## Convenience Methods ## def _prepare_model(self): # Instantiate U-Net model unet_model = smp.Unet( encoder_name=self.backbone, encoder_weights=self.weights, in_channels=4, classes=2, ) if self.gpu: unet_model.cuda() return unet_model