forked from amir-abdi/keras_to_tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
removed output_node_names_of_input_networ
- Loading branch information
Showing
2 changed files
with
181 additions
and
16 deletions.
There are no files selected for viewing
173 changes: 173 additions & 0 deletions
173
.ipynb_checkpoints/keras_to_tensorflow-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Set parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"input_fld = 'input_fld_path'\n", | ||
"weight_file = 'kerasmodel_file_name located inside input_fld'\n", | ||
"num_output = 1\n", | ||
"write_graph_def_ascii_flag = True\n", | ||
"prefix_output_node_names_of_final_network = 'output_node'\n", | ||
"output_graph_name = 'constant_graph_weights.pb'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# initialize" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using TensorFlow backend.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from keras.models import load_model\n", | ||
"import tensorflow as tf\n", | ||
"import os\n", | ||
"import os.path as osp\n", | ||
"\n", | ||
"output_fld = input_fld + 'tensorflow_model/'\n", | ||
"if not os.path.isdir(output_fld):\n", | ||
" os.mkdir(output_fld)\n", | ||
"weight_file_path = osp.join(input_fld, weight_file)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load keras model and rename output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"output nodes names are: ['output_node0']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"net_model = load_model(weight_file_path)\n", | ||
"\n", | ||
"pred = [None]*num_output\n", | ||
"pred_node_names = [None]*num_output\n", | ||
"for i in range(num_output):\n", | ||
" pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)\n", | ||
" pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])\n", | ||
"print('output nodes names are: ', pred_node_names)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### [optional] write graph definition in ascii" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"saved the graph definition in ascii format at: /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/only_the_graph_def.pb.ascii\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from keras import backend as K\n", | ||
"sess = K.get_session()\n", | ||
"\n", | ||
"if write_graph_def_ascii_flag:\n", | ||
" f = 'only_the_graph_def.pb.ascii'\n", | ||
" tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)\n", | ||
" print('saved the graph definition in ascii format at: ', osp.join(output_fld, f))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### convert variables to constants and save" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"INFO:tensorflow:Froze 27 variables.\n", | ||
"Converted 27 variables to const ops.\n", | ||
"saved the constant graph (ready for inference) at: /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/constant_graph_weights.pb\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from tensorflow.python.framework import graph_util\n", | ||
"from tensorflow.python.framework import graph_io\n", | ||
"constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)\n", | ||
"graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False)\n", | ||
"print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.5.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters