Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Conv fusion optimizations in optimizeForIdeep #9255

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix test case of conv_fusion op
Signed-off-by: Gu, Jinghui <jinghui.gu@intel.com>
  • Loading branch information
gujinghui committed Jul 12, 2018
commit 1d801edde353b8f14b89bfcd820f2d1ec50290ff
54 changes: 52 additions & 2 deletions caffe2/python/ideep/conv_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import hypothesis.strategies as st
from hypothesis import given, settings
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
from caffe2.python.transformations import optimizeForIDEEP
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu

Expand Down Expand Up @@ -63,12 +65,60 @@ def test_depthwise_convolution(self, batch_size, gc, dc):
pad=0,
kernel=1,
group=4,
device_option=dc[0]
)
op1 = core.CreateOperator(
"Conv",
["X", "w", "b"],
["Y"],
stride=1,
pad=0,
kernel=1,
group=4,
device_option=dc[1]
)
X = np.random.rand(batch_size, 544, 14, 14).astype(np.float32)
w = np.random.rand(544, 136, 1, 1).astype(np.float32)
b = np.random.rand(544).astype(np.float32)
inputs = [X, w, b]
self.assertDeviceChecks(dc, op, inputs, [0])

workspace.SwitchWorkspace("_device_check_", True)
workspace.FeedBlob('X', X, dc[0])
workspace.FeedBlob('w', w, dc[0])
workspace.FeedBlob('b', b, dc[0])
workspace.RunOperatorOnce(op)
Y0 = workspace.FetchBlob('Y')

workspace.ResetWorkspace()
workspace.FeedBlob('X', X, dc[1])
workspace.FeedBlob('w', w, dc[1])
workspace.FeedBlob('b', b, dc[1])
net = core.Net("net")
old_net = caffe2_pb2.NetDef()
old_net.op.extend([op1])
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
workspace.RunOperatorOnce(net.Proto().op[0])
Y1 = workspace.FetchBlob('Y')

if not np.allclose(Y0, Y1, atol=0.01, rtol=0.01):
print(Y1.flatten())
print(Y0.flatten())
print(np.max(np.abs(Y1 - Y0)))
self.assertTrue(False)

workspace.ResetWorkspace()
workspace.FeedBlob('X', X, dc[1])
workspace.FeedBlob('w', w, dc[1])
workspace.FeedBlob('b', b, dc[1])
workspace.RunOperatorOnce(op1)
Y2 = workspace.FetchBlob('Y')

if not np.allclose(Y0, Y2, atol=0.01, rtol=0.01):
print(Y2.flatten())
print(Y0.flatten())
print(np.max(np.abs(Y2 - Y0)))
self.assertTrue(False)


if __name__ == "__main__":
unittest.main()
128 changes: 112 additions & 16 deletions caffe2/python/ideep/convfusion_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size,
device_option=dc[0]
)

# Manual fusion
# Manual fusion for Conv + ReLU
conv_fusion = core.CreateOperator(
"ConvFusion",
["X1", "w1", "b1"] if use_bias else ["X1", "w1"],
Expand All @@ -60,21 +60,6 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size,
device_option=dc[1]
)

# Auto fusion
old_net = caffe2_pb2.NetDef()
conv_old = caffe2_pb2.OperatorDef()
conv_old.CopyFrom(conv)
conv_old.device_option.CopyFrom(dc[1])
relu_old = caffe2_pb2.OperatorDef()
relu_old.CopyFrom(relu)
relu_old.device_option.CopyFrom(dc[1])
old_net.op.extend([conv_old, relu_old])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
self.assertTrue(len(net.Proto().op) == 1)
self.assertTrue(net.Proto().op[0].type == "ConvFusion")

