Skip to content

Commit

Permalink
batchified transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbuniat committed Nov 14, 2020
1 parent 623efff commit 56d06be
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 26 deletions.
7 changes: 7 additions & 0 deletions examples/upload_tfds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import hub
from hub.utils import Timer

if __name__ == "__main__":
with Timer("Eurosat TFDS"):
out_ds = hub.Dataset.from_tfds('eurosat')
res_ds = out_ds.store("./data/test/tfds_new/eurosat")
1 change: 1 addition & 0 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
def chunksize(self):
# FIXME assumes chunking is done on the first sample
chunks = [t.chunksize[0] for t in self._tensors.values()]
print(chunks)
return compute_lcm(chunks)

@property
Expand Down
4 changes: 4 additions & 0 deletions hub/api/tensorview.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,7 @@ def make_shape(self, shape):
shape.append(shape[i])
final_shape = [dim for dim in shape if dim != 1]
return tuple(final_shape)

@property
def chunksize(self):
return self.dataset._tensors[self.subpath].chunksize
15 changes: 8 additions & 7 deletions hub/compute/pathos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hub
from hub.utils import batch
from hub.utils import batchify
from hub.compute.transform import Transform

try:
Expand All @@ -19,24 +19,25 @@ def store(self, url, token=None):
mary chunks with compute
"""
ds = hub.Dataset(
url, mode="w", shape=self._ds.shape, schema=self._schema, token=token, cache=False
url, mode="w", shape=(len(self._ds),), schema=self._schema, token=token, cache=False
)

# Chunkwise compute
batch_size = ds.chunksize

def batchify(ds):
return tuple(batch(ds, batch_size))
def batchify_remote(ds):
return tuple(batchify(ds, batch_size))

def batched_func(i_xs):
i, xs = i_xs
print(i)
print(xs)
xs = [self._func(x) for x in xs]
self._transfer_batch(ds, i, xs)

batched = batchify(ds)

batched = batchify_remote(ds)
results = self.map(batched_func, enumerate(batched))

#uploading can be done per chunk per item
results = list(results)
return ds

Expand Down
8 changes: 4 additions & 4 deletions hub/compute/ray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hub
from hub.utils import batch
from hub.utils import batchify
from hub.compute import Transform

try:
Expand Down Expand Up @@ -40,10 +40,10 @@ def store_chunkwise(self, url, token=None):
batch_size = ds.chunksize

@remote(num_returns=int(len(ds) / batch_size))
def batchify(results):
return tuple(batch(results, batch_size))
def batchify_remote(results):
return tuple(batchify(results, batch_size))

results_batched = batchify.remote(results)
results_batched = batchify_remote.remote(results)
if isinstance(results_batched, list):
results = [
self._transfer_batch.remote(self, ds, i, result)
Expand Down
40 changes: 28 additions & 12 deletions hub/compute/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hub
from collections.abc import MutableMapping
from hub.features.features import Primitive
from tqdm import tqdm
from hub.utils import batchify


class Transform:
Expand Down Expand Up @@ -28,19 +30,17 @@ def store(self, url, token=None, length=None):
url, mode="w", shape=shape, schema=self._schema, token=token, cache=False,
)

# apply transformation and some rewrapping
results = [self._func(item) for item in self._ds]

for i, result in enumerate(results):
dic = self.flatten_dict(result)
for key in dic:
path_key = key.split("/")
if isinstance(self._schema[path_key[0]], Primitive):
ds[path_key[0], i] = result[path_key[0]]
else:
val = result
for path in path_key:
val = val.get(path)
ds[key, i] = val
results = [self.flatten_dict(r) for r in results]
results = self.split_list_to_dicts(results)

# batchified upload
for key, value in results.items():
length = ds[key].chunksize[0]
batched_values = batchify(value, length)
for i, batch in enumerate(batched_values):
ds[key, i * length:(i + 1) * length] = batch
return ds

def flatten_dict(self, d, parent_key=''):
Expand All @@ -53,6 +53,22 @@ def flatten_dict(self, d, parent_key=''):
items.append((new_key, v))
return dict(items)


def split_list_to_dicts(self, xs):
"""
Transform list of dicts into dicts of lists
"""
xs_new = {}
for x in xs:
for key, value in x.items():
if key in xs_new:
xs_new[key].append(value)
else:
xs_new[key] = [value]
return xs_new



def dtype_from_path(self, path):
path = path.split('/')
cur_type = self._schema
Expand Down
4 changes: 2 additions & 2 deletions hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def compute_lcm(a):
"""
lcm = a[0]
for i in a[1:]:
lcm = lcm * i / gcd(lcm, i)
lcm = lcm * i // gcd(lcm, i)
return int(lcm)


def batch(iterable, n=1):
def batchify(iterable, n=1):
"""
Batchify an iteratable
"""
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ lz4>=3,<4
zarr==2.5
lmdb==1.0.0
boto3==1.16.10

0 comments on commit 56d06be

Please sign in to comment.