Skip to content

Commit

Permalink
More flexible range of Dirac measure with limits instead of range
Browse files Browse the repository at this point in the history
…. Removed `range`.
  • Loading branch information
kduxin committed Nov 2, 2022
1 parent 89597f0 commit 8e373de
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 19 deletions.
67 changes: 48 additions & 19 deletions firelang/measure/dirac.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations
from typing import List, Tuple
import torch
from torch import Tensor
from torch.nn import Parameter
from .base import Measure
from firelang.utils.limits import parse_rect_limits

__all__ = ["DiracMixture"]
__all__ = [
"DiracMixture",
]


class DiracMixture(Measure):
Expand All @@ -13,28 +19,30 @@ def __init__(
self,
dim: int,
k: int,
range: float = None,
limits: float | Tuple[float, float] | List[Tuple[float, float]] = None,
mfix: bool = False,
stack_size: int = 1,
):
Measure.__init__(self, locals())
assert range is None or range > 0
self._x = (
Parameter(torch.randn(stack_size, k, dim, dtype=torch.float32))
if range is None
else Parameter(
torch.rand(stack_size, k, dim, dtype=torch.float32) * (2 * range)
- range
if limits is None:
self._x = Parameter(torch.randn(stack_size, k, dim, dtype=torch.float32))
else:
limits = torch.tensor(
parse_rect_limits(limits, dim), dtype=torch.float32
) # (dim, 2)
ranges = (limits[:, 1] - limits[:, 0])[None, None] # (1, 1, dim)
starts = limits[:, 0][None, None] # (1, 1, dim)
self._x = Parameter(
torch.rand(stack_size, k, dim, dtype=torch.float32) * ranges + starts
)
)
self._m = (
1.0 if mfix else Parameter(torch.ones(stack_size, k, dtype=torch.float32))
)

self.dim = dim
self.k = k
self.stack_size = stack_size
self.range = range
self.limits = limits
self.mfix = mfix

def integral(self, func, cross=False, batch_size=1000000, sum=True):
Expand Down Expand Up @@ -64,8 +72,20 @@ def integral(self, func, cross=False, batch_size=1000000, sum=True):
return res

def get_x(self):
if hasattr(self, "range") and self.range is not None:
return torch.tanh(self._x / self.range) * self.range
# _x: (stack_size, k, dim)
if self.limits is not None:
limits = self.limits.to(self.detect_device())
ranges = (limits[:, 1] - limits[:, 0])[None, None] # (1, 1, dim)
_x = self._x / (ranges / 2)
_x = torch.tanh(_x)
_x = _x * (ranges / 2)

if _x.isnan().any():
print(f"_x has NaN: {_x}")
print(f"self._x has NaN ?: {self._x.isnan().any()}")
print(f"ranges: {ranges}")
exit()
return _x
else:
return self._x

Expand All @@ -74,7 +94,10 @@ def x(self):
return self.get_x()

def get_m(self):
return self._m
if isinstance(self._m, Tensor):
return self._m.abs()
else:
return self._m

@property
def m(self):
Expand All @@ -86,12 +109,18 @@ def __repr__(self):
segs.append(f", k={self.k}")
if self.mfix:
segs.append(f", m=1.0")
if self.range is not None:
segs.append(f", range=[-{self.range}, {self.range}]")

if self.limits is not None:
limits = self.limits.data.cpu().numpy().tolist()
else:
limits = None
if limits is not None:
segs.append(f", limits={limits}")
segs.append(")")
return "".join(segs)

def _parameter_shape_hash(self):
return hash(
Measure._parameter_shape_hash(self) + hash(self.mfix) + hash(self.range)
)
hsh = Measure._parameter_shape_hash(self)
hsh += hash(self.mfix)
hsh += hash(self.limits)
return hash(hsh)
20 changes: 20 additions & 0 deletions firelang/utils/limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations
from typing import List, Tuple, Iterable

__all__ = [
"parse_rect_limits",
]


def parse_rect_limits(
limits: float | Tuple[float, float] | List[Tuple[float, float]],
dim: int,
) -> List[Tuple[float, float]]:
if not isinstance(limits, Iterable):
return [[-limits, limits] for _ in range(dim)]
elif not isinstance(limits[0], Iterable):
assert len(limits) == 2
return [limits for _ in range(dim)]
else:
assert len(limits) == dim
return limits

0 comments on commit 8e373de

Please sign in to comment.