Skip to content

Commit

Permalink
v1.1
Browse files Browse the repository at this point in the history
+ MlplanarDivFast
+ function.operators
+ new way of saving / loading models
  • Loading branch information
kduxin committed Dec 7, 2022
1 parent 25cb72b commit c6c67c5
Show file tree
Hide file tree
Showing 23 changed files with 1,160 additions and 460 deletions.
2 changes: 1 addition & 1 deletion firelang/function/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import *
from .functional import *
from .compose import *
from .operators import *
267 changes: 119 additions & 148 deletions firelang/function/base.py
Original file line number Diff line number Diff line change
@@ -1,209 +1,180 @@
from __future__ import division, annotations
from typing import Tuple, Mapping, Any, Iterable, Callable, Union
import operator as op
import numpy as np
from torch.nn import Module, ModuleList
from typing import Mapping, Any, Iterable, Callable
import copy
from collections import OrderedDict
from torch.nn import Module, ModuleList, ModuleDict
import firelang
from firelang.stack import StackingSlicing
from firelang.utils.shape import parse_shape

__all__ = ["Functional"]


_vmap_log = set()


class cached_forward:
def __init__(self, callable: Callable, identifier: str):
self.callable = callable
self.identifier = identifier
self.cache = OrderedDict()

def __call__(self, x, *args, **kwargs):
key = (id(x), *args, *sorted(kwargs.items(), key=lambda x: x[0]))
if key in self.cache:
print(f"Triggered cache at {self.identifier}")
return self.cache[key]

fx = self.callable(x, *args, **kwargs)
while len(self.cache) >= 1:
self.cache.popitem(last=False)
self.cache[key] = fx
return fx

def __getstate__(self):
return {
"cache": OrderedDict(),
**{key: val for key, val in self.__dict__.items() if key != "cache"},
}


class Functional(StackingSlicing):
def __init__(
self,
locals_: Mapping[str, Any],
unsliceable_params: Iterable[str] = [],
prev=[],
operator=None,
operator: Callable = None,
is_fleaf: bool = True,
):
if "shape_out" not in locals_:
locals_["shape_out"] = locals_["shape"]
StackingSlicing.__init__(
self, locals_=locals_, unsliceable_params=unsliceable_params
)
prevm = [p for p in prev if isinstance(p, Module)]
self.prevm = ModuleList(prevm) if len(prevm) else prevm
self.prev = prev
self.operator = operator

def __add__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.add,
self._sanity_check()
self.shape_out = locals_["shape_out"]
# prevm = [p for p in prev if isinstance(p, Module)]
# self.prevm = ModuleList(prevm) if len(prevm) else prevm
# self.prev = prev
self.operator = cached_forward(
self.forward if operator is None else operator,
identifier=f"{self.__class__.__name__} (id={id(self)})",
)
self.is_fleaf = is_fleaf

def __sub__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.sub,
def _sanity_check(self):
assert "shape_out" in self.init_locals, (
"A `StackingSlicing` subclass must accept `shape_out` as an "
"initialization argument."
)

def __mul__(self, other: Union[float, Functional, firelang.Measure]):
def __add__(self, other: Functional | float):
return firelang.function.Add(self, other)

def __sub__(self, other: Functional | float):
return firelang.function.Sub(self, other)

def __mul__(self, other: Functional | float | firelang.Measure):
"""
Args:
other (Union[float, Functional, Measure]):
- if is `float` or `Functional`: generate a new Functional.
- if is `Measure`, compute the paired integral.
"""
if isinstance(other, float) or isinstance(other, Functional):
return Functional(
locals_={"shape": self.shape},
prev=[self, other],
operator=op.mul,
)
if isinstance(other, Functional) or isinstance(other, float):
return firelang.function.Mul(self, other)
elif isinstance(other, firelang.Measure):
return other.integral(self)
else:
raise TypeError(
f"`other` must be a float or Functional or Measure object, not {type(other)}."
)

def __truediv__(self, func_or_scalar):
if isinstance(func_or_scalar, StackingSlicing):
assert self.shape == func_or_scalar.shape
return Functional(
locals_={"shape": self.shape},
prev=[self, func_or_scalar],
operator=op.truediv,
)
def __truediv__(self, other: Functional | float):
return firelang.function.TrueDiv(self, other)

def __pow__(self, pow: float):
return Functional(
locals_={"shape": self.shape},
prev=[self, pow],
operator=op.pow,
)
return firelang.function.Pow(self, pow)

def __neg__(self):
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg)
return firelang.function.Neg(self)

def neg(self):
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg)
return self.__neg__()

def __abs__(self):
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs)
return firelang.function.Abs(self)

def abs(self):
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs)
return self.__abs__()

def apply_op(self, mapping: Callable, *other_nodes):
return Functional(
locals_={"shape": self.shape},
prev=[self, *other_nodes],
operator=mapping,
)
def __matmul__(self, other: firelang.Measure):
return other.integral(self, cross=True)

