wko-on-cloud-n/benchmark/cloud_dataset.py

68 lines
2.3 KiB
Python
Raw Normal View History

2022-02-15 16:42:28 +01:00
import numpy as np
import pandas as pd
import rasterio
import torch
from typing import Optional, List
class CloudDataset(torch.utils.data.Dataset):
"""Reads in images, transforms pixel values, and serves a
dictionary containing chip ids, image tensors, and
label masks (where available).
"""
def __init__(
self,
x_paths: pd.DataFrame,
bands: List[str],
y_paths: Optional[pd.DataFrame] = None,
transforms: Optional[list] = None,
):
"""
Instantiate the CloudDataset class.
Args:
x_paths (pd.DataFrame): a dataframe with a row for each chip. There must be a column for chip_id,
and a column with the path to the TIF for each of bands
bands (list[str]): list of the bands included in the data
y_paths (pd.DataFrame, optional): a dataframe with a for each chip and columns for chip_id
and the path to the label TIF with ground truth cloud cover
transforms (list, optional): list of transforms to apply to the feature data (eg augmentations)
"""
self.data = x_paths
self.label = y_paths
self.transforms = transforms
self.bands = bands
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
# Loads an n-channel image from a chip-level dataframe
img = self.data.loc[idx]
band_arrs = []
for band in self.bands:
with rasterio.open(img[f"{band}_path"]) as b:
band_arr = b.read(1).astype("float32")
band_arrs.append(band_arr)
x_arr = np.stack(band_arrs, axis=-1)
# Apply data augmentations, if provided
if self.transforms:
x_arr = self.transforms(image=x_arr)["image"]
x_arr = np.transpose(x_arr, [2, 0, 1])
# Prepare dictionary for item
item = {"chip_id": img.chip_id, "chip": x_arr}
# Load label if available
if self.label is not None:
label_path = self.label.loc[idx].label_path
with rasterio.open(label_path) as lp:
y_arr = lp.read(1).astype("float32")
# Apply same data augmentations to the label
if self.transforms:
y_arr = self.transforms(image=y_arr)["image"]
item["label"] = y_arr
return item