Skip to content

Commit

Permalink
merge keras and tensorflow layers
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Sep 10, 2021
1 parent b928179 commit 3156a42
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 44 deletions.
42 changes: 6 additions & 36 deletions einops/layers/keras.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,9 @@
from tensorflow.keras.layers import Layer

from .._backends import UnknownSize
from . import RearrangeMixin, ReduceMixin

__author__ = 'Alex Rogozhnikov'

from ..layers.tensorflow import Rearrange, Reduce, WeightedEinsum

def _compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, added_axes, final_shape = \
self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape


class Rearrange(RearrangeMixin, Layer):
def compute_output_shape(self, input_shape):
return _compute_output_shape(self, input_shape)

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, **self.axes_lengths}


class Reduce(ReduceMixin, Layer):
def compute_output_shape(self, input_shape):
return _compute_output_shape(self, input_shape)

def call(self, inputs):
return self._apply_recipe(inputs)

def get_config(self):
return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths}


keras_custom_objects = {Rearrange.__name__: Rearrange, Reduce.__name__: Reduce}
keras_custom_objects = {
Rearrange.__name__: Rearrange,
Reduce.__name__: Reduce,
WeightedEinsum.__name__: WeightedEinsum,
}
21 changes: 13 additions & 8 deletions einops/layers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from typing import List, Optional

import tensorflow as tf
from tensorflow.keras.layers import Layer

from .._backends import UnknownSize
from . import RearrangeMixin, ReduceMixin
from ._weighted_einsum import WeightedEinsumMixin
from ..einops import TransformRecipe, _reconstruct_from_shape_uncached

__author__ = 'Alex Rogozhnikov'


def _compute_output_shape(recipe: TransformRecipe, input_shape) -> List[Optional[int]]:
input_shape = [UnknownSize() if d is None else int(d) for d in input_shape]
init_shapes, reduced_axes, axes_reordering, added_axes, final_shape = \
_reconstruct_from_shape_uncached(recipe, input_shape)
output_shape: List[Optional[int]] = [None if isinstance(d, UnknownSize) else int(d) for d in final_shape]
return output_shape


class Rearrange(RearrangeMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d.value is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape
return _compute_output_shape(self.recipe(), input_shape)

def call(self, inputs):
return self._apply_recipe(inputs)
Expand All @@ -24,10 +32,7 @@ def get_config(self):

class Reduce(ReduceMixin, Layer):
def compute_output_shape(self, input_shape):
input_shape = tuple(UnknownSize() if d.value is None else int(d) for d in input_shape)
init_shapes, reduced_axes, axes_reordering, final_shape = self.recipe().reconstruct_from_shape(input_shape)
final_shape = tuple(None if isinstance(d, UnknownSize) else int(d) for d in final_shape)
return final_shape
return _compute_output_shape(self.recipe(), input_shape)

def call(self, inputs):
return self._apply_recipe(inputs)
Expand Down

0 comments on commit 3156a42

Please sign in to comment.