Skip to content

Commit

Permalink
Adding DUC semantic segmentation model (onnx#80)
Browse files Browse the repository at this point in the history
* Update README.md

* Create backlogs.md

* Create contribute.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update contribute.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* image folder

* Update README.md

* Update README.md

* Update README.md

* added models folder

* fixes

* Update and rename backlogs.md to squeezenet.md

* Create resnet.md

* Create vgg.md

* Create mobilenet.md

* Change layout - add model pages

* Update README.md

* Update README.md

* added sqnet train nb

* Update train_notebook_squeezenet.ipynb

* Delete train_notebook_squeezenet.ipynb

* added sqnet train nb

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* added test script

* created folders for models

* added extract imagenet script

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Rename train_notebook_squeezenet.ipynb to train_squeezenet.ipynb

* Update squeezenet.md

* adding artifacts for resnet

* minor fix

* updated train_squeezenet

* updated imagenet_verify

* Update README.md

* Update README.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update resnet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update squeezenet.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* resnet training notebook

* fix documentation

* changing name

* renaming squeezenet

* renamed files

* fixing link

* fixed format

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update imagenet_verify.ipynb

* updated train_squeezenet

* updated train_squeezenet

* Update README.md

* added train_vgg

* fix train_vgg

* Update README.md

* Update README.md

* Create imagenet_prep.md

* Update README.md

* Update README.md

* Update README.md

* Update imagenet_prep.md

* Update imagenet_prep.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update contribute.md

* documentation for mobilenet

* add mobilenet training

* fixed link

* fix spelling

* Update README.md

* fixed readme

* fixed imagenet_inference

* added MMS info

* initial feedback

* Update README.md

* Update README.md

* Update README.md

* Update contribute.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* comments and edits on the notebooks and imagenet prep

* added onnx links, fixed CamelCase and add accuracy tables

* fixed formatting

* fixed formatting

* fixed camelcase and contrbution section

* added model description

* added keywords and contributions

* added model server description

* Fixed file

* fixed landing page

* add image placeholders and categories

* fixed model links

* fixed links for real this time

* fixed resnet

* fixed resnet

* updated training and inference notebooks

* fixed models

* fixed models

* added synset

* synset

* removed DS store

* fixed spell

* fixed contri page

* fixed contri page

* fixed contri page

* updated imagenet_verify

* updated extract_imagenet.py, fixed archive sizes, changed name of imagenet_verify, fixed links

* updated target

* version differences

* fixed preprocess

* fixed

* updated readmes, moved folders

* fixed links

* removed DS

* removed DS

* fixed spacing

* push line brreaks

* review commenst

* addressed PR comments

* fixed formatting

* minor updates

* fixed code

* updates after feedback

* fixed links

* Delete .DS_Store

* added PR template

* added model use-cases

* added badges

* Updates after feedback

* fixed links, added imagenet licence info

* formatting

* update wip models

* minor edits

* WaveNet is not based on GAN framework

Rename the heading to Generative Models instead of GAN because WaveNet doesn't involve GAN training framework.

* added license, onnx version, removed MMS, MXNet code

* added pre/post-process scripts, updated PR template, model readme

* typos and grammar fixes

* added sample test data, updated val acc

* fixed formatting

* Update PULL_REQUEST_TEMPLATE.md

* remove README

* update README

* fixed links

* removed keywords, updated notebooks, fixed typos

* added model checker info

* merged ack with ref

* added arcface model

* updated PR temp, image name, imagenet pre/post

* added md5 checksum, opset version

* fixed links

* Update PULL_REQUEST_TEMPLATE.md

* added infer nb

* added duc val nb

* fix infer nb

* update main README

* fixed tiny-yolov2 link

* added duc readme

* fixed links, formatting

* added pre/post-process code

* fix typos

* update readme, added comments

* fix val nb

* added cell for gen .lst

* updated readme header

* added scripts info and comments

* updated notebooks
  • Loading branch information
abhinavs95 authored and ankkhedia committed Jul 28, 2018
1 parent 4aa95e1 commit 1f3653d
Show file tree
Hide file tree
Showing 9 changed files with 1,525 additions and 11 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,22 @@ These models detect and/or recognize human faces in images. Some more popular mo

<hr>

### Semantic Segmentation
Semantic segmentation models partition an input image by labeling each pixel into a set of pre-defined categories.

|Model Class |Reference |Description |
|-|-|-|
|<b>[DUC](models/semantic_segmentation/DUC/)</b>|[Wang et al.](https://arxiv.org/abs/1702.08502)|Deep CNN based model with >80% [mIOU](/models/semantic_segmentation/DUC/README.md/#metric) (mean Intersection Over Union) trained on urban street images|
|<b>FCN</b>|[Long et al.](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)|[contribute](contribute.md)|

<hr>

### Object Detection & Segmentation
These models detect the presence of multiple objects in an image and segment out areas of the image where the objects are detected.

|Model Class |Reference |Description |
|-|-|-|
|<b>Tiny_YOLOv2</b>|[Redmon et al.](https://arxiv.org/pdf/1612.08242.pdf)|Deep CNN model for Object Detection|
|<b>[Tiny_YOLOv2](tiny_yolov2)</b>|[Redmon et al.](https://arxiv.org/pdf/1612.08242.pdf)|Deep CNN model for Object Detection|
|<b>SSD</b>|[Liu et al.](https://arxiv.org/abs/1512.02325)|[contribute](contribute.md)|
|<b>Faster-RCNN</b>|[Ren et al.](https://arxiv.org/abs/1506.01497)|[contribute](contribute.md)|
|<b>Mask-RCNN</b>|[He et al.](https://arxiv.org/abs/1703.06870)|[contribute](contribute.md)|
Expand All @@ -59,15 +69,6 @@ These models detect the presence of multiple objects in an image and segment out

<hr>

### Semantic Segmentation
Semantic segmentation models will identify multiple classes of objects in an image and provide information on the areas of the image that object was detected.

|Model Class |Reference |Description |
|-|-|-|
|<b>FCN</b>|[Long et al.](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)|[contribute](contribute.md)|

<hr>

### Emotion Recognition

|Model Class |Reference |Description |
Expand Down Expand Up @@ -226,4 +227,4 @@ Do you want to contribute a model? To get started, pick any model presented abov

# License

[MIT License](LICENSE)
[MIT License](LICENSE)
71 changes: 71 additions & 0 deletions models/semantic_segmentation/DUC/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Dense Upsampling Convolution (DUC)

## Use cases
DUC is a semantic segmentation model, i.e., for an input image the model labels each pixel in the image into a set of pre-defined categories. The model provides very good accuracy in terms of [mIOU](#metric) (mean Intersection Over Union) score and can be used in any application requiring semantic segmentation. In particular, since the model is trained on the [cityscapes dataset](#dset) which contains images from urban street scenes, it can be used effectively in self driving vehicle systems.

## Description
DUC is a CNN based model for semantic segmentation which uses an image classification network (ResNet) as a backend and achieves improved accuracy in terms of mIOU score using two novel techniques. The first technique is called Dense Upsampling Convolution (DUC) which generates pixel-level prediction by capturing and decoding more detailed information that is generally missing in bilinear upsampling. Secondly, a framework called Hybrid Dilated Convolution (HDC) is proposed in the encoding phase which enlarges the receptive fields of the network to aggregate global information. It also alleviates the checkerboard receptive field problem ("gridding") caused by the standard dilated convolution operation.

## Model
The model ResNet101_DUC_HDC uses ResNet101 as a backend network with both Dense Upsampling Convolution (DUC) and Hybrid Dilated Convolution (HDC) techniques.

|Model |Download |Checksum|Download (with sample test data)| ONNX version |Opset version|[mIOU](#metric) (%)|
|-------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|
|ResNet101_DUC_HDC| [248.6 MB](https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx) |[MD5](https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC-md5.txt) | [282.0 MB](https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.tar.gz) |1.2.2 |7 |81.92 |

## Inference
We used MXNet as framework to perform inference. View the notebook [duc-inference](duc-inference.ipynb) to understand how to use above models for doing inference. A brief description of the inference process is provided below:

### Input
Since the model is trained on the cityscapes dataset which contains images of urban street scenes, the input should preferably be an image of a street scene to produce best results. There are no constraints on the size of the image. The example in the inference notebook is done using a png image.

### Preprocessing
The DUC layer has an effect of partitioning the image into d^2 subparts (d = downsampling rate). This is why the input image is extrapolated with a small border in order to obtain an accurate reshaped image after the DUC layer. After this the image is normalized using mean subtraction. Check [duc-preprocess.py](duc-preprocess.py) for code.

### Output
The output of the network is a tensor of shape (1 X `label_num` X `H` * `W`) where `H` and `W` are the height and width of the output segmented map.

### Postprocessing
The output tensor is reshaped and resized to give the softmax map of shape (`H` X `W` X `label_num`). The raw label map is computed by doing an argmax on the softmax map. The script [cityscapes_labels.py](cityscapes_labels.py) contains the segmentation category labels and their corresponding color map. Using this the colorized segmented images are generated. Check [duc-postprocess.py](duc-postprocess.py) for code.

<!--
To do quick inference with the model, check out [Model Server](https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md/#arcface-resnet100_onnx).
-->

## <a name="dset"></a>Dataset
Cityscapes dataset is used for training and validation. It is a large dataset that focuses on semantic understanding of urban street scenes. It contains 5000 images with fine annotations across 50 cities, different seasons, varying scene layout and background. There are a total of 30 categories in the dataset of which 19 are included for training and evaluation. The training, validation and test set contains 2975, 500 and 1525 fine images, respectively.

### Download
First, go to the [Cityscapes download page](https://www.cityscapes-dataset.com/downloads/) and register for an account (login if account already made). Next, find and download the following two files:

|Filename | Size | Details|
|-------------------------|:------|:-------|
|leftImg8bit_trainvaltest.zip| 11 GB| train/val/test images|
|gtFine_trainvaltest.zip | 241 MB| fine annotations for train and val sets|

### Setup
* Unpack the zip files into folders `leftImg8bit_trainvaltest` and `gtFine_trainvaltest`.
* Use the path to the train/val folders inside these folders for training/validation.

Please note that the dataset is under copyright. Refer to the [citation](https://www.cityscapes-dataset.com/citation/) page for details.

## Validation accuracy
The [mIOU](#metric) score obtained by the models on the validation set are mentioned above and they match with those mentioned in the paper.

## Training
Coming soon.

## <a name="metric"></a>Validation
**mean Intersection Over Union (mIOU)** is the metric used for validation. For each class the intersection over union (IOU) of pixel labels between the output and the target segmentation maps is computed and then averaged over all classes to give us the mean intersection over union (mIOU).

We used MXNet framework to compute mIOU of the models on the validation set described above. Use the notebook [duc-validation](duc-validation.ipynb) to verify the mIOU of the model. The scripts [cityscapes_loader.py](cityscapes_loader.py), [cityscapes_labels.py](cityscapes_labels.py) and [utils.py](utils.py) are used in the notebook for data loading and processing.

## References
* All models are from the paper [Understanding Convolution for Semantic Segmentation](https://arxiv.org/abs/1702.08502).
* [TuSimple-DUC repo](https://github.com/TuSimple/TuSimple-DUC), [MXNet](http://mxnet.incubator.apache.org)

## Contributors
[abhinavs95](https://github.com/abhinavs95) (Amazon AI)

## License
Apache 2.0
182 changes: 182 additions & 0 deletions models/semantic_segmentation/DUC/cityscapes_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/python
#
'''
This script contains label names, original ids, modified train ids and color maps for the categories in the cityscapes dataset.
Used in the duc-inference.ipynb and duc-validation.ipynb
'''
from collections import namedtuple


#--------------------------------------------------------------------------------
# Definitions
#--------------------------------------------------------------------------------

# a label and all meta information
Label = namedtuple( 'Label' , [

'name' , # The identifier of this label, e.g. 'car', 'person', ... .
# We use them to uniquely name a class

'id' , # An integer ID that is associated with this label.
# The IDs are used to represent the label in ground truth images
# An ID of -1 means that this label does not have an ID and thus
# is ignored when creating ground truth images (e.g. license plate).
# Do not modify these IDs, since exactly these IDs are expected by the
# evaluation server.

'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
# ground truth images with train IDs, using the tools provided in the
# 'preparation' folder. However, make sure to validate or submit results
# to our evaluation server using the regular IDs above!
# For trainIds, multiple labels might have the same ID. Then, these labels
# are mapped to the same class in the ground truth images. For the inverse
# mapping, we use the label that is defined first in the list below.
# For example, mapping all void-type classes to the same ID in training,
# might make sense for some approaches.
# Max value is 255!

'category' , # The name of the category that this label belongs to

'categoryId' , # The ID of this category. Used to create ground truth images
# on category level.

'hasInstances', # Whether this label distinguishes between single instances or not

'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
# during evaluations or not

'color' , # The color of this label
] )


#--------------------------------------------------------------------------------
# A list of all labels
#--------------------------------------------------------------------------------

# Please adapt the train IDs as appropriate for you approach.
# Note that you might want to ignore labels with ID 255 during training.
# Further note that the current train IDs are only a suggestion. You can use whatever you like.
# Make sure to provide your results using the original IDs and not the training IDs.
# Note that many IDs are ignored in evaluation and thus you never need to predict these!

labels = [
# name id trainId category catId hasInstances ignoreInEval color
Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
]


#--------------------------------------------------------------------------------
# Create dictionaries for a fast lookup
#--------------------------------------------------------------------------------

# Please refer to the main method below for example usages!

# name to label object
name2label = { label.name : label for label in labels }
# id to label object
id2label = { label.id : label for label in labels }
# trainId to label object
trainId2label = { label.trainId : label for label in reversed(labels) }
# category to list of label objects
category2labels = {}
for label in labels:
category = label.category
if category in category2labels:
category2labels[category].append(label)
else:
category2labels[category] = [label]

#--------------------------------------------------------------------------------
# Assure single instance name
#--------------------------------------------------------------------------------

# returns the label name that describes a single instance (if possible)
# e.g. input | output
# ----------------------
# car | car
# cargroup | car
# foo | None
# foogroup | None
# skygroup | None
def assureSingleInstanceName( name ):
# if the name is known, it is not a group
if name in name2label:
return name
# test if the name actually denotes a group
if not name.endswith("group"):
return None
# remove group
name = name[:-len("group")]
# test if the new name exists
if not name in name2label:
return None
# test if the new name denotes a label that actually has instances
if not name2label[name].hasInstances:
return None
# all good then
return name

#--------------------------------------------------------------------------------
# Main for testing
#--------------------------------------------------------------------------------

# just a dummy main
if __name__ == "__main__":
# Print all the labels
print("List of cityscapes labels:")
print("")
print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))
print(" " + ('-' * 98))
for label in labels:
print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))
print("")

print("Example usages:")

# Map from name to label
name = 'car'
id = name2label[name].id
print("ID of label '{name}': {id}".format( name=name, id=id ))

# Map from ID to label
category = id2label[id].category
print("Category of label with ID '{id}': {category}".format( id=id, category=category ))

# Map from trainID to label
trainId = 0
name = trainId2label[trainId].name
print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))
Loading

0 comments on commit 1f3653d

Please sign in to comment.