Skip to content

Commit

Permalink
Fix bug when loading group ids in CVAT video tasks (#1917)
Browse files Browse the repository at this point in the history
  • Loading branch information
ehofesmann authored Jul 5, 2022
1 parent 8a2ef7c commit 56a7aa5
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 4 deletions.
25 changes: 22 additions & 3 deletions fiftyone/utils/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4450,6 +4450,7 @@ def download_annotations(self, results):
for track_index, track in enumerate(tracks, 1):
label_id = track["label_id"]
shapes = track["shapes"]
track_group_id = track.get("group", None)
for shape in shapes:
shape["label_id"] = label_id

Expand All @@ -4468,6 +4469,7 @@ def download_annotations(self, results):
ignore_types,
assigned_scalar_attrs=scalar_attrs,
track_index=track_index,
track_group_id=track_group_id,
immutable_attrs=immutable_attrs,
occluded_attrs=_occluded_attrs,
group_id_attrs=_group_id_attrs,
Expand Down Expand Up @@ -5311,6 +5313,7 @@ def _parse_shapes_tags(
ignore_types,
assigned_scalar_attrs=False,
track_index=None,
track_group_id=None,
immutable_attrs=None,
occluded_attrs=None,
group_id_attrs=None,
Expand Down Expand Up @@ -5350,6 +5353,7 @@ def _parse_shapes_tags(
ignore_types,
assigned_scalar_attrs=assigned_scalar_attrs,
track_index=track_index,
track_group_id=track_group_id,
immutable_attrs=immutable_attrs,
occluded_attrs=occluded_attrs,
group_id_attrs=group_id_attrs,
Expand Down Expand Up @@ -5378,6 +5382,7 @@ def _parse_shapes_tags(
ignore_types,
assigned_scalar_attrs=assigned_scalar_attrs,
track_index=track_index,
track_group_id=track_group_id,
immutable_attrs=immutable_attrs,
occluded_attrs=occluded_attrs,
group_id_attrs=group_id_attrs,
Expand All @@ -5401,6 +5406,7 @@ def _parse_annotation(
ignore_types,
assigned_scalar_attrs=False,
track_index=None,
track_group_id=None,
immutable_attrs=None,
occluded_attrs=None,
group_id_attrs=None,
Expand Down Expand Up @@ -5444,6 +5450,7 @@ def _parse_annotation(
immutable_attrs=immutable_attrs,
occluded_attrs=occluded_attrs,
group_id_attrs=group_id_attrs,
group_id=track_group_id,
)

# Non-keyframe annotations were interpolated from keyframes but
Expand Down Expand Up @@ -6550,6 +6557,8 @@ class CVATShape(CVATLabel):
corresponding attribute linked to the CVAT occlusion widget, if any
group_id_attrs (None): a dictonary mapping class names to the
corresponding attribute linked to the CVAT group id, if any
group_id (None): an optional group id value for this shape when it
cannot be parsed from the label dict
"""

def __init__(
Expand All @@ -6563,6 +6572,7 @@ def __init__(
immutable_attrs=None,
occluded_attrs=None,
group_id_attrs=None,
group_id=None,
):
super().__init__(
label_dict,
Expand All @@ -6583,13 +6593,22 @@ def __init__(
self._parse_named_attribute(label_dict, "occluded", occluded_attrs)

# Parse group id attribute, if necessary
self._parse_named_attribute(label_dict, "group", group_id_attrs)
self._parse_named_attribute(
label_dict, "group", group_id_attrs, default=group_id
)

def _parse_named_attribute(self, label_dict, attr_key, attrs):
def _parse_named_attribute(
self, label_dict, attr_key, attrs, default=None
):
if attrs is not None:
attr_name = attrs.get(self.label, None)
if attr_name is not None:
self.attributes[attr_name] = label_dict[attr_key]
if attr_key in label_dict:
attr_value = label_dict[attr_key]
else:
attr_value = default

self.attributes[attr_name] = attr_value

def _to_pairs_of_points(self, points):
reshaped_points = np.reshape(points, (-1, 2))
Expand Down
71 changes: 70 additions & 1 deletion tests/intensive/cvat_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def test_dest_field(self):
),
)

def test_group_id(self):
def test_group_id_image(self):
dataset = (
foz.load_zoo_dataset("quickstart", max_samples=2)
.select_fields("ground_truth")
Expand Down Expand Up @@ -1083,6 +1083,75 @@ def test_group_id(self):
any([gid == test_group_id for gid in id_group_map.values()])
)

def test_group_id_video(self):
dataset = (
foz.load_zoo_dataset("quickstart-video", max_samples=1)
.select_fields("frames.detections")
.clone()
)
group_id_attr_name = "group_id_attr"

prev_ids = dataset.values(
"frames.detections.detections.id", unwind=True
)

# Set group id attribute
sample = dataset.first()
for det in sample.frames[1].detections.detections:
det["group_id_attr"] = 1
sample.save()

anno_key = "cvat_group_ids"

# Populate a new `group_id` attribute on the existing `ground_truth` labels
label_schema = {
"frames.detections": {
"attributes": {
group_id_attr_name: {
"type": "group_id",
}
}
}
}

results = dataset.annotate(
anno_key, label_schema=label_schema, backend="cvat"
)

api = results.connect_to_api()
task_id = results.task_ids[0]

test_group_id = 2
_create_annotation(
api,
task_id,
track=(0, 1),
group_id=test_group_id,
)

dataset.load_annotations(anno_key, cleanup=True)

new_id = list(
set(dataset.values("frames.detections.detections.id", unwind=True))
- set(prev_ids)
)[0]

id_group_map = dict(
zip(
*dataset.values(
[
"frames.detections.detections.id",
"frames.detections.detections.%s" % group_id_attr_name,
],
unwind=True,
)
)
)
self.assertEqual(id_group_map.pop(new_id), test_group_id)
self.assertFalse(
any([gid == test_group_id for gid in id_group_map.values()])
)


if __name__ == "__main__":
fo.config.show_progress_bars = False
Expand Down

0 comments on commit 56a7aa5

Please sign in to comment.