Skip to content

Commit

Permalink
Add server group tests (voxel51#4257)
Browse files Browse the repository at this point in the history
* add server group tests, lint

* fix test

* check for slices

* use not any

* lint docstring

* slice is still needed

* update tests
  • Loading branch information
benjaminpkane authored Apr 18, 2024
1 parent 78d4885 commit 35959eb
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 33 deletions.
1 change: 1 addition & 0 deletions fiftyone/server/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""

import strawberry as gql
from strawberry.schema_directive import Location
import typing as t
Expand Down
78 changes: 45 additions & 33 deletions fiftyone/server/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import fiftyone.core.utils as fou
import fiftyone.core.view as fov

from fiftyone.server.aggregations import GroupElementFilter, SampleFilter
from fiftyone.server.filters import GroupElementFilter, SampleFilter
from fiftyone.server.scalars import BSONArray, JSON


Expand Down Expand Up @@ -121,7 +121,7 @@ def run(dataset, stages):

if sample_filter is not None:
if sample_filter.group:
view = _handle_group_filter(dataset, view, sample_filter.group)
view = handle_group_filter(dataset, view, sample_filter.group)

elif sample_filter.id:
view = fov.make_optimized_select_view(view, sample_filter.id)
Expand Down Expand Up @@ -214,43 +214,30 @@ def extend_view(view, extended_stages):
return view


def _add_labels_tags_counts(view):
view = view.set_field(_LABEL_TAGS, [], _allow_missing=True)

for path, field in foc._iter_label_fields(view):
if isinstance(field, fof.ListField) or (
isinstance(field, fof.EmbeddedDocumentField)
and issubclass(field.document_type, fol._HasLabelList)
):
if path.startswith(view._FRAMES_PREFIX):
add_tags = _add_frame_labels_tags
else:
add_tags = _add_labels_tags
else:
if path.startswith(view._FRAMES_PREFIX):
add_tags = _add_frame_label_tags
else:
add_tags = _add_label_tags

view = add_tags(path, field, view)

view = _count_list_items(_LABEL_TAGS, view)

return view


def _handle_group_filter(
def handle_group_filter(
dataset: fod.Dataset,
view: foc.SampleCollection,
filter: GroupElementFilter,
):
) -> fov.DatasetView:
"""Handle a group filter for App view requests.
Args:
dataset: the :class:`fiftyone.core.dataset.Dataset`
view: the base :class:`fiftyone.core.collections.SampleCollection`
filter: the :class:`fiftyone.server.aggregations.GroupElementFilter`
Returns:
a :class:`fiftyone.core.view.DatasetView` with a group or slice
selection
"""
stages = view._stages
unselected = all(
not isinstance(stage, fosg.SelectGroupSlices) for stage in stages
)
group_field = dataset.group_field

unselected = not any(
isinstance(stage, fosg.SelectGroupSlices) for stage in stages
)
if unselected and filter.slice:
# flatten the collection if the view has no slice selection
# flatten the collection if the view has no slice(s) selected
view = dataset.select_group_slices(_force_mixed=True)

if filter.id:
Expand Down Expand Up @@ -278,6 +265,31 @@ def _handle_group_filter(
return view


def _add_labels_tags_counts(view):
view = view.set_field(_LABEL_TAGS, [], _allow_missing=True)

for path, field in foc._iter_label_fields(view):
if isinstance(field, fof.ListField) or (
isinstance(field, fof.EmbeddedDocumentField)
and issubclass(field.document_type, fol._HasLabelList)
):
if path.startswith(view._FRAMES_PREFIX):
add_tags = _add_frame_labels_tags
else:
add_tags = _add_labels_tags
else:
if path.startswith(view._FRAMES_PREFIX):
add_tags = _add_frame_label_tags
else:
add_tags = _add_label_tags

view = add_tags(path, field, view)

view = _count_list_items(_LABEL_TAGS, view)

return view


def _project_pagination_paths(view: foc.SampleCollection):
schema = view.get_field_schema(flat=True)
frame_schema = view.get_frame_field_schema(flat=True)
Expand Down
64 changes: 64 additions & 0 deletions tests/unittests/server_group_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
FiftyOne Server group tests.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""

import unittest

from bson import ObjectId

import fiftyone as fo
from fiftyone import ViewExpression as F
from fiftyone.server.aggregations import GroupElementFilter
import fiftyone.server.view as fosv

from decorators import drop_datasets


class ServerGroupTests(unittest.TestCase):
@drop_datasets
def test_manual_group_slice(self):
dataset: fo.Dataset = fo.Dataset()
group = fo.Group()
image = fo.Sample(
filepath="image.png",
group=group.element("image"),
label=fo.Classification(label="label"),
)
dataset.add_sample(image)
expr = F("label") == "label"
filtered = dataset.filter_labels("label", F("label") == "label")

view = fosv.handle_group_filter(
dataset,
filtered,
GroupElementFilter(slice="image", slices=["image"]),
)
self.assertEqual(
view._all_stages,
[
fo.SelectGroupSlices(_force_mixed=True),
fo.FilterLabels("label", expr),
fo.Match({"group.name": {"$in": ["image"]}}),
],
)

view = fosv.handle_group_filter(
dataset,
filtered,
GroupElementFilter(
id=image.group.id, slice="image", slices=["image"]
),
)
self.assertEqual(
view._all_stages,
[
fo.SelectGroupSlices(_force_mixed=True),
fo.Match({"group._id": {"$in": [ObjectId(image.group.id)]}}),
fo.FilterLabels("label", expr),
fo.Match({"group.name": {"$in": ["image"]}}),
],
)

0 comments on commit 35959eb

Please sign in to comment.