Skip to content

Commit

Permalink
Replace ConvolutionalGradient with FiniteDifferenct(..., circular=True)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Sep 21, 2021
1 parent eb995bb commit f6ccbb0
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 71 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/admm_tv_circ_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
f = loss.SquaredL2Loss(y=y, A=A)
# Penalty parameters must be accounted for in the gi functions, not as additional inputs
g = λ * functional.L21Norm() # Regularization functionals gi
C = linop.ConvolutionalGradient(x_gt.shape)
C = linop.FiniteDifference(x_gt.shape, circular=True)
solver = ADMM(
f=f,
g_list=[g],
Expand Down
3 changes: 1 addition & 2 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._matrix import MatrixOperator
from ._diff import FiniteDifference
from ._convolve import Convolve, ConvolveByX
from ._circconv import CircularConvolve, ConvolutionalGradient
from ._circconv import CircularConvolve
from ._dft import DFT
from ._stack import LinearOperatorStack

Expand All @@ -28,7 +28,6 @@
"FiniteDifference",
"Convolve",
"CircularConvolve",
"ConvolutionalGradient",
"DFT",
"LinearOperatorStack",
"Sum",
Expand Down
40 changes: 0 additions & 40 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,43 +286,3 @@ def _gradient_filters(ndim: int, axes: Shape, shape: Shape, dtype: DType = snp.f
Gf = snp.fft.fftn(g, shape, axes=fft_axes)

return Gf


class ConvolutionalGradient(CircularConvolve):
"""A gradient linear operator computed in the DFT domain."""

def __init__(
self,
input_shape: Shape,
grad_axes: Optional[Axes] = None,
input_dtype: DType = snp.float32,
jit: bool = True,
**kwargs,
):
"""
Args:
input_shape: Shape of input array
grad_axes: Indices of the dimensions in the input array on which gradient is to be computed. By default, gradients are evaluated along all dimensions.
input_dtype: `dtype` for input argument. Defaults to `float32`.
jit: If `True`, jit the evaluation, adjoint, and gram functions of the LinearOperator
"""

# Handle grad_axes default
self.grad_axes = parse_axes(grad_axes, input_shape)

# Shape of gradient axes
shape = tuple([input_shape[a] for a in self.grad_axes])
# Array of DFT-domain gradient filters
Gf = _gradient_filters(
ndim=len(input_shape), axes=self.grad_axes, shape=shape, dtype=input_dtype
)

super().__init__(
h=Gf,
input_shape=input_shape,
input_dtype=input_dtype,
ndims=len(input_shape),
h_is_dft=True,
jit=jit,
**kwargs,
)
27 changes: 1 addition & 26 deletions scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import scico.numpy as snp
from scico.linop import CircularConvolve, ConvolutionalGradient, Convolve
from scico.linop import CircularConvolve, Convolve
from scico.random import randint, randn, uniform
from scico.test.linop.test_linop import adjoint_AAt_test, adjoint_AtA_test

Expand Down Expand Up @@ -138,28 +138,3 @@ def test_from_operator(self, axes_shape_spec, input_dtype, jit_old_op, jit_new_o
B = CircularConvolve.from_operator(A, ndims, jit=jit_new_op)

np.testing.assert_allclose(A @ x, B @ x, atol=1e-5)


class TestConvolutionalGradient:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize(
"axes_shape_spec",
[
((12, 12), None), # 2d
((12, 12), 0), # 2d specific axis
((15, 12, 12), None), # 3d
((15, 12, 12), [0, 2]), # 3d specific axes
],
)
def test_eval(self, axes_shape_spec, input_dtype):

input_shape, axes = axes_shape_spec
x, key = randn(tuple(input_shape), dtype=input_dtype, key=self.key)
A = ConvolutionalGradient(input_shape, axes, input_dtype)
Ax = A @ x

for ax in A.grad_axes:
np.testing.assert_allclose(x - np.roll(x, 1, ax), Ax[ax], atol=1e-5, rtol=0)
4 changes: 2 additions & 2 deletions scico/test/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def setup_method(self, method):
λ = 1e-2
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g_list = [λ * functional.L1Norm()]
self.C_list = [linop.ConvolutionalGradient(input_shape=x.shape)]
self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)]

def test_admm(self):
maxiter = 50
Expand Down Expand Up @@ -201,5 +201,5 @@ def test_admm(self):
subproblem_solver=CircularConvolveSolver(),
)
x_dft = admm_dft.solve()
np.testing.assert_allclose(x_dft, x_lin, atol=1e-5, rtol=0)
np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)
assert metric.mse(x_lin, x_dft) < 1e-9

0 comments on commit f6ccbb0

Please sign in to comment.