159 lines
5.5 KiB
Python
159 lines
5.5 KiB
Python
import csv
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
from PIL import Image
|
|
|
|
from .utils import download_and_extract_archive
|
|
from .vision import VisionDataset
|
|
|
|
|
|
class Kitti(VisionDataset):
|
|
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
|
|
|
|
It corresponds to the "left color images of object" dataset, for object detection.
|
|
|
|
Args:
|
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
|
Expects the following folder structure if download=False:
|
|
|
|
.. code::
|
|
|
|
<root>
|
|
└── Kitti
|
|
└─ raw
|
|
├── training
|
|
| ├── image_2
|
|
| └── label_2
|
|
└── testing
|
|
└── image_2
|
|
train (bool, optional): Use ``train`` split if true, else ``test`` split.
|
|
Defaults to ``train``.
|
|
transform (callable, optional): A function/transform that takes in a PIL image
|
|
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
transforms (callable, optional): A function/transform that takes input sample
|
|
and its target as entry and returns a transformed version.
|
|
download (bool, optional): If true, downloads the dataset from the internet and
|
|
puts it in root directory. If dataset is already downloaded, it is not
|
|
downloaded again.
|
|
|
|
"""
|
|
|
|
data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
|
|
resources = [
|
|
"data_object_image_2.zip",
|
|
"data_object_label_2.zip",
|
|
]
|
|
image_dir_name = "image_2"
|
|
labels_dir_name = "label_2"
|
|
|
|
def __init__(
|
|
self,
|
|
root: Union[str, Path],
|
|
train: bool = True,
|
|
transform: Optional[Callable] = None,
|
|
target_transform: Optional[Callable] = None,
|
|
transforms: Optional[Callable] = None,
|
|
download: bool = False,
|
|
):
|
|
super().__init__(
|
|
root,
|
|
transform=transform,
|
|
target_transform=target_transform,
|
|
transforms=transforms,
|
|
)
|
|
self.images = []
|
|
self.targets = []
|
|
self.train = train
|
|
self._location = "training" if self.train else "testing"
|
|
|
|
if download:
|
|
self.download()
|
|
if not self._check_exists():
|
|
raise RuntimeError("Dataset not found. You may use download=True to download it.")
|
|
|
|
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
|
|
if self.train:
|
|
labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
|
|
for img_file in os.listdir(image_dir):
|
|
self.images.append(os.path.join(image_dir, img_file))
|
|
if self.train:
|
|
self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
|
"""Get item at a given index.
|
|
|
|
Args:
|
|
index (int): Index
|
|
Returns:
|
|
tuple: (image, target), where
|
|
target is a list of dictionaries with the following keys:
|
|
|
|
- type: str
|
|
- truncated: float
|
|
- occluded: int
|
|
- alpha: float
|
|
- bbox: float[4]
|
|
- dimensions: float[3]
|
|
- locations: float[3]
|
|
- rotation_y: float
|
|
|
|
"""
|
|
image = Image.open(self.images[index])
|
|
target = self._parse_target(index) if self.train else None
|
|
if self.transforms:
|
|
image, target = self.transforms(image, target)
|
|
return image, target
|
|
|
|
def _parse_target(self, index: int) -> List:
|
|
target = []
|
|
with open(self.targets[index]) as inp:
|
|
content = csv.reader(inp, delimiter=" ")
|
|
for line in content:
|
|
target.append(
|
|
{
|
|
"type": line[0],
|
|
"truncated": float(line[1]),
|
|
"occluded": int(line[2]),
|
|
"alpha": float(line[3]),
|
|
"bbox": [float(x) for x in line[4:8]],
|
|
"dimensions": [float(x) for x in line[8:11]],
|
|
"location": [float(x) for x in line[11:14]],
|
|
"rotation_y": float(line[14]),
|
|
}
|
|
)
|
|
return target
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.images)
|
|
|
|
@property
|
|
def _raw_folder(self) -> str:
|
|
return os.path.join(self.root, self.__class__.__name__, "raw")
|
|
|
|
def _check_exists(self) -> bool:
|
|
"""Check if the data directory exists."""
|
|
folders = [self.image_dir_name]
|
|
if self.train:
|
|
folders.append(self.labels_dir_name)
|
|
return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
|
|
|
|
def download(self) -> None:
|
|
"""Download the KITTI data if it doesn't exist already."""
|
|
|
|
if self._check_exists():
|
|
return
|
|
|
|
os.makedirs(self._raw_folder, exist_ok=True)
|
|
|
|
# download files
|
|
for fname in self.resources:
|
|
download_and_extract_archive(
|
|
url=f"{self.data_url}{fname}",
|
|
download_root=self._raw_folder,
|
|
filename=fname,
|
|
)
|