Skip to content

Commit

Permalink
Do not perform sum of channels in InputFile
Browse files Browse the repository at this point in the history
Make channel arguments uniform
  • Loading branch information
gmazzamuto committed Jun 9, 2022
1 parent d5f00f8 commit 348c7ba
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 19 deletions.
24 changes: 11 additions & 13 deletions zetastitcher/align/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_args():

parser.add_argument('input_folder', help='input folder')
parser.add_argument('-o', type=str, default='stitch.yml', dest='output_file', help='output file')
parser.add_argument('-c', type=str, default='s', dest='channel', choices=['r', 'g', 'b', 's'], help='color channel')
parser.add_argument('-c', '--ch', type=int, dest='channel', help='color channel')
parser.add_argument('-j', type=int, dest='n_of_workers',
help='number of parallel jobs (defaults to number of system cores)')
parser.add_argument('-r', action='store_true', dest='recursive', help='recursively look for files')
Expand Down Expand Up @@ -120,15 +120,6 @@ def parse_args():
setattr(args, 'overlap_h', args.overlap)
setattr(args, 'overlap_v', args.overlap)

channels = {
's': -2, # sum
'r': 0,
'g': 1,
'b': 2
}

args.channel = channels[args.channel]

args.max_dx = int(round(args.max_dx / args.px_size_xy))
args.max_dy = int(round(args.max_dy / args.px_size_xy))
args.max_dz = int(round(args.max_dz / args.px_size_z))
Expand Down Expand Up @@ -157,18 +148,25 @@ def worker(item, overlap_dict, channel, max_dz, max_dy, max_dx):
a = InputFile(aname)
b = InputFile(bname)

a.channel = channel
b.channel = channel

z_min = z_frame - max_dz
z_max = z_frame + max_dz + 1

aslice = a.zslice(z_min, z_max, copy=True)
if a.nchannels > 1:
if channel is not None:
aslice = aslice[:, channel]
else:
aslice = np.sum(aslice.astype(np.float32), axis=1)
if axis == 2:
aslice = np.rot90(aslice, axes=(-1, -2))
aslice = aslice[..., -(overlap + max_dy):, :]

bframe = b.zslice_idx(z_frame, copy=True)
if b.nchannels > 1:
if channel is not None:
bframe = bframe[:, channel]
else:
bframe = np.sum(bframe.astype(np.float32), axis=1)
if axis == 2:
bframe = np.rot90(bframe, axes=(-1, -2))
bframe = bframe[..., :overlap - max_dy, :]
Expand Down
2 changes: 1 addition & 1 deletion zetastitcher/fuse/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_args():
group.add_argument('-d', dest='debug', action='store_true',
help='overlay debug info')

group.add_argument('--ch', type=int, dest='channel', help='channel')
group.add_argument('-c', '--ch', type=int, dest='channel', help='channel')

group.add_argument('--downsample-xy', metavar='S', type=int, required=False,
help='downsample XY plane by factor S')
Expand Down
8 changes: 3 additions & 5 deletions zetastitcher/io/inputfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, file_path=None):
super().__init__()
self.file_path = file_path
self.wrapper = None
self._channel = -1
self._channel = None
self.squeeze = True

self.nfrms = None
Expand Down Expand Up @@ -103,7 +103,7 @@ def shape(self):
the `channels` dimension is squeezed.
"""
s = [self.nfrms, self.nchannels, self.ysize, self.xsize]
if self.nchannels == 1 or self.channel != -1:
if self.nchannels == 1 or self.channel is not None:
del s[1]
return tuple(s)

Expand Down Expand Up @@ -223,9 +223,7 @@ def zslice(self, arg1, arg2=None, step=None, dtype=None, copy=True):
a[z] = self.wrapper.frame(i)
z += 1

if self.channel == -2:
a = np.sum(a, axis=-1)
elif self.channel != -1:
if self.channel is not None:
a = a[..., self.channel]
elif self.nchannels > 1 and a.ndim >= 3:
a = np.moveaxis(a, -1, -3)
Expand Down

0 comments on commit 348c7ba

Please sign in to comment.