Skip to content

Commit

Permalink
Merge branch 'gridmap'
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 2, 2022
2 parents b78e225 + 9d958c3 commit 41c80ec
Show file tree
Hide file tree
Showing 6 changed files with 663 additions and 0 deletions.
4 changes: 4 additions & 0 deletions firelang/map/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .rect import *
from .fourier import *
from .conv import *
from .wavelet import *
151 changes: 151 additions & 0 deletions firelang/map/_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations
from typing import List, Iterable
from numba import cuda
import torch
from torch import Tensor, dtype, device
from torch.nn import Parameter
from firelang.stack import StackingSlicing


class Grid(StackingSlicing):

_gridvals: Tensor

def __init__(
self,
dim_sizes: List[int],
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
):
StackingSlicing.__init__(self, locals())
self._gridvals = Parameter(
torch.empty(stack_size, *dim_sizes, dtype=dtype, device=device).normal_(
0, 0.1
)
)
self.dim_sizes = torch.tensor(dim_sizes, dtype=torch.long, device=device)
self.ndim = len(dim_sizes)

def detect_device(self):
return self._gridvals.device

@property
def gridvals(self):
return self._gridvals

def slice_rectangle(
self,
corners: Tensor,
rect_dim_sizes: int | List[int] | Tensor,
cross: bool = False,
) -> Tensor:
"""Slice batches of (hyper-)rectangles from the d-dim grid that are specified
by `corners` and `rect_dim_sizes`.
Args:
- corners (Tensor): (measure_stack, batch_size, ndim). \
If cross == False: measure_stack must be equal to self.stack_size
- rect_dim_sizes (int | List[int] | Tensor): rect_dim_sizes at each dimension of the rectangle. \
If is a `int`, it gives the size at all dimensions.
A rectangle at dimension d is located at:
- If `rect_dim_sizes` is int: [corner : corner + rect_dim_sizes]
- If `rect_dim_sizes` is List[int]: [corner : corner + rect_dim_sizes[d]]
Returns:
- Tensor: let each rectangle be represented by (n1, n2, ..., nd), returns
a Tensor with shape:
- if cross == False: (measure_size, batch_size, n1, n2, ..., nd),
- else: (self.stack_size, measure_size, batch_size, n1, n2, ..., nd).
"""
measure_stack, batch_size, ndim = corners.shape
device, dtype = corners.device, corners.dtype

grid_dim_sizes = self.dim_sizes.to(device)
stack_size = self.stack_size

if not isinstance(rect_dim_sizes, Iterable):
rect_dim_sizes = [rect_dim_sizes] * len(ndim)
else:
assert len(rect_dim_sizes) == ndim
if not isinstance(rect_dim_sizes, Tensor):
rect_dim_sizes = torch.tensor(
rect_dim_sizes, dtype=torch.long, device=device
)

corners = corners.reshape(measure_stack * batch_size, ndim)
offsets = torch.zeros(
measure_stack * batch_size,
torch.prod(rect_dim_sizes),
dtype=torch.long,
device=device,
)

BLOCKDIM_X = 512
n_block = (measure_stack * batch_size + BLOCKDIM_X - 1) // BLOCKDIM_X
rectangle_offsets_in_grid_kernel[n_block, BLOCKDIM_X](
cuda.as_cuda_array(corners),
cuda.as_cuda_array(grid_dim_sizes),
cuda.as_cuda_array(rect_dim_sizes),
measure_stack * batch_size,
ndim,
cuda.as_cuda_array(offsets),
)

offsets = offsets.reshape(measure_stack, batch_size, *rect_dim_sizes.tolist())
# (measure_stack, batch_size, n1, n2, ..., nd)

grid_size = torch.prod(grid_dim_sizes)
if not cross:
stack_offsets = (
torch.arange(measure_stack, dtype=torch.long, device=device) * grid_size
).reshape(-1, 1, *[1 for _ in range(ndim)])
# (measure_stack, 1 (batch_size), 1, ..., 1), length=2+ndim.
offsets = offsets + stack_offsets
else:
stack_offsets = (
torch.arange(stack_size, dtype=torch.long, device=device) * grid_size
).reshape(-1, 1, 1, *[1 for _ in range(ndim)])
# (self.stack_size, 1 (measure_stack), 1 (batch_size), 1, ..., 1), length=3+ndim.
offsets = offsets[None] + stack_offsets

rectangle_vals = self.gridvals.take(offsets)
return rectangle_vals


