Skip to content

Commit

Permalink
publish pytorch examples
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Nov 30, 2018
1 parent b0c9ca0 commit 28a3fe6
Show file tree
Hide file tree
Showing 6 changed files with 1,351 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ Suppose `x` shape was `(3, 4, 5)`, then `y` has shape ...
- keras, tensorflow.layers, mxnet and gluon: `(3, 20)`
- pytorch: no such function

<!-- TODO examples for depth-to-space and pixel shuffle? transpose vs permute? -->
<!-- TODO examples for depth-to-space and pixel shuffle? transpose vs permute? torch.repeat is numpy.tile -->

## Supported frameworks

Expand Down
1,282 changes: 1,282 additions & 0 deletions docs/pytorch-examples.html

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def reduce(self, x, operation, axes):
def stack_on_zeroth_dimension(self, tensors: list):
raise NotImplementedError()

def tile(self, x, repeats):
raise NotImplementedError()

def is_float_type(self, x):
# some backends (torch) can't compute average for non-floating types.
# Decided to drop average for all backends if type is not floating
Expand Down Expand Up @@ -137,6 +140,9 @@ def arange(self, start, stop):
def stack_on_zeroth_dimension(self, tensors: list):
return self.np.stack(tensors)

def tile(self, x, repeats):
return self.np.tile(x, repeats)

def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')

Expand Down Expand Up @@ -170,6 +176,9 @@ def arange(self, start, stop):
def stack_on_zeroth_dimension(self, tensors: list):
return self.mx.nd.stack(*tensors)

def tile(self, x, repeats):
return self.mx.nd.tile(x, repeats)

def is_float_type(self, x):
return 'float' in str(x.dtype)

Expand Down Expand Up @@ -225,6 +234,9 @@ def arange(self, start, stop):
def stack_on_zeroth_dimension(self, tensors: list):
return self.mx.symbol.stack(*tensors)

def tile(self, x, repeats):
return self.mx.symbol.tile(x, repeats)

def is_float_type(self, x):
return 'float' in str(x.infer_type()[1][0])

Expand Down Expand Up @@ -274,6 +286,9 @@ def transpose(self, x, axes):
def stack_on_zeroth_dimension(self, tensors: list):
return self.torch.stack(tensors)

def tile(self, x, repeats):
return x.repeat(repeats)

def is_float_type(self, x):
return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64]

Expand Down Expand Up @@ -304,6 +319,9 @@ def arange(self, start, stop):
def stack_on_zeroth_dimension(self, tensors: list):
return self.cupy.stack(tensors)

def tile(self, x, repeats):
return self.cupy.tile(x, repeats)

def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')

Expand Down Expand Up @@ -338,6 +356,9 @@ def reduce(self, x, operation, axes):
def stack_on_zeroth_dimension(self, tensors: list):
return self.chainer.functions.stack(tensors)

def tile(self, x, repeats):
return self.chainer.functions.tile(x, repeats)

def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')

Expand Down Expand Up @@ -394,6 +415,9 @@ def transpose(self, x, axes):
def stack_on_zeroth_dimension(self, tensors: list):
return self.tf.stack(tensors)

def tile(self, x, repeats):
return self.tf.tile(x, repeats)

def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')

Expand Down Expand Up @@ -436,6 +460,9 @@ def transpose(self, x, axes):
def stack_on_zeroth_dimension(self, tensors: list):
return self.K.stack(tensors)

def tile(self, x, repeats):
return self.K.tile(x, repeats)

def is_float_type(self, x):
return 'float' in self.K.dtype(x)

Expand Down
2 changes: 1 addition & 1 deletion einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def rearrange(tensor, pattern, **axes_lengths):
:return: tensor of the same type as input. If possible, a view to the original tensor is returned.
When composing axes, C-order enumeration used (consecutive elements have different last axis)
More examples and explanations can be found in the einops guide.
More source_examples and explanations can be found in the einops guide.
"""
if isinstance(tensor, list):
if len(tensor) == 0:
Expand Down
4 changes: 2 additions & 2 deletions einops/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class RearrangeMixin:
:param pattern: str, rearrangement pattern
:param axes_lengths: any additional specification of dimensions
See einops.rearrange for examples.
See einops.rearrange for source_examples.
"""

def __init__(self, pattern, **axes_lengths):
Expand Down Expand Up @@ -50,7 +50,7 @@ class ReduceMixin:
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
:param axes_lengths: any additional specification of dimensions
See einops.reduce for examples.
See einops.reduce for source_examples.
"""

def __init__(self, pattern, reduction, **axes_lengths):
Expand Down
39 changes: 38 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def unet_like_1d(x, usual_convolution):
]

for backend in imp_op_backends:
print('testing examples for ', backend.framework_name)
print('testing source_examples for ', backend.framework_name)
for test in tests:
x = numpy.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40])
result1 = test(x)
Expand Down Expand Up @@ -585,3 +585,40 @@ def new_way(input, num_classes, num_anchors, anchors, stride_h, stride_w):
stride_h=stride_h, stride_w=stride_w, anchors=anchors)
result1 = result1.reshape(result2.shape)
assert torch.allclose(result1, result2)


def test_tiling_imperatives():
for backend in imp_op_backends:
print('Tiling tests for ', backend.framework_name)
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
test_cases = [
(1, 1, 1, 1, 1),
(1, 2, 1, 3, 1),
(3, 1, 1, 4, 1),
]
for repeats in test_cases:
expected = numpy.tile(input, repeats)
converted = backend.from_numpy(input)
repeated = backend.tile(converted, repeats)
result = backend.to_numpy(repeated)
assert numpy.array_equal(result, expected)


def test_tiling_symbolic():
for backend in sym_op_backends:
print('Tiling tests for ', backend.framework_name)
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
test_cases = [
(1, 1, 1, 1, 1),
(1, 2, 1, 3, 1),
(3, 1, 1, 4, 1),
]
for repeats in test_cases:
expected = numpy.tile(input, repeats)
sym = backend.create_symbol(input.shape)
result = backend.eval_symbol(backend.tile(sym, repeats), [[sym, input]])
assert numpy.array_equal(result, expected)

sym = backend.create_symbol([None] * len(input.shape))
result = backend.eval_symbol(backend.tile(sym, repeats), [[sym, input]])
assert numpy.array_equal(result, expected)

0 comments on commit 28a3fe6

Please sign in to comment.