diff --git a/stellargraph/layer/appnp.py b/stellargraph/layer/appnp.py index 26d76b532..103cc4bec 100644 --- a/stellargraph/layer/appnp.py +++ b/stellargraph/layer/appnp.py @@ -15,7 +15,7 @@ # limitations under the License. import warnings -from tensorflow.keras.layers import Dense, Lambda, Dropout, Input, Layer +from tensorflow.keras.layers import Dense, Lambda, Dropout, Input, Layer, InputLayer import tensorflow.keras.backend as K from ..mapper import FullBatchGenerator @@ -434,7 +434,10 @@ def propagate_model(self, base_model): "APPNP does not currently support propagating a link model" ) - return self._tensors(multiplicity=1, feature_layers=base_model.layers[1:]) + feature_layers = [ + layer for layer in base_model.layers if not isinstance(layer, InputLayer) + ] + return self._tensors(multiplicity=1, feature_layers=feature_layers) node_model = deprecated_model_function(_node_model, "node_model") link_model = deprecated_model_function(_link_model, "link_model") diff --git a/tests/layer/test_appnp.py b/tests/layer/test_appnp.py index cf98e5323..0043e2980 100644 --- a/tests/layer/test_appnp.py +++ b/tests/layer/test_appnp.py @@ -220,18 +220,7 @@ def test_APPNP_apply_propagate_model_dense(): assert preds_1 == pytest.approx(preds_2) -@pytest.mark.parametrize( - "model_type", - [ - pytest.param( - "sequential", - marks=pytest.mark.xfail( - reason="https://github.com/stellargraph/stellargraph/issues/1213" - ), - ), - "model", - ], -) +@pytest.mark.parametrize("model_type", ["sequential", "model"]) def test_APPNP_propagate_model_matches_manual(model_type): dense_size = 5