Skip to content

Commit

Permalink
Updates to support coco training and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
amdegroot committed Mar 5, 2018
1 parent c8c386b commit 66faf9c
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 225 deletions.
27 changes: 24 additions & 3 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES
from .coco import COCODetection, COCOAnnotationTransform
from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT
from .config import *
import torch
import cv2
import numpy as np


def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on
0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
return torch.stack(imgs, 0), targets


def base_transform(image, size, mean):
x = cv2.resize(image, (size, size)).astype(np.float32)
# x = cv2.resize(np.array(image), (size, size)).astype(np.float32)
x -= mean
x = x.astype(np.float32)
return x
Expand Down
114 changes: 82 additions & 32 deletions data/coco.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .config import HOME
import os
import os.path
import sys
Expand All @@ -7,29 +8,41 @@
import cv2
import numpy as np

COCO_ROOT = os.path.join(HOME, 'data/coco/')
IMAGES = 'images'
ANNOTATIONS = 'annotations'
COCO_API = 'PythonAPI'
INSTANCES_SET = 'instances_{}.json'
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')


class COCOAnnotationTransform(object):
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
Initilized with a dictionary lookup of classnames to indexes
Arguments:
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
(default: alphabetic indexing of VOC's 20 classes)
keep_difficult (bool, optional): keep difficult instances or not
(default: False)
height (int): height
width (int): width
"""

# def __init__(self)

def __call__(self, target, width, height):
"""
Arguments:
target (annotation) : the target annotation to be made usable
will be an ET.Element
Args:
target (dict): COCO target json annotation as a python dict
height (int): height
width (int): width
Returns:
a list containing lists of bounding boxes [bbox coords, class name]
a list containing lists of bounding boxes [bbox coords, class idx]
"""
scale = np.array([width, height, width, height])
res = []
Expand All @@ -41,35 +54,40 @@ def __call__(self, target, width, height):
label_idx = obj['category_id']
final_box = list(np.array(bbox)/scale)
final_box.append(label_idx)
res += [final_box] # [xmin, ymin, xmax, ymax, label_ind]
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]


class COCODetection(data.Dataset):
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
set_name (string): Name of the specific set of COCO images.
transform (callable, optional): A function/transform that augments the
raw images`
target_transform (callable, optional): A function/transform that takes
in the target (bbox) and transforms it.
"""

def __init__(self, root, annFile, transform=None, target_transform=None):
def __init__(self, root, image_set, transform=None,
target_transform=None, dataset_name='COCO2014'):
sys.path.append(os.path.join(root, COCO_API))
from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile)
self.root = os.path.join(root, IMAGES, image_set)
self.coco = COCO(os.path.join(root, ANNOTATIONS,
INSTANCES_SET.format(image_set)))
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
self.name = dataset_name

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
tuple: Tuple (image, target).
target is the object returned by ``coco.loadAnns``.
"""
im, gt, h, w = self.pull_item(index)
return im, gt
Expand All @@ -82,26 +100,58 @@ def pull_item(self, index):
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
tuple: Tuple (image, target, height, width).
target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
ann_ids = self.coco.getAnnIds(imgIds=img_id)
target = self.coco.loadAnns(ann_ids)
path = self.coco.loadImgs(img_id)[0]['file_name']
img = cv2.imread(os.path.join(self.root, path))
height, width, channels = img.shape
if self.target_transform is not None:
target = self.target_transform(target, width, height)
if self.transform is not None:
target = np.array(target)
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
img, boxes, labels = self.transform(img, target[:, :4],
target[:, 4])
# to rgb
img = img[:, :, (2, 1, 0)]
# img = img.transpose(2, 0, 1)
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
return torch.from_numpy(img).permute(2, 0, 1), target, height, width

def pull_image(self, index):
'''Returns the original image object at index in PIL form
Note: not using self.__getitem__(), as any transformations passed in
could mess up this functionality.
Argument:
index (int): index of img to show
Return:
cv2 img
'''
img_id = self.ids[index]
path = self.coco.loadImgs(img_id)[0]['file_name']
return cv2.imread(os.path.join(self.root, path), cv2.IMREAD_COLOR)

def pull_anno(self, index):
'''Returns the original annotation of image at index
Note: not using self.__getitem__(), as any transformations passed in
could mess up this functionality.
Argument:
index (int): index of img to get annotation of
Return:
list: [img_id, [(label, bbox coords),...]]
eg: ('001718', [('dog', (96, 13, 438, 332))])
'''
img_id = self.ids[index]
ann_ids = self.coco.getAnnIds(imgIds=img_id)
return self.coco.loadAnns(ann_ids)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
Expand Down
86 changes: 29 additions & 57 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,35 @@
import os.path

# gets home dir cross platform
home = os.path.expanduser("~")
ddir = os.path.join(home,"data/VOCdevkit/")

# note: if you used our download scripts, this should be right
VOCroot = ddir # path to VOCdevkit root dir

# default batch size
BATCHES = 32
# data reshuffled at every epoch
SHUFFLE = True
# number of subprocesses to use for data loading
WORKERS = 4


#SSD300 CONFIGS
# newer version: use additional conv11_2 layer as last layer before multibox layers
v2 = {
'feature_maps' : [38, 19, 10, 5, 3, 1],

'min_dim' : 300,

'steps' : [8, 16, 32, 64, 100, 300],

'min_sizes' : [30, 60, 111, 162, 213, 264],

'max_sizes' : [60, 111, 162, 213, 264, 315],

# 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3],
# [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]],
'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2], [2]],

'variance' : [0.1, 0.2],

'clip' : True,

'name' : 'v2',
HOME = os.path.expanduser("~")

# for making bounding boxes pretty
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))

MEANS = (104, 117, 123)

# SSD300 CONFIGS
voc = {
'feature_maps': [38, 19, 10, 5, 3, 1],
'min_dim': 300,
'steps': [8, 16, 32, 64, 100, 300],
'min_sizes': [30, 60, 111, 162, 213, 264],
'max_sizes': [60, 111, 162, 213, 264, 315],
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
'variance': [0.1, 0.2],
'clip': True,
'name': 'VOC',
}

# use average pooling layer as last layer before multibox layers
v1 = {
'feature_maps' : [38, 19, 10, 5, 3, 1],

'min_dim' : 300,

'steps' : [8, 16, 32, 64, 100, 300],

'min_sizes' : [30, 60, 114, 168, 222, 276],

'max_sizes' : [-1, 114, 168, 222, 276, 330],

# 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
'aspect_ratios' : [[1,1,2,1/2],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],
[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3]],

'variance' : [0.1, 0.2],

'clip' : True,

'name' : 'v1',
coco = {
'feature_maps': [38, 19, 10, 5, 3, 1],
'min_dim': 300,
'steps': [8, 16, 32, 64, 100, 300],
'min_sizes': [21, 45, 99, 153, 207, 261],
'max_sizes': [45, 99, 153, 207, 261, 315],
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
'variance': [0.1, 0.2],
'clip': True,
'name': 'COCO',
}
30 changes: 5 additions & 25 deletions data/voc0712.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Updated by: Ellis Brown, Max deGroot
"""

