From fbcd2ffffc6c65774ba248556cfacbb15cfd1691 Mon Sep 17 00:00:00 2001 From: duxin Date: Tue, 8 Nov 2022 11:14:52 +0900 Subject: [PATCH] FIRETensor supporting maps --- firelang/map/_grid.py | 72 ++++++++++----------- firelang/map/conv.py | 40 ++++++++---- firelang/map/fourier.py | 6 +- firelang/map/rect.py | 136 +++++++++++++--------------------------- firelang/map/wavelet.py | 32 ++++++---- 5 files changed, 128 insertions(+), 158 deletions(-) diff --git a/firelang/map/_grid.py b/firelang/map/_grid.py index ed0b0e8..1055288 100644 --- a/firelang/map/_grid.py +++ b/firelang/map/_grid.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import List, Iterable +from typing import List, Tuple, Iterable +import numpy as np from numba import cuda import torch from torch import Tensor, dtype, device @@ -16,13 +17,11 @@ def __init__( dim_sizes: List[int], dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): StackingSlicing.__init__(self, locals()) self._gridvals = Parameter( - torch.empty(stack_size, *dim_sizes, dtype=dtype, device=device).normal_( - 0, 0.1 - ) + torch.empty(*shape, *dim_sizes, dtype=dtype, device=device).normal_(0, 0.1) ) self.dim_sizes = torch.tensor(dim_sizes, dtype=torch.long, device=device) self.ndim = len(dim_sizes) @@ -38,14 +37,13 @@ def slice_rectangle( self, corners: Tensor, rect_dim_sizes: int | List[int] | Tensor, - cross: bool = False, ) -> Tensor: - """Slice batches of (hyper-)rectangles from the d-dim grid that are specified - by `corners` and `rect_dim_sizes`. + """Parallelized slicing of (hyper-)rectangles from the d-dim grid + that are specified by `corners` and `rect_dim_sizes`. Args: - - corners (Tensor): (measure_stack, batch_size, ndim). \ - If cross == False: measure_stack must be equal to self.stack_size + - corners (Tensor): (...xshape, ndim). The i-th corner out of `xshape` corners + is aligned with the i-th grid of self._gridvals. - rect_dim_sizes (int | List[int] | Tensor): rect_dim_sizes at each dimension of the rectangle. \ If is a `int`, it gives the size at all dimensions. @@ -55,15 +53,13 @@ def slice_rectangle( Returns: - Tensor: let each rectangle be represented by (n1, n2, ..., nd), returns - a Tensor with shape: - - if cross == False: (measure_size, batch_size, n1, n2, ..., nd), - - else: (self.stack_size, measure_size, batch_size, n1, n2, ..., nd). + a Tensor with shape: (...shape, n1, n2, ..., nd), """ - measure_stack, batch_size, ndim = corners.shape + (*shape, ndim) = corners.shape + size = int(np.prod(shape)) device, dtype = corners.device, corners.dtype grid_dim_sizes = self.dim_sizes.to(device) - stack_size = self.stack_size if not isinstance(rect_dim_sizes, Iterable): rect_dim_sizes = [rect_dim_sizes] * len(ndim) @@ -71,44 +67,42 @@ def slice_rectangle( assert len(rect_dim_sizes) == ndim if not isinstance(rect_dim_sizes, Tensor): rect_dim_sizes = torch.tensor( - rect_dim_sizes, dtype=torch.long, device=device + rect_dim_sizes, + dtype=torch.long, + device=device, ) - corners = corners.reshape(measure_stack * batch_size, ndim) + """ Compute the offsets for all grid points within the rectangle, + which will be used in torch.take(...). + """ + corners = corners.reshape(size, ndim) offsets = torch.zeros( - measure_stack * batch_size, - torch.prod(rect_dim_sizes), + size, + int(np.prod(rect_dim_sizes.data.cpu().numpy())), dtype=torch.long, device=device, ) - BLOCKDIM_X = 512 - n_block = (measure_stack * batch_size + BLOCKDIM_X - 1) // BLOCKDIM_X + n_block = (size + BLOCKDIM_X - 1) // BLOCKDIM_X rectangle_offsets_in_grid_kernel[n_block, BLOCKDIM_X]( cuda.as_cuda_array(corners), cuda.as_cuda_array(grid_dim_sizes), cuda.as_cuda_array(rect_dim_sizes), - measure_stack * batch_size, + size, ndim, cuda.as_cuda_array(offsets), ) - - offsets = offsets.reshape(measure_stack, batch_size, *rect_dim_sizes.tolist()) - # (measure_stack, batch_size, n1, n2, ..., nd) - - grid_size = torch.prod(grid_dim_sizes) - if not cross: - stack_offsets = ( - torch.arange(measure_stack, dtype=torch.long, device=device) * grid_size - ).reshape(-1, 1, *[1 for _ in range(ndim)]) - # (measure_stack, 1 (batch_size), 1, ..., 1), length=2+ndim. - offsets = offsets + stack_offsets - else: - stack_offsets = ( - torch.arange(stack_size, dtype=torch.long, device=device) * grid_size - ).reshape(-1, 1, 1, *[1 for _ in range(ndim)]) - # (self.stack_size, 1 (measure_stack), 1 (batch_size), 1, ..., 1), length=3+ndim. - offsets = offsets[None] + stack_offsets + offsets = offsets.reshape(*shape, *rect_dim_sizes.tolist()) + # (*shape, n1, n2, ..., nd) + + """ Consider additional offset caused by `shape` """ + grid_size = int(np.prod(grid_dim_sizes.data.cpu().numpy())) + stack_offsets = ( + torch.arange(int(np.prod(self.shape)), dtype=torch.long, device=device) + * grid_size + ).reshape(*self.shape, *[1] * ndim) + # (*self.shape, 1, ..., 1) + offsets = offsets + stack_offsets rectangle_vals = self.gridvals.take(offsets) return rectangle_vals diff --git a/firelang/map/conv.py b/firelang/map/conv.py index f91170e..5673df0 100644 --- a/firelang/map/conv.py +++ b/firelang/map/conv.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List, Tuple, Iterable from typing_extensions import Literal +import numpy as np import torch from torch import dtype, device, Tensor from torch.nn import Module, Parameter, ModuleList @@ -26,13 +27,19 @@ def __init__( conv_layers: int = 1, dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): Grid.__init__( - self, dim_sizes=dim_sizes, dtype=dtype, device=device, stack_size=stack_size + self, + dim_sizes=dim_sizes, + dtype=dtype, + device=device, + shape=shape, ) self.register_extra_init_kwargs( - conv_size=conv_size, conv_chans=conv_chans, conv_layers=conv_layers + conv_size=conv_size, + conv_chans=conv_chans, + conv_layers=conv_layers, ) self.conv_size = conv_size self.conv_chans = conv_chans @@ -43,7 +50,11 @@ def __init__( in_chans = 1 if l == 0 else conv_chans self.conv.append( torch.nn.Conv2d( - in_chans, conv_chans, conv_size, padding="same", device=device + in_chans, + conv_chans, + conv_size, + padding="same", + device=device, ) ) self.unsliceable_params.add(f"conv.{l}.weight") @@ -52,11 +63,11 @@ def __init__( @property def gridvals(self) -> Tensor: assert self.ndim == 2 - g: Parameter = self._gridvals # (self.stack_size, n1, n2) + g: Parameter = self._gridvals # (size, n1, n2) - g = g[:, None] # (self.stack_size, 1, n1, n2) + g = g[:, None, :, :] # (size, shape, 1, n1, n2) for layer in self.conv: - g = layer(g) # (self.stack_size, channels, n1, n2) + g = layer(g) # (size, channels, n1, n2) g = g.sum(dim=1) return g @@ -75,11 +86,12 @@ def __init__( conv_layers: int = 1, dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): Functional.__init__(self, locals()) - self.stack_size = stack_size + self.shape = shape + size = int(np.prod(shape)) self.ndim = len(grid_dim_sizes) self._grid = Conv2DGrid( grid_dim_sizes, @@ -88,7 +100,7 @@ def __init__( conv_layers=conv_layers, dtype=dtype, device=device, - stack_size=stack_size, + shape=shape, ) def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: @@ -100,15 +112,17 @@ def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: self.grid_dim_sizes = _sizes_to_tensor(grid_dim_sizes, self.ndim) self.rect_dim_sizes = _sizes_to_tensor(rect_dim_sizes, self.ndim) self.limits = torch.tensor( - parse_rect_limits(limits, self.ndim), dtype=dtype, device=device + parse_rect_limits(limits, self.ndim), + dtype=dtype, + device=device, ) self.rect_weight_decay = rect_weight_decay self.bandwidth_mode = bandwidth_mode self.bandwidth_lb = bandwidth_lb if bandwidth_mode == "constant": - self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype) + self._bandwidth = bandwidth_lb * torch.ones(size, dtype=dtype) elif bandwidth_mode == "parameter": - self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype)) + self._bandwidth = Parameter(torch.zeros(size, dtype=dtype)) else: raise ValueError(bandwidth_mode) diff --git a/firelang/map/fourier.py b/firelang/map/fourier.py index a570451..d76d38c 100644 --- a/firelang/map/fourier.py +++ b/firelang/map/fourier.py @@ -4,14 +4,14 @@ ) __all__ = [ - 'SmoothedRectFourier2DMap', + "SmoothedRectFourier2DMap", ] -class SmoothedRectFourier2DMap(SmoothedRectMap): +class SmoothedRectFourier2DMap(SmoothedRectMap): @property def grid(self): g = SmoothedRectMap.get_grid(self) assert g.ndim == 2 - g = torch.fft.fft2(g).real \ No newline at end of file + g = torch.fft.fft2(g).real diff --git a/firelang/map/rect.py b/firelang/map/rect.py index 7965d28..31b454b 100644 --- a/firelang/map/rect.py +++ b/firelang/map/rect.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List, Tuple, Union, Iterable from typing_extensions import Literal +import numpy as np import torch from torch import Tensor, dtype, device from torch.nn import Parameter @@ -26,15 +27,14 @@ def __init__( bandwidth_lb: float = 0.3, dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): Functional.__init__(self, locals()) - self.stack_size = stack_size + self.shape = shape + size = int(np.prod(shape)) self.ndim = len(grid_dim_sizes) - self._grid: Grid = Grid( - grid_dim_sizes, stack_size=stack_size, dtype=dtype, device=device - ) + self._grid: Grid = Grid(grid_dim_sizes, shape=shape, dtype=dtype, device=device) def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: if not isinstance(sizes, Iterable): @@ -52,9 +52,9 @@ def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: self.bandwidth_mode = bandwidth_mode self.bandwidth_lb = bandwidth_lb if bandwidth_mode == "constant": - self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype) + self._bandwidth = bandwidth_lb * torch.ones(size, dtype=dtype) elif bandwidth_mode == "parameter": - self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype)) + self._bandwidth = Parameter(torch.zeros(size, dtype=dtype)) else: raise ValueError(bandwidth_mode) @@ -83,40 +83,27 @@ def get_grid(self): def detect_device(self): return self._grid.detect_device() - def forward(self, locs: Tensor, cross: bool = False) -> Tensor: + def forward(self, locs: Tensor) -> Tensor: """Interpolate grid values at `locs` Args: - - locs (Tensor): (measure_stack, batch_size, dim). + - locs (Tensor): (...shape, dim). Locations in the grid - - If cross == False: - measure_stack must be equal to self.stack_size - - - cross (bool, optional): Defaults to False. Returns: - Tensor: - - If cross == False: - (measure_stack, batch_size) - - If cross == True: - (self.stack_size, measure_stack, batch_size) + Tensor: (...shape,) """ device = locs.device - measure_stack, batch_size, ndim = locs.shape - map_stack = self.stack_size - if not cross: - assert map_stack == measure_stack + ndim = locs.shape[-1] with Timer(elapsed, "gridmap/location_transform"): """Transform values of `locs` from the range [limits_lower, limits_upper] to [0, dim_size] """ limits = self.limits.to(device) # (ndim, 2) - eps = 0 - limits_lower = limits[:, 0].reshape(1, 1, -1) - eps - limits_upper = limits[:, 1].reshape(1, 1, -1) + eps - # print('locs:', locs) + limits_lower = limits[:, 0] + limits_upper = limits[:, 1] assert ( locs >= limits_lower ).all(), f"locs: {locs}\nSmallest value = {locs.min().item()}" @@ -125,57 +112,41 @@ def forward(self, locs: Tensor, cross: bool = False) -> Tensor: ).all(), f"locs: {locs}\nLargest value = {locs.max().item()}" dim_sizes = self.grid_dim_sizes.to(device) # (ndim,) locs = (locs - limits_lower) / (limits_upper - limits_lower) - locs = locs * dim_sizes.type(torch.float32).reshape(1, 1, -1) + locs = locs * dim_sizes.type(torch.float32) with Timer(elapsed, "gridmap/vertices"): rect_dim_sizes = self.rect_dim_sizes.to(device) - rect_dim_sizes = rect_dim_sizes.reshape(1, 1, -1) # (ndim,) # (1, 1, ndim) + rect_dim_sizes = rect_dim_sizes # (ndim,) lower = torch.ceil(locs.data - rect_dim_sizes / 2).type(torch.long) lower = torch.maximum(lower, torch.zeros_like(lower)) - upper = torch.minimum(lower + rect_dim_sizes, dim_sizes.reshape(1, 1, -1)) + upper = torch.minimum(lower + rect_dim_sizes, dim_sizes) lower = upper - rect_dim_sizes corners = lower - # (measure_stack, batch_size, ndim, 2) + # (...shape, ndim, 2) with Timer(elapsed, "gridmap/subgrid_weights", sync_cuda=True): rect_dim_sizes = self.rect_dim_sizes.to(device) - # distances = self._grid.rectangle_distance_to_loc(locs, corners, rect_dim_sizes) distances = rectangle_distance_to_loc(locs, corners, rect_dim_sizes) - # (measure_stack, batch_size, n1, ..., nd) - bandwidth = self.bandwidth.to(device) + # (...shape, n1, ..., nd) + bandwidth = self.bandwidth.to(device).view(*self.shape, *[1] * ndim) + # (...shape, 1, ..., 1) subgrid_weights = weights_from_distances( - distances, self.rect_weight_decay, bandwidth, cross=cross + distances, ndim, self.rect_weight_decay, bandwidth + ) # (...shape, n1, ..., nd) + + with Timer(elapsed, "gridmap/subgrid_values", sync_cuda=True): + subgrid_vals = self._grid.slice_rectangle( + corners, + rect_dim_sizes, ) - # If cross==False: (measure_stack, batch_size, n1, ..., nd) - # Else: (self.stack_size, measure_stack, batch_size, n1, ..., nd) - - if not cross: - with Timer(elapsed, "gridmap/subgrid_values", sync_cuda=True): - subgrid_vals = self._grid.slice_rectangle( - corners, rect_dim_sizes, cross=cross - ) - # (measure_stack, batch_size, n1, ..., nd) - - with Timer(elapsed, "gridmap/weighted_sum"): - vals = (subgrid_vals * subgrid_weights).sum( - dim=list(range(2, ndim + 2)) - ) - # (measure_stack, batch_size) - else: - with Timer(elapsed, "gridmap/subgrid_values", sync_cuda=True): - subgrid_vals = self._grid.slice_rectangle( - corners, rect_dim_sizes, cross=cross - ) - # (self.stack_size, measure_stack, batch_size, n1, ..., nd) - - with Timer(elapsed, "gridmap/weighted_sum"): - vals = (subgrid_vals * subgrid_weights).sum( - dim=list(range(3, ndim + 3)) - ) - # (self.stack_size, measure_stack, batch_size) + # (...shape, n1, ..., nd) + + with Timer(elapsed, "gridmap/weighted_sum"): + vals = (subgrid_vals * subgrid_weights).sum(dim=list(range(-ndim, 0))) + # (...shape,) return vals @@ -187,66 +158,49 @@ def rectangle_distance_to_loc( (hyper-)rectangle that are specified by `corners` and `rect_dim_sizes`. Args: - - locs (Tensor): (measure_stack, batch_size, ndim) locations - - corners (Tensor): (measure_stack, batch_size, ndim). \ - If cross == False: measure_stack must be equal to self.stack_size + - locs (Tensor): (...shape, ndim) locations + - corners (Tensor): (...shape, ndim). - rect_dim_sizes (Tensor): rect_dim_sizes at each dimension of the rectangle. Returns: Tensor: let each rectangle be represented by (n1, n2, ..., nd), returns \ - a Tensor with shape (measure_stack, batch_size, n1, n2, ..., nd). + a Tensor with shape (...shape, n1, n2, ..., nd). """ - measure_stack, batch_size, ndim = locs.shape device = locs.device + (*shape, ndim) = locs.shape - shifts = corners.to(locs) - locs # (measure_stack, batch_size, ndim) + shifts = corners.to(locs) - locs # (...shape, ndim) distsq = 0 for d in range(ndim): - shifts_at_dim_d = shifts[:, :, d : d + 1] + torch.arange( + shifts_at_dim_d = shifts[..., d : d + 1] + torch.arange( rect_dim_sizes[d], device=device, dtype=torch.long - ).reshape( - 1, 1, -1 - ) # (measure_stack, batch_size, nd=rect_dim_sizes[d]) + ) # (...shape, nd=rect_dim_sizes[d]) distsq_at_dim_d = (shifts_at_dim_d**2).reshape( - measure_stack, batch_size, *([1] * d), -1, *([1] * (ndim - d - 1)) - ) # (measure_stack, batch_size, 1, ..., 1, nd, 1, ..., 1) + *shape, *([1] * d), -1, *([1] * (ndim - d - 1)) + ) # (...shape, 1, ..., 1, nd, 1, ..., 1) distsq = distsq + distsq_at_dim_d return distsq**0.5 def weights_from_distances( distances: Tensor, + ndim: int, decay: Literal["gauss", "exp"], bandwidth: Tensor, - cross: bool = False, ) -> Tensor: """Compute weights from distances. Args: - - distances (Tensor): (measure_stack, batch_size, *) + - distances (Tensor): (...shape, *) - decay ("gauss" | "exp"): different weight decaying patterns with respect to distance. - "gauss": proportional to $exp(-d^2)$ - "exp": proportional to $exp(-d)$ - - bandwidth (Tensor): (map_stack,) the weights decay slower with a larger `bandwidth` + - bandwidth (Tensor): (...shape,) the weights decay slower with a larger `bandwidth` Returns: - Tensor: - - If cross == False, measure_stack must be equal to map_stack. - Returns (measure_stack, *) - - Else, returns (map_stack, measure_stack, *) + Tensor: (...shape, *) """ - ndim = distances.ndim - 2 - measure_stack = distances.shape[0] - map_stack = bandwidth.shape[0] - - bandwidth = bandwidth.reshape(map_stack, 1, *[1 for _ in range(ndim)]) - if not cross: - assert map_stack == measure_stack - else: - distances = distances[None] - bandwidth = bandwidth[:, None] - distances = distances / bandwidth if decay == "gauss": logweights = -(distances**2) diff --git a/firelang/map/wavelet.py b/firelang/map/wavelet.py index edca4bb..a804563 100644 --- a/firelang/map/wavelet.py +++ b/firelang/map/wavelet.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List, Tuple, Iterable from typing_extensions import Literal +import numpy as np import torch from torch import Tensor, dtype, device from torch.nn import Parameter @@ -26,10 +27,14 @@ def __init__( level: int = 3, dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): Grid.__init__( - self, dim_sizes=dim_sizes, dtype=dtype, device=device, stack_size=stack_size + self, + dim_sizes=dim_sizes, + dtype=dtype, + device=device, + shape=shape, ) self.register_extra_init_kwargs( wavelet=wavelet, @@ -42,14 +47,14 @@ def __init__( @property def gridvals(self) -> Tensor: assert self.ndim == 2 - g: Parameter = self._gridvals # (self.stack_size, n1, n2) + g: Parameter = self._gridvals # (size, n1, n2) - g = g[:, None] # (self.stack_size, 1, n1, n2) + g = g[:, None, :, :] # (size, 1, n1, n2) sizes: Tensor = self.dim_sizes sizes = sizes.data.cpu().numpy().tolist() coeffs = [] - for l in range(self.level - 1): + for _ in range(self.level - 1): h, w = (sizes[0] + 1) // 2, (sizes[1] + 1) // 2 coeff1 = g[:, :, -h:, :w] coeff2 = g[:, :, -h:, -w:] @@ -59,7 +64,7 @@ def gridvals(self) -> Tensor: coeffs.append(g[:, :, :h, :w]) coeffs = list(reversed(coeffs)) - return ptwt.waverec2(coeffs, self.wavelet) # (self.stack, 1, n1, n2) + return ptwt.waverec2(coeffs, self.wavelet) # (size, 1, n1, n2) class SmoothedRectWavelet2DMap(SmoothedRectMap): @@ -75,11 +80,12 @@ def __init__( level: str = 3, dtype: dtype = torch.float32, device: device = "cuda", - stack_size: int = 1, + shape: Tuple[int] = (1,), ): Functional.__init__(self, locals()) - self.stack_size = stack_size + self.shape = shape + size = int(np.prod(shape)) self.ndim = len(grid_dim_sizes) self._grid = Wavelet2DMap( grid_dim_sizes, @@ -87,7 +93,7 @@ def __init__( level=level, dtype=dtype, device=device, - stack_size=stack_size, + shape=shape, ) def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: @@ -99,15 +105,17 @@ def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: self.grid_dim_sizes = _sizes_to_tensor(grid_dim_sizes, self.ndim) self.rect_dim_sizes = _sizes_to_tensor(rect_dim_sizes, self.ndim) self.limits = torch.tensor( - parse_rect_limits(limits, self.ndim), dtype=dtype, device=device + parse_rect_limits(limits, self.ndim), + dtype=dtype, + device=device, ) self.rect_weight_decay = rect_weight_decay self.bandwidth_mode = bandwidth_mode self.bandwidth_lb = bandwidth_lb if bandwidth_mode == "constant": - self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype) + self._bandwidth = bandwidth_lb * torch.ones(size, dtype=dtype) elif bandwidth_mode == "parameter": - self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype)) + self._bandwidth = Parameter(torch.zeros(size, dtype=dtype)) else: raise ValueError(bandwidth_mode)