-
Notifications
You must be signed in to change notification settings - Fork 631
/
transform.py
464 lines (399 loc) · 14.5 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import zarr
import numpy as np
import math
from typing import Dict, Iterable
from hub.api.dataset import Dataset
from tqdm import tqdm
from collections.abc import MutableMapping
from hub.utils import batchify
from hub.api.dataset_utils import get_value, slice_extract_info, slice_split, str_to_int
import collections.abc as abc
from hub.api.datasetview import DatasetView
from pathos.pools import ProcessPool, ThreadPool
from hub.schema import Primitive
from hub.schema.sequence import Sequence
from hub.schema.features import featurify
import posixpath
from hub.defaults import OBJECT_CHUNK
def get_sample_size(schema, workers):
"""Given Schema, decides how many samples to take at once and returns it"""
schema = featurify(schema)
samples = 10000
for feature in schema._flatten():
shp = list(feature.max_shape)
if len(shp) == 0:
shp = [1]
sz = np.dtype(feature.dtype).itemsize
if feature.dtype == "object":
sz = (16 * 1024 * 1024 * 8) / 128
def prod(shp):
res = 1
for s in shp:
res *= s
return res
samples = min(samples, (16 * 1024 * 1024 * 8) // (prod(shp) * sz))
samples = max(samples, 1)
return samples * workers
class Transform:
def __init__(
self, func, schema, ds, scheduler: str = "single", workers: int = 1, **kwargs
):
"""| Transform applies a user defined function to each sample in single threaded manner.
Parameters
----------
func: function
user defined function func(x, **kwargs)
schema: dict of dtypes
the structure of the final dataset that will be created
ds: Iterative
input dataset or a list that can be iterated
scheduler: str
choice between "single", "threaded", "processed"
workers: int
how many threads or processes to use
**kwargs:
additional arguments that will be passed to func as static argument for all samples
"""
self._func = func
self.schema = schema
self._ds = ds
self.kwargs = kwargs
self.workers = workers
if isinstance(self._ds, Transform):
self.base_ds = self._ds.base_ds
self._func = self._ds._func[:]
self._func.append(func)
self.kwargs = self._ds.kwargs[:]
self.kwargs.append(kwargs)
else:
self.base_ds = ds
self._func = [func]
self.kwargs = [kwargs]
if scheduler == "threaded" or (scheduler == "single" and workers > 1):
self.map = ThreadPool(nodes=workers).map
elif scheduler == "processed":
self.map = ProcessPool(nodes=workers).map
elif scheduler == "single":
self.map = map
elif scheduler == "ray":
try:
from ray.util.multiprocessing import Pool as RayPool
except Exception:
pass
self.map = RayPool().map
else:
raise Exception(
f"Scheduler {scheduler} not understood, please use 'single', 'threaded', 'processed'"
)
def __len__(self):
return self.shape[0]
def __getitem__(self, slice_):
"""| Get an item to be computed without iterating on the whole dataset.
| Creates a dataset view, then a temporary dataset to apply the transform.
Parameters:
----------
slice_: slice
Gets a slice or slices from dataset
"""
if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str):
slice_ = [slice_]
slice_ = list(slice_)
subpath, slice_list = slice_split(slice_)
if len(slice_list) == 0:
slice_list = [slice(None, None, None)]
num, ofs = slice_extract_info(slice_list[0], self.shape[0])
ds_view = DatasetView(
dataset=self._ds,
num_samples=num,
offset=ofs,
squeeze_dim=isinstance(slice_list[0], int),
)
path = posixpath.expanduser("~/.activeloop/tmparray")
new_ds = self.store(path, length=num, ds=ds_view, progressbar=False)
index = 1 if len(slice_) > 1 else 0
slice_[index] = (
slice(None, None, None) if not isinstance(slice_list[0], int) else 0
) # Get all shape dimension since we already sliced
return new_ds[slice_]
def __iter__(self):
for index in range(len(self)):
yield self[index]
@classmethod
def _flatten_dict(self, d: Dict, parent_key="", schema=None):
"""| Helper function to flatten dictionary of a recursive tensor
Parameters
----------
d: dict
"""
items = []
for k, v in d.items():
new_key = parent_key + "/" + k if parent_key else k
if isinstance(v, MutableMapping) and not isinstance(
self.dtype_from_path(new_key, schema), Sequence
):
items.extend(
self._flatten_dict(v, parent_key=new_key, schema=schema).items()
)
else:
items.append((new_key, v))
return dict(items)
@classmethod
def _flatten(cls, items, schema):
"""
Takes a dictionary or list of dictionary.
Returns a dictionary of concatenated values.
Dictionary follows schema.
"""
final_item = {}
for item in cls._unwrap(items):
item = cls._flatten_dict(item, schema=schema)
for k, v in item.items():
if k in final_item:
final_item[k].append(v)
else:
final_item[k] = [v]
return final_item
@classmethod
def dtype_from_path(cls, path, schema):
"""
Helper function to get the dtype from the path
"""
path = path.split("/")
cur_type = schema
for subpath in path[:-1]:
cur_type = cur_type[subpath]
cur_type = cur_type.dict_
return cur_type[path[-1]]
@classmethod
def _unwrap(cls, results):
"""
If there is any list then unwrap it into its elements
"""
items = []
for r in results:
if isinstance(r, dict):
items.append(r)
else:
items.extend(r)
return items
def _split_list_to_dicts(self, xs):
"""| Helper function that transform list of dicts into dicts of lists
Parameters
----------
xs: list of dicts
Returns
----------
xs_new: dicts of lists
"""
xs_new = {}
for x in xs:
if isinstance(x, list):
x = dict(
zip(self._flatten_dict(self.schema, schema=self.schema).keys(), x)
)
for key, value in x.items():
if key in xs_new:
xs_new[key].append(value)
else:
xs_new[key] = [value]
return xs_new
def _pbar(self, show: bool = True):
"""
Returns a progress bar, if empty then it function does nothing
"""
def _empty_pbar(xs, **kwargs):
return xs
single_threaded = self.map == map
return tqdm if show and single_threaded else _empty_pbar
def create_dataset(
self, url: str, length: int = None, token: dict = None, public: bool = True
):
"""Helper function to creat a dataset"""
shape = (length,)
ds = Dataset(
url,
mode="w",
shape=shape,
schema=self.schema,
token=token,
fs=zarr.storage.MemoryStore() if "tmp" in url else None,
cache=False,
public=public,
)
return ds
def upload(self, results, ds: Dataset, token: dict, progressbar: bool = True):
"""Batchified upload of results.
For each tensor batchify based on its chunk and upload.
If tensor is dynamic then still upload element by element.
For dynamic tensors, it disable dynamicness and then enables it back.
Parameters
----------
dataset: hub.Dataset
Dataset object that should be written to
results:
Output of transform function
progressbar: bool
Returns
----------
ds: hub.Dataset
Uploaded dataset
"""
for key, value in results.items():
chunk = ds[key].chunksize[0]
chunk = 1 if chunk == 0 else chunk
value = get_value(value)
value = str_to_int(value, ds.dataset.tokenizer)
num_chunks = math.ceil(len(value) / (chunk * self.workers))
length = num_chunks * chunk if self.workers != 1 else len(value)
batched_values = batchify(value, length)
def upload_chunk(i_batch):
i, batch = i_batch
length = len(batch)
slice_ = slice(i * length, (i + 1) * length)
ds[key, slice_] = batch
index_batched_values = list(
zip(list(range(len(batched_values))), batched_values)
)
# Disable dynamic arrays
ds.dataset._tensors[f"/{key}"].disable_dynamicness()
list(self.map(upload_chunk, index_batched_values))
# Enable and rewrite shapes
if ds.dataset._tensors[f"/{key}"].is_dynamic:
ds.dataset._tensors[f"/{key}"].enable_dynamicness()
ds.dataset._tensors[f"/{key}"].set_shape(
[slice(ds.offset, ds.offset + len(value))], value
)
ds.commit()
return ds
def call_func(self, fn_index, item, as_list=False):
"""Calls all the functions one after the other
Parameters
----------
fn_index: int
The index starting from which the functions need to be called
item:
The item on which functions need to be applied
as_list: bool, optional
If true then treats the item as a list.
Returns
----------
result:
The final output obtained after all transforms
"""
result = item
if fn_index < len(self._func):
if as_list:
result = [self.call_func(fn_index, it) for it in result]
else:
result = self._func[fn_index](result, **self.kwargs[fn_index])
result = self.call_func(fn_index + 1, result, isinstance(result, list))
result = self._unwrap(result) if isinstance(result, list) else result
return result
def store_shard(self, ds_in: Iterable, ds_out: Dataset, offset: int, token=None):
"""
Takes a shard of iteratable ds_in, compute and stores in DatasetView
"""
def _func_argd(item):
if isinstance(item, DatasetView) or isinstance(item, Dataset):
item = item.numpy()
result = self.call_func(
0, item
) # If the iterable obtained from iterating ds_in is a list, it is not treated as list
return result
ds_in = list(ds_in)
results = self.map(
_func_argd,
ds_in,
)
results = self._unwrap(results)
results = self.map(lambda x: self._flatten_dict(x, schema=self.schema), results)
results = list(results)
results = self._split_list_to_dicts(results)
results_values = list(results.values())
if len(results_values) == 0:
return 0
n_results = len(results_values[0])
if n_results == 0:
return 0
additional = max(offset + n_results - ds_out.shape[0], 0)
ds_out.append_shape(additional)
self.upload(
results,
ds_out[offset : offset + n_results],
token=token,
)
return n_results
def store(
self,
url: str,
token: dict = None,
length: int = None,
ds: Iterable = None,
progressbar: bool = True,
sample_per_shard: int = None,
public: bool = True,
):
"""| The function to apply the transformation for each element in batchified manner
Parameters
----------
url: str
path where the data is going to be stored
token: str or dict, optional
If url is refering to a place where authorization is required,
token is the parameter to pass the credentials, it can be filepath or dict
length: int
in case shape is None, user can provide length
ds: Iterable
progressbar: bool
Show progress bar
sample_per_shard: int
How to split the iterator not to overfill RAM
public: bool, optional
only applicable if using hub storage, ignored otherwise
setting this to False allows only the user who created it to access the dataset and
the dataset won't be visible in the visualizer to the public
Returns
----------
ds: hub.Dataset
uploaded dataset
"""
ds_in = ds or self.base_ds
# compute shard length
if sample_per_shard is None:
n_samples = get_sample_size(self.schema, self.workers)
else:
n_samples = sample_per_shard
try:
length = len(ds_in) if hasattr(ds_in, "__len__") else n_samples
except Exception:
length = length or n_samples
if length < n_samples:
n_samples = length
ds_out = self.create_dataset(url, length=length, token=token, public=public)
def batchify_generator(iterator: Iterable, size: int):
batch = []
for el in iterator:
batch.append(el)
if len(batch) >= size:
yield batch
batch = []
yield batch
start = 0
total = 0
with tqdm(
total=length,
unit_scale=True,
unit=" items",
desc="Computing the transormation",
) as pbar:
for ds_in_shard in batchify_generator(ds_in, n_samples):
n_results = self.store_shard(ds_in_shard, ds_out, start, token=token)
total += n_results
pbar.update(len(ds_in_shard))
start += n_results
ds_out.resize_shape(total)
ds_out.commit()
return ds_out
@property
def shape(self):
return self._ds.shape