Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding utility for selecting objects from datasets #616

Merged
merged 7 commits into from
Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions fiftyone/utils/selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Utilities for selecting content from datasets.

| Copyright 2017-2020, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from collections import defaultdict
import warnings

from bson import ObjectId

import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
from fiftyone.core.expressions import ViewField as F


def select_samples(sample_collection, sample_ids):
"""Selects the specified samples from the collection.

Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
sample_ids: an iterable of sample IDs to select

Returns:
a :class:`fiftyone.core.view.DatasetView` containing only the specified
samples
"""
return sample_collection.select(sample_ids)


def exclude_samples(sample_collection, sample_ids):
"""Excludes the specified samples from the collection.

Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
sample_ids: an iterable of sample IDs to exclude

Returns:
a :class:`fiftyone.core.view.DatasetView` that excludes the specified
samples
"""
return sample_collection.exclude(sample_ids)


def select_objects(sample_collection, objects):
"""Selects the specified objects from the sample collection.

The returned view will omit samples, sample fields, and individual objects
that do not appear in the provided ``objects`` argument, which should have
the following format::

[
{
"sample_id": "5f8d254a27ad06815ab89df4",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also need frame_number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add that yet because filtering frame-level labels is not yet supported. I believe @benjaminpkane plans to work on that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I plan on starting work on this today.

"field": "ground_truth",
"object_id": "5f8d254a27ad06815ab89df3",
},
{
"sample_id": "5f8d255e27ad06815ab93bf8",
"field": "ground_truth",
"object_id": "5f8d255e27ad06815ab93bf6",
},
...
]

Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
objects: a list of dicts defining the objects to select

Returns:
a :class:`fiftyone.core.view.DatasetView` containing only the specified
objects
"""
sample_ids, object_ids = _parse_objects(objects)

label_schema = sample_collection.get_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)

view = sample_collection.select(sample_ids)
view = view.select_fields(list(object_ids.keys()))

for field, object_ids in object_ids.items():
label_filter = F("_id").is_in(object_ids)
view = _apply_label_filter(view, label_schema, field, label_filter)

return view


def exclude_objects(sample_collection, objects):
"""Excludes the specified objects from the sample collection.

The returned view will omit the labels specified in the provided
``objects`` argument, which should have the following format::

[
{
"sample_id": "5f8d254a27ad06815ab89df4",
"field": "ground_truth",
"object_id": "5f8d254a27ad06815ab89df3",
},
{
"sample_id": "5f8d255e27ad06815ab93bf8",
"field": "ground_truth",
"object_id": "5f8d255e27ad06815ab93bf6",
},
...
]

Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
objects: a list of dicts defining the objects to exclude

Returns:
a :class:`fiftyone.core.view.DatasetView` that excludes the specified
objects
"""
_, object_ids = _parse_objects(objects)

label_schema = sample_collection.get_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)

view = sample_collection
for field, object_ids in object_ids.items():
label_filter = ~F("_id").is_in(object_ids)
view = _apply_label_filter(view, label_schema, field, label_filter)

return view


def _parse_objects(objects):
sample_ids = set()
object_ids = defaultdict(set)
for obj in objects:
sample_ids.add(obj["sample_id"])
object_ids[obj["field"]].add(ObjectId(obj["object_id"]))

return sample_ids, object_ids


def _apply_label_filter(sample_collection, label_schema, field, label_filter):
if field not in label_schema:
raise ValueError(
"%s '%s' has no label field '%s'"
% (
sample_collection.__class__.__name__,
sample_collection.name,
field,
)
)

field_type = label_schema[field]
label_type = field_type.document_type

if label_type in (
fol.Classification,
fol.Detection,
fol.Polyline,
fol.Keypoint,
):
return sample_collection.filter_field(field, label_filter)

if label_type is fol.Classifications:
return sample_collection.filter_classifications(field, label_filter)

if label_type is fol.Detections:
return sample_collection.filter_detections(field, label_filter)

if label_type is fol.Polylines:
return sample_collection.filter_polylines(field, label_filter)

if label_type is fol.Keypoints:
return sample_collection.filter_keypoints(field, label_filter)

msg = "Ignoring unsupported field '%s' (%s)" % (field, label_type)
warnings.warn(msg)
return sample_collection
1 change: 1 addition & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Tests do exist, but their coverage generally needs improvement...
| `benchmarking/*.py` | Tests related to benchmarking the performance of FiftyOne |
| `import_export/*.py` | Tests for importing/exporting datasets |
| `isolated/*.py` | Tests that must be run in a separate `pytest` process to avoid interfering with other tests |
| `misc/*.py` | Miscellaneous tests that have not been upgraded to official unit tests |

## Running a test

Expand Down
75 changes: 75 additions & 0 deletions tests/misc/selection_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be fine (as well as consistent) to have this file in tests/unittests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't because it involves downloading the quickstart dataset (~25MB). If you're comfortable with that happening all the time, feel free to upgrade it from misc/ to unittests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't see that. Looks like it's running in GitHub Actions and not taking too much time there, but separating it into another folder (like you did) is probably helpful for people running tests locally.

Unit tests for the :mod:`fiftyone.utils.selection` module.

| Copyright 2017-2020, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
import random
import unittest

import fiftyone as fo
import fiftyone.core.dataset as fod
import fiftyone.utils.selection as fous
import fiftyone.zoo as foz


class SelectionTests(unittest.TestCase):
def test_select_objects(self):
num_samples_to_select = 5
max_objects_per_sample_to_select = 3

dataset = foz.load_zoo_dataset(
"quickstart", dataset_name=fod.get_default_dataset_name(),
)

# Generate some random selections
selected_objects = []
for sample in dataset.take(num_samples_to_select):
detections = sample.ground_truth.detections

max_num_objects = min(
len(detections), max_objects_per_sample_to_select
)
if max_num_objects >= 1:
num_objects = random.randint(1, max_num_objects)
else:
num_objects = 0

for detection in random.sample(detections, num_objects):
selected_objects.append(
{
"sample_id": sample.id,
"field": "ground_truth",
"object_id": detection.id,
}
)

selected_view = fous.select_objects(dataset, selected_objects)
excluded_view = fous.exclude_objects(dataset, selected_objects)

total_objects = _count_detections(dataset, "ground_truth")
num_selected_objects = len(selected_objects)
num_objects_in_selected_view = _count_detections(
selected_view, "ground_truth"
)
num_objects_in_excluded_view = _count_detections(
excluded_view, "ground_truth"
)
num_objects_excluded = total_objects - num_objects_in_excluded_view

self.assertEqual(num_selected_objects, num_objects_in_selected_view)
self.assertEqual(num_selected_objects, num_objects_excluded)


def _count_detections(sample_collection, label_field):
num_objects = 0
for sample in sample_collection:
num_objects += len(sample[label_field].detections)

return num_objects


if __name__ == "__main__":
fo.config.show_progress_bars = False
unittest.main(verbosity=2)