Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading model fails: can only concatenate tuple #19776

Closed
sebplorenz opened this issue May 29, 2024 · 5 comments
Closed

Loading model fails: can only concatenate tuple #19776

sebplorenz opened this issue May 29, 2024 · 5 comments

Comments

@sebplorenz
Copy link

sebplorenz commented May 29, 2024

Hi,
I'm trying to save and load the model from this example: https://keras.io/examples/rl/deep_q_network_breakout/

Saving the model works. When I load the model I'm getting the following error:

Exception encountered: Exception encountered when calling Lambda.call().

can only concatenate tuple (not "TrackedList") to tuple

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 4, 84, 84), dtype=float32, sparse=False, name=input_layer>',)
  • kwargs={'mask': 'None'}

I've created a small script to reproduce:

import keras

def create_q_model():
    # Network defined by the Deepmind paper
    return keras.Sequential(
        [
            keras.layers.Lambda(
                lambda tensor: keras.ops.transpose(tensor, [0, 2, 3, 1]),
                output_shape=(84, 84, 4),
                input_shape=(4, 84, 84),
            ),
            # Convolutions on the frames on the screen
            keras.layers.Conv2D(32, 8, strides=4, activation="relu", input_shape=(4, 84, 84)),
            keras.layers.Conv2D(64, 4, strides=2, activation="relu"),
            keras.layers.Conv2D(64, 3, strides=1, activation="relu"),
            keras.layers.Flatten(),
            keras.layers.Dense(512, activation="relu"),
            keras.layers.Dense(4, activation="linear"),
        ]
    )

def save(model):
    keras.saving.save_model(model, './my_model.keras')

def load():
    return keras.models.load_model('my_model.keras', safe_mode=False)

model = create_q_model()
save(model)

# the following line fails
loaded_model = load()
Grvzard added a commit to Grvzard/keras that referenced this issue Jun 1, 2024
Grvzard added a commit to Grvzard/keras that referenced this issue Jun 2, 2024
fchollet pushed a commit that referenced this issue Jun 2, 2024
* Fix `output_shape` deserialization in Lambda Layer (#19776)

* Fix `output_shape` deserialization in Lambda Layer (#19776)
@stylesaniswi
Copy link

can only concatenate tuple (not "TrackedList") to tuple

Arguments received by Lambda.call():
  • args=('<KerasTensor shape=(None, 48, 48, 1), dtype=float32, sparse=False, name=input_layer_4>',)
  • kwargs={'mask': 'None'}``

I have model which works on previous version of keras but due to gpu access constraints, I need to stick with keras 3x and tf 2.16
TBH ,I am naive . Help me resolve the problem. This is my model which i saved but while loading it gives problem:

`
def create_model(norm, start_filters=32):
    ir069 = tf.keras.layers.Input(shape=(192, 192, 1)) 
    ir107 = tf.keras.layers.Input(shape=(192, 192, 1)) 
    lght = tf.keras.layers.Input(shape=(48, 48, 1))
    inputs = [ir069, ir107, lght]

    # Normalize inputs
    ir069_norm = tf.keras.layers.Lambda(lambda x, mu, scale: (x - mu) / scale,
                                        arguments={'mu': norm['ir069']['shift'], 'scale': norm['ir069']['scale']},
                                        output_shape=(192, 192, 1))(ir069)
    ir107_norm = tf.keras.layers.Lambda(lambda x, mu, scale: (x - mu) / scale,
                                        arguments={'mu': norm['ir107']['shift'], 'scale': norm['ir107']['scale']},
                                        output_shape=(192, 192, 1))(ir107)
    lght_norm = tf.keras.layers.Lambda(lambda x, mu, scale: (x - mu) / scale,
                                       arguments={'mu': norm['lght']['shift'], 'scale': norm['lght']['scale']},
                                       output_shape=(48, 48, 1))(lght)

    # Reshape lght into 192
    lght_res = tf.keras.layers.Lambda(lambda t: tf.image.resize(t, (192, 192)), output_shape=(192, 192, 1))(lght_norm)

    # Concatenate all inputs
    x_inp = tf.keras.layers.Concatenate(axis=-1)([ir069_norm, ir107_norm, lght_res])

    encoder0_pool, encoder0 = encoder_block(x_inp, start_filters)
    encoder1_pool, encoder1 = encoder_block(encoder0_pool, start_filters * 2)
    encoder2_pool, encoder2 = encoder_block(encoder1_pool, start_filters * 4)
    encoder3_pool, encoder3 = encoder_block(encoder2_pool, start_filters * 8)
    center = conv_block(encoder3_pool, start_filters * 32)
    decoder3 = decoder_block(center, encoder3, start_filters * 8)
    decoder2 = decoder_block(decoder3, encoder2, start_filters * 6)
    decoder1 = decoder_block(decoder2, encoder1, start_filters * 4)
    decoder0 = decoder_block(decoder1, encoder0, start_filters * 2)
    decoder00 = decoder_block(decoder0, None, start_filters)
    output = layers.Conv2D(1, (1, 1), padding='same', activation='linear', name='output_layer')(decoder00)

    
    return inputs , output`

james77777778 pushed a commit to james77777778/keras that referenced this issue Jun 15, 2024
…keras-team#19785)

* Fix `output_shape` deserialization in Lambda Layer (keras-team#19776)

* Fix `output_shape` deserialization in Lambda Layer (keras-team#19776)
@dhantule
Copy link
Contributor

Hi @sebplorenz, thanks for reporting this. This issue doesn't seem to generate any warning in latest keras version. Attaching gist

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Dec 11, 2024
Copy link

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants