Skip to content

Commit

Permalink
Fix pinning sidebar filters for group datasets (#4097)
Browse files Browse the repository at this point in the history
* use match for slice and group selection

* fix e2e entry count

* add hidden select group slices kwargs to serialization

* update server view test

* omit group slice when creating saved view

* rm print

* cleanup get view calls

* cleanup
  • Loading branch information
benjaminpkane authored Apr 4, 2024
1 parent d4615a6 commit 8c8505d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 49 deletions.
26 changes: 21 additions & 5 deletions fiftyone/core/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""

from collections import defaultdict, OrderedDict
import contextlib
from copy import deepcopy
Expand Down Expand Up @@ -941,8 +942,8 @@ def _get_meta_filtered_fields(sample_collection, meta_filter, frames=False):
base, leaf = key.split(".", 1)
info_filter[leaf] = val

matcher = (
lambda q, v: q.lower() in v.lower()
matcher = lambda q, v: (
q.lower() in v.lower()
if isinstance(v, str) and isinstance(q, str)
else (
q.lower() in str(v).lower()
Expand All @@ -951,8 +952,8 @@ def _get_meta_filtered_fields(sample_collection, meta_filter, frames=False):
)
)

type_matcher = (
lambda query, field: (
type_matcher = lambda query, field: (
(
type(field.document_type).__name__ == query
or field.document_type.__name__ == query
if isinstance(field, EmbeddedDocumentField)
Expand Down Expand Up @@ -4711,7 +4712,12 @@ def _get_group_media_types(self, sample_collection):
}

def _kwargs(self):
return [["slices", self._slices], ["media_type", self._media_type]]
return [
["slices", self._slices],
["media_type", self._media_type],
["_allow_mixed", self._allow_mixed],
["_force_mixed", self._force_mixed],
]

@classmethod
def _params(cls):
Expand All @@ -4728,6 +4734,16 @@ def _params(cls):
"placeholder": "media_type (default=None)",
"default": "None",
},
{
"name": "_allow_mixed",
"type": "NoneType|bool",
"default": "None",
},
{
"name": "_force_mixed",
"type": "NoneType|bool",
"default": "None",
},
]


Expand Down
43 changes: 25 additions & 18 deletions fiftyone/server/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,12 @@ async def aggregate_resolver(
if not form.dataset:
raise ValueError("Aggregate form missing dataset")

view = await fosv.get_view(
form.dataset,
view_name=form.view_name or None,
stages=form.view,
filters=form.filters,
extended_stages=form.extended_stages,
sample_filter=SampleFilter(
group=(
GroupElementFilter(
id=form.group_id, slice=form.slice, slices=form.slices
)
if not form.sample_ids
else None
)
),
awaitable=True,
)
view = await _load_view(form, form.slices)

slice_view = view if form.mixed and "" in form.paths else None
slice_view = None

if form.mixed and "" in form.paths:
slice_view = await _load_view(form, [form.slice])

if form.sample_ids:
view = fov.make_optimized_select_view(view, form.sample_ids)
Expand Down Expand Up @@ -202,6 +189,26 @@ async def aggregate_resolver(
}


async def _load_view(form: AggregationForm, slices: t.List[str]):
return await fosv.get_view(
form.dataset,
view_name=form.view_name or None,
stages=form.view,
filters=form.filters,
extended_stages=form.extended_stages,
sample_filter=SampleFilter(
group=(
GroupElementFilter(
id=form.group_id, slice=form.slice, slices=slices
)
if not form.sample_ids
else None
)
),
awaitable=True,
)


def _resolve_path_aggregation(
path: str, view: foc.SampleCollection
) -> AggregateResult:
Expand Down
9 changes: 0 additions & 9 deletions fiftyone/server/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,6 @@ async def create_saved_view(
stages=view_stages if view_stages else None,
filters=form.filters if form else None,
extended_stages=form.extended if form else None,
sample_filter=(
SampleFilter(
group=GroupElementFilter(
slice=form.slice, slices=[form.slice]
)
)
if form.slice
else None
),
awaitable=True,
)

Expand Down
59 changes: 43 additions & 16 deletions fiftyone/server/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_view(
a :class:`fiftyone.core.view.DatasetView`
"""

def run(dataset):
def run(dataset, stages):
if isinstance(dataset, str):
dataset = fod.load_dataset(dataset)

Expand All @@ -121,19 +121,7 @@ def run(dataset):

if sample_filter is not None:
if sample_filter.group:
if sample_filter.group.slice:
view.group_slice = sample_filter.group.slice

if sample_filter.group.id:
view = fov.make_optimized_select_view(
view, sample_filter.group.id, groups=True
)

if sample_filter.group.slices:
view = view.select_group_slices(
sample_filter.group.slices,
_force_mixed=True,
)
view = _handle_group_filter(dataset, view, sample_filter.group)

elif sample_filter.id:
view = fov.make_optimized_select_view(view, sample_filter.id)
Expand All @@ -149,9 +137,9 @@ def run(dataset):
return view

if awaitable:
return fou.run_sync_task(run, dataset)
return fou.run_sync_task(run, dataset, stages)

return run(dataset)
return run(dataset, stages)


def get_extended_view(
Expand Down Expand Up @@ -251,6 +239,45 @@ def _add_labels_tags_counts(view):
return view


def _handle_group_filter(
dataset: fod.Dataset,
view: foc.SampleCollection,
filter: GroupElementFilter,
):
stages = view._stages
unselected = all(
not isinstance(stage, fosg.SelectGroupSlices) for stage in stages
)
group_field = dataset.group_field
if unselected and filter.slice:
# flatten the collection if the view has no slice selection
view = dataset.select_group_slices(_force_mixed=True)

if filter.id:
# use 'match' to select a group by 'id'
view = view.match(
{group_field + "._id": {"$in": [ObjectId(filter.id)]}}
)

for stage in stages:
# add stages after flattening and group match
view = view._add_view_stage(stage, validate=False)

else:
if filter.slice:
view.group_slice = filter.slice

if filter.id:
view = fov.make_optimized_select_view(view, filter.id, groups=True)

if filter.slices:
# use 'match' to select requested slices, and avoid media type
# validation
view = view.match({group_field + ".name": {"$in": filter.slices}})

return view


def _project_pagination_paths(view: foc.SampleCollection):
schema = view.get_field_schema(flat=True)
frame_schema = view.get_frame_field_schema(flat=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/view_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4269,8 +4269,9 @@ def test_make_optimized_select_view_group_dataset(self):
optimized_view = fov.make_optimized_select_view(
dataset, sample_ids[0], flatten=True
)

expected_stages = [
fosg.SelectGroupSlices(),
fosg.SelectGroupSlices(_allow_mixed=True),
fosg.Select(sample_ids[0]),
]
self.assertEqual(optimized_view._all_stages, expected_stages)
Expand Down

0 comments on commit 8c8505d

Please sign in to comment.