168 lines
5.3 KiB
Python
168 lines
5.3 KiB
Python
|
import glob
|
||
|
import os
|
||
|
from collections import defaultdict
|
||
|
from html.parser import HTMLParser
|
||
|
from pathlib import Path
|
||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
from PIL import Image
|
||
|
|
||
|
from .vision import VisionDataset
|
||
|
|
||
|
|
||
|
class Flickr8kParser(HTMLParser):
|
||
|
"""Parser for extracting captions from the Flickr8k dataset web page."""
|
||
|
|
||
|
def __init__(self, root: Union[str, Path]) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
self.root = root
|
||
|
|
||
|
# Data structure to store captions
|
||
|
self.annotations: Dict[str, List[str]] = {}
|
||
|
|
||
|
# State variables
|
||
|
self.in_table = False
|
||
|
self.current_tag: Optional[str] = None
|
||
|
self.current_img: Optional[str] = None
|
||
|
|
||
|
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
|
||
|
self.current_tag = tag
|
||
|
|
||
|
if tag == "table":
|
||
|
self.in_table = True
|
||
|
|
||
|
def handle_endtag(self, tag: str) -> None:
|
||
|
self.current_tag = None
|
||
|
|
||
|
if tag == "table":
|
||
|
self.in_table = False
|
||
|
|
||
|
def handle_data(self, data: str) -> None:
|
||
|
if self.in_table:
|
||
|
if data == "Image Not Found":
|
||
|
self.current_img = None
|
||
|
elif self.current_tag == "a":
|
||
|
img_id = data.split("/")[-2]
|
||
|
img_id = os.path.join(self.root, img_id + "_*.jpg")
|
||
|
img_id = glob.glob(img_id)[0]
|
||
|
self.current_img = img_id
|
||
|
self.annotations[img_id] = []
|
||
|
elif self.current_tag == "li" and self.current_img:
|
||
|
img_id = self.current_img
|
||
|
self.annotations[img_id].append(data.strip())
|
||
|
|
||
|
|
||
|
class Flickr8k(VisionDataset):
|
||
|
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
|
||
|
|
||
|
Args:
|
||
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
||
|
ann_file (string): Path to annotation file.
|
||
|
transform (callable, optional): A function/transform that takes in a PIL image
|
||
|
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
||
|
target_transform (callable, optional): A function/transform that takes in the
|
||
|
target and transforms it.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Union[str, Path],
|
||
|
ann_file: str,
|
||
|
transform: Optional[Callable] = None,
|
||
|
target_transform: Optional[Callable] = None,
|
||
|
) -> None:
|
||
|
super().__init__(root, transform=transform, target_transform=target_transform)
|
||
|
self.ann_file = os.path.expanduser(ann_file)
|
||
|
|
||
|
# Read annotations and store in a dict
|
||
|
parser = Flickr8kParser(self.root)
|
||
|
with open(self.ann_file) as fh:
|
||
|
parser.feed(fh.read())
|
||
|
self.annotations = parser.annotations
|
||
|
|
||
|
self.ids = list(sorted(self.annotations.keys()))
|
||
|
|
||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||
|
"""
|
||
|
Args:
|
||
|
index (int): Index
|
||
|
|
||
|
Returns:
|
||
|
tuple: Tuple (image, target). target is a list of captions for the image.
|
||
|
"""
|
||
|
img_id = self.ids[index]
|
||
|
|
||
|
# Image
|
||
|
img = Image.open(img_id).convert("RGB")
|
||
|
if self.transform is not None:
|
||
|
img = self.transform(img)
|
||
|
|
||
|
# Captions
|
||
|
target = self.annotations[img_id]
|
||
|
if self.target_transform is not None:
|
||
|
target = self.target_transform(target)
|
||
|
|
||
|
return img, target
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.ids)
|
||
|
|
||
|
|
||
|
class Flickr30k(VisionDataset):
|
||
|
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
|
||
|
|
||
|
Args:
|
||
|
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
||
|
ann_file (string): Path to annotation file.
|
||
|
transform (callable, optional): A function/transform that takes in a PIL image
|
||
|
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
||
|
target_transform (callable, optional): A function/transform that takes in the
|
||
|
target and transforms it.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: str,
|
||
|
ann_file: str,
|
||
|
transform: Optional[Callable] = None,
|
||
|
target_transform: Optional[Callable] = None,
|
||
|
) -> None:
|
||
|
super().__init__(root, transform=transform, target_transform=target_transform)
|
||
|
self.ann_file = os.path.expanduser(ann_file)
|
||
|
|
||
|
# Read annotations and store in a dict
|
||
|
self.annotations = defaultdict(list)
|
||
|
with open(self.ann_file) as fh:
|
||
|
for line in fh:
|
||
|
img_id, caption = line.strip().split("\t")
|
||
|
self.annotations[img_id[:-2]].append(caption)
|
||
|
|
||
|
self.ids = list(sorted(self.annotations.keys()))
|
||
|
|
||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||
|
"""
|
||
|
Args:
|
||
|
index (int): Index
|
||
|
|
||
|
Returns:
|
||
|
tuple: Tuple (image, target). target is a list of captions for the image.
|
||
|
"""
|
||
|
img_id = self.ids[index]
|
||
|
|
||
|
# Image
|
||
|
filename = os.path.join(self.root, img_id)
|
||
|
img = Image.open(filename).convert("RGB")
|
||
|
if self.transform is not None:
|
||
|
img = self.transform(img)
|
||
|
|
||
|
# Captions
|
||
|
target = self.annotations[img_id]
|
||
|
if self.target_transform is not None:
|
||
|
target = self.target_transform(target)
|
||
|
|
||
|
return img, target
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self.ids)
|