Skip to content

Commit

Permalink
Merge pull request #131 from tdiekel/master
Browse files Browse the repository at this point in the history
Support for training custom datasets
  • Loading branch information
bharatsingh430 authored Oct 8, 2019
2 parents 3b82040 + ce7b3bc commit 0117d7a
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion lib/dataset/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def create_roidb_from_box_list(self, box_list, mapping_list, gt_roidb):
self.num_images = len(gt_roidb)
assert len(box_list) == self.num_images, 'number of boxes matrix must match number of images'
roidb = []
stats = np.zeros(81)
stats = np.zeros(self.num_classes)
for i in range(self.num_images):
roi_rec = dict()
roi_rec['image'] = gt_roidb[i]['image']
Expand Down
4 changes: 2 additions & 2 deletions symbols/faster/mobilenetv2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')

num_reg_classes = 1
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)

Expand Down Expand Up @@ -351,7 +351,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')

num_reg_classes = 1
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)

Expand Down
4 changes: 2 additions & 2 deletions symbols/faster/resnet_mx_101_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down Expand Up @@ -353,7 +353,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down
4 changes: 2 additions & 2 deletions symbols/faster/resnet_mx_101_e2e_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down Expand Up @@ -428,7 +428,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down
4 changes: 2 additions & 2 deletions symbols/faster/resnet_mx_50_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down Expand Up @@ -354,7 +354,7 @@ def get_symbol_rcnn(self, cfg, is_train=True):

fc_new_2 = mx.sym.FullyConnected(name='fc_new_2', data=fc_new_1_relu, num_hidden=1024)
fc_new_2_relu = mx.sym.Activation(data=fc_new_2, act_type='relu', name='fc_new_2_relu')
num_classes = 81
num_classes = cfg.dataset.NUM_CLASSES
num_reg_classes = 1
cls_score = mx.sym.FullyConnected(name='cls_score', data=fc_new_2_relu, num_hidden=num_classes)
bbox_pred = mx.sym.FullyConnected(name='bbox_pred', data=fc_new_2_relu, num_hidden=num_reg_classes * 4)
Expand Down

0 comments on commit 0117d7a

Please sign in to comment.