Skip to content

Commit

Permalink
removed output_node_names_of_input_networ
Browse files Browse the repository at this point in the history
  • Loading branch information
amir-abdi committed Jul 2, 2017
1 parent 491a5d1 commit 5e27593
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 16 deletions.
173 changes: 173 additions & 0 deletions .ipynb_checkpoints/keras_to_tensorflow-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set parameters"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"input_fld = 'input_fld_path'\n",
"weight_file = 'kerasmodel_file_name located inside input_fld'\n",
"num_output = 1\n",
"write_graph_def_ascii_flag = True\n",
"prefix_output_node_names_of_final_network = 'output_node'\n",
"output_graph_name = 'constant_graph_weights.pb'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# initialize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from keras.models import load_model\n",
"import tensorflow as tf\n",
"import os\n",
"import os.path as osp\n",
"\n",
"output_fld = input_fld + 'tensorflow_model/'\n",
"if not os.path.isdir(output_fld):\n",
" os.mkdir(output_fld)\n",
"weight_file_path = osp.join(input_fld, weight_file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load keras model and rename output"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"output nodes names are: ['output_node0']\n"
]
}
],
"source": [
"net_model = load_model(weight_file_path)\n",
"\n",
"pred = [None]*num_output\n",
"pred_node_names = [None]*num_output\n",
"for i in range(num_output):\n",
" pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)\n",
" pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])\n",
"print('output nodes names are: ', pred_node_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### [optional] write graph definition in ascii"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saved the graph definition in ascii format at: /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/only_the_graph_def.pb.ascii\n"
]
}
],
"source": [
"from keras import backend as K\n",
"sess = K.get_session()\n",
"\n",
"if write_graph_def_ascii_flag:\n",
" f = 'only_the_graph_def.pb.ascii'\n",
" tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)\n",
" print('saved the graph definition in ascii format at: ', osp.join(output_fld, f))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### convert variables to constants and save"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Froze 27 variables.\n",
"Converted 27 variables to const ops.\n",
"saved the constant graph (ready for inference) at: /home/amir/deep-batch/trained/KerasEchoQualityMultiStream/snapshots/freeze_uniformValid_bn/s/tensorflow_model/constant_graph_weights.pb\n"
]
}
],
"source": [
"from tensorflow.python.framework import graph_util\n",
"from tensorflow.python.framework import graph_io\n",
"constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)\n",
"graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False)\n",
"print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
24 changes: 8 additions & 16 deletions keras_to_tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"source": [
"input_fld = 'input_fld_path'\n",
"weight_file = 'kerasmodel_file_name located inside input_fld'\n",
"output_node_names_of_input_network = [\"pred0\"] \n",
"num_output = 1\n",
"write_graph_def_ascii_flag = True\n",
"output_node_names_of_final_network = 'output_node'\n",
"prefix_output_node_names_of_final_network = 'output_node'\n",
"output_graph_name = 'constant_graph_weights.pb'"
]
},
Expand All @@ -33,9 +33,7 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -67,9 +65,7 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand All @@ -82,11 +78,10 @@
"source": [
"net_model = load_model(weight_file_path)\n",
"\n",
"num_output = len(output_node_names_of_input_network)\n",
"pred = [None]*num_output\n",
"pred_node_names = [None]*num_output\n",
"for i in range(num_output):\n",
" pred_node_names[i] = output_node_names_of_final_network+str(i)\n",
" pred_node_names[i] = prefix_output_node_names_of_final_network+str(i)\n",
" pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])\n",
"print('output nodes names are: ', pred_node_names)"
]
Expand All @@ -101,9 +96,7 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -134,7 +127,6 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
Expand Down Expand Up @@ -173,9 +165,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.5.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 1
}

0 comments on commit 5e27593

Please sign in to comment.