wko-on-cloud-n/benchmark/main.py

135 lines
5.2 KiB
Python
Raw Normal View History

2022-02-15 16:42:28 +01:00
import os
from pathlib import Path
from typing import List
from loguru import logger
import pandas as pd
from PIL import Image
import torch
import typer
try:
from cloud_dataset import CloudDataset
from cloud_model import CloudModel
except ImportError:
from benchmark.cloud_dataset import CloudDataset
from benchmark.cloud_model import CloudModel
ROOT_DIRECTORY = Path("/codeexecution")
PREDICTIONS_DIRECTORY = ROOT_DIRECTORY / "predictions"
ASSETS_DIRECTORY = Path("./submission/assets")
DATA_DIRECTORY = ROOT_DIRECTORY / "data"
INPUT_IMAGES_DIRECTORY = DATA_DIRECTORY / "test_features"
# Set the pytorch cache directory and include cached models in your submission.zip
os.environ["TORCH_HOME"] = str(ASSETS_DIRECTORY / "assets/torch")
def get_metadata(features_dir: os.PathLike, bands: List[str]):
"""
Given a folder of feature data, return a dataframe where the index is the chip id
and there is a column for the path to each band's TIF image.
Args:
features_dir (os.PathLike): path to the directory of feature data, which should have
a folder for each chip
bands (list[str]): list of bands provided for each chip
"""
chip_metadata = pd.DataFrame(index=[f"{band}_path" for band in bands])
chip_ids = (
pth.name for pth in features_dir.iterdir() if not pth.name.startswith(".")
)
for chip_id in chip_ids:
chip_bands = [features_dir / chip_id / f"{band}.tif" for band in bands]
chip_metadata[chip_id] = chip_bands
return chip_metadata.transpose().reset_index().rename(columns={"index": "chip_id"})
def make_predictions(
model: CloudModel,
x_paths: pd.DataFrame,
bands: List[str],
predictions_dir: os.PathLike,
):
"""Predicts cloud cover and saves results to the predictions directory.
Args:
model (CloudModel): an instantiated CloudModel based on pl.LightningModule
x_paths (pd.DataFrame): a dataframe with a row for each chip. There must be a column for chip_id,
and a column with the path to the TIF for each of bands provided
bands (list[str]): list of bands provided for each chip
predictions_dir (os.PathLike): Destination directory to save the predicted TIF masks
"""
test_dataset = CloudDataset(x_paths=x_paths, bands=bands)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=model.batch_size,
num_workers=model.num_workers,
shuffle=False,
pin_memory=True,
)
for batch_index, batch in enumerate(test_dataloader):
logger.debug(f"Predicting batch {batch_index} of {len(test_dataloader)}")
x = batch["chip"]
preds = model.forward(x)
preds = torch.softmax(preds, dim=1)[:, 1]
preds = (preds > 0.5).detach().numpy().astype("uint8")
for chip_id, pred in zip(batch["chip_id"], preds):
chip_pred_path = predictions_dir / f"{chip_id}.tif"
chip_pred_im = Image.fromarray(pred)
chip_pred_im.save(chip_pred_path)
def main(
model_weights_path: Path = ASSETS_DIRECTORY / "cloud_model.pt",
test_features_dir: Path = DATA_DIRECTORY / "test_features",
predictions_dir: Path = PREDICTIONS_DIRECTORY,
bands: List[str] = ["B02", "B03", "B04", "B08"],
fast_dev_run: bool = False,
):
"""
Generate predictions for the chips in test_features_dir using the model saved at
model_weights_path.
Predictions are saved in predictions_dir. The default paths to all three files are based on
the structure of the code execution runtime.
Args:
model_weights_path (os.PathLike): Path to the weights of a trained CloudModel.
test_features_dir (os.PathLike, optional): Path to the features for the test data. Defaults
to 'data/test_features' in the same directory as main.py
predictions_dir (os.PathLike, optional): Destination directory to save the predicted TIF masks
Defaults to 'predictions' in the same directory as main.py
bands (List[str], optional): List of bands provided for each chip
"""
if not test_features_dir.exists():
raise ValueError(
f"The directory for test feature images must exist and {test_features_dir} does not exist"
)
predictions_dir.mkdir(exist_ok=True, parents=True)
logger.info("Loading model")
model = CloudModel(bands=bands, hparams={"weights": None})
try:
model.load_state_dict(torch.load(model_weights_path))
except RuntimeError:
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
logger.info("Loading test metadata")
test_metadata = get_metadata(test_features_dir, bands=bands)
train_metadata = get_metadata(Path('data/train_features'), bands=bands)
if fast_dev_run:
test_metadata = test_metadata.head()
logger.info(f"Found {len(test_metadata)} chips")
logger.info("Generating predictions in batches")
make_predictions(model, test_metadata, bands, predictions_dir)
make_predictions(model, train_metadata, bands, Path('data/predictions'))
logger.info(f"""Saved {len(list(predictions_dir.glob("*.tif")))} predictions""")
if __name__ == "__main__":
typer.run(main)