@cuda.jit("void(int64[:,:], int64[:], int64[:], int32, int32, int64[:,:])")
def rectangle_offsets_in_grid_kernel(
corners, grid_dim_sizes, rect_dim_sizes, batch_size, ndim, out
):
"""Compute the offsets of grid points from the beginning of the grid,
where the grid points are those inside rectangles specified by `corners` and `rect_dim_sizes`.
Args:
- corner (CudaArray): (batch_size, ndim)
- grid_dim_sizes (CudaArray): (ndim,)
- rect_dim_sizes (CudaArray): (ndim,)
- batch_size (int)
- ndim (int)
- out (CudaArray): (batch_size, n1*n2...*nd). Buffer for the output (offsets).
"""

i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
if i >= batch_size:
return

rect_size = 1
for d in range(ndim):
rect_size *= rect_dim_sizes[d]

for j in range(rect_size):
rect_stride = 1
grid_stride = 1
offset = 0
for d in range(ndim):
d = ndim - 1 - d
id_at_dim_d = (j // rect_stride) % rect_dim_sizes[d] + corners[i, d]
offset += id_at_dim_d * grid_stride
rect_stride *= rect_dim_sizes[d]
grid_stride *= grid_dim_sizes[d]
out[i, j] = offset
114 changes: 114 additions & 0 deletions firelang/map/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations
from typing import List, Tuple, Iterable
from typing_extensions import Literal
import torch
from torch import dtype, device, Tensor
from torch.nn import Module, Parameter, ModuleList

from firelang.function import Functional
from firelang.utils.limits import parse_rect_limits
from ._grid import Grid
from .rect import (
SmoothedRectMap,
)

__all__ = [
"SmoothedRectConv2DMap",
]


class Conv2DGrid(Grid):
def __init__(
self,
dim_sizes: List[int],
conv_size: int | List[int] = 3,
conv_chans: int = 1,
conv_layers: int = 1,
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
):
Grid.__init__(
self, dim_sizes=dim_sizes, dtype=dtype, device=device, stack_size=stack_size
)
self.register_extra_init_kwargs(
conv_size=conv_size, conv_chans=conv_chans, conv_layers=conv_layers
)
self.conv_size = conv_size
self.conv_chans = conv_chans
self.conv_layers = conv_layers

self.conv = ModuleList()
for l in range(conv_layers):
in_chans = 1 if l == 0 else conv_chans
self.conv.append(
torch.nn.Conv2d(
in_chans, conv_chans, conv_size, padding="same", device=device
)
)
self.unsliceable_params.add(f"conv.{l}.weight")
self.unsliceable_params.add(f"conv.{l}.bias")

@property
def gridvals(self) -> Tensor:
assert self.ndim == 2
g: Parameter = self._gridvals # (self.stack_size, n1, n2)

g = g[:, None] # (self.stack_size, 1, n1, n2)
for layer in self.conv:
g = layer(g) # (self.stack_size, channels, n1, n2)
g = g.sum(dim=1)
return g


class SmoothedRectConv2DMap(SmoothedRectMap):
def __init__(
self,
limits: Tuple[float, float] | List[Tuple[float, float]],
grid_dim_sizes: List[int],
rect_dim_sizes: int | List[int] = 3,
rect_weight_decay: Literal["gauss", "exp"] = "gauss",
bandwidth_mode: Literal["parameter", "constant"] = "parameter",
bandwidth_lb: float = 0.3,
conv_size: int | List[int] = 3,
conv_chans: int = 1,
conv_layers: int = 1,
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
):
Functional.__init__(self, locals())

self.stack_size = stack_size
self.ndim = len(grid_dim_sizes)
self._grid = Conv2DGrid(
grid_dim_sizes,
conv_size=conv_size,
conv_chans=conv_chans,
conv_layers=conv_layers,
dtype=dtype,
device=device,
stack_size=stack_size,
)

def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor:
if not isinstance(sizes, Iterable):
sizes = [sizes] * ndim
sizes = torch.tensor(sizes, dtype=torch.long, device=device)
return sizes

self.grid_dim_sizes = _sizes_to_tensor(grid_dim_sizes, self.ndim)
self.rect_dim_sizes = _sizes_to_tensor(rect_dim_sizes, self.ndim)
self.limits = torch.tensor(
parse_rect_limits(limits, self.ndim), dtype=dtype, device=device
)

self.rect_weight_decay = rect_weight_decay
self.bandwidth_mode = bandwidth_mode
self.bandwidth_lb = bandwidth_lb
if bandwidth_mode == "constant":
self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype)
elif bandwidth_mode == "parameter":
self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype))
else:
raise ValueError(bandwidth_mode)
17 changes: 17 additions & 0 deletions firelang/map/fourier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from .rect import (
SmoothedRectMap,
)

__all__ = [
'SmoothedRectFourier2DMap',
]

class SmoothedRectFourier2DMap(SmoothedRectMap):

@property
def grid(self):
g = SmoothedRectMap.get_grid(self)
assert g.ndim == 2

g = torch.fft.fft2(g).real
Loading

0 comments on commit 41c80ec

Please sign in to comment.