From 9d958c3a638624b950b35379d4bf678d19ab3941 Mon Sep 17 00:00:00 2001 From: duxin Date: Wed, 2 Nov 2022 15:53:20 +0900 Subject: [PATCH] Wavelet reconstruction of 2D rectangle map --- firelang/map/__init__.py | 1 + firelang/map/wavelet.py | 113 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 firelang/map/wavelet.py diff --git a/firelang/map/__init__.py b/firelang/map/__init__.py index f5b4a44..402fe70 100644 --- a/firelang/map/__init__.py +++ b/firelang/map/__init__.py @@ -1,3 +1,4 @@ from .rect import * from .fourier import * from .conv import * +from .wavelet import * \ No newline at end of file diff --git a/firelang/map/wavelet.py b/firelang/map/wavelet.py new file mode 100644 index 0000000..edca4bb --- /dev/null +++ b/firelang/map/wavelet.py @@ -0,0 +1,113 @@ +from __future__ import annotations +from typing import List, Tuple, Iterable +from typing_extensions import Literal +import torch +from torch import Tensor, dtype, device +from torch.nn import Parameter +import pywt +import ptwt +from ._grid import Grid +from .rect import ( + SmoothedRectMap, +) +from firelang.function import Functional +from firelang.utils.limits import parse_rect_limits + +__all__ = [ + "SmoothedRectWavelet2DMap", +] + + +class Wavelet2DMap(Grid): + def __init__( + self, + dim_sizes: List[int], + wavelet: Literal["haar"] = "haar", + level: int = 3, + dtype: dtype = torch.float32, + device: device = "cuda", + stack_size: int = 1, + ): + Grid.__init__( + self, dim_sizes=dim_sizes, dtype=dtype, device=device, stack_size=stack_size + ) + self.register_extra_init_kwargs( + wavelet=wavelet, + level=level, + ) + self._wavelet = wavelet + self.wavelet = pywt.Wavelet(wavelet) + self.level = level + + @property + def gridvals(self) -> Tensor: + assert self.ndim == 2 + g: Parameter = self._gridvals # (self.stack_size, n1, n2) + + g = g[:, None] # (self.stack_size, 1, n1, n2) + sizes: Tensor = self.dim_sizes + sizes = sizes.data.cpu().numpy().tolist() + + coeffs = [] + for l in range(self.level - 1): + h, w = (sizes[0] + 1) // 2, (sizes[1] + 1) // 2 + coeff1 = g[:, :, -h:, :w] + coeff2 = g[:, :, -h:, -w:] + coeff3 = g[:, :, :h, -w:] + coeffs.append((coeff1, coeff2, coeff3)) + sizes = [h, w] + + coeffs.append(g[:, :, :h, :w]) + coeffs = list(reversed(coeffs)) + return ptwt.waverec2(coeffs, self.wavelet) # (self.stack, 1, n1, n2) + + +class SmoothedRectWavelet2DMap(SmoothedRectMap): + def __init__( + self, + limits: Tuple[float, float] | List[Tuple[float, float]], + grid_dim_sizes: List[int], + rect_dim_sizes: int | List[int] = 3, + rect_weight_decay: Literal["gauss", "exp"] = "gauss", + bandwidth_mode: Literal["parameter", "constant"] = "parameter", + bandwidth_lb: float = 0.3, + wavelet: str = "haar", + level: str = 3, + dtype: dtype = torch.float32, + device: device = "cuda", + stack_size: int = 1, + ): + Functional.__init__(self, locals()) + + self.stack_size = stack_size + self.ndim = len(grid_dim_sizes) + self._grid = Wavelet2DMap( + grid_dim_sizes, + wavelet=wavelet, + level=level, + dtype=dtype, + device=device, + stack_size=stack_size, + ) + + def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor: + if not isinstance(sizes, Iterable): + sizes = [sizes] * ndim + sizes = torch.tensor(sizes, dtype=torch.long, device=device) + return sizes + + self.grid_dim_sizes = _sizes_to_tensor(grid_dim_sizes, self.ndim) + self.rect_dim_sizes = _sizes_to_tensor(rect_dim_sizes, self.ndim) + self.limits = torch.tensor( + parse_rect_limits(limits, self.ndim), dtype=dtype, device=device + ) + + self.rect_weight_decay = rect_weight_decay + self.bandwidth_mode = bandwidth_mode + self.bandwidth_lb = bandwidth_lb + if bandwidth_mode == "constant": + self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype) + elif bandwidth_mode == "parameter": + self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype)) + else: + raise ValueError(bandwidth_mode)