68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import rasterio
|
|
import torch
|
|
from typing import Optional, List
|
|
|
|
|
|
class CloudDataset(torch.utils.data.Dataset):
|
|
"""Reads in images, transforms pixel values, and serves a
|
|
dictionary containing chip ids, image tensors, and
|
|
label masks (where available).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
x_paths: pd.DataFrame,
|
|
bands: List[str],
|
|
y_paths: Optional[pd.DataFrame] = None,
|
|
transforms: Optional[list] = None,
|
|
):
|
|
"""
|
|
Instantiate the CloudDataset class.
|
|
|
|
Args:
|
|
x_paths (pd.DataFrame): a dataframe with a row for each chip. There must be a column for chip_id,
|
|
and a column with the path to the TIF for each of bands
|
|
bands (list[str]): list of the bands included in the data
|
|
y_paths (pd.DataFrame, optional): a dataframe with a for each chip and columns for chip_id
|
|
and the path to the label TIF with ground truth cloud cover
|
|
transforms (list, optional): list of transforms to apply to the feature data (eg augmentations)
|
|
"""
|
|
self.data = x_paths
|
|
self.label = y_paths
|
|
self.transforms = transforms
|
|
self.bands = bands
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx: int):
|
|
# Loads an n-channel image from a chip-level dataframe
|
|
img = self.data.loc[idx]
|
|
band_arrs = []
|
|
for band in self.bands:
|
|
with rasterio.open(img[f"{band}_path"]) as b:
|
|
band_arr = b.read(1).astype("float32")
|
|
band_arrs.append(band_arr)
|
|
x_arr = np.stack(band_arrs, axis=-1)
|
|
|
|
# Apply data augmentations, if provided
|
|
if self.transforms:
|
|
x_arr = self.transforms(image=x_arr)["image"]
|
|
x_arr = np.transpose(x_arr, [2, 0, 1])
|
|
|
|
# Prepare dictionary for item
|
|
item = {"chip_id": img.chip_id, "chip": x_arr}
|
|
|
|
# Load label if available
|
|
if self.label is not None:
|
|
label_path = self.label.loc[idx].label_path
|
|
with rasterio.open(label_path) as lp:
|
|
y_arr = lp.read(1).astype("float32")
|
|
# Apply same data augmentations to the label
|
|
if self.transforms:
|
|
y_arr = self.transforms(image=y_arr)["image"]
|
|
item["label"] = y_arr
|
|
|
|
return item |