68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import rasterio
|
||
|
import torch
|
||
|
from typing import Optional, List
|
||
|
|
||
|
|
||
|
class CloudDataset(torch.utils.data.Dataset):
|
||
|
"""Reads in images, transforms pixel values, and serves a
|
||
|
dictionary containing chip ids, image tensors, and
|
||
|
label masks (where available).
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
x_paths: pd.DataFrame,
|
||
|
bands: List[str],
|
||
|
y_paths: Optional[pd.DataFrame] = None,
|
||
|
transforms: Optional[list] = None,
|
||
|
):
|
||
|
"""
|
||
|
Instantiate the CloudDataset class.
|
||
|
|
||
|
Args:
|
||
|
x_paths (pd.DataFrame): a dataframe with a row for each chip. There must be a column for chip_id,
|
||
|
and a column with the path to the TIF for each of bands
|
||
|
bands (list[str]): list of the bands included in the data
|
||
|
y_paths (pd.DataFrame, optional): a dataframe with a for each chip and columns for chip_id
|
||
|
and the path to the label TIF with ground truth cloud cover
|
||
|
transforms (list, optional): list of transforms to apply to the feature data (eg augmentations)
|
||
|
"""
|
||
|
self.data = x_paths
|
||
|
self.label = y_paths
|
||
|
self.transforms = transforms
|
||
|
self.bands = bands
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
|
||
|
def __getitem__(self, idx: int):
|
||
|
# Loads an n-channel image from a chip-level dataframe
|
||
|
img = self.data.loc[idx]
|
||
|
band_arrs = []
|
||
|
for band in self.bands:
|
||
|
with rasterio.open(img[f"{band}_path"]) as b:
|
||
|
band_arr = b.read(1).astype("float32")
|
||
|
band_arrs.append(band_arr)
|
||
|
x_arr = np.stack(band_arrs, axis=-1)
|
||
|
|
||
|
# Apply data augmentations, if provided
|
||
|
if self.transforms:
|
||
|
x_arr = self.transforms(image=x_arr)["image"]
|
||
|
x_arr = np.transpose(x_arr, [2, 0, 1])
|
||
|
|
||
|
# Prepare dictionary for item
|
||
|
item = {"chip_id": img.chip_id, "chip": x_arr}
|
||
|
|
||
|
# Load label if available
|
||
|
if self.label is not None:
|
||
|
label_path = self.label.loc[idx].label_path
|
||
|
with rasterio.open(label_path) as lp:
|
||
|
y_arr = lp.read(1).astype("float32")
|
||
|
# Apply same data augmentations to the label
|
||
|
if self.transforms:
|
||
|
y_arr = self.transforms(image=y_arr)["image"]
|
||
|
item["label"] = y_arr
|
||
|
|
||
|
return item
|