Skip to content

Commit

Permalink
Fix import of Keras (keras-team#2420)
Browse files Browse the repository at this point in the history
* Fix import of Keras

* Fix import of Keras

* Fix import of Keras

* Fix Keras2 import
  • Loading branch information
sampathweb authored Apr 23, 2024
1 parent c60112e commit 60ebdf6
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 66 deletions.
9 changes: 8 additions & 1 deletion keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2.
"""
from keras_cv.backend import config # noqa: E402
from keras_cv.backend import keras # noqa: E402

if config.keras_3():
import keras # noqa: E402

keras.backend.name_scope = keras.name_scope
else:
import keras_cv.backend.keras2 as keras # noqa: E402

from keras_cv.backend import ops # noqa: E402
from keras_cv.backend import random # noqa: E402
from keras_cv.backend import tf_ops # noqa: E402
Expand Down
65 changes: 0 additions & 65 deletions keras_cv/backend/keras.py

This file was deleted.

59 changes: 59 additions & 0 deletions keras_cv/backend/keras2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types

from tensorflow import keras # noqa: F403, F401
from tensorflow.keras import * # noqa: F403, F401

from keras_cv.backend import config # noqa: F403, F401

_KERAS_CORE_ALIASES = {
"utils->saving": [
"register_keras_serializable",
"deserialize_keras_object",
"serialize_keras_object",
"get_registered_object",
],
"models->saving": ["load_model"],
}

if not hasattr(keras, "saving"):
keras.saving = types.SimpleNamespace()

# add aliases
for key, value in _KERAS_CORE_ALIASES.items():
src, _, dst = key.partition("->")
src = src.split(".")
dst = dst.split(".")

src_mod, dst_mod = keras, keras

# navigate to where we want to alias the attributes
for mod in src:
src_mod = getattr(src_mod, mod)
for mod in dst:
dst_mod = getattr(dst_mod, mod)

# add an alias for each attribute
for attr in value:
if isinstance(attr, tuple):
src_attr, dst_attr = attr
else:
src_attr, dst_attr = attr, attr
attr_val = getattr(src_mod, src_attr)
setattr(dst_mod, dst_attr, attr_val)

# TF Keras doesn't have this rename.
keras.activations.silu = keras.activations.swish

0 comments on commit 60ebdf6

Please sign in to comment.