Skip to content

Commit

Permalink
batched dynamic shape write much faster
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbuniat committed Dec 1, 2020
1 parent 6100e9a commit 142b1ba
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 45 deletions.
11 changes: 1 addition & 10 deletions examples/upload_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@
if __name__ == "__main__":
path = "./data/test/tfds_new/coco"
with Timer("Eurosat TFDS"):
out_ds = hub.Dataset.from_tfds("coco", num=100)
out_ds = hub.Dataset.from_tfds("coco", num=10000)

ds = hub.Dataset(
"./data/test/tfds_new/coco2", schema=out_ds.schema, shape=(10000,), mode="w"
)
print(out_ds.schema)
for key in ds.keys:
print(ds[key].chunksize)
exit()
res_ds = out_ds.store(path)
ds = hub.load(path)
print(ds)
print(ds["image", 0].compute())
20 changes: 14 additions & 6 deletions hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,17 +231,23 @@ def test_dataset_batch_write():

ds["image", 0:4] = 4 * np.ones((4, 67, 65, 3))

assert(ds["image", 0].numpy() == 4 * np.ones((67, 65, 3))).all()
assert(ds["image", 1].numpy() == 4 * np.ones((67, 65, 3))).all()
assert(ds["image", 2].numpy() == 4 * np.ones((67, 65, 3))).all()
assert(ds["image", 3].numpy() == 4 * np.ones((67, 65, 3))).all()
assert (ds["image", 0].numpy() == 4 * np.ones((67, 65, 3))).all()
assert (ds["image", 1].numpy() == 4 * np.ones((67, 65, 3))).all()
assert (ds["image", 2].numpy() == 4 * np.ones((67, 65, 3))).all()
assert (ds["image", 3].numpy() == 4 * np.ones((67, 65, 3))).all()

ds["image", 5:7] = [2 * np.ones((60, 65, 3)), 3 * np.ones((54, 30, 3))]

assert(ds["image", 5].numpy() == 2 * np.ones((60, 65, 3))).all()
assert(ds["image", 6].numpy() == 3 * np.ones((54, 30, 3))).all()
assert (ds["image", 5].numpy() == 2 * np.ones((60, 65, 3))).all()
assert (ds["image", 6].numpy() == 3 * np.ones((54, 30, 3))).all()


def test_dataset_batch_write_2():
schema = {"image": Image(shape=(None, None, 3), max_shape=(640, 640, 3))}
ds = Dataset("./data/batch", shape=(100,), mode="w", schema=schema)

ds["image", 0:14] = [np.ones((640 - i, 640, 3)) for i in range(14)]


@pytest.mark.skipif(not gcp_creds_exist(), reason="requires gcp credentials")
def test_dataset_gcs():
Expand Down Expand Up @@ -331,6 +337,8 @@ def test_append_dataset():
# test_tensorview_slicing()
# test_datasetview_slicing()
# test_dataset()
test_dataset_batch_write_2()
exit()
test_append_dataset()
test_dataset2()
test_text_dataset()
Expand Down
11 changes: 3 additions & 8 deletions hub/compute/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,10 @@ def upload_chunk(i_batch, key, ds):
batch = ray.get(batch)

# TODO some sort of syncronizer across nodes
# FIXME replace below 8 lines with ds[key, i * length : (i + 1) * length] = batch
if not ds[key].is_dynamic:
if len(batch) != 1:
ds[key, i * length : (i + 1) * length] = batch
else:
ds[key, i * length] = batch[0]
if length != 1:
ds[key, i * length : (i + 1) * length] = batch
else:
for k, el in enumerate(batch):
ds[key, i * length + k] = el
ds[key, i * length] = batch[0]

def upload(self, results, url: str, token: dict, progressbar: bool = True):
"""Batchified upload of results
Expand Down
12 changes: 4 additions & 8 deletions hub/compute/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,11 @@ def upload(self, results, ds: Dataset, token: dict, progressbar: bool = True):

def upload_chunk(i_batch):
i, batch = i_batch
if not ds[key].is_dynamic:
batch_length = len(batch)
if batch_length != 1:
ds[key, i * length : i * length + batch_length] = batch
else:
ds[key, i * length] = batch[0]
batch_length = len(batch)
if batch_length != 1:
ds[key, i * length : i * length + batch_length] = batch
else:
for k, el in enumerate(batch):
ds[key, i * length + k] = el
ds[key, i * length] = batch[0]

index_batched_values = list(
zip(list(range(len(batched_values))), batched_values)
Expand Down
25 changes: 12 additions & 13 deletions hub/store/dynamic_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,23 @@ def __setitem__(self, slice_, value):
for item in value:
max_shape = tuple([max(value) for value in zip(max_shape, item.shape)])
for i in range(len(value)):
pad = [(0, max_shape[dim] - value[i].shape[dim]) for dim in range(value[i].ndim)]
pad = [
(0, max_shape[dim] - value[i].shape[dim])
for dim in range(value[i].ndim)
]
value[i] = np.pad(value[i], pad)
real_shapes = np.array([max_shape[i] for i in range(len(max_shape)) if i + 1 in self._dynamic_dims])
real_shapes = np.array(
[
max_shape[i]
for i in range(len(max_shape))
if i + 1 in self._dynamic_dims
]
)
else:
real_shapes = None

if not self._enabled_dynamicness:
real_shapes = list(value.shape) if hasattr(value, "shape") else [1]
real_shapes = list(value.shape) if hasattr(value, "shape") else real_shapes

slice_ = self._get_slice(slice_, real_shapes)
value = self.check_value_shape(value, slice_)
Expand Down Expand Up @@ -254,17 +263,7 @@ def resize_shape(self, size: int) -> None:
self._resize_shape(self._storage_tensor, size)

if self._dynamic_tensor:
print(
"-> write before",
self._dynamic_tensor.shape,
self._dynamic_tensor.chunks,
)
self._resize_shape(self._dynamic_tensor, size)
print(
"-> write after",
self._dynamic_tensor.shape,
self._dynamic_tensor.chunks,
)

self.fs_map[".hub.dynamic_tensor"] = bytes(
json.dumps({"shape": self.shape}), "utf-8"
Expand Down

0 comments on commit 142b1ba

Please sign in to comment.