def forward(self, x, *args, **kwargs):
def __call__(self, x, *args, **kwargs):
operator = self.operator
prev = self.prev
if operator is None:
return NotImplementedError
elif operator is op.add:
return prev[0].forward(x, *args, **kwargs) + prev[1].forward(
x, *args, **kwargs
)
elif operator is op.sub:
return prev[0].forward(x, *args, **kwargs) - prev[1].forward(
x, *args, **kwargs
)
elif operator is op.mul:
return prev[0].forward(x, *args, **kwargs) * prev[1].forward(
x, *args, **kwargs
)
elif operator is op.truediv:
return prev[0].forward(x, *args, **kwargs) / prev[1].forward(
x, *args, **kwargs
)
elif operator is op.pow:
return prev[0].forward(x, *args, **kwargs) ** prev[1]
elif operator is op.neg:
return -prev[0].forward(x, *args, **kwargs)
elif operator is op.abs:
return prev[0].forward(x, *args, **kwargs).abs()
elif isinstance(operator, Callable):
return operator(prev, x, *args, **kwargs)
fx = operator(x, *args, **kwargs)
else:
raise ValueError(f"Unrecognized operator: {operator}")
return fx

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def is_leaf(self):
return not hasattr(self, "prev") or not len(self.prev)

def __getitem__(self, idx):
if self.is_leaf():
newop = StackingSlicing.__getitem__(self, idx)
else:
sliced = [
node[idx] if isinstance(node, StackingSlicing) else node
for node in self.prev
]
newop = sliced[0].apply_op(self.operator, *sliced[1:])
return newop

def view(self, *shape, inplace: bool = False):
shape = parse_shape(shape, num_elements=int(np.prod(self.shape)))
def clear_cache(self):
self.operator.clear_cache()
for child in self.func_children():
child.clear_cache()

def vmap(self, num_extra_dimensions: int = 1, inplace: bool = False) -> Functional:
_vmap_log.clear()
if inplace:
if self.is_leaf():
StackingSlicing.view(self, shape, inplace=True)
else:
for node in self.prev:
if isinstance(node, Functional):
Functional.view(node, shape, inplace=True)
return self
self._vmap_(num_extra_dimensions)
new = self
else:
if self.is_leaf():
newop = StackingSlicing.view(self, shape)
else:
prev = [
Functional.view(node, shape)
if isinstance(node, Functional)
else node
for node in self.prev
]
newop = prev[0].apply_op(self.operator, *prev[1:])
return newop
new = copy.deepcopy(self)
for name, p_from in self.named_parameters():
p_to = new.get_parameter(name)
p_to.requires_grad_(False)
p_to.copy_(p_from)
new._vmap_(num_extra_dimensions)
_vmap_log.clear()
return new

def _vmap_(self, num_extra_dimensions: int = 1):
if id(self) not in _vmap_log:
_vmap_log.add(id(self))
self.shape = (*[1] * num_extra_dimensions, *self.shape)
self.shape_out = (*[1] * num_extra_dimensions, *self.shape_out)
for func in self.func_children():
func._vmap_(num_extra_dimensions)

def func_children(self):
return (child for name, child in self.named_func_children())

def named_func_children(self):
for name, child in self.named_children():
if isinstance(child, Functional):
yield name, child
elif isinstance(child, ModuleList):
for i, c in enumerate(child):
yield f"{name}.{i}", c
elif isinstance(child, ModuleDict):
for n, c in child.items():
yield f"{name}.{n}", c
return StopIteration

def __repr__(self):
if self.is_leaf():
return Module.__repr__(self) + f", shape={self.shape}"
if self.is_fleaf:
reprstr = Module.__repr__(self)
else:
segs = [f"{self.__class__.__name__}("]
for i, node in enumerate(self.prev):
for j, line in enumerate(repr(node).split("\n")):
segs = [f"{self.__class__.__name__} {self.shape}->{self.shape_out} ("]
for name, child in self.named_children():
s = repr(child)
for j, line in enumerate(s.split("\n")):
if j == 0:
segs.append(" " + f"prev[{i}]: " + line)
segs.append(" " + f"{name}: " + line)
else:
segs.append(" " + line)
op_name = (
self.operator.__name__
if hasattr(self.operator, "__name__")
else self.operator.__class__.__name__
)
segs.append(f"), operator={op_name}")
return "\n".join(segs)

def restack(self, shape: Tuple[int] = None):
if self.is_leaf():
newop = StackingSlicing.restack(self, shape)
else:
stacked = [
node.restack(shape) if hasattr(node, "restack") else node
for node in self.prev
]
newop = stacked[0].apply_op(self.operator, *stacked[1:])
return newop

stack = restack

def __matmul__(self, measure: firelang.Measure):
assert isinstance(measure, firelang.Measure)
return measure.integral(self, cross=True)
segs.append(")")
reprstr = "\n".join(segs)
return reprstr
Loading

0 comments on commit c6c67c5

Please sign in to comment.