from .config import HOME
import os
import os.path
import sys
Expand All @@ -27,12 +27,11 @@
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')

# for making bounding boxes pretty
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
# note: if you used our download scripts, this should be right
VOC_ROOT = os.path.join(HOME, "data/VOCdevkit/")


class AnnotationTransform(object):
class VOCAnnotationTransform(object):
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
Initilized with a dictionary lookup of classnames to indexes
Expand Down Expand Up @@ -115,6 +114,7 @@ def __init__(self, root, image_sets, transform=None, target_transform=None,

def __getitem__(self, index):
im, gt, h, w = self.pull_item(index)

return im, gt

def __len__(self):
Expand Down Expand Up @@ -183,23 +183,3 @@ def pull_tensor(self, index):
tensorized version of img, squeezed
'''
return torch.Tensor(self.pull_image(index)).unsqueeze_(0)


def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
targets = []
imgs = []
for sample in batch:
imgs.append(sample[0])
targets.append(torch.FloatTensor(sample[1]))
return torch.stack(imgs, 0), targets
2 changes: 1 addition & 1 deletion layers/functions/detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.autograd import Function
from ..box_utils import decode, nms
from data import v2 as cfg
from data import voc as cfg


class Detect(Function):
Expand Down
Loading

0 comments on commit 66faf9c

Please sign in to comment.