From b0ad7f759a234748619845863d27c07d8dcadb00 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Tue, 18 Sep 2018 09:04:28 +0300 Subject: [PATCH] Import tensorflow to create text graphs if import cv is failed --- samples/dnn/tf_text_graph_common.py | 23 +++++++++++++++++++++++ samples/dnn/tf_text_graph_faster_rcnn.py | 3 +-- samples/dnn/tf_text_graph_mask_rcnn.py | 3 +-- samples/dnn/tf_text_graph_ssd.py | 7 +++---- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/samples/dnn/tf_text_graph_common.py b/samples/dnn/tf_text_graph_common.py index d4d202e68297..564c572d58b4 100644 --- a/samples/dnn/tf_text_graph_common.py +++ b/samples/dnn/tf_text_graph_common.py @@ -302,3 +302,26 @@ def removeUnusedNodesAndAttrs(to_remove, graph_def): for i in reversed(range(len(node.input))): if node.input[i] in removedNodes: del node.input[i] + + +def writeTextGraph(modelPath, outputPath, outNodes): + try: + import cv2 as cv + + cv.dnn.writeTextGraph(modelPath, outputPath) + except: + import tensorflow as tf + from tensorflow.tools.graph_transforms import TransformGraph + + with tf.gfile.FastGFile(modelPath, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order']) + + for node in graph_def.node: + if node.op == 'Const': + if 'value' in node.attr: + del node.attr['value'] + + tf.train.write_graph(graph_def, "", outputPath, as_text=True) diff --git a/samples/dnn/tf_text_graph_faster_rcnn.py b/samples/dnn/tf_text_graph_faster_rcnn.py index 368edaa6fa59..a6db8dcd4a6a 100644 --- a/samples/dnn/tf_text_graph_faster_rcnn.py +++ b/samples/dnn/tf_text_graph_faster_rcnn.py @@ -1,6 +1,5 @@ import argparse import numpy as np -import cv2 as cv from tf_text_graph_common import * @@ -42,7 +41,7 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath): print('Features stride: %f' % features_stride) # Read the graph. - cv.dnn.writeTextGraph(modelPath, outputPath) + writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes']) graph_def = parseTextGraph(outputPath) removeIdentity(graph_def) diff --git a/samples/dnn/tf_text_graph_mask_rcnn.py b/samples/dnn/tf_text_graph_mask_rcnn.py index b80a0fc4dfcc..b92d4623b8e5 100644 --- a/samples/dnn/tf_text_graph_mask_rcnn.py +++ b/samples/dnn/tf_text_graph_mask_rcnn.py @@ -1,6 +1,5 @@ import argparse import numpy as np -import cv2 as cv from tf_text_graph_common import * parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' @@ -48,7 +47,7 @@ print('Features stride: %f' % features_stride) # Read the graph. -cv.dnn.writeTextGraph(args.input, args.output) +writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks']) graph_def = parseTextGraph(args.output) removeIdentity(graph_def) diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py index eb4b33042cb9..5017dba7a74b 100644 --- a/samples/dnn/tf_text_graph_ssd.py +++ b/samples/dnn/tf_text_graph_ssd.py @@ -11,7 +11,6 @@ # See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API import argparse from math import sqrt -import cv2 as cv from tf_text_graph_common import * def createSSDGraph(modelPath, configPath, outputPath): @@ -52,12 +51,12 @@ def createSSDGraph(modelPath, configPath, outputPath): print('Input image size: %dx%d' % (image_width, image_height)) # Read the graph. - cv.dnn.writeTextGraph(modelPath, outputPath) - graph_def = parseTextGraph(outputPath) - inpNames = ['image_tensor'] outNames = ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'] + writeTextGraph(modelPath, outputPath, outNames) + graph_def = parseTextGraph(outputPath) + def getUnconnectedNodes(): unconnected = [] for node in graph_def.node: