forked from kduxin/firelang
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
+ MlplanarDivFast + function.operators + new way of saving / loading models
- Loading branch information
Showing
23 changed files
with
1,160 additions
and
460 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .base import * | ||
from .functional import * | ||
from .compose import * | ||
from .operators import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,209 +1,180 @@ | ||
from __future__ import division, annotations | ||
from typing import Tuple, Mapping, Any, Iterable, Callable, Union | ||
import operator as op | ||
import numpy as np | ||
from torch.nn import Module, ModuleList | ||
from typing import Mapping, Any, Iterable, Callable | ||
import copy | ||
from collections import OrderedDict | ||
from torch.nn import Module, ModuleList, ModuleDict | ||
import firelang | ||
from firelang.stack import StackingSlicing | ||
from firelang.utils.shape import parse_shape | ||
|
||
__all__ = ["Functional"] | ||
|
||
|
||
_vmap_log = set() | ||
|
||
|
||
class cached_forward: | ||
def __init__(self, callable: Callable, identifier: str): | ||
self.callable = callable | ||
self.identifier = identifier | ||
self.cache = OrderedDict() | ||
|
||
def __call__(self, x, *args, **kwargs): | ||
key = (id(x), *args, *sorted(kwargs.items(), key=lambda x: x[0])) | ||
if key in self.cache: | ||
print(f"Triggered cache at {self.identifier}") | ||
return self.cache[key] | ||
|
||
fx = self.callable(x, *args, **kwargs) | ||
while len(self.cache) >= 1: | ||
self.cache.popitem(last=False) | ||
self.cache[key] = fx | ||
return fx | ||
|
||
def __getstate__(self): | ||
return { | ||
"cache": OrderedDict(), | ||
**{key: val for key, val in self.__dict__.items() if key != "cache"}, | ||
} | ||
|
||
|
||
class Functional(StackingSlicing): | ||
def __init__( | ||
self, | ||
locals_: Mapping[str, Any], | ||
unsliceable_params: Iterable[str] = [], | ||
prev=[], | ||
operator=None, | ||
operator: Callable = None, | ||
is_fleaf: bool = True, | ||
): | ||
if "shape_out" not in locals_: | ||
locals_["shape_out"] = locals_["shape"] | ||
StackingSlicing.__init__( | ||
self, locals_=locals_, unsliceable_params=unsliceable_params | ||
) | ||
prevm = [p for p in prev if isinstance(p, Module)] | ||
self.prevm = ModuleList(prevm) if len(prevm) else prevm | ||
self.prev = prev | ||
self.operator = operator | ||
|
||
def __add__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.shape == func_or_scalar.shape | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, func_or_scalar], | ||
operator=op.add, | ||
self._sanity_check() | ||
self.shape_out = locals_["shape_out"] | ||
# prevm = [p for p in prev if isinstance(p, Module)] | ||
# self.prevm = ModuleList(prevm) if len(prevm) else prevm | ||
# self.prev = prev | ||
self.operator = cached_forward( | ||
self.forward if operator is None else operator, | ||
identifier=f"{self.__class__.__name__} (id={id(self)})", | ||
) | ||
self.is_fleaf = is_fleaf | ||
|
||
def __sub__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.shape == func_or_scalar.shape | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, func_or_scalar], | ||
operator=op.sub, | ||
def _sanity_check(self): | ||
assert "shape_out" in self.init_locals, ( | ||
"A `StackingSlicing` subclass must accept `shape_out` as an " | ||
"initialization argument." | ||
) | ||
|
||
def __mul__(self, other: Union[float, Functional, firelang.Measure]): | ||
def __add__(self, other: Functional | float): | ||
return firelang.function.Add(self, other) | ||
|
||
def __sub__(self, other: Functional | float): | ||
return firelang.function.Sub(self, other) | ||
|
||
def __mul__(self, other: Functional | float | firelang.Measure): | ||
""" | ||
Args: | ||
other (Union[float, Functional, Measure]): | ||
- if is `float` or `Functional`: generate a new Functional. | ||
- if is `Measure`, compute the paired integral. | ||
""" | ||
if isinstance(other, float) or isinstance(other, Functional): | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, other], | ||
operator=op.mul, | ||
) | ||
if isinstance(other, Functional) or isinstance(other, float): | ||
return firelang.function.Mul(self, other) | ||
elif isinstance(other, firelang.Measure): | ||
return other.integral(self) | ||
else: | ||
raise TypeError( | ||
f"`other` must be a float or Functional or Measure object, not {type(other)}." | ||
) | ||
|
||
def __truediv__(self, func_or_scalar): | ||
if isinstance(func_or_scalar, StackingSlicing): | ||
assert self.shape == func_or_scalar.shape | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, func_or_scalar], | ||
operator=op.truediv, | ||
) | ||
def __truediv__(self, other: Functional | float): | ||
return firelang.function.TrueDiv(self, other) | ||
|
||
def __pow__(self, pow: float): | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, pow], | ||
operator=op.pow, | ||
) | ||
return firelang.function.Pow(self, pow) | ||
|
||
def __neg__(self): | ||
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg) | ||
return firelang.function.Neg(self) | ||
|
||
def neg(self): | ||
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg) | ||
return self.__neg__() | ||
|
||
def __abs__(self): | ||
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs) | ||
return firelang.function.Abs(self) | ||
|
||
def abs(self): | ||
return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs) | ||
return self.__abs__() | ||
|
||
def apply_op(self, mapping: Callable, *other_nodes): | ||
return Functional( | ||
locals_={"shape": self.shape}, | ||
prev=[self, *other_nodes], | ||
operator=mapping, | ||
) | ||
def __matmul__(self, other: firelang.Measure): | ||
return other.integral(self, cross=True) | ||
|
||
def forward(self, x, *args, **kwargs): | ||
def __call__(self, x, *args, **kwargs): | ||
operator = self.operator | ||
prev = self.prev | ||
if operator is None: | ||
return NotImplementedError | ||
elif operator is op.add: | ||
return prev[0].forward(x, *args, **kwargs) + prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.sub: | ||
return prev[0].forward(x, *args, **kwargs) - prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.mul: | ||
return prev[0].forward(x, *args, **kwargs) * prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.truediv: | ||
return prev[0].forward(x, *args, **kwargs) / prev[1].forward( | ||
x, *args, **kwargs | ||
) | ||
elif operator is op.pow: | ||
return prev[0].forward(x, *args, **kwargs) ** prev[1] | ||
elif operator is op.neg: | ||
return -prev[0].forward(x, *args, **kwargs) | ||
elif operator is op.abs: | ||
return prev[0].forward(x, *args, **kwargs).abs() | ||
elif isinstance(operator, Callable): | ||
return operator(prev, x, *args, **kwargs) | ||
fx = operator(x, *args, **kwargs) | ||
else: | ||
raise ValueError(f"Unrecognized operator: {operator}") | ||
return fx | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
def is_leaf(self): | ||
return not hasattr(self, "prev") or not len(self.prev) | ||
|
||
def __getitem__(self, idx): | ||
if self.is_leaf(): | ||
newop = StackingSlicing.__getitem__(self, idx) | ||
else: | ||
sliced = [ | ||
node[idx] if isinstance(node, StackingSlicing) else node | ||
for node in self.prev | ||
] | ||
newop = sliced[0].apply_op(self.operator, *sliced[1:]) | ||
return newop | ||
|
||
def view(self, *shape, inplace: bool = False): | ||
shape = parse_shape(shape, num_elements=int(np.prod(self.shape))) | ||
def clear_cache(self): | ||
self.operator.clear_cache() | ||
for child in self.func_children(): | ||
child.clear_cache() | ||
|
||
def vmap(self, num_extra_dimensions: int = 1, inplace: bool = False) -> Functional: | ||
_vmap_log.clear() | ||
if inplace: | ||
if self.is_leaf(): | ||
StackingSlicing.view(self, shape, inplace=True) | ||
else: | ||
for node in self.prev: | ||
if isinstance(node, Functional): | ||
Functional.view(node, shape, inplace=True) | ||
return self | ||
self._vmap_(num_extra_dimensions) | ||
new = self | ||
else: | ||
if self.is_leaf(): | ||
newop = StackingSlicing.view(self, shape) | ||
else: | ||
prev = [ | ||
Functional.view(node, shape) | ||
if isinstance(node, Functional) | ||
else node | ||
for node in self.prev | ||
] | ||
newop = prev[0].apply_op(self.operator, *prev[1:]) | ||
return newop | ||
new = copy.deepcopy(self) | ||
for name, p_from in self.named_parameters(): | ||
p_to = new.get_parameter(name) | ||
p_to.requires_grad_(False) | ||
p_to.copy_(p_from) | ||
new._vmap_(num_extra_dimensions) | ||
_vmap_log.clear() | ||
return new | ||
|
||
def _vmap_(self, num_extra_dimensions: int = 1): | ||
if id(self) not in _vmap_log: | ||
_vmap_log.add(id(self)) | ||
self.shape = (*[1] * num_extra_dimensions, *self.shape) | ||
self.shape_out = (*[1] * num_extra_dimensions, *self.shape_out) | ||
for func in self.func_children(): | ||
func._vmap_(num_extra_dimensions) | ||
|
||
def func_children(self): | ||
return (child for name, child in self.named_func_children()) | ||
|
||
def named_func_children(self): | ||
for name, child in self.named_children(): | ||
if isinstance(child, Functional): | ||
yield name, child | ||
elif isinstance(child, ModuleList): | ||
for i, c in enumerate(child): | ||
yield f"{name}.{i}", c | ||
elif isinstance(child, ModuleDict): | ||
for n, c in child.items(): | ||
yield f"{name}.{n}", c | ||
return StopIteration | ||
|
||
def __repr__(self): | ||
if self.is_leaf(): | ||
return Module.__repr__(self) + f", shape={self.shape}" | ||
if self.is_fleaf: | ||
reprstr = Module.__repr__(self) | ||
else: | ||
segs = [f"{self.__class__.__name__}("] | ||
for i, node in enumerate(self.prev): | ||
for j, line in enumerate(repr(node).split("\n")): | ||
segs = [f"{self.__class__.__name__} {self.shape}->{self.shape_out} ("] | ||
for name, child in self.named_children(): | ||
s = repr(child) | ||
for j, line in enumerate(s.split("\n")): | ||
if j == 0: | ||
segs.append(" " + f"prev[{i}]: " + line) | ||
segs.append(" " + f"{name}: " + line) | ||
else: | ||
segs.append(" " + line) | ||
op_name = ( | ||
self.operator.__name__ | ||
if hasattr(self.operator, "__name__") | ||
else self.operator.__class__.__name__ | ||
) | ||
segs.append(f"), operator={op_name}") | ||
return "\n".join(segs) | ||
|
||
def restack(self, shape: Tuple[int] = None): | ||
if self.is_leaf(): | ||
newop = StackingSlicing.restack(self, shape) | ||
else: | ||
stacked = [ | ||
node.restack(shape) if hasattr(node, "restack") else node | ||
for node in self.prev | ||
] | ||
newop = stacked[0].apply_op(self.operator, *stacked[1:]) | ||
return newop | ||
|
||
stack = restack | ||
|
||
def __matmul__(self, measure: firelang.Measure): | ||
assert isinstance(measure, firelang.Measure) | ||
return measure.integral(self, cross=True) | ||
segs.append(")") | ||
reprstr = "\n".join(segs) | ||
return reprstr |
Oops, something went wrong.