import os.path from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image from .vision import VisionDataset class CocoDetection(VisionDataset): """`MS Coco Detection `_ Dataset. It requires the `COCO API to be installed `_. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__( self, root: Union[str, Path], annFile: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ) -> None: super().__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) def _load_image(self, id: int) -> Image.Image: path = self.coco.loadImgs(id)[0]["file_name"] return Image.open(os.path.join(self.root, path)).convert("RGB") def _load_target(self, id: int) -> List[Any]: return self.coco.loadAnns(self.coco.getAnnIds(id)) def __getitem__(self, index: int) -> Tuple[Any, Any]: if not isinstance(index, int): raise ValueError(f"Index must be of type integer, got {type(index)} instead.") id = self.ids[index] image = self._load_image(id) target = self._load_target(id) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self) -> int: return len(self.ids) class CocoCaptions(CocoDetection): """`MS Coco Captions `_ Dataset. It requires the `COCO API to be installed `_. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. Example: .. code:: python import torchvision.datasets as dset import torchvision.transforms as transforms cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.PILToTensor()) print('Number of samples: ', len(cap)) img, target = cap[3] # load 4th sample print("Image Size: ", img.size()) print(target) Output: :: Number of samples: 82783 Image Size: (3L, 427L, 640L) [u'A plane emitting smoke stream flying over a mountain.', u'A plane darts across a bright blue sky behind a mountain covered in snow', u'A plane leaves a contrail above the snowy mountain top.', u'A mountain that has a plane flying overheard in the distance.', u'A mountain view with a plume of smoke in the background'] """ def _load_target(self, id: int) -> List[str]: return [ann["caption"] for ann in super()._load_target(id)]