Skip to content

Commit

Permalink
add checks for 3d convolutions alternative optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
affanv14 committed Aug 25, 2017
1 parent 9592125 commit 4982e94
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
7 changes: 4 additions & 3 deletions theano/gpuarray/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3266,6 +3266,7 @@ def local_abstractconv3d_cudnn_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
precision = get_precision(None, [inp1, inp2])

if node.op.filter_flip:
Expand All @@ -3274,7 +3275,7 @@ def local_abstractconv3d_cudnn_alt(node):
conv_mode = 'cross'

if isinstance(op, AbstractConv3d):
if border_mode == 'half' or subsample != (1, 1, 1):
if border_mode == 'half' or subsample != (1, 1, 1) or num_groups > 1:
return None
if border_mode == 'full':
direction_hint = 'bprop inputs'
Expand All @@ -3292,7 +3293,7 @@ def local_abstractconv3d_cudnn_alt(node):

elif isinstance(op, AbstractConv3d_gradWeights):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1)):
filter_dilation == (1, 1, 1) and num_groups == 1):
img = gpu_contiguous(inp1)
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(img, topgrad)
Expand Down Expand Up @@ -3323,7 +3324,7 @@ def local_abstractconv3d_cudnn_alt(node):
return None

elif isinstance(op, AbstractConv3d_gradInputs):
if border_mode == 'valid' and subsample == (1, 1, 1):
if border_mode == 'valid' and subsample == (1, 1, 1) and num_groups == 1:
kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4))
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(kerns, topgrad)
Expand Down
18 changes: 13 additions & 5 deletions theano/gpuarray/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,8 +1842,10 @@ def local_abstractconv3d_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups

if ((border_mode == 'full') and (subsample == (1, 1, 1))):
if((border_mode == 'full') and (subsample == (1, 1, 1)) and
(num_groups == 1)):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
kern = kern.dimshuffle(1, 0, 2, 3, 4)
Expand All @@ -1853,7 +1855,7 @@ def local_abstractconv3d_alt(node):
gpu_contiguous(kern), gpu_contiguous(img))

elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
border_mode == 'valid'):
border_mode == 'valid' and num_groups == 1):
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM_gradWeights(border_mode,
Expand Down Expand Up @@ -1881,8 +1883,10 @@ def local_abstractconv3d2d(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups

if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
if(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
num_groups == 1):
reorder_array = [0, 2, 1, 3, 4]
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)),
gpu_contiguous(kern.dimshuffle(*reorder_array)),
Expand Down Expand Up @@ -1968,8 +1972,10 @@ def local_abstractconv3d_gemm_gradweights_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups

if border_mode == 'valid' and subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1) and num_groups == 1):
rval = GpuCorr3dMM(border_mode,
subsample,
filter_dilation)(
Expand Down Expand Up @@ -2091,8 +2097,10 @@ def local_abstractconv3d_gradinputs_gemm_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups

if border_mode == 'valid' and subsample == (1, 1, 1):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
num_groups == 1):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM(border_mode='full',
Expand Down

0 comments on commit 4982e94

Please sign in to comment.