Skip to content

Commit

Permalink
fixed bug (#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
kieranricardo authored Apr 30, 2020
2 parents 0c03f7c + 36fbd17 commit 8a831f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
7 changes: 5 additions & 2 deletions stellargraph/layer/appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import warnings
from tensorflow.keras.layers import Dense, Lambda, Dropout, Input, Layer
from tensorflow.keras.layers import Dense, Lambda, Dropout, Input, Layer, InputLayer
import tensorflow.keras.backend as K

from ..mapper import FullBatchGenerator
Expand Down Expand Up @@ -434,7 +434,10 @@ def propagate_model(self, base_model):
"APPNP does not currently support propagating a link model"
)

return self._tensors(multiplicity=1, feature_layers=base_model.layers[1:])
feature_layers = [
layer for layer in base_model.layers if not isinstance(layer, InputLayer)
]
return self._tensors(multiplicity=1, feature_layers=feature_layers)

node_model = deprecated_model_function(_node_model, "node_model")
link_model = deprecated_model_function(_link_model, "link_model")
Expand Down
13 changes: 1 addition & 12 deletions tests/layer/test_appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,7 @@ def test_APPNP_apply_propagate_model_dense():
assert preds_1 == pytest.approx(preds_2)


@pytest.mark.parametrize(
"model_type",
[
pytest.param(
"sequential",
marks=pytest.mark.xfail(
reason="https://github.com/stellargraph/stellargraph/issues/1213"
),
),
"model",
],
)
@pytest.mark.parametrize("model_type", ["sequential", "model"])
def test_APPNP_propagate_model_matches_manual(model_type):
dense_size = 5

Expand Down

0 comments on commit 8a831f9

Please sign in to comment.