Skip to content

Commit

Permalink
Refactor, remove unnecessary flags, switched to absl.logging and absl…
Browse files Browse the repository at this point in the history
….flag
  • Loading branch information
amir-abdi committed Oct 10, 2018
1 parent 4b89d1e commit dbae300
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 820 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
.ipynb_checkpoints
*.h5
*.hdf5
*.json
*.pb
*.yml
*.ckpt
Expand Down
286 changes: 142 additions & 144 deletions keras_to_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@

# coding: utf-8

# In[ ]:

"""
Copyright (c) 2017, by the Authors: Amir H. Abdi
This software is freely available under the MIT Public License.
Expand All @@ -12,149 +7,152 @@
which is saved using model.save('kerasmodel_weight_file'),
to the freezed .pb tensorflow weight file which holds both the
network architecture and its associated weights.
""";


# In[ ]:

'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc.,
and it is mostly equal to 1. If the network is a **multi-stream network**
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''


# Parse input arguments

# In[ ]:

import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)

if args.theano_backend is True and args.quantize is True:
raise ValueError("Quantize feature does not work with theano backend.")


# initialize

# In[ ]:
"""

from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from absl import app
from absl import flags
from absl import logging
import keras
from keras import backend as K

output_fld = args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)


# Load keras model and rename output

# In[ ]:

K.set_learning_phase(0)
if args.theano_backend:
K.set_image_data_format('channels_first')
else:
K.set_image_data_format('channels_last')

try:
net_model = load_model(weight_file_path)
except ValueError as err:
print('''Input file specified ({}) only holds the weights, and not the model defenition.
Save the model using mode.save(filename.h5) which will contain the network architecture
as well as its weights.
If the model is saved using model.save_weights(filename.h5), the model architecture is
expected to be saved separately in a json format and loaded prior to loading the weights.
Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
.format(weight_file_path))
raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
pred_node_names[i] = args.output_node_prefix+str(i)
pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
f = args.output_graphdef_file
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from keras.models import model_from_json
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))

K.set_learning_phase(0)
FLAGS = flags.FLAGS

flags.DEFINE_string('input_model', None, 'Path to the input model.')
flags.DEFINE_string('input_model_json', None, 'Path to the input model'
'architecture in json format.')
flags.DEFINE_string('output_model', None, 'Path where the converted model will'
'be stored.')
# flags.DEFINE_integer('num_output_nodes', 1,
# 'Number of outputs the network produces. Most networks '
# 'has only one output, while others can have a forked '
# 'architecture and produce multiple outputs.')
flags.DEFINE_boolean('save_graph_def', False,
'Whether to save the graphdef.pbtxt file which contains '
'the graph definition in ASCII format.')
flags.DEFINE_string('output_nodes_prefix', None,
'If set, the output nodes will be renamed to '
'`output_nodes_prefix`+i, where `i` will numerate the '
'number of of output nodes of the network.')
flags.DEFINE_boolean('quantize', False,
'If set, the resultant TensorFlow graph weights will be '
'converted from float into eight-bit equivalents. See '
'documentation here: '
'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms')
flags.DEFINE_boolean('channels_first', False,
'Whether channels are the first dimension of a tensor. '
'The default is TensorFlow behaviour where channels are '
'the last dimension.')

flags.mark_flag_as_required('input_model')
flags.mark_flag_as_required('output_model')


def load_model(input_model_path, input_json_path):
model = None
# TODO(amirabdi): enable loading weights and json files separately
try:
model = keras.models.load_model(input_model_path)
return model
except FileNotFoundError as err:
logging.error('Input mode file (%s) does not exist.', FLAGS.input_model)
raise err
except ValueError as wrong_file_err:
if input_json_path and Path(input_json_path).exists():
try:
model = model_from_json(open(str(input_json_path)).read())
model.load_weights(input_model_path)
return model
except Exception as err:
logging.error("Couldn't load model from json.")
raise err
else:
logging.error(
'Input file specified only holds the weights, and not '
'the model definition. Save the model using '
'model.save(filename.h5) which will contain the network '
'architecture as well as its weights. If the model is '
'saved using model.save_weights(filename), the flag '
'input_model_json should also be set to the '
'architecture which is exported separately in a '
'json format. Check the keras documentation for more details '
'(https://keras.io/getting-started/faq/)')
raise wrong_file_err


def main(args):
# If output_model path is relative, make it absolute
output_model = FLAGS.output_model
if str(Path(output_model).parent) == '.':
output_model = str((Path.cwd() / output_model))

output_fld = Path(output_model).parent
output_model_pbtxt_name = Path(output_model).stem + '.pbtxt'
output_model_name = Path(output_model).name

# Create output directory if it does not exist
Path(output_model).parent.mkdir(parents=True, exist_ok=True)

if FLAGS.channels_first:
K.set_image_data_format('channels_first')
else:
K.set_image_data_format('channels_last')

model = load_model(FLAGS.input_model, FLAGS.input_model_json)

# TODO(amirabdi): Support networks with multiple inputs


orig_output_node_names = [node.op.name for node in model.outputs]
if FLAGS.output_nodes_prefix:
num_output = len(orig_output_node_names)
pred = [None] * num_output
converted_output_node_names = [None] * num_output

# Create dummy tf nodes to rename output
for i in range(num_output):
converted_output_node_names[i] = '{}{}'.format(
FLAGS.output_nodes_prefix, i)
pred[i] = tf.identity(model.outputs[i],
name=converted_output_node_names[i])
else:
converted_output_node_names = orig_output_node_names
logging.info('Converted output node names are: %s',
str(converted_output_node_names))

sess = K.get_session()
if FLAGS.save_graph_def:
tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld),
output_model_pbtxt_name, as_text=True)
logging.info('Saved the graph definition in ascii format at %s',
str(Path(output_fld) / output_model_pbtxt_name))

if FLAGS.quantize:
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
converted_output_node_names,
transforms)
constant_graph = graph_util.convert_variables_to_constants(
sess,
transformed_graph_def,
converted_output_node_names)
else:
constant_graph = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
converted_output_node_names)

graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
as_text=False)
logging.info('Saved the freezed graph at %s',
str(Path(output_fld) / output_model_name))


if __name__ == "__main__":
app.run(main)
Loading

0 comments on commit dbae300

Please sign in to comment.