From ddab19955aafd1ac805eeeb262b607788bff838f Mon Sep 17 00:00:00 2001 From: bddppq Date: Sun, 10 Sep 2017 18:22:53 -0700 Subject: [PATCH] Python3 compatibility fixes (#4) --- onnx_caffe2/backend.py | 4 +++- onnx_caffe2/frontend.py | 2 +- tests/caffe2_ref_test.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnx_caffe2/backend.py b/onnx_caffe2/backend.py index 963d6ce..b25cd98 100644 --- a/onnx_caffe2/backend.py +++ b/onnx_caffe2/backend.py @@ -315,7 +315,9 @@ def prepare(cls, predict_graph, device='CPU', for init_tensor in predict_graph.initializer: workspace.FeedBlob(init_tensor.name, to_array(init_tensor)) workspace.RunNetOnce(init_net) - uninitialized = filter(lambda x:not workspace.HasBlob(x), predict_net.external_input) + uninitialized = [x + for x in predict_net.external_input + if not workspace.HasBlob(x)] return Caffe2Rep(init_net, predict_net, device, tmp_ws, uninitialized) @classmethod diff --git a/onnx_caffe2/frontend.py b/onnx_caffe2/frontend.py index f106a94..2e2ad30 100644 --- a/onnx_caffe2/frontend.py +++ b/onnx_caffe2/frontend.py @@ -21,7 +21,7 @@ _blacklist_caffe2_args = {'order', 'global_pooling'} # expected argument values -_expected_arg_values = {'order': ['NCHW'], 'global_pooling': [1]} +_expected_arg_values = {'order': [b'NCHW'], 'global_pooling': [1]} _renamed_args = { 'Squeeze': {'dims': 'axes'}, diff --git a/tests/caffe2_ref_test.py b/tests/caffe2_ref_test.py index 74b2a53..1ccc509 100644 --- a/tests/caffe2_ref_test.py +++ b/tests/caffe2_ref_test.py @@ -96,7 +96,7 @@ def test_initializer(self): name="test_initializer", inputs=["X", "Y", "weight"], outputs=["W"], - initializer=[helper.make_tensor("weight", onnx_pb2.TensorProto.FLOAT, [2, 2], weight.flatten())] + initializer=[helper.make_tensor("weight", onnx_pb2.TensorProto.FLOAT, [2, 2], weight.flatten().astype(float))] ) def sigmoid(x): return 1 / (1 + np.exp(-x))