From 4982e94dcdfec0088d1c735b1bdef9fcb0daf2db Mon Sep 17 00:00:00 2001 From: affanv14 Date: Fri, 25 Aug 2017 20:21:39 +0530 Subject: [PATCH] add checks for 3d convolutions alternative optimizers --- theano/gpuarray/dnn.py | 7 ++++--- theano/gpuarray/opt.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/theano/gpuarray/dnn.py b/theano/gpuarray/dnn.py index b7045e50a49..4b45b92e4c3 100644 --- a/theano/gpuarray/dnn.py +++ b/theano/gpuarray/dnn.py @@ -3266,6 +3266,7 @@ def local_abstractconv3d_cudnn_alt(node): border_mode = node.op.border_mode subsample = node.op.subsample filter_dilation = node.op.filter_dilation + num_groups = node.op.num_groups precision = get_precision(None, [inp1, inp2]) if node.op.filter_flip: @@ -3274,7 +3275,7 @@ def local_abstractconv3d_cudnn_alt(node): conv_mode = 'cross' if isinstance(op, AbstractConv3d): - if border_mode == 'half' or subsample != (1, 1, 1): + if border_mode == 'half' or subsample != (1, 1, 1) or num_groups > 1: return None if border_mode == 'full': direction_hint = 'bprop inputs' @@ -3292,7 +3293,7 @@ def local_abstractconv3d_cudnn_alt(node): elif isinstance(op, AbstractConv3d_gradWeights): if(border_mode == 'valid' and subsample == (1, 1, 1) and - filter_dilation == (1, 1, 1)): + filter_dilation == (1, 1, 1) and num_groups == 1): img = gpu_contiguous(inp1) topgrad = gpu_contiguous(inp2) ctx_name = infer_context_name(img, topgrad) @@ -3323,7 +3324,7 @@ def local_abstractconv3d_cudnn_alt(node): return None elif isinstance(op, AbstractConv3d_gradInputs): - if border_mode == 'valid' and subsample == (1, 1, 1): + if border_mode == 'valid' and subsample == (1, 1, 1) and num_groups == 1: kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4)) topgrad = gpu_contiguous(inp2) ctx_name = infer_context_name(kerns, topgrad) diff --git a/theano/gpuarray/opt.py b/theano/gpuarray/opt.py index 744c6926604..822c21bfa86 100644 --- a/theano/gpuarray/opt.py +++ b/theano/gpuarray/opt.py @@ -1842,8 +1842,10 @@ def local_abstractconv3d_alt(node): border_mode = node.op.border_mode subsample = node.op.subsample filter_dilation = node.op.filter_dilation + num_groups = node.op.num_groups - if ((border_mode == 'full') and (subsample == (1, 1, 1))): + if((border_mode == 'full') and (subsample == (1, 1, 1)) and + (num_groups == 1)): if not node.op.filter_flip: kern = kern[:, :, ::-1, ::-1, ::-1] kern = kern.dimshuffle(1, 0, 2, 3, 4) @@ -1853,7 +1855,7 @@ def local_abstractconv3d_alt(node): gpu_contiguous(kern), gpu_contiguous(img)) elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and - border_mode == 'valid'): + border_mode == 'valid' and num_groups == 1): if node.op.filter_flip: kern = kern[:, :, ::-1, ::-1, ::-1] rval = GpuCorr3dMM_gradWeights(border_mode, @@ -1881,8 +1883,10 @@ def local_abstractconv3d2d(node): border_mode = node.op.border_mode subsample = node.op.subsample filter_dilation = node.op.filter_dilation + num_groups = node.op.num_groups - if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1): + if(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and + num_groups == 1): reorder_array = [0, 2, 1, 3, 4] rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)), gpu_contiguous(kern.dimshuffle(*reorder_array)), @@ -1968,8 +1972,10 @@ def local_abstractconv3d_gemm_gradweights_alt(node): border_mode = node.op.border_mode subsample = node.op.subsample filter_dilation = node.op.filter_dilation + num_groups = node.op.num_groups - if border_mode == 'valid' and subsample == (1, 1, 1) and filter_dilation == (1, 1, 1): + if(border_mode == 'valid' and subsample == (1, 1, 1) and + filter_dilation == (1, 1, 1) and num_groups == 1): rval = GpuCorr3dMM(border_mode, subsample, filter_dilation)( @@ -2091,8 +2097,10 @@ def local_abstractconv3d_gradinputs_gemm_alt(node): border_mode = node.op.border_mode subsample = node.op.subsample filter_dilation = node.op.filter_dilation + num_groups = node.op.num_groups - if border_mode == 'valid' and subsample == (1, 1, 1): + if(border_mode == 'valid' and subsample == (1, 1, 1) and + num_groups == 1): if not node.op.filter_flip: kern = kern[:, :, ::-1, ::-1, ::-1] rval = GpuCorr3dMM(border_mode='full',