197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
from typing import Optional, List
|
|
|
|
import pandas as pd
|
|
import pytorch_lightning as pl
|
|
import segmentation_models_pytorch as smp
|
|
import torch
|
|
|
|
try:
|
|
from cloud_dataset import CloudDataset
|
|
from losses import intersection_over_union
|
|
except ImportError:
|
|
from benchmark_src.cloud_dataset import CloudDataset
|
|
from benchmark_src.losses import intersection_over_union
|
|
|
|
|
|
class CloudModel(pl.LightningModule):
|
|
def __init__(
|
|
self,
|
|
bands: List[str],
|
|
x_train: Optional[pd.DataFrame] = None,
|
|
y_train: Optional[pd.DataFrame] = None,
|
|
x_val: Optional[pd.DataFrame] = None,
|
|
y_val: Optional[pd.DataFrame] = None,
|
|
hparams: dict = {},
|
|
):
|
|
"""
|
|
Instantiate the CloudModel class based on the pl.LightningModule
|
|
(https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html).
|
|
|
|
Args:
|
|
bands (list[str]): Names of the bands provided for each chip
|
|
x_train (pd.DataFrame, optional): a dataframe of the training features with a row for each chip.
|
|
There must be a column for chip_id, and a column with the path to the TIF for each of bands.
|
|
Required for model training
|
|
y_train (pd.DataFrame, optional): a dataframe of the training labels with a for each chip
|
|
and columns for chip_id and the path to the label TIF with ground truth cloud cover.
|
|
Required for model training
|
|
x_val (pd.DataFrame, optional): a dataframe of the validation features with a row for each chip.
|
|
There must be a column for chip_id, and a column with the path to the TIF for each of bands.
|
|
Required for model training
|
|
y_val (pd.DataFrame, optional): a dataframe of the validation labels with a for each chip
|
|
and columns for chip_id and the path to the label TIF with ground truth cloud cover.
|
|
Required for model training
|
|
hparams (dict, optional): Dictionary of additional modeling parameters.
|
|
"""
|
|
super().__init__()
|
|
self.hparams.update(hparams)
|
|
self.save_hyperparameters()
|
|
|
|
# required
|
|
self.bands = bands
|
|
|
|
# optional modeling params
|
|
self.backbone = self.hparams.get("backbone", "resnet34")
|
|
self.weights = self.hparams.get("weights", "imagenet")
|
|
self.learning_rate = self.hparams.get("lr", 1e-3)
|
|
self.patience = self.hparams.get("patience", 4)
|
|
self.num_workers = self.hparams.get("num_workers", 2)
|
|
self.batch_size = self.hparams.get("batch_size", 32)
|
|
self.gpu = self.hparams.get("gpu", False)
|
|
self.transform = None
|
|
|
|
# Instantiate datasets, model, and trainer params if provided
|
|
self.train_dataset = CloudDataset(
|
|
x_paths=x_train,
|
|
bands=self.bands,
|
|
y_paths=y_train,
|
|
transforms=self.transform,
|
|
)
|
|
self.val_dataset = CloudDataset(
|
|
x_paths=x_val,
|
|
bands=self.bands,
|
|
y_paths=y_val,
|
|
transforms=None,
|
|
)
|
|
self.model = self._prepare_model()
|
|
|
|
## Required LightningModule methods ##
|
|
|
|
def forward(self, image: torch.Tensor):
|
|
# Forward pass
|
|
return self.model(image)
|
|
|
|
def training_step(self, batch: dict, batch_idx: int):
|
|
"""
|
|
Training step.
|
|
|
|
Args:
|
|
batch (dict): dictionary of items from CloudDataset of the form
|
|
{'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]}
|
|
batch_idx (int): batch number
|
|
"""
|
|
if self.train_dataset.data is None:
|
|
raise ValueError(
|
|
"x_train and y_train must be specified when CloudModel is instantiated to run training"
|
|
)
|
|
|
|
# Switch on training mode
|
|
self.model.train()
|
|
torch.set_grad_enabled(True)
|
|
|
|
# Load images and labels
|
|
x = batch["chip"]
|
|
y = batch["label"].long()
|
|
if self.gpu:
|
|
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
|
|
|
|
# Forward pass
|
|
preds = self.forward(x)
|
|
|
|
# Log batch loss
|
|
loss = torch.nn.CrossEntropyLoss(reduction="none")(preds, y).mean()
|
|
self.log(
|
|
"loss",
|
|
loss,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
prog_bar=True,
|
|
logger=True,
|
|
)
|
|
return loss
|
|
|
|
def validation_step(self, batch: dict, batch_idx: int):
|
|
"""
|
|
Validation step.
|
|
|
|
Args:
|
|
batch (dict): dictionary of items from CloudDataset of the form
|
|
{'chip_id': list[str], 'chip': list[torch.Tensor], 'label': list[torch.Tensor]}
|
|
batch_idx (int): batch number
|
|
"""
|
|
if self.val_dataset.data is None:
|
|
raise ValueError(
|
|
"x_val and y_val must be specified when CloudModel is instantiated to run validation"
|
|
)
|
|
|
|
# Switch on validation mode
|
|
self.model.eval()
|
|
torch.set_grad_enabled(False)
|
|
|
|
# Load images and labels
|
|
x = batch["chip"]
|
|
y = batch["label"].long()
|
|
if self.gpu:
|
|
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
|
|
|
|
# Forward pass & softmax
|
|
preds = self.forward(x)
|
|
preds = torch.softmax(preds, dim=1)[:, 1]
|
|
preds = (preds > 0.5) * 1 # convert to int
|
|
|
|
# Log batch IOU
|
|
batch_iou = intersection_over_union(preds, y)
|
|
self.log(
|
|
"iou", batch_iou, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
)
|
|
return batch_iou
|
|
|
|
def train_dataloader(self):
|
|
# DataLoader class for training
|
|
return torch.utils.data.DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
shuffle=True,
|
|
pin_memory=True,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
# DataLoader class for validation
|
|
return torch.utils.data.DataLoader(
|
|
self.val_dataset,
|
|
batch_size=self.batch_size,
|
|
num_workers=0,
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
)
|
|
|
|
def configure_optimizers(self):
|
|
opt = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
|
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
|
|
return [opt], [sch]
|
|
|
|
## Convenience Methods ##
|
|
|
|
def _prepare_model(self):
|
|
# Instantiate U-Net model
|
|
unet_model = smp.Unet(
|
|
encoder_name=self.backbone,
|
|
encoder_weights=self.weights,
|
|
in_channels=4,
|
|
classes=2,
|
|
)
|
|
if self.gpu:
|
|
unet_model.cuda()
|
|
|
|
return unet_model |