Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Merge branch 'master' of https://github.com/onnx/keras-onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom committed Apr 7, 2020
2 parents 64b12c9 + 53a485a commit 7ad6a10
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,13 @@
list_input_tensors, list_input_mask, list_output_mask,
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)

ALLOWED_SHARED_KERAS_TYPES = {
keras.layers.embeddings.Embedding,
}

def _find_node(nodes, name):
try:
opname = tsname_to_node(name)
return next(n_ for n_ in nodes if n_.name == opname)
except StopIteration:
return None


def _locate_inputs_by_node(node_list, varset):
inputs = {}
for n_ in node_list:
Expand Down Expand Up @@ -402,9 +397,9 @@ def _create_keras_nodelist(layer, inference_nodeset, out_node=None):
if out_node is not None and out_node.name not in \
[tsname_to_node(ts_.name) for ts_ in list_output_tensors(node_)]:
continue # this layer could be reused several times in the whole graph.
if any(ts_.op not in inference_nodeset for ts_ in list_output_tensors(node_)):
continue
newly.extend([ts_.op for ts_ in list_output_tensors(node_)])
for ts_ in list_output_tensors(node_):
if ts_.op in inference_nodeset:
newly.extend([ts_.op for ts_ in list_output_tensors(node_)])
ts_end |= set(list_input_tensors(node_))

for ts_ in list_input_mask(layer):
Expand Down Expand Up @@ -573,7 +568,6 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
for n_ in model_outputs:
q_overall.put_nowait(n_)

visited_layers = set()
visited = set() # since the output could be shared among the successor nodes.
inference_nodeset = _build_inference_nodeset(graph, model_outputs)
keras_nodeset = _build_keras_nodeset(inference_nodeset, keras_node_dict)
Expand All @@ -586,13 +580,6 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
layer_key_, model_ = _parse_nodes(graph, inference_nodeset, input_nodes, keras_node_dict, keras_nodeset,
node, nodes, varset, visited, q_overall)

# Only parse Keras layers once (allow certain shared classes)
if layer_key_ in visited_layers:
if not type(layer_key_) in ALLOWED_SHARED_KERAS_TYPES:
continue
else:
visited_layers.add(layer_key_)

if not nodes: # already processed by the _parse_nodes
continue

Expand Down

0 comments on commit 7ad6a10

Please sign in to comment.