diff --git a/fiftyone/server/filters.py b/fiftyone/server/filters.py
index 76ac714926..68bec7a6c8 100644
--- a/fiftyone/server/filters.py
+++ b/fiftyone/server/filters.py
@@ -5,6 +5,7 @@
| `voxel51.com `_
|
"""
+
import strawberry as gql
from strawberry.schema_directive import Location
import typing as t
diff --git a/fiftyone/server/view.py b/fiftyone/server/view.py
index 2b2f2eaae1..4f2e94e162 100644
--- a/fiftyone/server/view.py
+++ b/fiftyone/server/view.py
@@ -21,7 +21,7 @@
import fiftyone.core.utils as fou
import fiftyone.core.view as fov
-from fiftyone.server.aggregations import GroupElementFilter, SampleFilter
+from fiftyone.server.filters import GroupElementFilter, SampleFilter
from fiftyone.server.scalars import BSONArray, JSON
@@ -121,7 +121,7 @@ def run(dataset, stages):
if sample_filter is not None:
if sample_filter.group:
- view = _handle_group_filter(dataset, view, sample_filter.group)
+ view = handle_group_filter(dataset, view, sample_filter.group)
elif sample_filter.id:
view = fov.make_optimized_select_view(view, sample_filter.id)
@@ -214,43 +214,30 @@ def extend_view(view, extended_stages):
return view
-def _add_labels_tags_counts(view):
- view = view.set_field(_LABEL_TAGS, [], _allow_missing=True)
-
- for path, field in foc._iter_label_fields(view):
- if isinstance(field, fof.ListField) or (
- isinstance(field, fof.EmbeddedDocumentField)
- and issubclass(field.document_type, fol._HasLabelList)
- ):
- if path.startswith(view._FRAMES_PREFIX):
- add_tags = _add_frame_labels_tags
- else:
- add_tags = _add_labels_tags
- else:
- if path.startswith(view._FRAMES_PREFIX):
- add_tags = _add_frame_label_tags
- else:
- add_tags = _add_label_tags
-
- view = add_tags(path, field, view)
-
- view = _count_list_items(_LABEL_TAGS, view)
-
- return view
-
-
-def _handle_group_filter(
+def handle_group_filter(
dataset: fod.Dataset,
view: foc.SampleCollection,
filter: GroupElementFilter,
-):
+) -> fov.DatasetView:
+ """Handle a group filter for App view requests.
+
+ Args:
+ dataset: the :class:`fiftyone.core.dataset.Dataset`
+ view: the base :class:`fiftyone.core.collections.SampleCollection`
+ filter: the :class:`fiftyone.server.aggregations.GroupElementFilter`
+
+ Returns:
+ a :class:`fiftyone.core.view.DatasetView` with a group or slice
+ selection
+ """
stages = view._stages
- unselected = all(
- not isinstance(stage, fosg.SelectGroupSlices) for stage in stages
- )
group_field = dataset.group_field
+
+ unselected = not any(
+ isinstance(stage, fosg.SelectGroupSlices) for stage in stages
+ )
if unselected and filter.slice:
- # flatten the collection if the view has no slice selection
+ # flatten the collection if the view has no slice(s) selected
view = dataset.select_group_slices(_force_mixed=True)
if filter.id:
@@ -278,6 +265,31 @@ def _handle_group_filter(
return view
+def _add_labels_tags_counts(view):
+ view = view.set_field(_LABEL_TAGS, [], _allow_missing=True)
+
+ for path, field in foc._iter_label_fields(view):
+ if isinstance(field, fof.ListField) or (
+ isinstance(field, fof.EmbeddedDocumentField)
+ and issubclass(field.document_type, fol._HasLabelList)
+ ):
+ if path.startswith(view._FRAMES_PREFIX):
+ add_tags = _add_frame_labels_tags
+ else:
+ add_tags = _add_labels_tags
+ else:
+ if path.startswith(view._FRAMES_PREFIX):
+ add_tags = _add_frame_label_tags
+ else:
+ add_tags = _add_label_tags
+
+ view = add_tags(path, field, view)
+
+ view = _count_list_items(_LABEL_TAGS, view)
+
+ return view
+
+
def _project_pagination_paths(view: foc.SampleCollection):
schema = view.get_field_schema(flat=True)
frame_schema = view.get_frame_field_schema(flat=True)
diff --git a/tests/unittests/server_group_tests.py b/tests/unittests/server_group_tests.py
new file mode 100644
index 0000000000..b88feec461
--- /dev/null
+++ b/tests/unittests/server_group_tests.py
@@ -0,0 +1,64 @@
+"""
+FiftyOne Server group tests.
+
+| Copyright 2017-2024, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+import unittest
+
+from bson import ObjectId
+
+import fiftyone as fo
+from fiftyone import ViewExpression as F
+from fiftyone.server.aggregations import GroupElementFilter
+import fiftyone.server.view as fosv
+
+from decorators import drop_datasets
+
+
+class ServerGroupTests(unittest.TestCase):
+ @drop_datasets
+ def test_manual_group_slice(self):
+ dataset: fo.Dataset = fo.Dataset()
+ group = fo.Group()
+ image = fo.Sample(
+ filepath="image.png",
+ group=group.element("image"),
+ label=fo.Classification(label="label"),
+ )
+ dataset.add_sample(image)
+ expr = F("label") == "label"
+ filtered = dataset.filter_labels("label", F("label") == "label")
+
+ view = fosv.handle_group_filter(
+ dataset,
+ filtered,
+ GroupElementFilter(slice="image", slices=["image"]),
+ )
+ self.assertEqual(
+ view._all_stages,
+ [
+ fo.SelectGroupSlices(_force_mixed=True),
+ fo.FilterLabels("label", expr),
+ fo.Match({"group.name": {"$in": ["image"]}}),
+ ],
+ )
+
+ view = fosv.handle_group_filter(
+ dataset,
+ filtered,
+ GroupElementFilter(
+ id=image.group.id, slice="image", slices=["image"]
+ ),
+ )
+ self.assertEqual(
+ view._all_stages,
+ [
+ fo.SelectGroupSlices(_force_mixed=True),
+ fo.Match({"group._id": {"$in": [ObjectId(image.group.id)]}}),
+ fo.FilterLabels("label", expr),
+ fo.Match({"group.name": {"$in": ["image"]}}),
+ ],
+ )