-
Notifications
You must be signed in to change notification settings - Fork 591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding utility for selecting objects from datasets #616
Changes from 6 commits
b9e4d53
443ea65
25ce94d
d6a0928
64f8c68
47a794d
15b1857
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
""" | ||
Utilities for selecting content from datasets. | ||
|
||
| Copyright 2017-2020, Voxel51, Inc. | ||
| `voxel51.com <https://voxel51.com/>`_ | ||
| | ||
""" | ||
from collections import defaultdict | ||
import warnings | ||
|
||
from bson import ObjectId | ||
|
||
import fiftyone.core.fields as fof | ||
import fiftyone.core.labels as fol | ||
from fiftyone.core.expressions import ViewField as F | ||
|
||
|
||
def select_samples(sample_collection, sample_ids): | ||
"""Selects the specified samples from the collection. | ||
|
||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
sample_ids: an iterable of sample IDs to select | ||
|
||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` containing only the specified | ||
samples | ||
""" | ||
return sample_collection.select(sample_ids) | ||
|
||
|
||
def exclude_samples(sample_collection, sample_ids): | ||
"""Excludes the specified samples from the collection. | ||
|
||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
sample_ids: an iterable of sample IDs to exclude | ||
|
||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` that excludes the specified | ||
samples | ||
""" | ||
return sample_collection.exclude(sample_ids) | ||
|
||
|
||
def select_objects(sample_collection, objects): | ||
"""Selects the specified objects from the sample collection. | ||
|
||
The returned view will omit samples, sample fields, and individual objects | ||
that do not appear in the provided ``objects`` argument, which should have | ||
the following format:: | ||
|
||
[ | ||
{ | ||
"sample_id": "5f8d254a27ad06815ab89df4", | ||
"field": "ground_truth", | ||
"object_id": "5f8d254a27ad06815ab89df3", | ||
}, | ||
{ | ||
"sample_id": "5f8d255e27ad06815ab93bf8", | ||
"field": "ground_truth", | ||
"object_id": "5f8d255e27ad06815ab93bf6", | ||
}, | ||
... | ||
] | ||
|
||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
objects: a list of dicts defining the objects to select | ||
|
||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` containing only the specified | ||
objects | ||
""" | ||
sample_ids, object_ids = _parse_objects(objects) | ||
|
||
label_schema = sample_collection.get_field_schema( | ||
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label | ||
) | ||
|
||
view = sample_collection.select(sample_ids) | ||
view = view.select_fields(list(object_ids.keys())) | ||
|
||
for field, object_ids in object_ids.items(): | ||
label_filter = F("_id").is_in(object_ids) | ||
view = _apply_label_filter(view, label_schema, field, label_filter) | ||
|
||
return view | ||
|
||
|
||
def exclude_objects(sample_collection, objects): | ||
"""Excludes the specified objects from the sample collection. | ||
|
||
The returned view will omit the labels specified in the provided | ||
``objects`` argument, which should have the following format:: | ||
|
||
[ | ||
{ | ||
"sample_id": "5f8d254a27ad06815ab89df4", | ||
"field": "ground_truth", | ||
"object_id": "5f8d254a27ad06815ab89df3", | ||
}, | ||
{ | ||
"sample_id": "5f8d255e27ad06815ab93bf8", | ||
"field": "ground_truth", | ||
"object_id": "5f8d255e27ad06815ab93bf6", | ||
}, | ||
... | ||
] | ||
|
||
Args: | ||
sample_collection: a | ||
:class:`fiftyone.core.collections.SampleCollection` | ||
objects: a list of dicts defining the objects to exclude | ||
|
||
Returns: | ||
a :class:`fiftyone.core.view.DatasetView` that excludes the specified | ||
objects | ||
""" | ||
_, object_ids = _parse_objects(objects) | ||
|
||
label_schema = sample_collection.get_field_schema( | ||
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label | ||
) | ||
|
||
view = sample_collection | ||
for field, object_ids in object_ids.items(): | ||
label_filter = ~F("_id").is_in(object_ids) | ||
view = _apply_label_filter(view, label_schema, field, label_filter) | ||
|
||
return view | ||
|
||
|
||
def _parse_objects(objects): | ||
sample_ids = set() | ||
object_ids = defaultdict(set) | ||
for obj in objects: | ||
sample_ids.add(obj["sample_id"]) | ||
object_ids[obj["field"]].add(ObjectId(obj["object_id"])) | ||
|
||
return sample_ids, object_ids | ||
|
||
|
||
def _apply_label_filter(sample_collection, label_schema, field, label_filter): | ||
if field not in label_schema: | ||
raise ValueError( | ||
"%s '%s' has no label field '%s'" | ||
% ( | ||
sample_collection.__class__.__name__, | ||
sample_collection.name, | ||
field, | ||
) | ||
) | ||
|
||
field_type = label_schema[field] | ||
label_type = field_type.document_type | ||
|
||
if label_type in ( | ||
fol.Classification, | ||
fol.Detection, | ||
fol.Polyline, | ||
fol.Keypoint, | ||
): | ||
return sample_collection.filter_field(field, label_filter) | ||
|
||
if label_type is fol.Classifications: | ||
return sample_collection.filter_classifications(field, label_filter) | ||
|
||
if label_type is fol.Detections: | ||
return sample_collection.filter_detections(field, label_filter) | ||
|
||
if label_type is fol.Polylines: | ||
return sample_collection.filter_polylines(field, label_filter) | ||
|
||
if label_type is fol.Keypoints: | ||
return sample_collection.filter_keypoints(field, label_filter) | ||
|
||
msg = "Ignoring unsupported field '%s' (%s)" % (field, label_type) | ||
warnings.warn(msg) | ||
return sample_collection |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be fine (as well as consistent) to have this file in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't because it involves downloading the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I didn't see that. Looks like it's running in GitHub Actions and not taking too much time there, but separating it into another folder (like you did) is probably helpful for people running tests locally. |
||
Unit tests for the :mod:`fiftyone.utils.selection` module. | ||
|
||
| Copyright 2017-2020, Voxel51, Inc. | ||
| `voxel51.com <https://voxel51.com/>`_ | ||
| | ||
""" | ||
import random | ||
import unittest | ||
|
||
import fiftyone as fo | ||
import fiftyone.core.dataset as fod | ||
import fiftyone.utils.selection as fous | ||
import fiftyone.zoo as foz | ||
|
||
|
||
class SelectionTests(unittest.TestCase): | ||
def test_select_objects(self): | ||
num_samples_to_select = 5 | ||
max_objects_per_sample_to_select = 3 | ||
|
||
dataset = foz.load_zoo_dataset( | ||
"quickstart", dataset_name=fod.get_default_dataset_name(), | ||
) | ||
|
||
# Generate some random selections | ||
selected_objects = [] | ||
for sample in dataset.take(num_samples_to_select): | ||
detections = sample.ground_truth.detections | ||
|
||
max_num_objects = min( | ||
len(detections), max_objects_per_sample_to_select | ||
) | ||
if max_num_objects >= 1: | ||
num_objects = random.randint(1, max_num_objects) | ||
else: | ||
num_objects = 0 | ||
|
||
for detection in random.sample(detections, num_objects): | ||
selected_objects.append( | ||
{ | ||
"sample_id": sample.id, | ||
"field": "ground_truth", | ||
"object_id": detection.id, | ||
} | ||
) | ||
|
||
selected_view = fous.select_objects(dataset, selected_objects) | ||
excluded_view = fous.exclude_objects(dataset, selected_objects) | ||
|
||
total_objects = _count_detections(dataset, "ground_truth") | ||
num_selected_objects = len(selected_objects) | ||
num_objects_in_selected_view = _count_detections( | ||
selected_view, "ground_truth" | ||
) | ||
num_objects_in_excluded_view = _count_detections( | ||
excluded_view, "ground_truth" | ||
) | ||
num_objects_excluded = total_objects - num_objects_in_excluded_view | ||
|
||
self.assertEqual(num_selected_objects, num_objects_in_selected_view) | ||
self.assertEqual(num_selected_objects, num_objects_excluded) | ||
|
||
|
||
def _count_detections(sample_collection, label_field): | ||
num_objects = 0 | ||
for sample in sample_collection: | ||
num_objects += len(sample[label_field].detections) | ||
|
||
return num_objects | ||
|
||
|
||
if __name__ == "__main__": | ||
fo.config.show_progress_bars = False | ||
unittest.main(verbosity=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also need
frame_number
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't add that yet because filtering frame-level labels is not yet supported. I believe @benjaminpkane plans to work on that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I plan on starting work on this today.