Skip to content

Commit

Permalink
Add barrel/pincushion correction support (labsyspharm#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmuhlich authored Mar 29, 2023
1 parent 48119b7 commit 7aded93
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 15 deletions.
19 changes: 19 additions & 0 deletions ashlar/reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import zarr
from . import utils
from . import thumbnail
from . import transform
from . import __version__ as _version


Expand Down Expand Up @@ -408,6 +409,24 @@ def read(self, series, c):
return img


class BarrelCorrectionReader(Reader):
"""Wraps a reader to correct barrel/pincushion image distortion."""

def __init__(self, reader, k):
self.reader = reader
self.k = k

@property
def metadata(self):
return self.reader.metadata

def read(self, series, c):
img = self.reader.read(series, c)
img = transform.barrel_correction(img, self.k)
img = utils.dtype_convert(img, self.metadata.pixel_dtype)
return img


class CachingReader(Reader):
"""Wraps a reader to provide tile image caching."""

Expand Down
28 changes: 17 additions & 11 deletions ashlar/scripts/ashlar.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def main(argv=sys.argv):
" file for every cycle. Channel counts must match input files."
" (default: no dark field correction)"),
)
parser.add_argument(
"--barrel-correction", type=float, default=0, help=argparse.SUPPRESS
)
parser.add_argument(
'--plates', default=False, action='store_true',
help='Enable plate mode for HTS data',
Expand Down Expand Up @@ -215,15 +218,15 @@ def main(argv=sys.argv):
if args.plates:
return process_plates(
filepaths, output_path, args.filename_format, args.flip_x,
args.flip_y, ffp_paths, dfp_paths, aligner_args, mosaic_args,
args.pyramid, args.quiet
args.flip_y, ffp_paths, dfp_paths, args.barrel_correction,
aligner_args, mosaic_args, args.pyramid, args.quiet
)
else:
mosaic_path_format = str(output_path / args.filename_format)
return process_single(
filepaths, mosaic_path_format, args.flip_x, args.flip_y,
ffp_paths, dfp_paths, aligner_args, mosaic_args, args.pyramid,
args.quiet
ffp_paths, dfp_paths, args.barrel_correction, aligner_args,
mosaic_args, args.pyramid, args.quiet
)
except ProcessingError as e:
print_error(str(e))
Expand All @@ -232,7 +235,8 @@ def main(argv=sys.argv):

def process_single(
filepaths, output_path_format, flip_x, flip_y, ffp_paths, dfp_paths,
aligner_args, mosaic_args, pyramid, quiet, plate_well=None
barrel_correction, aligner_args, mosaic_args, pyramid, quiet,
plate_well=None
):

mosaic_args = mosaic_args.copy()
Expand All @@ -245,7 +249,7 @@ def process_single(
print("Stitching and registering input images")
print('Cycle 0:')
print(' reading %s' % filepaths[0])
reader = build_reader(filepaths[0], plate_well=plate_well)
reader = build_reader(filepaths[0], barrel_correction, plate_well=plate_well)
process_axis_flip(reader, flip_x, flip_y)
ea_args = aligner_args.copy()
for arg in ("alpha", "max_error"):
Expand All @@ -266,7 +270,7 @@ def process_single(
if not quiet:
print('Cycle %d:' % cycle)
print(' reading %s' % filepath)
reader = build_reader(filepath, plate_well=plate_well)
reader = build_reader(filepath, barrel_correction, plate_well=plate_well)
process_axis_flip(reader, flip_x, flip_y)
layer_aligner = reg.LayerAligner(reader, edge_aligner, **aligner_args)
layer_aligner.run()
Expand Down Expand Up @@ -294,7 +298,7 @@ def process_single(

def process_plates(
filepaths, output_path, filename_format, flip_x, flip_y, ffp_paths,
dfp_paths, aligner_args, mosaic_args, pyramid, quiet
dfp_paths, barrel_correction, aligner_args, mosaic_args, pyramid, quiet
):

temp_reader = build_reader(filepaths[0])
Expand All @@ -318,8 +322,8 @@ def process_plates(
mosaic_path_format = str(out_file_path)
process_single(
filepaths, mosaic_path_format, flip_x, flip_y,
ffp_paths, dfp_paths, aligner_args, mosaic_args, pyramid,
quiet, plate_well=(p, w)
ffp_paths, dfp_paths, barrel_correction, aligner_args,
mosaic_args, pyramid, quiet, plate_well=(p, w)
)
else:
print("Skipping -- No images found.")
Expand Down Expand Up @@ -347,7 +351,7 @@ def process_axis_flip(reader, flip_x, flip_y):

# This is a short-term hack to provide a way to specify alternate reader
# classes and pass specific args to them.
def build_reader(path, plate_well=None):
def build_reader(path, barrel_correction=0, plate_well=None):
# Default to BioformatsReader if name not specified.
reader_class = BioformatsReader
kwargs = {}
Expand All @@ -369,6 +373,8 @@ def build_reader(path, plate_well=None):
)
kwargs.update(plate=plate_well[0], well=plate_well[1])
reader = reader_class(path, **kwargs)
if barrel_correction != 0:
reader = reg.BarrelCorrectionReader(reader, barrel_correction)
return reader


Expand Down
95 changes: 95 additions & 0 deletions ashlar/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
from skimage.transform import warp


def _barrel_mapping(xy, center, k):
x, y = xy.T.astype(float)
y0, x0 = center
x -= x0
y -= y0
# The warp mapping function defines the INVERSE map to apply, which in the
# case of a distortion correction is the FORWARD map of the distortion
# itself. See Fitzgibbon 2001 eq 1 for details.
r2 = x ** 2 + y ** 2
f = 1 + k * r2
xy[..., 0] = x * f + x0
xy[..., 1] = y * f + y0
return xy


def barrel_correction(
image,
k,
center=None,
output_shape=None,
order=1,
mode=None,
cval=0,
clip=True,
preserve_range=False,
):
"""Perform a transform to correct for barrel or pincushion distortion.
Parameters
----------
image : ndarray
Input image.
k : float
Distortion parameter
center : (row, column) tuple or (2,) ndarray, optional
Center coordinate of transformation.
Returns
-------
cartesian : ndarray
Cartesian version of the input.
Rows correspond to radius and columns to angle values.
Other parameters
----------------
output_shape : tuple (rows, cols), optional
Shape of the output image generated. By default the shape of the input
image is preserved.
order : int, optional
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5. See `skimage.transform.warp` for detail.
mode : {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}, optional
Points outside the boundaries of the input are filled according
to the given mode, with 'constant' used as the default. Modes match
the behaviour of `numpy.pad`.
cval : float, optional
Used in conjunction with mode 'constant', the value outside
the image boundaries.
clip : bool, optional
Whether to clip the output to the range of values of the input image.
This is enabled by default, since higher order interpolation may
produce values outside the given input range.
preserve_range : bool, optional
Whether to keep the original range of values. Otherwise, the input
image is converted according to the conventions of `img_as_float`.
Notes
-----
Radial distortion is modeled here using a one-parameter model described in
[1]_ (Equation 1).
References
----------
.. [1] A. W. Fitzgibbon, "Simultaneous linear estimation of multiple view
geometry and lens distortion," Proceedings of the 2001 IEEE Computer
Society Conference on Computer Vision and Pattern Recognition. CVPR 2001,
Kauai, HI, USA, 2001, pp. I-I. :DOI:`10.1109/CVPR.2001.990465`.
"""

if mode is None:
mode = "constant"

if center is None:
center = np.array(image.shape)[:2] / 2

warp_args = {"center": center, "k": k}

return warp(image, _barrel_mapping, map_args=warp_args,
output_shape=output_shape, order=order, mode=mode, cval=cval,
clip=clip, preserve_range=preserve_range)
18 changes: 14 additions & 4 deletions ashlar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,7 @@ def paste(target, img, pos, func=None):
return
if np.issubdtype(img.dtype, np.floating):
np.clip(img, 0, 1, img)
# It's safe to silence this FutureWarning as we pinned the skimage version.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r".*scikit-image 1\.0", FutureWarning)
img = skimage.util.dtype.convert(img, target.dtype)
img = dtype_convert(img, target.dtype)
if func is None:
target_slice[:] = img
elif isinstance(func, np.ufunc):
Expand Down Expand Up @@ -192,6 +189,19 @@ def crop_like(img, target):
return img


def dtype_convert(img, dtype):
"""Convert an image to the requested data-type.
This is just a wrapper around skimage.util.dtype.convert that silences its
FutureWarning, as Ashlar pins skimage to a version before that planned
deprecation.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r".*scikit-image 1\.0", FutureWarning)
return skimage.util.dtype.convert(img, dtype)


def imsave(fname, arr, **kwargs):
"""Save an image to file.
Expand Down

0 comments on commit 7aded93

Please sign in to comment.