Skip to content

Public API to force load custom ops #1151

Closed
@guillaumekln

Description

Currently, it is inconvenient to load in Python a SavedModel that includes Addons custom ops. Consider the example below:

  • save.py
import tensorflow as tf
import tensorflow_addons as tfa

class Model(tf.keras.Model):
    @tf.function(input_signature=(tf.TensorSpec(shape=[None, 32], dtype=tf.float32),))
    def call(self, x):
        return tfa.activations.gelu(x)

model = Model()
tf.saved_model.save(model, '/tmp/model', signatures=model.call)
  • load.py
import tensorflow as tf
tf.saved_model.load("/tmp/model")

The load will fail because Addons custom ops are not registered to the TensorFlow runtime. This is expected as we first have to invoke tf.load_op_library on the custom ops.

However, with the new work on lazy loading #855 it got harder to force this op registration. For this model, the user should run the following which relies on internal APIs:

from tensorflow_addons.activations.gelu import _activation_so
_activation_so.ops

If the custom ops are not loaded during the main import (i.e. during import tensorflow_addons), then the package should expose a public API that registers all custom ops.

Any thoughts?

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions