-
Notifications
You must be signed in to change notification settings - Fork 584
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #616 from voxel51/selections
Adding utility for selecting objects from datasets
- Loading branch information
Showing
3 changed files
with
258 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
Utilities for selecting content from datasets. | ||
| Copyright 2017-2020, Voxel51, Inc. | ||
| `voxel51.com <https://voxel51.com/>`_ | ||
| | ||
""" | ||
from collections import defaultdict | ||
import warnings | ||
|
||
from bson import ObjectId | ||
|
||
import fiftyone.core.fields as fof | ||
import fiftyone.core.labels as fol | ||
from fiftyone.core.expressions import ViewField as F | ||
|
||
|
||
def select_samples(sample_collection, sample_ids): | ||
"""Selects the specified samples from the collection. | ||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
sample_ids: an iterable of sample IDs to select | ||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` containing only the specified | ||
samples | ||
""" | ||
return sample_collection.select(sample_ids) | ||
|
||
|
||
def exclude_samples(sample_collection, sample_ids): | ||
"""Excludes the specified samples from the collection. | ||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
sample_ids: an iterable of sample IDs to exclude | ||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` that excludes the specified | ||
samples | ||
""" | ||
return sample_collection.exclude(sample_ids) | ||
|
||
|
||
def select_objects(sample_collection, objects): | ||
"""Selects the specified objects from the sample collection. | ||
The returned view will omit samples, sample fields, and individual objects | ||
that do not appear in the provided ``objects`` argument, which should have | ||
the following format:: | ||
[ | ||
{ | ||
"sample_id": "5f8d254a27ad06815ab89df4", | ||
"field": "ground_truth", | ||
"object_id": "5f8d254a27ad06815ab89df3", | ||
}, | ||
{ | ||
"sample_id": "5f8d255e27ad06815ab93bf8", | ||
"field": "ground_truth", | ||
"object_id": "5f8d255e27ad06815ab93bf6", | ||
}, | ||
... | ||
] | ||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
objects: a list of dicts defining the objects to select | ||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` containing only the specified | ||
objects | ||
""" | ||
sample_ids, object_ids = _parse_objects(objects) | ||
|
||
label_schema = sample_collection.get_field_schema( | ||
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label | ||
) | ||
|
||
view = sample_collection.select(sample_ids) | ||
view = view.select_fields(list(object_ids.keys())) | ||
|
||
for field, object_ids in object_ids.items(): | ||
label_filter = F("_id").is_in(object_ids) | ||
view = _apply_label_filter(view, label_schema, field, label_filter) | ||
|
||
return view | ||
|
||
|
||
def exclude_objects(sample_collection, objects): | ||
"""Excludes the specified objects from the sample collection. | ||
The returned view will omit the labels specified in the provided | ||
``objects`` argument, which should have the following format:: | ||
[ | ||
{ | ||
"sample_id": "5f8d254a27ad06815ab89df4", | ||
"field": "ground_truth", | ||
"object_id": "5f8d254a27ad06815ab89df3", | ||
}, | ||
{ | ||
"sample_id": "5f8d255e27ad06815ab93bf8", | ||
"field": "ground_truth", | ||
"object_id": "5f8d255e27ad06815ab93bf6", | ||
}, | ||
... | ||
] | ||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
objects: a list of dicts defining the objects to exclude | ||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` that excludes the specified | ||
objects | ||
""" | ||
_, object_ids = _parse_objects(objects) | ||
|
||
label_schema = sample_collection.get_field_schema( | ||
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label | ||
) | ||
|
||
view = sample_collection | ||
for field, object_ids in object_ids.items(): | ||
label_filter = ~F("_id").is_in(object_ids) | ||
view = _apply_label_filter(view, label_schema, field, label_filter) | ||
|
||
return view | ||
|
||
|
||
def _parse_objects(objects): | ||
sample_ids = set() | ||
object_ids = defaultdict(set) | ||
for obj in objects: | ||
sample_ids.add(obj["sample_id"]) | ||
object_ids[obj["field"]].add(ObjectId(obj["object_id"])) | ||
|
||
return sample_ids, object_ids | ||
|
||
|
||
def _apply_label_filter(sample_collection, label_schema, field, label_filter): | ||
if field not in label_schema: | ||
raise ValueError( | ||
"%s '%s' has no label field '%s'" | ||
% ( | ||
sample_collection.__class__.__name__, | ||
sample_collection.name, | ||
field, | ||
) | ||
) | ||
|
||
label_type = label_schema[field].document_type | ||
|
||
if label_type in ( | ||
fol.Classification, | ||
fol.Detection, | ||
fol.Polyline, | ||
fol.Keypoint, | ||
): | ||
return sample_collection.filter_field(field, label_filter) | ||
|
||
if label_type is fol.Classifications: | ||
return sample_collection.filter_classifications(field, label_filter) | ||
|
||
if label_type is fol.Detections: | ||
return sample_collection.filter_detections(field, label_filter) | ||
|
||
if label_type is fol.Polylines: | ||
return sample_collection.filter_polylines(field, label_filter) | ||
|
||
if label_type is fol.Keypoints: | ||
return sample_collection.filter_keypoints(field, label_filter) | ||
|
||
msg = "Ignoring unsupported field '%s' (%s)" % (field, label_type) | ||
warnings.warn(msg) | ||
return sample_collection |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
Unit tests for the :mod:`fiftyone.utils.selection` module. | ||
| Copyright 2017-2020, Voxel51, Inc. | ||
| `voxel51.com <https://voxel51.com/>`_ | ||
| | ||
""" | ||
import random | ||
import unittest | ||
|
||
import fiftyone as fo | ||
import fiftyone.core.dataset as fod | ||
import fiftyone.utils.selection as fous | ||
import fiftyone.zoo as foz | ||
|
||
|
||
class SelectionTests(unittest.TestCase): | ||
def test_select_objects(self): | ||
num_samples_to_select = 5 | ||
max_objects_per_sample_to_select = 3 | ||
|
||
dataset = foz.load_zoo_dataset( | ||
"quickstart", dataset_name=fod.get_default_dataset_name(), | ||
) | ||
|
||
# Generate some random selections | ||
selected_objects = [] | ||
for sample in dataset.take(num_samples_to_select): | ||
detections = sample.ground_truth.detections | ||
|
||
max_num_objects = min( | ||
len(detections), max_objects_per_sample_to_select | ||
) | ||
if max_num_objects >= 1: | ||
num_objects = random.randint(1, max_num_objects) | ||
else: | ||
num_objects = 0 | ||
|
||
for detection in random.sample(detections, num_objects): | ||
selected_objects.append( | ||
{ | ||
"sample_id": sample.id, | ||
"field": "ground_truth", | ||
"object_id": detection.id, | ||
} | ||
) | ||
|
||
selected_view = fous.select_objects(dataset, selected_objects) | ||
excluded_view = fous.exclude_objects(dataset, selected_objects) | ||
|
||
total_objects = _count_detections(dataset, "ground_truth") | ||
num_selected_objects = len(selected_objects) | ||
num_objects_in_selected_view = _count_detections( | ||
selected_view, "ground_truth" | ||
) | ||
num_objects_in_excluded_view = _count_detections( | ||
excluded_view, "ground_truth" | ||
) | ||
num_objects_excluded = total_objects - num_objects_in_excluded_view | ||
|
||
self.assertEqual(num_selected_objects, num_objects_in_selected_view) | ||
self.assertEqual(num_selected_objects, num_objects_excluded) | ||
|
||
|
||
def _count_detections(sample_collection, label_field): | ||
num_objects = 0 | ||
for sample in sample_collection: | ||
num_objects += len(sample[label_field].detections) | ||
|
||
return num_objects | ||
|
||
|
||
if __name__ == "__main__": | ||
fo.config.show_progress_bars = False | ||
unittest.main(verbosity=2) |