X = np.random.rand(
batch_size, input_channels * group, size, size).astype(np.float32) - 0.5
w = np.random.rand(
Expand Down Expand Up @@ -103,10 +88,24 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size,
print(np.max(np.abs(Y1 - Y0)))
self.assertTrue(False)

# Auto fusion for Conv + ReLU
workspace.ResetWorkspace()
old_net = caffe2_pb2.NetDef()
conv_old = caffe2_pb2.OperatorDef()
conv_old.CopyFrom(conv)
conv_old.device_option.CopyFrom(dc[1])
relu_old = caffe2_pb2.OperatorDef()
relu_old.CopyFrom(relu)
relu_old.device_option.CopyFrom(dc[1])
old_net.op.extend([conv_old, relu_old])
workspace.FeedBlob('X0', X, dc[1])
workspace.FeedBlob('w0', w, dc[1])
workspace.FeedBlob('b0', b, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
self.assertTrue(len(net.Proto().op) == 1)
self.assertTrue(net.Proto().op[0].type == "ConvFusion")
workspace.RunOperatorOnce(net.Proto().op[0])
Y2 = workspace.FetchBlob('Y0')
if not np.allclose(Y0, Y2, atol=0.01, rtol=0.01):
Expand All @@ -130,6 +129,12 @@ def test_convolution_relu_fusion(self, stride, pad, kernel, size,
def test_convolution_sum_fusion(self, stride, pad, kernel, size,
input_channels, output_channels,
batch_size, use_bias, group, gc, dc):
relu_S0 = core.CreateOperator(
"Relu",
["S0"],
["S0"],
device_option=dc[0]
)
conv = core.CreateOperator(
"Conv",
["X0", "w0", "b0"] if use_bias else ["X0", "w0"],
Expand All @@ -146,6 +151,14 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size,
["S0"],
device_option=dc[0]
)

# Manual fusion for Conv + Sum
relu_S1 = core.CreateOperator(
"Relu",
["S1"],
["S1"],
device_option=dc[1]
)
conv_fusion = core.CreateOperator(
"ConvFusion",
["X1", "w1", "b1", "S1"] if use_bias else ["X1", "w1", "S1"],
Expand Down Expand Up @@ -173,6 +186,7 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size,
Y0 = workspace.FetchBlob('Y0')
S = np.random.rand(*Y0.shape).astype(np.float32) - 0.5
workspace.FeedBlob('S0', S, dc[0])
workspace.RunOperatorOnce(relu_S0)
workspace.RunOperatorOnce(sum)
S0 = workspace.FetchBlob('S0')

Expand All @@ -181,6 +195,7 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size,
workspace.FeedBlob('w1', w, dc[1])
workspace.FeedBlob('b1', b, dc[1])
workspace.FeedBlob('S1', S, dc[1])
workspace.RunOperatorOnce(relu_S1)
workspace.RunOperatorOnce(conv_fusion)
S1 = workspace.FetchBlob('S1')

Expand All @@ -189,6 +204,37 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size,
print(S0.flatten())
print(np.max(np.abs(S1 - S0)))
self.assertTrue(False)

# Auto fusion for Conv + Sum
workspace.ResetWorkspace()
old_net = caffe2_pb2.NetDef()
relu_S0_old = caffe2_pb2.OperatorDef()
relu_S0_old.CopyFrom(relu_S0)
relu_S0_old.device_option.CopyFrom(dc[1])
conv_old = caffe2_pb2.OperatorDef()
conv_old.CopyFrom(conv)
conv_old.device_option.CopyFrom(dc[1])
sum_old = caffe2_pb2.OperatorDef()
sum_old.CopyFrom(sum)
sum_old.device_option.CopyFrom(dc[1])
old_net.op.extend([relu_S0_old, conv_old, sum_old])
workspace.FeedBlob('X0', X, dc[1])
workspace.FeedBlob('w0', w, dc[1])
workspace.FeedBlob('b0', b, dc[1])
workspace.FeedBlob('S0', S, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
self.assertTrue(len(net.Proto().op) == 2)
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
workspace.RunNetOnce(net.Proto())
S2 = workspace.FetchBlob('S0')
if not np.allclose(S0, S2, atol=0.01, rtol=0.01):
print(S2.flatten())
print(S0.flatten())
print(np.max(np.abs(S2 - S0)))
self.assertTrue(False)

workspace.SwitchWorkspace(old_ws_name)

@given(stride=st.integers(1, 3),
Expand All @@ -204,6 +250,12 @@ def test_convolution_sum_fusion(self, stride, pad, kernel, size,
def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
input_channels, output_channels,
batch_size, use_bias, group, gc, dc):
relu_S0 = core.CreateOperator(
"Relu",
["S0"],
["S0"],
device_option=dc[0]
)
conv = core.CreateOperator(
"Conv",
["X0", "w0", "b0"] if use_bias else ["X0", "w0"],
Expand All @@ -226,6 +278,14 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
["S0"],
device_option=dc[0]
)

# Manual fusion for Conv + Sum + ReLU
relu_S1 = core.CreateOperator(
"Relu",
["S1"],
["S1"],
device_option=dc[1]
)
conv_fusion = core.CreateOperator(
"ConvFusion",
["X1", "w1", "b1", "S1"] if use_bias else ["X1", "w1", "S1"],
Expand Down Expand Up @@ -253,6 +313,7 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
Y0 = workspace.FetchBlob('Y0')
S = np.random.rand(*Y0.shape).astype(np.float32) - 0.5
workspace.FeedBlob('S0', S, dc[0])
workspace.RunOperatorOnce(relu_S0)
workspace.RunOperatorOnce(sum)
workspace.RunOperatorOnce(relu)
S0 = workspace.FetchBlob('S0')
Expand All @@ -262,6 +323,7 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
workspace.FeedBlob('w1', w, dc[1])
workspace.FeedBlob('b1', b, dc[1])
workspace.FeedBlob('S1', S, dc[1])
workspace.RunOperatorOnce(relu_S1)
workspace.RunOperatorOnce(conv_fusion)
S1 = workspace.FetchBlob('S1')

Expand All @@ -270,6 +332,40 @@ def test_convolution_sum_relu_fusion(self, stride, pad, kernel, size,
print(S0.flatten())
print(np.max(np.abs(S1 - S0)))
self.assertTrue(False)

# Auto fusion for Conv + Sum + ReLU
workspace.ResetWorkspace()
old_net = caffe2_pb2.NetDef()
relu_S0_old = caffe2_pb2.OperatorDef()
relu_S0_old.CopyFrom(relu_S0)
relu_S0_old.device_option.CopyFrom(dc[1])
conv_old = caffe2_pb2.OperatorDef()
conv_old.CopyFrom(conv)
conv_old.device_option.CopyFrom(dc[1])
sum_old = caffe2_pb2.OperatorDef()
sum_old.CopyFrom(sum)
sum_old.device_option.CopyFrom(dc[1])
relu_old = caffe2_pb2.OperatorDef()
relu_old.CopyFrom(relu)
relu_old.device_option.CopyFrom(dc[1])
old_net.op.extend([relu_S0_old, conv_old, sum_old, relu_old])
workspace.FeedBlob('X0', X, dc[1])
workspace.FeedBlob('w0', w, dc[1])
workspace.FeedBlob('b0', b, dc[1])
workspace.FeedBlob('S0', S, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
self.assertTrue(len(net.Proto().op) == 2)
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
workspace.RunNetOnce(net.Proto())
S2 = workspace.FetchBlob('S0')
if not np.allclose(S0, S2, atol=0.01, rtol=0.01):
print(S2.flatten())
print(S0.flatten())
print(np.max(np.abs(S2 - S0)))
self.assertTrue(False)

workspace.SwitchWorkspace(old_ws_name)

if __name__ == "__main__":
Expand Down