Skip to content

Commit

Permalink
initial work
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Aug 25, 2022
1 parent ef536b1 commit 5eb1cf3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
18 changes: 12 additions & 6 deletions fiftyone/core/odm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,17 +573,23 @@ def import_document(json_path):
return json_util.loads(f.read())


def import_collection(json_path):
def import_collection(json_path, key="documents"):
"""Imports the collection from JSON on disk.
Args:
json_path: the path to the collection on disk
key ("documents"): the field name under which the documents are stored
Returns:
a BSON dict
a tuple of
- the list of BSON documents
- the number of documents
"""
with open(json_path, "r") as f:
return json_util.loads(f.read())
docs = json_util.loads(f.read()).get(key, [])

return docs, len(docs)


def insert_documents(docs, coll, ordered=False):
Expand All @@ -593,8 +599,8 @@ def insert_documents(docs, coll, ordered=False):
already set.
Args:
docs: the list of BSON document dicts to insert
coll: a pymongo collection instance
docs: an iterable of BSON document dicts
coll: a pymongo collection
ordered (False): whether the documents must be inserted in order
"""
try:
Expand All @@ -610,7 +616,7 @@ def bulk_write(ops, coll, ordered=False):
Args:
ops: a list of pymongo operations
coll: a pymongo collection instance
coll: a pymongo collection
ordered (False): whether the operations must be performed in order
"""
try:
Expand Down
46 changes: 32 additions & 14 deletions fiftyone/utils/data/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
from copy import copy
import inspect
import itertools
import logging
import os
import random
Expand Down Expand Up @@ -831,25 +831,39 @@ def _preprocess_list(self, l):
applying the values of the ``shuffle``, ``seed``, and ``max_samples``
parameters of the importer.
You may also provide an iterable, in which case the output will also be
an iterable, unless the elements must be shuffled, in which case the
iterable must be read in-memory into a list and returned as a list.
Args:
l: a list
l: a list or iterable
Returns:
a processed copy of the list
a processed copy of the list/iterable
"""
if self.shuffle:
if self.seed is not None:
random.seed(self.seed)

l = copy(l)
random.shuffle(l)
_random = _get_rng(self.seed)
l = list(l).copy()
_random.shuffle(l)

if self.max_samples is not None:
l = l[: self.max_samples]
if isinstance(l, (list, tuple)):
l = l[: self.max_samples]
else:
l = itertools.islice(l, self.max_samples)

return l


def _get_rng(seed):
if seed is None:
return random

_random = random.Random()
_random.seed(seed)
return _random


class BatchDatasetImporter(DatasetImporter):
"""Base interface for importers that load all of their samples in a single
call to :meth:`import_samples`.
Expand Down Expand Up @@ -1566,7 +1580,7 @@ def _import_samples(self, dataset, dataset_dict, tags=None):
#

logger.info("Importing samples...")
samples = foo.import_collection(self._samples_path).get("samples", [])
samples, _ = foo.import_collection(self._samples_path, key="samples")

samples = self._preprocess_list(samples)

Expand All @@ -1577,17 +1591,21 @@ def _import_samples(self, dataset, dataset_dict, tags=None):
# Prepend `dataset_dir` to all relative paths
rel_dir = self.dataset_dir

for sample in samples:
def parse_sample(sample):
filepath = sample["filepath"]
if not os.path.isabs(filepath):
sample["filepath"] = os.path.join(rel_dir, filepath)

if tags is not None:
for sample in samples:
if tags is not None:
sample["tags"].extend(tags)

return sample

samples = list(map(parse_sample, samples))

foo.insert_documents(samples, dataset._sample_collection, ordered=True)

# @todo return from `insert_documents()` so `samples` can be iterable
sample_ids = [s["_id"] for s in samples]

#
Expand All @@ -1596,7 +1614,7 @@ def _import_samples(self, dataset, dataset_dict, tags=None):

if os.path.isfile(self._frames_path):
logger.info("Importing frames...")
frames = foo.import_collection(self._frames_path).get("frames", [])
frames, _ = foo.import_collection(self._frames_path, key="frames")

if self.max_samples is not None:
frames = [
Expand Down

0 comments on commit 5eb1cf3

Please sign in to comment.