Skip to content

Commit

Permalink
FIRETensor supporting maps
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Nov 8, 2022
1 parent 1232569 commit fbcd2ff
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 158 deletions.
72 changes: 33 additions & 39 deletions firelang/map/_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import List, Iterable
from typing import List, Tuple, Iterable
import numpy as np
from numba import cuda
import torch
from torch import Tensor, dtype, device
Expand All @@ -16,13 +17,11 @@ def __init__(
dim_sizes: List[int],
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
shape: Tuple[int] = (1,),
):
StackingSlicing.__init__(self, locals())
self._gridvals = Parameter(
torch.empty(stack_size, *dim_sizes, dtype=dtype, device=device).normal_(
0, 0.1
)
torch.empty(*shape, *dim_sizes, dtype=dtype, device=device).normal_(0, 0.1)
)
self.dim_sizes = torch.tensor(dim_sizes, dtype=torch.long, device=device)
self.ndim = len(dim_sizes)
Expand All @@ -38,14 +37,13 @@ def slice_rectangle(
self,
corners: Tensor,
rect_dim_sizes: int | List[int] | Tensor,
cross: bool = False,
) -> Tensor:
"""Slice batches of (hyper-)rectangles from the d-dim grid that are specified
by `corners` and `rect_dim_sizes`.
"""Parallelized slicing of (hyper-)rectangles from the d-dim grid
that are specified by `corners` and `rect_dim_sizes`.
Args:
- corners (Tensor): (measure_stack, batch_size, ndim). \
If cross == False: measure_stack must be equal to self.stack_size
- corners (Tensor): (...xshape, ndim). The i-th corner out of `xshape` corners
is aligned with the i-th grid of self._gridvals.
- rect_dim_sizes (int | List[int] | Tensor): rect_dim_sizes at each dimension of the rectangle. \
If is a `int`, it gives the size at all dimensions.
Expand All @@ -55,60 +53,56 @@ def slice_rectangle(
Returns:
- Tensor: let each rectangle be represented by (n1, n2, ..., nd), returns
a Tensor with shape:
- if cross == False: (measure_size, batch_size, n1, n2, ..., nd),
- else: (self.stack_size, measure_size, batch_size, n1, n2, ..., nd).
a Tensor with shape: (...shape, n1, n2, ..., nd),
"""
measure_stack, batch_size, ndim = corners.shape
(*shape, ndim) = corners.shape
size = int(np.prod(shape))
device, dtype = corners.device, corners.dtype

grid_dim_sizes = self.dim_sizes.to(device)
stack_size = self.stack_size

if not isinstance(rect_dim_sizes, Iterable):
rect_dim_sizes = [rect_dim_sizes] * len(ndim)
else:
assert len(rect_dim_sizes) == ndim
if not isinstance(rect_dim_sizes, Tensor):
rect_dim_sizes = torch.tensor(
rect_dim_sizes, dtype=torch.long, device=device
rect_dim_sizes,
dtype=torch.long,
device=device,
)

corners = corners.reshape(measure_stack * batch_size, ndim)
""" Compute the offsets for all grid points within the rectangle,
which will be used in torch.take(...).
"""
corners = corners.reshape(size, ndim)
offsets = torch.zeros(
measure_stack * batch_size,
torch.prod(rect_dim_sizes),
size,
int(np.prod(rect_dim_sizes.data.cpu().numpy())),
dtype=torch.long,
device=device,
)

BLOCKDIM_X = 512
n_block = (measure_stack * batch_size + BLOCKDIM_X - 1) // BLOCKDIM_X
n_block = (size + BLOCKDIM_X - 1) // BLOCKDIM_X
rectangle_offsets_in_grid_kernel[n_block, BLOCKDIM_X](
cuda.as_cuda_array(corners),
cuda.as_cuda_array(grid_dim_sizes),
cuda.as_cuda_array(rect_dim_sizes),
measure_stack * batch_size,
size,
ndim,
cuda.as_cuda_array(offsets),
)

offsets = offsets.reshape(measure_stack, batch_size, *rect_dim_sizes.tolist())
# (measure_stack, batch_size, n1, n2, ..., nd)

grid_size = torch.prod(grid_dim_sizes)
if not cross:
stack_offsets = (
torch.arange(measure_stack, dtype=torch.long, device=device) * grid_size
).reshape(-1, 1, *[1 for _ in range(ndim)])
# (measure_stack, 1 (batch_size), 1, ..., 1), length=2+ndim.
offsets = offsets + stack_offsets
else:
stack_offsets = (
torch.arange(stack_size, dtype=torch.long, device=device) * grid_size
).reshape(-1, 1, 1, *[1 for _ in range(ndim)])
# (self.stack_size, 1 (measure_stack), 1 (batch_size), 1, ..., 1), length=3+ndim.
offsets = offsets[None] + stack_offsets
offsets = offsets.reshape(*shape, *rect_dim_sizes.tolist())
# (*shape, n1, n2, ..., nd)

""" Consider additional offset caused by `shape` """
grid_size = int(np.prod(grid_dim_sizes.data.cpu().numpy()))
stack_offsets = (
torch.arange(int(np.prod(self.shape)), dtype=torch.long, device=device)
* grid_size
).reshape(*self.shape, *[1] * ndim)
# (*self.shape, 1, ..., 1)
offsets = offsets + stack_offsets

rectangle_vals = self.gridvals.take(offsets)
return rectangle_vals
Expand Down
40 changes: 27 additions & 13 deletions firelang/map/conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Tuple, Iterable
from typing_extensions import Literal
import numpy as np
import torch
from torch import dtype, device, Tensor
from torch.nn import Module, Parameter, ModuleList
Expand All @@ -26,13 +27,19 @@ def __init__(
conv_layers: int = 1,
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
shape: Tuple[int] = (1,),
):
Grid.__init__(
self, dim_sizes=dim_sizes, dtype=dtype, device=device, stack_size=stack_size
self,
dim_sizes=dim_sizes,
dtype=dtype,
device=device,
shape=shape,
)
self.register_extra_init_kwargs(
conv_size=conv_size, conv_chans=conv_chans, conv_layers=conv_layers
conv_size=conv_size,
conv_chans=conv_chans,
conv_layers=conv_layers,
)
self.conv_size = conv_size
self.conv_chans = conv_chans
Expand All @@ -43,7 +50,11 @@ def __init__(
in_chans = 1 if l == 0 else conv_chans
self.conv.append(
torch.nn.Conv2d(
in_chans, conv_chans, conv_size, padding="same", device=device
in_chans,
conv_chans,
conv_size,
padding="same",
device=device,
)
)
self.unsliceable_params.add(f"conv.{l}.weight")
Expand All @@ -52,11 +63,11 @@ def __init__(
@property
def gridvals(self) -> Tensor:
assert self.ndim == 2
g: Parameter = self._gridvals # (self.stack_size, n1, n2)
g: Parameter = self._gridvals # (size, n1, n2)

g = g[:, None] # (self.stack_size, 1, n1, n2)
g = g[:, None, :, :] # (size, shape, 1, n1, n2)
for layer in self.conv:
g = layer(g) # (self.stack_size, channels, n1, n2)
g = layer(g) # (size, channels, n1, n2)
g = g.sum(dim=1)
return g

Expand All @@ -75,11 +86,12 @@ def __init__(
conv_layers: int = 1,
dtype: dtype = torch.float32,
device: device = "cuda",
stack_size: int = 1,
shape: Tuple[int] = (1,),
):
Functional.__init__(self, locals())

self.stack_size = stack_size
self.shape = shape
size = int(np.prod(shape))
self.ndim = len(grid_dim_sizes)
self._grid = Conv2DGrid(
grid_dim_sizes,
Expand All @@ -88,7 +100,7 @@ def __init__(
conv_layers=conv_layers,
dtype=dtype,
device=device,
stack_size=stack_size,
shape=shape,
)

def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor:
Expand All @@ -100,15 +112,17 @@ def _sizes_to_tensor(sizes: int | List[int], ndim: int) -> Tensor:
self.grid_dim_sizes = _sizes_to_tensor(grid_dim_sizes, self.ndim)
self.rect_dim_sizes = _sizes_to_tensor(rect_dim_sizes, self.ndim)
self.limits = torch.tensor(
parse_rect_limits(limits, self.ndim), dtype=dtype, device=device
parse_rect_limits(limits, self.ndim),
dtype=dtype,
device=device,
)

self.rect_weight_decay = rect_weight_decay
self.bandwidth_mode = bandwidth_mode
self.bandwidth_lb = bandwidth_lb
if bandwidth_mode == "constant":
self._bandwidth = bandwidth_lb * torch.ones(stack_size, dtype=dtype)
self._bandwidth = bandwidth_lb * torch.ones(size, dtype=dtype)
elif bandwidth_mode == "parameter":
self._bandwidth = Parameter(torch.ones(stack_size, dtype=dtype))
self._bandwidth = Parameter(torch.zeros(size, dtype=dtype))
else:
raise ValueError(bandwidth_mode)
6 changes: 3 additions & 3 deletions firelang/map/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
)

__all__ = [
'SmoothedRectFourier2DMap',
"SmoothedRectFourier2DMap",
]

class SmoothedRectFourier2DMap(SmoothedRectMap):

class SmoothedRectFourier2DMap(SmoothedRectMap):
@property
def grid(self):
g = SmoothedRectMap.get_grid(self)
assert g.ndim == 2

g = torch.fft.fft2(g).real
g = torch.fft.fft2(g).real
Loading

0 comments on commit fbcd2ff

Please sign in to comment.