Skip to content

Commit

Permalink
implemented composition/decomposition for backends
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Sep 19, 2021
1 parent 8a7c834 commit a15dd73
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
16 changes: 16 additions & 0 deletions einops/layers/chainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Dict

import chainer

from . import RearrangeMixin, ReduceMixin
Expand Down Expand Up @@ -26,6 +28,20 @@ def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound)
else:
self.bias = None

def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern)

def __call__(self, input):
result = chainer.functions.einsum(self.einsum_pattern, input, self.weight)
if self.bias is not None:
Expand Down
14 changes: 13 additions & 1 deletion einops/layers/gluon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Dict

import mxnet

from . import RearrangeMixin, ReduceMixin
Expand Down Expand Up @@ -30,8 +32,18 @@ def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound)
else:
self.bias = None

def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict]):
raise NotImplementedError("WeightedEinsum in mxnet/gluon doesn't support axis merge/unmerge "
"because einsum in gluon defined only for mx.np.ndarrays")

def hybrid_forward(self, F, x, *args, **kwargs):
result = mxnet.np.einsum(self.einsum_pattern, x, self.weight.data())
# mxnet.np can't work with 'usual' ndarrays; .data() is a standard way to get within in gluon
# .as_np_mndarray makes the necessary conversion
result = mxnet.np.einsum(self.einsum_pattern, x.as_np_ndarray(), self.weight.data())
if self.bias is not None:
result += self.bias.data()
return result
16 changes: 15 additions & 1 deletion einops/layers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Dict

import tensorflow as tf
from tensorflow.keras.layers import Layer
Expand Down Expand Up @@ -51,6 +51,20 @@ def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound)
else:
self.bias = None

def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)

def build(self, input_shape):
pass

Expand Down
20 changes: 20 additions & 0 deletions einops/layers/torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Dict

import torch

from . import RearrangeMixin, ReduceMixin
Expand Down Expand Up @@ -35,8 +37,26 @@ def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound)
else:
self.bias = None

def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)

self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)

def forward(self, input):
if self.pre_rearrange is not None:
input = self.pre_rearrange(input)
result = torch.einsum(self.einsum_pattern, input, self.weight)
if self.bias is not None:
result += self.bias
if self.post_rearrange is not None:
result = self.post_rearrange(result)
return result

0 comments on commit a15dd73

Please sign in to comment.