Skip to content

Commit

Permalink
added argument for max text label size
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli committed Sep 14, 2020
1 parent 8690049 commit 1a2fe37
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 20 deletions.
3 changes: 2 additions & 1 deletion examples/fashion-mnist/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def main():
ds = dataset.load("mnist/fashion-mnist")

# Transform into pytorch
ds = ds.to_pytorch()
# max_text_len is an optional argument that sets the maximum length of text labels, default is 30
ds = ds.to_pytorch(max_text_len=15)

# Splitting back into the original train and test sets, instead of random split
train_dataset = torch.utils.data.Subset(ds, range(60000))
Expand Down
3 changes: 2 additions & 1 deletion examples/fashion-mnist/train_tf_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def main():
ds = dataset.load("mnist/fashion-mnist")

# transform into Tensorflow dataset
ds = ds.to_tensorflow()
# max_text_len is an optional argument that fixes the maximum length of text labels
ds = ds.to_tensorflow(max_text_len = 15)

# converting ds so that it can be directly used in model.fit
ds = ds.map(lambda x: to_model_fit(x))
Expand Down
3 changes: 2 additions & 1 deletion examples/fashion-mnist/train_tf_gradient_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def main():
ds = dataset.load("mnist/fashion-mnist")

# transform into Tensorflow dataset
ds = ds.to_tensorflow()
# max_text_len is an optional argument that sets the maximum length of text labels, default is 30
ds = ds.to_tensorflow(max_text_len = 15)

# Splitting back into the original train and test sets
train_dataset = ds.take(60000)
Expand Down
34 changes: 17 additions & 17 deletions hub/collections/dataset/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def store(self, tag, creds=None, session_creds=True) -> "Dataset":

return load(tag, creds)

def to_pytorch(self, transform=None):
def to_pytorch(self, transform = None, max_text_len = 30):
"""
Transforms into pytorch dataset
Expand All @@ -568,9 +568,9 @@ def to_pytorch(self, transform=None):
transform: func
any transform that takes input a dictionary of a sample and returns transformed dictionary
"""
return TorchDataset(self, transform)
return TorchDataset(self, transform, max_text_len)

def to_tensorflow(self):
def to_tensorflow(self, max_text_len = 30):
"""
Transforms into tensorflow dataset
"""
Expand All @@ -588,12 +588,12 @@ def tf_gen(step=4):
arrs
)
arrs = arrs.compute()
for ind,arr in enumerate(arrs):
for ind, arr in enumerate(arrs):
if arr.dtype.type is np.str_:
arr = [([ord(x) for x in sample.tolist()]) for sample in arr]
arr = np.array([np.pad(sample, (0, 20-len(sample)), 'constant', constant_values=(32)) for sample in arr])
arrs[ind]=arr
arr = [([ ord(x) for x in sample.tolist()[0:max_text_len] ] ) for sample in arr]
arr = np.array([np.pad(sample, (0, max_text_len-len(sample)), 'constant', constant_values=(32)) for sample in arr])
arrs[ind] = arr

for i in range(step):
sample = {key: r[i] for key, r in zip(self[index].keys(), arrs)}
yield sample
Expand All @@ -609,13 +609,13 @@ def tf_dtype(np_dtype):
output_shapes={}
output_types={}
for key in self.keys():
output_types[key]=tf_dtype(self._tensors[key].dtype)
output_shapes[key]= self._tensors[key].shape[1:]
output_types[key] = tf_dtype(self._tensors[key].dtype)
output_shapes[key] = self._tensors[key].shape[1:]

# if this is a string, we change the type to int, as it's going to become ascii. shape is also set to None
if output_types[key]==tf.dtypes.as_dtype("string"):
output_types[key]=tf.dtypes.as_dtype("int8")
output_shapes[key]=None
if output_types[key] == tf.dtypes.as_dtype("string"):
output_types[key] = tf.dtypes.as_dtype("int8")
output_shapes[key] = None

# TODO use None for dimensions you don't know the length tf.TensorShape([None])
# FIXME Dataset Generator is not very good with multiprocessing but its good for fast tensorflow support
Expand Down Expand Up @@ -742,12 +742,13 @@ def _is_tensor_dynamic(tensor):


class TorchDataset:
def __init__(self, ds, transform=None):
def __init__(self, ds, transform = None, max_text_len = 30):
self._ds = ds
self._transform = transform
self._dynkeys = {
key for key in self._ds.keys() if _is_tensor_dynamic(self._ds[key])
}
self._max_text_len = max_text_len

def cost(nbytes, time):
print(nbytes, time)
Expand Down Expand Up @@ -784,8 +785,8 @@ def __iter__(self):
def _to_tensor(self, key, sample):
if key not in self._dynkeys:
if isinstance(sample, np.str_):
sample = np.array([ord(x) for x in sample.tolist()])
sample=np.pad(sample, (0, 20-len(sample)), 'constant',constant_values=(32))
sample = np.array([ ord(x) for x in sample.tolist()[0:self._max_text_len] ])
sample=np.pad(sample, (0, self._max_text_len-len(sample)), 'constant',constant_values=(32))
return torch.tensor(sample)
else:
return [torch.tensor(item) for item in sample]
Expand All @@ -801,7 +802,6 @@ def collate_fn(self, batch):

return ans


def _dask_concat(arr):
if len(arr) == 1:
return arr[0]
Expand Down

0 comments on commit 1a2fe37

Please sign in to comment.