diff --git a/fiftyone/server/filters.py b/fiftyone/server/filters.py index 76ac714926..68bec7a6c8 100644 --- a/fiftyone/server/filters.py +++ b/fiftyone/server/filters.py @@ -5,6 +5,7 @@ | `voxel51.com `_ | """ + import strawberry as gql from strawberry.schema_directive import Location import typing as t diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py index 2b2f2eaae1..4f2e94e162 100644 --- a/fiftyone/server/view.py +++ b/fiftyone/server/view.py @@ -21,7 +21,7 @@ import fiftyone.core.utils as fou import fiftyone.core.view as fov -from fiftyone.server.aggregations import GroupElementFilter, SampleFilter +from fiftyone.server.filters import GroupElementFilter, SampleFilter from fiftyone.server.scalars import BSONArray, JSON @@ -121,7 +121,7 @@ def run(dataset, stages): if sample_filter is not None: if sample_filter.group: - view = _handle_group_filter(dataset, view, sample_filter.group) + view = handle_group_filter(dataset, view, sample_filter.group) elif sample_filter.id: view = fov.make_optimized_select_view(view, sample_filter.id) @@ -214,43 +214,30 @@ def extend_view(view, extended_stages): return view -def _add_labels_tags_counts(view): - view = view.set_field(_LABEL_TAGS, [], _allow_missing=True) - - for path, field in foc._iter_label_fields(view): - if isinstance(field, fof.ListField) or ( - isinstance(field, fof.EmbeddedDocumentField) - and issubclass(field.document_type, fol._HasLabelList) - ): - if path.startswith(view._FRAMES_PREFIX): - add_tags = _add_frame_labels_tags - else: - add_tags = _add_labels_tags - else: - if path.startswith(view._FRAMES_PREFIX): - add_tags = _add_frame_label_tags - else: - add_tags = _add_label_tags - - view = add_tags(path, field, view) - - view = _count_list_items(_LABEL_TAGS, view) - - return view - - -def _handle_group_filter( +def handle_group_filter( dataset: fod.Dataset, view: foc.SampleCollection, filter: GroupElementFilter, -): +) -> fov.DatasetView: + """Handle a group filter for App view requests. + + Args: + dataset: the :class:`fiftyone.core.dataset.Dataset` + view: the base :class:`fiftyone.core.collections.SampleCollection` + filter: the :class:`fiftyone.server.aggregations.GroupElementFilter` + + Returns: + a :class:`fiftyone.core.view.DatasetView` with a group or slice + selection + """ stages = view._stages - unselected = all( - not isinstance(stage, fosg.SelectGroupSlices) for stage in stages - ) group_field = dataset.group_field + + unselected = not any( + isinstance(stage, fosg.SelectGroupSlices) for stage in stages + ) if unselected and filter.slice: - # flatten the collection if the view has no slice selection + # flatten the collection if the view has no slice(s) selected view = dataset.select_group_slices(_force_mixed=True) if filter.id: @@ -278,6 +265,31 @@ def _handle_group_filter( return view +def _add_labels_tags_counts(view): + view = view.set_field(_LABEL_TAGS, [], _allow_missing=True) + + for path, field in foc._iter_label_fields(view): + if isinstance(field, fof.ListField) or ( + isinstance(field, fof.EmbeddedDocumentField) + and issubclass(field.document_type, fol._HasLabelList) + ): + if path.startswith(view._FRAMES_PREFIX): + add_tags = _add_frame_labels_tags + else: + add_tags = _add_labels_tags + else: + if path.startswith(view._FRAMES_PREFIX): + add_tags = _add_frame_label_tags + else: + add_tags = _add_label_tags + + view = add_tags(path, field, view) + + view = _count_list_items(_LABEL_TAGS, view) + + return view + + def _project_pagination_paths(view: foc.SampleCollection): schema = view.get_field_schema(flat=True) frame_schema = view.get_frame_field_schema(flat=True) diff --git a/tests/unittests/server_group_tests.py b/tests/unittests/server_group_tests.py new file mode 100644 index 0000000000..b88feec461 --- /dev/null +++ b/tests/unittests/server_group_tests.py @@ -0,0 +1,64 @@ +""" +FiftyOne Server group tests. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import unittest + +from bson import ObjectId + +import fiftyone as fo +from fiftyone import ViewExpression as F +from fiftyone.server.aggregations import GroupElementFilter +import fiftyone.server.view as fosv + +from decorators import drop_datasets + + +class ServerGroupTests(unittest.TestCase): + @drop_datasets + def test_manual_group_slice(self): + dataset: fo.Dataset = fo.Dataset() + group = fo.Group() + image = fo.Sample( + filepath="image.png", + group=group.element("image"), + label=fo.Classification(label="label"), + ) + dataset.add_sample(image) + expr = F("label") == "label" + filtered = dataset.filter_labels("label", F("label") == "label") + + view = fosv.handle_group_filter( + dataset, + filtered, + GroupElementFilter(slice="image", slices=["image"]), + ) + self.assertEqual( + view._all_stages, + [ + fo.SelectGroupSlices(_force_mixed=True), + fo.FilterLabels("label", expr), + fo.Match({"group.name": {"$in": ["image"]}}), + ], + ) + + view = fosv.handle_group_filter( + dataset, + filtered, + GroupElementFilter( + id=image.group.id, slice="image", slices=["image"] + ), + ) + self.assertEqual( + view._all_stages, + [ + fo.SelectGroupSlices(_force_mixed=True), + fo.Match({"group._id": {"$in": [ObjectId(image.group.id)]}}), + fo.FilterLabels("label", expr), + fo.Match({"group.name": {"$in": ["image"]}}), + ], + )