wko-on-cloud-n/benchmark/main.py
2022-02-15 16:42:28 +01:00

135 lines
5.2 KiB
Python

import os
from pathlib import Path
from typing import List
from loguru import logger
import pandas as pd
from PIL import Image
import torch
import typer
try:
from cloud_dataset import CloudDataset
from cloud_model import CloudModel
except ImportError:
from benchmark.cloud_dataset import CloudDataset
from benchmark.cloud_model import CloudModel
ROOT_DIRECTORY = Path("/codeexecution")
PREDICTIONS_DIRECTORY = ROOT_DIRECTORY / "predictions"
ASSETS_DIRECTORY = Path("./submission/assets")
DATA_DIRECTORY = ROOT_DIRECTORY / "data"
INPUT_IMAGES_DIRECTORY = DATA_DIRECTORY / "test_features"
# Set the pytorch cache directory and include cached models in your submission.zip
os.environ["TORCH_HOME"] = str(ASSETS_DIRECTORY / "assets/torch")
def get_metadata(features_dir: os.PathLike, bands: List[str]):
"""
Given a folder of feature data, return a dataframe where the index is the chip id
and there is a column for the path to each band's TIF image.
Args:
features_dir (os.PathLike): path to the directory of feature data, which should have
a folder for each chip
bands (list[str]): list of bands provided for each chip
"""
chip_metadata = pd.DataFrame(index=[f"{band}_path" for band in bands])
chip_ids = (
pth.name for pth in features_dir.iterdir() if not pth.name.startswith(".")
)
for chip_id in chip_ids:
chip_bands = [features_dir / chip_id / f"{band}.tif" for band in bands]
chip_metadata[chip_id] = chip_bands
return chip_metadata.transpose().reset_index().rename(columns={"index": "chip_id"})
def make_predictions(
model: CloudModel,
x_paths: pd.DataFrame,
bands: List[str],
predictions_dir: os.PathLike,
):
"""Predicts cloud cover and saves results to the predictions directory.
Args:
model (CloudModel): an instantiated CloudModel based on pl.LightningModule
x_paths (pd.DataFrame): a dataframe with a row for each chip. There must be a column for chip_id,
and a column with the path to the TIF for each of bands provided
bands (list[str]): list of bands provided for each chip
predictions_dir (os.PathLike): Destination directory to save the predicted TIF masks
"""
test_dataset = CloudDataset(x_paths=x_paths, bands=bands)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=model.batch_size,
num_workers=model.num_workers,
shuffle=False,
pin_memory=True,
)
for batch_index, batch in enumerate(test_dataloader):
logger.debug(f"Predicting batch {batch_index} of {len(test_dataloader)}")
x = batch["chip"]
preds = model.forward(x)
preds = torch.softmax(preds, dim=1)[:, 1]
preds = (preds > 0.5).detach().numpy().astype("uint8")
for chip_id, pred in zip(batch["chip_id"], preds):
chip_pred_path = predictions_dir / f"{chip_id}.tif"
chip_pred_im = Image.fromarray(pred)
chip_pred_im.save(chip_pred_path)
def main(
model_weights_path: Path = ASSETS_DIRECTORY / "cloud_model.pt",
test_features_dir: Path = DATA_DIRECTORY / "test_features",
predictions_dir: Path = PREDICTIONS_DIRECTORY,
bands: List[str] = ["B02", "B03", "B04", "B08"],
fast_dev_run: bool = False,
):
"""
Generate predictions for the chips in test_features_dir using the model saved at
model_weights_path.
Predictions are saved in predictions_dir. The default paths to all three files are based on
the structure of the code execution runtime.
Args:
model_weights_path (os.PathLike): Path to the weights of a trained CloudModel.
test_features_dir (os.PathLike, optional): Path to the features for the test data. Defaults
to 'data/test_features' in the same directory as main.py
predictions_dir (os.PathLike, optional): Destination directory to save the predicted TIF masks
Defaults to 'predictions' in the same directory as main.py
bands (List[str], optional): List of bands provided for each chip
"""
if not test_features_dir.exists():
raise ValueError(
f"The directory for test feature images must exist and {test_features_dir} does not exist"
)
predictions_dir.mkdir(exist_ok=True, parents=True)
logger.info("Loading model")
model = CloudModel(bands=bands, hparams={"weights": None})
try:
model.load_state_dict(torch.load(model_weights_path))
except RuntimeError:
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
logger.info("Loading test metadata")
test_metadata = get_metadata(test_features_dir, bands=bands)
train_metadata = get_metadata(Path('data/train_features'), bands=bands)
if fast_dev_run:
test_metadata = test_metadata.head()
logger.info(f"Found {len(test_metadata)} chips")
logger.info("Generating predictions in batches")
make_predictions(model, test_metadata, bands, predictions_dir)
make_predictions(model, train_metadata, bands, Path('data/predictions'))
logger.info(f"""Saved {len(list(predictions_dir.glob("*.tif")))} predictions""")
if __name__ == "__main__":
typer.run(main)