Skip to content

Commit

Permalink
visdom multi-obj segmentation + bbox support
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu Paul committed Sep 18, 2022
1 parent b4a3ec9 commit f7d4964
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
6 changes: 5 additions & 1 deletion pytracking/evaluation/multi_object_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ def track(self, image, info: dict = None) -> dict:
return out_merged

def visdom_draw_tracking(self, image, box, segmentation):
if isinstance(box, (OrderedDict, dict)):
if box is None:
box = []
elif isinstance(box, (OrderedDict, dict)):
box = [v for k, v in box.items()]
elif isinstance(box, list):
box = [list(col)[0] for col in zip(*[d.values() for d in box])]
else:
box = (box,)
if segmentation is None:
Expand Down
31 changes: 29 additions & 2 deletions pytracking/evaluation/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,27 @@ def _store_outputs(tracker_out: dict, defaults=None):
prev_output = OrderedDict(out)

init_default = {'target_bbox': init_info.get('init_bbox'),
'clf_target_bbox': init_info.get('init_bbox'),
'time': time.time() - start_time,
'segmentation': init_info.get('init_mask'),
'object_presence_score': 1.}

_store_outputs(out, init_default)

segmentation = out['segmentation'] if 'segmentation' in out else None
bboxes = [init_default['target_bbox']]
if 'clf_target_bbox' in out:
bboxes.append(out['clf_target_bbox'])
if 'clf_search_area' in out:
bboxes.append(out['clf_search_area'])
if 'segm_search_area' in out:
bboxes.append(out['segm_search_area'])

if self.visdom is not None:
tracker.visdom_draw_tracking(image, bboxes, segmentation)
elif tracker.params.visualization:
self.visualize(image, bboxes, segmentation)

for frame_num, frame_path in enumerate(seq.frames[1:], start=1):
while True:
if not self.pause_mode:
Expand All @@ -218,15 +233,25 @@ def _store_outputs(tracker_out: dict, defaults=None):
_store_outputs(out, {'time': time.time() - start_time})

segmentation = out['segmentation'] if 'segmentation' in out else None

bboxes = [out['target_bbox']]
if 'clf_target_bbox' in out:
bboxes.append(out['clf_target_bbox'])
if 'clf_search_area' in out:
bboxes.append(out['clf_search_area'])
if 'segm_search_area' in out:
bboxes.append(out['segm_search_area'])

if self.visdom is not None:
tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation)
tracker.visdom_draw_tracking(image, bboxes, segmentation)
elif tracker.params.visualization:
self.visualize(image, out['target_bbox'], segmentation)
self.visualize(image, bboxes, segmentation)

for key in ['target_bbox', 'segmentation']:
if key in output and len(output[key]) <= 1:
output.pop(key)

# next two lines are needed for oxuva output format.
output['image_shape'] = image.shape[:2]
output['object_presence_score_threshold'] = tracker.params.get('object_presence_score_threshold', 0.55)

Expand Down Expand Up @@ -674,6 +699,8 @@ def visualize(self, image, state, segmentation=None):

if isinstance(state, (OrderedDict, dict)):
boxes = [v for k, v in state.items()]
elif isinstance(state, list):
boxes = state
else:
boxes = (state,)

Expand Down
6 changes: 5 additions & 1 deletion pytracking/tracker/base/basetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def track(self, image, info: dict = None) -> dict:


def visdom_draw_tracking(self, image, box, segmentation=None):
if isinstance(box, OrderedDict):
if box is None:
box = []
elif isinstance(box, OrderedDict):
box = [v for k, v in box.items()]
elif isinstance(box, list):
box = box
else:
box = (box,)
if segmentation is None:
Expand Down
11 changes: 8 additions & 3 deletions pytracking/utils/visdom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import visdom
import visdom.server
from pytracking.features.preprocessing import numpy_to_torch
from pytracking.utils.plotting import show_image_with_boxes, overlay_mask
from pytracking.utils.plotting import show_image_with_boxes, overlay_mask, _pascal_color_map
import cv2
import torch
import copy
Expand Down Expand Up @@ -322,11 +322,16 @@ def draw_data(self):

boxes = [resize_factor * b.clone() for b in self.raw_data[1]]

colors = np.asarray(_pascal_color_map(), dtype=np.uint8)

for i, disp_rect in enumerate(boxes):
color = ((255*((i%3)>0)), 255*((i+1)%2), (255*(i%5))//4)
# Changed the color to match the _pascal_color_map used in overlay_mask
# color = ((255*((i%3)>0)), 255*((i+1)%2), (255*(i%5))//4)
color = colors[i+1].tolist()
cv2.rectangle(disp_image,
(int(disp_rect[0]), int(disp_rect[1])),
(int(disp_rect[0] + disp_rect[2]), int(disp_rect[1] + disp_rect[3])), color, 2)
(int(disp_rect[0] + disp_rect[2]), int(disp_rect[1] + disp_rect[3])), color, 1)

for i, mask in enumerate(self.raw_data[2], 1):
disp_image = overlay_mask(disp_image, mask * i)
disp_image = numpy_to_torch(disp_image).squeeze(0)
Expand Down

0 comments on commit f7d4964

Please sign in to comment.