Skip to content

Commit

Permalink
Merge pull request #616 from voxel51/selections
Browse files Browse the repository at this point in the history
Adding utility for selecting objects from datasets
  • Loading branch information
brimoor authored Oct 20, 2020
2 parents d651c23 + 15b1857 commit 771f216
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 0 deletions.
182 changes: 182 additions & 0 deletions fiftyone/utils/selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Utilities for selecting content from datasets.
| Copyright 2017-2020, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from collections import defaultdict
import warnings

from bson import ObjectId

import fiftyone.core.fields as fof
import fiftyone.core.labels as fol
from fiftyone.core.expressions import ViewField as F


def select_samples(sample_collection, sample_ids):
"""Selects the specified samples from the collection.
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
sample_ids: an iterable of sample IDs to select
Returns:
a :class:`fiftyone.core.view.DatasetView` containing only the specified
samples
"""
return sample_collection.select(sample_ids)


def exclude_samples(sample_collection, sample_ids):
"""Excludes the specified samples from the collection.
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
sample_ids: an iterable of sample IDs to exclude
Returns:
a :class:`fiftyone.core.view.DatasetView` that excludes the specified
samples
"""
return sample_collection.exclude(sample_ids)


def select_objects(sample_collection, objects):
"""Selects the specified objects from the sample collection.
The returned view will omit samples, sample fields, and individual objects
that do not appear in the provided ``objects`` argument, which should have
the following format::
[
{
"sample_id": "5f8d254a27ad06815ab89df4",
"field": "ground_truth",
"object_id": "5f8d254a27ad06815ab89df3",
},
{
"sample_id": "5f8d255e27ad06815ab93bf8",
"field": "ground_truth",
"object_id": "5f8d255e27ad06815ab93bf6",
},
...
]
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
objects: a list of dicts defining the objects to select
Returns:
a :class:`fiftyone.core.view.DatasetView` containing only the specified
objects
"""
sample_ids, object_ids = _parse_objects(objects)

label_schema = sample_collection.get_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)

view = sample_collection.select(sample_ids)
view = view.select_fields(list(object_ids.keys()))

for field, object_ids in object_ids.items():
label_filter = F("_id").is_in(object_ids)
view = _apply_label_filter(view, label_schema, field, label_filter)

return view


def exclude_objects(sample_collection, objects):
"""Excludes the specified objects from the sample collection.
The returned view will omit the labels specified in the provided
``objects`` argument, which should have the following format::
[
{
"sample_id": "5f8d254a27ad06815ab89df4",
"field": "ground_truth",
"object_id": "5f8d254a27ad06815ab89df3",
},
{
"sample_id": "5f8d255e27ad06815ab93bf8",
"field": "ground_truth",
"object_id": "5f8d255e27ad06815ab93bf6",
},
...
]
Args:
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
objects: a list of dicts defining the objects to exclude
Returns:
a :class:`fiftyone.core.view.DatasetView` that excludes the specified
objects
"""
_, object_ids = _parse_objects(objects)

label_schema = sample_collection.get_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)

view = sample_collection
for field, object_ids in object_ids.items():
label_filter = ~F("_id").is_in(object_ids)
view = _apply_label_filter(view, label_schema, field, label_filter)

return view


def _parse_objects(objects):
sample_ids = set()
object_ids = defaultdict(set)
for obj in objects:
sample_ids.add(obj["sample_id"])
object_ids[obj["field"]].add(ObjectId(obj["object_id"]))

return sample_ids, object_ids


def _apply_label_filter(sample_collection, label_schema, field, label_filter):
if field not in label_schema:
raise ValueError(
"%s '%s' has no label field '%s'"
% (
sample_collection.__class__.__name__,
sample_collection.name,
field,
)
)

label_type = label_schema[field].document_type

if label_type in (
fol.Classification,
fol.Detection,
fol.Polyline,
fol.Keypoint,
):
return sample_collection.filter_field(field, label_filter)

if label_type is fol.Classifications:
return sample_collection.filter_classifications(field, label_filter)

if label_type is fol.Detections:
return sample_collection.filter_detections(field, label_filter)

if label_type is fol.Polylines:
return sample_collection.filter_polylines(field, label_filter)

if label_type is fol.Keypoints:
return sample_collection.filter_keypoints(field, label_filter)

msg = "Ignoring unsupported field '%s' (%s)" % (field, label_type)
warnings.warn(msg)
return sample_collection
1 change: 1 addition & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ FiftyOne currently uses both
| `benchmarking/*.py` | Tests related to benchmarking the performance of FiftyOne |
| `import_export/*.py` | Tests for importing/exporting datasets |
| `isolated/*.py` | Tests that must be run in a separate `pytest` process to avoid interfering with other tests |
| `misc/*.py` | Miscellaneous tests that have not been upgraded to official unit tests |

## Running tests

Expand Down
75 changes: 75 additions & 0 deletions tests/misc/selection_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Unit tests for the :mod:`fiftyone.utils.selection` module.
| Copyright 2017-2020, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
import random
import unittest

import fiftyone as fo
import fiftyone.core.dataset as fod
import fiftyone.utils.selection as fous
import fiftyone.zoo as foz


class SelectionTests(unittest.TestCase):
def test_select_objects(self):
num_samples_to_select = 5
max_objects_per_sample_to_select = 3

dataset = foz.load_zoo_dataset(
"quickstart", dataset_name=fod.get_default_dataset_name(),
)

# Generate some random selections
selected_objects = []
for sample in dataset.take(num_samples_to_select):
detections = sample.ground_truth.detections

max_num_objects = min(
len(detections), max_objects_per_sample_to_select
)
if max_num_objects >= 1:
num_objects = random.randint(1, max_num_objects)
else:
num_objects = 0

for detection in random.sample(detections, num_objects):
selected_objects.append(
{
"sample_id": sample.id,
"field": "ground_truth",
"object_id": detection.id,
}
)

selected_view = fous.select_objects(dataset, selected_objects)
excluded_view = fous.exclude_objects(dataset, selected_objects)

total_objects = _count_detections(dataset, "ground_truth")
num_selected_objects = len(selected_objects)
num_objects_in_selected_view = _count_detections(
selected_view, "ground_truth"
)
num_objects_in_excluded_view = _count_detections(
excluded_view, "ground_truth"
)
num_objects_excluded = total_objects - num_objects_in_excluded_view

self.assertEqual(num_selected_objects, num_objects_in_selected_view)
self.assertEqual(num_selected_objects, num_objects_excluded)


def _count_detections(sample_collection, label_field):
num_objects = 0
for sample in sample_collection:
num_objects += len(sample[label_field].detections)

return num_objects


if __name__ == "__main__":
fo.config.show_progress_bars = False
unittest.main(verbosity=2)

0 comments on commit 771f216

Please sign in to comment.