# ------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from torchvision # ------------------------------------------------------------------------ """ Copy-Paste from torchvision, but add utility of caching images on memory """ from torchvision.datasets.vision import VisionDataset from PIL import Image import os import os.path import tqdm from io import BytesIO class CocoDetection(VisionDataset): """`MS Coco Detection `_ Dataset. Args: root (string): Root directory where images are downloaded to. annFile (string): Path to json annotation file. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, cache_mode=False, local_rank=0, local_size=1): super(CocoDetection, self).__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) self.cache_mode = cache_mode self.local_rank = local_rank self.local_size = local_size if cache_mode: self.cache = {} self.cache_images() def cache_images(self): self.cache = {} for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): if index % self.local_size != self.local_rank: continue path = self.coco.loadImgs(img_id)[0]['file_name'] with open(os.path.join(self.root, path), 'rb') as f: self.cache[path] = f.read() def get_image(self, path): if self.cache_mode: if path not in self.cache.keys(): with open(os.path.join(self.root, path), 'rb') as f: self.cache[path] = f.read() return Image.open(BytesIO(self.cache[path])).convert('RGB') return Image.open(os.path.join(self.root, path)).convert('RGB') def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. """ coco = self.coco img_id = self.ids[index] ann_ids = coco.getAnnIds(imgIds=img_id) target = coco.loadAnns(ann_ids) path = coco.loadImgs(img_id)[0]['file_name'] img = self.get_image(path) if self.transforms is not None: img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.ids)