project1/datasets/torchvision_datasets/coco.py

85 lines
3.2 KiB
Python

# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from torchvision
# ------------------------------------------------------------------------
"""
Copy-Paste from torchvision, but add utility of caching images on memory
"""
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import tqdm
from io import BytesIO
class CocoDetection(VisionDataset):
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None,
cache_mode=False, local_rank=0, local_size=1):
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
self.cache_mode = cache_mode
self.local_rank = local_rank
self.local_size = local_size
if cache_mode:
self.cache = {}
self.cache_images()
def cache_images(self):
self.cache = {}
for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids):
if index % self.local_size != self.local_rank:
continue
path = self.coco.loadImgs(img_id)[0]['file_name']
with open(os.path.join(self.root, path), 'rb') as f:
self.cache[path] = f.read()
def get_image(self, path):
if self.cache_mode:
if path not in self.cache.keys():
with open(os.path.join(self.root, path), 'rb') as f:
self.cache[path] = f.read()
return Image.open(BytesIO(self.cache[path])).convert('RGB')
return Image.open(os.path.join(self.root, path)).convert('RGB')
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
img = self.get_image(path)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.ids)