diff --git a/firelang/function/__init__.py b/firelang/function/__init__.py index b10b3ca..d21342b 100644 --- a/firelang/function/__init__.py +++ b/firelang/function/__init__.py @@ -1,3 +1,3 @@ from .base import * from .functional import * -from .compose import * \ No newline at end of file +from .operators import * \ No newline at end of file diff --git a/firelang/function/base.py b/firelang/function/base.py index 9128738..ea4a150 100644 --- a/firelang/function/base.py +++ b/firelang/function/base.py @@ -1,62 +1,87 @@ from __future__ import division, annotations -from typing import Tuple, Mapping, Any, Iterable, Callable, Union -import operator as op -import numpy as np -from torch.nn import Module, ModuleList +from typing import Mapping, Any, Iterable, Callable +import copy +from collections import OrderedDict +from torch.nn import Module, ModuleList, ModuleDict import firelang from firelang.stack import StackingSlicing -from firelang.utils.shape import parse_shape __all__ = ["Functional"] +_vmap_log = set() + + +class cached_forward: + def __init__(self, callable: Callable, identifier: str): + self.callable = callable + self.identifier = identifier + self.cache = OrderedDict() + + def __call__(self, x, *args, **kwargs): + key = (id(x), *args, *sorted(kwargs.items(), key=lambda x: x[0])) + if key in self.cache: + print(f"Triggered cache at {self.identifier}") + return self.cache[key] + + fx = self.callable(x, *args, **kwargs) + while len(self.cache) >= 1: + self.cache.popitem(last=False) + self.cache[key] = fx + return fx + + def __getstate__(self): + return { + "cache": OrderedDict(), + **{key: val for key, val in self.__dict__.items() if key != "cache"}, + } + + class Functional(StackingSlicing): def __init__( self, locals_: Mapping[str, Any], unsliceable_params: Iterable[str] = [], - prev=[], - operator=None, + operator: Callable = None, + is_fleaf: bool = True, ): + if "shape_out" not in locals_: + locals_["shape_out"] = locals_["shape"] StackingSlicing.__init__( self, locals_=locals_, unsliceable_params=unsliceable_params ) - prevm = [p for p in prev if isinstance(p, Module)] - self.prevm = ModuleList(prevm) if len(prevm) else prevm - self.prev = prev - self.operator = operator - - def __add__(self, func_or_scalar): - if isinstance(func_or_scalar, StackingSlicing): - assert self.shape == func_or_scalar.shape - return Functional( - locals_={"shape": self.shape}, - prev=[self, func_or_scalar], - operator=op.add, + self._sanity_check() + self.shape_out = locals_["shape_out"] + # prevm = [p for p in prev if isinstance(p, Module)] + # self.prevm = ModuleList(prevm) if len(prevm) else prevm + # self.prev = prev + self.operator = cached_forward( + self.forward if operator is None else operator, + identifier=f"{self.__class__.__name__} (id={id(self)})", ) + self.is_fleaf = is_fleaf - def __sub__(self, func_or_scalar): - if isinstance(func_or_scalar, StackingSlicing): - assert self.shape == func_or_scalar.shape - return Functional( - locals_={"shape": self.shape}, - prev=[self, func_or_scalar], - operator=op.sub, + def _sanity_check(self): + assert "shape_out" in self.init_locals, ( + "A `StackingSlicing` subclass must accept `shape_out` as an " + "initialization argument." ) - def __mul__(self, other: Union[float, Functional, firelang.Measure]): + def __add__(self, other: Functional | float): + return firelang.function.Add(self, other) + + def __sub__(self, other: Functional | float): + return firelang.function.Sub(self, other) + + def __mul__(self, other: Functional | float | firelang.Measure): """ Args: other (Union[float, Functional, Measure]): - if is `float` or `Functional`: generate a new Functional. - if is `Measure`, compute the paired integral. """ - if isinstance(other, float) or isinstance(other, Functional): - return Functional( - locals_={"shape": self.shape}, - prev=[self, other], - operator=op.mul, - ) + if isinstance(other, Functional) or isinstance(other, float): + return firelang.function.Mul(self, other) elif isinstance(other, firelang.Measure): return other.integral(self) else: @@ -64,146 +89,92 @@ def __mul__(self, other: Union[float, Functional, firelang.Measure]): f"`other` must be a float or Functional or Measure object, not {type(other)}." ) - def __truediv__(self, func_or_scalar): - if isinstance(func_or_scalar, StackingSlicing): - assert self.shape == func_or_scalar.shape - return Functional( - locals_={"shape": self.shape}, - prev=[self, func_or_scalar], - operator=op.truediv, - ) + def __truediv__(self, other: Functional | float): + return firelang.function.TrueDiv(self, other) def __pow__(self, pow: float): - return Functional( - locals_={"shape": self.shape}, - prev=[self, pow], - operator=op.pow, - ) + return firelang.function.Pow(self, pow) def __neg__(self): - return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg) + return firelang.function.Neg(self) def neg(self): - return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.neg) + return self.__neg__() def __abs__(self): - return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs) + return firelang.function.Abs(self) def abs(self): - return Functional(locals_={"shape": self.shape}, prev=[self], operator=op.abs) + return self.__abs__() - def apply_op(self, mapping: Callable, *other_nodes): - return Functional( - locals_={"shape": self.shape}, - prev=[self, *other_nodes], - operator=mapping, - ) + def __matmul__(self, other: firelang.Measure): + return other.integral(self, cross=True) - def forward(self, x, *args, **kwargs): + def __call__(self, x, *args, **kwargs): operator = self.operator - prev = self.prev if operator is None: return NotImplementedError - elif operator is op.add: - return prev[0].forward(x, *args, **kwargs) + prev[1].forward( - x, *args, **kwargs - ) - elif operator is op.sub: - return prev[0].forward(x, *args, **kwargs) - prev[1].forward( - x, *args, **kwargs - ) - elif operator is op.mul: - return prev[0].forward(x, *args, **kwargs) * prev[1].forward( - x, *args, **kwargs - ) - elif operator is op.truediv: - return prev[0].forward(x, *args, **kwargs) / prev[1].forward( - x, *args, **kwargs - ) - elif operator is op.pow: - return prev[0].forward(x, *args, **kwargs) ** prev[1] - elif operator is op.neg: - return -prev[0].forward(x, *args, **kwargs) - elif operator is op.abs: - return prev[0].forward(x, *args, **kwargs).abs() elif isinstance(operator, Callable): - return operator(prev, x, *args, **kwargs) + fx = operator(x, *args, **kwargs) else: raise ValueError(f"Unrecognized operator: {operator}") + return fx - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def is_leaf(self): - return not hasattr(self, "prev") or not len(self.prev) - - def __getitem__(self, idx): - if self.is_leaf(): - newop = StackingSlicing.__getitem__(self, idx) - else: - sliced = [ - node[idx] if isinstance(node, StackingSlicing) else node - for node in self.prev - ] - newop = sliced[0].apply_op(self.operator, *sliced[1:]) - return newop - - def view(self, *shape, inplace: bool = False): - shape = parse_shape(shape, num_elements=int(np.prod(self.shape))) + def clear_cache(self): + self.operator.clear_cache() + for child in self.func_children(): + child.clear_cache() + def vmap(self, num_extra_dimensions: int = 1, inplace: bool = False) -> Functional: + _vmap_log.clear() if inplace: - if self.is_leaf(): - StackingSlicing.view(self, shape, inplace=True) - else: - for node in self.prev: - if isinstance(node, Functional): - Functional.view(node, shape, inplace=True) - return self + self._vmap_(num_extra_dimensions) + new = self else: - if self.is_leaf(): - newop = StackingSlicing.view(self, shape) - else: - prev = [ - Functional.view(node, shape) - if isinstance(node, Functional) - else node - for node in self.prev - ] - newop = prev[0].apply_op(self.operator, *prev[1:]) - return newop + new = copy.deepcopy(self) + for name, p_from in self.named_parameters(): + p_to = new.get_parameter(name) + p_to.requires_grad_(False) + p_to.copy_(p_from) + new._vmap_(num_extra_dimensions) + _vmap_log.clear() + return new + + def _vmap_(self, num_extra_dimensions: int = 1): + if id(self) not in _vmap_log: + _vmap_log.add(id(self)) + self.shape = (*[1] * num_extra_dimensions, *self.shape) + self.shape_out = (*[1] * num_extra_dimensions, *self.shape_out) + for func in self.func_children(): + func._vmap_(num_extra_dimensions) + + def func_children(self): + return (child for name, child in self.named_func_children()) + + def named_func_children(self): + for name, child in self.named_children(): + if isinstance(child, Functional): + yield name, child + elif isinstance(child, ModuleList): + for i, c in enumerate(child): + yield f"{name}.{i}", c + elif isinstance(child, ModuleDict): + for n, c in child.items(): + yield f"{name}.{n}", c + return StopIteration def __repr__(self): - if self.is_leaf(): - return Module.__repr__(self) + f", shape={self.shape}" + if self.is_fleaf: + reprstr = Module.__repr__(self) else: - segs = [f"{self.__class__.__name__}("] - for i, node in enumerate(self.prev): - for j, line in enumerate(repr(node).split("\n")): + segs = [f"{self.__class__.__name__} {self.shape}->{self.shape_out} ("] + for name, child in self.named_children(): + s = repr(child) + for j, line in enumerate(s.split("\n")): if j == 0: - segs.append(" " + f"prev[{i}]: " + line) + segs.append(" " + f"{name}: " + line) else: segs.append(" " + line) - op_name = ( - self.operator.__name__ - if hasattr(self.operator, "__name__") - else self.operator.__class__.__name__ - ) - segs.append(f"), operator={op_name}") - return "\n".join(segs) - - def restack(self, shape: Tuple[int] = None): - if self.is_leaf(): - newop = StackingSlicing.restack(self, shape) - else: - stacked = [ - node.restack(shape) if hasattr(node, "restack") else node - for node in self.prev - ] - newop = stacked[0].apply_op(self.operator, *stacked[1:]) - return newop - - stack = restack - - def __matmul__(self, measure: firelang.Measure): - assert isinstance(measure, firelang.Measure) - return measure.integral(self, cross=True) + segs.append(")") + reprstr = "\n".join(segs) + return reprstr diff --git a/firelang/function/components/multilayer.py b/firelang/function/components/multilayer.py index da56c13..38dd676 100644 --- a/firelang/function/components/multilayer.py +++ b/firelang/function/components/multilayer.py @@ -17,6 +17,7 @@ "_MLPlanarInit", "_Forward", "_Divergence", + "_DivergenceViaQuadform", "_Jacdet", "_Jaclogdet", ] @@ -36,7 +37,11 @@ def __init__( norm: Literal[None, "batch", "layer"] = None, shape: int | Tuple[int] = (1,), ): - Functional.__init__(self, locals()) + Functional.__init__( + self, + locals_={**locals(), "shape_out": shape}, + is_fleaf=True, + ) dims = [input_dim] + hidden_dims layer_kwargs = {"activation": activation, "norm": norm} last_layer_kwargs = {"activation": None, "norm": None} @@ -55,14 +60,17 @@ def __init__( self.input_dim = input_dim self.hidden_dims = hidden_dims self.dims = dims - self.shape = shape self.activation = activation self.norm = norm class _MLPlanarInit(Functional): def __init__(self, dim, nlayers, shape=(1,), **planar_kwargs): - Functional.__init__(self, locals()) + Functional.__init__( + self, + locals_={**locals(), "shape_out": shape}, + is_fleaf=True, + ) self.layers = ModuleList( [ PseudoPlanarTransform(dim=dim, shape=shape, **planar_kwargs) @@ -143,6 +151,43 @@ def forward(self, x: Tensor) -> Tensor: return div +class _DivergenceViaQuadform(_MultiLayerBase): + """ + Compute divergence indirectly via quadratic forms: + tr(A) = u1^T @ A @ u1 + ... + uD^T @ A @ uD, + where ui is the i-th column of an D-by-D identity matrix. + This way of computation is especially fast for backward pass. + Each layer should implement .jacob_mul_vecs method. + """ + + def forward(self, x: Tensor) -> Tensor: + dim = x.shape[-1] + + """ Estimate tr(A) by E[z^T @ A @ z], where z is selected as the columns of an identity matrix""" + """ Initialize z's as the columns of an identity matrix `vecs` """ + vecs = torch.eye(dim, dtype=x.dtype, device=x.device) # (nvecs, dim) + # vecs = vecs[None, None] # (1, 1, nvecs, dim) + + jacob_mul_vecs = vecs + for i, layer in enumerate(self.layers): + """Compute A @ z, and then A @ (A @ z), ...""" + if i < len(self.layers) - 1: + jacob_mul_vecs, x = layer.jacob_mul_vecs( + x, jacob_mul_vecs, return_fx=True + ) + else: # last iter + jacob_mul_vecs = layer.jacob_mul_vecs( + x, jacob_mul_vecs, return_fx=False + ) + + """ Compute z^T @ (A @ z) """ + dots = torch.einsum("...i,...i->...", vecs, jacob_mul_vecs) + + # compute trace by summing up the dots of all vecs + div = dots.sum(dim=-1) + + dim = x.shape[-1] + div = div / dim return div @@ -169,4 +214,4 @@ class _Jaclogdet(_MultiLayerBase): def forward(self, x: torch.Tensor, eps=LOGABS_EPS): cumjacdet = _Jacdet.forward(self, x) cumjaclogdet = (eps + cumjacdet.abs()).log() - return cumjaclogdet \ No newline at end of file + return cumjaclogdet diff --git a/firelang/function/components/planar.py b/firelang/function/components/planar.py index 7312d36..6e6b7fd 100644 --- a/firelang/function/components/planar.py +++ b/firelang/function/components/planar.py @@ -18,10 +18,9 @@ def __init__(self, dim, activation="tanh", shape: int | Tuple[int] = (1,)): Functional.__init__(self, locals()) scale = 0.1 / dim**0.5 - size = int(np.prod(shape)) - self.v = Parameter(torch.randn(size, dim).normal_(0, scale)) - self.b = Parameter(torch.randn(size, 1).normal_(0, scale)) - self.u = Parameter(torch.randn(size, dim).normal_(0, scale)) + self.v = Parameter(torch.empty(*shape, dim).normal_(0, scale)) + self.b = Parameter(torch.empty(*shape).normal_(0, scale)) + self.u = Parameter(torch.empty(*shape, dim).normal_(0, scale)) if activation is None: self.act, self.actderiv = identity, identity_deriv @@ -52,9 +51,7 @@ def forward(self, x: Tensor) -> Tensor: (*xshape, dim) = x.shape check_shape_consistency(fshape, xshape) - v = self.v.view(*fshape, self.dim) - b = self.b.view(*fshape) - u = self.u.view(*fshape, self.dim) + u, v, b = self.u, self.v, self.b a = torch.einsum("...i,...i->...", x, v) + b # (...shape,) fx = x + self.act(a)[..., None] * u # (...shape, dim) @@ -75,14 +72,12 @@ def jacob(self, x: Tensor, return_fx: bool = False) -> Tensor: (*xshape, dim) = x.shape check_shape_consistency(fshape, xshape) - v = self.v.view(*fshape, self.dim) - b = self.b.view(*fshape) - u = self.u.view(*fshape, self.dim) - I = torch.eye(dim, device=x.device, dtype=x.dtype).reshape( *[1 for _ in xshape], dim, dim ) + u, v, b = self.u, self.v, self.b + a = torch.einsum("...i,...i->...", x, v) + b ad = self.actderiv(a) jacob = ( @@ -109,9 +104,7 @@ def jacdet(self, x: Tensor, return_fx: bool = False) -> Tensor: (*xshape, dim) = x.shape check_shape_consistency(fshape, xshape) - v = self.v.view(*fshape, self.dim) - b = self.b.view(*fshape) - u = self.u.view(*fshape, self.dim) + u, v, b = self.u, self.v, self.b u_dot_v = torch.einsum("...i,...i->...", u, v) # (...fshape,) a = torch.einsum("...i,...i->...", x, v) + b # (...fshape,) @@ -123,3 +116,38 @@ def jacdet(self, x: Tensor, return_fx: bool = False) -> Tensor: return jacdet, fx else: return jacdet + + def jacob_mul_vecs( + self, x: Tensor, vecs: Tensor, return_fx: bool = False + ) -> Tensor: + """Compute (df(x) / dx) @ vec + + Args: + x (Tensor): (...shape, dim) + vecs (Tensor): (...shape, nvecs, dim) + return_fx (bool, optional): whether returns f(x) or not. Defaults to False. + + Returns: + Tensor: (...shape, nvecs, dim) + """ + fshape = self.shape + (*xshape, dim) = x.shape + check_shape_consistency(fshape, xshape) + + u, v, b = self.u, self.v, self.b + + a = torch.einsum("...i,...i->...", x, v) + b # (...fshape,) + ad = self.actderiv(a) # (...fshape,) + + v_dot_vecs = torch.einsum("...i,...ni->...n", v, vecs) # (...fshape, nvecs) + scalar = ad[..., None] * v_dot_vecs # (...fshape, nvecs) + + jacob_mul_vecs = ( + vecs + scalar[..., None] * u[..., None, :] + ) # (...fshape, nvecs, dim) + + if return_fx: + fx = x + self.act(a)[..., None] * u # (...shape, dim) + return jacob_mul_vecs, fx + else: + return jacob_mul_vecs diff --git a/firelang/function/compose.py b/firelang/function/compose.py deleted file mode 100644 index d09db4f..0000000 --- a/firelang/function/compose.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Iterable -from .base import Functional - -__all__ = ["Sequential", "sequential"] - - -class Sequential(Functional): - def __init__(self, funcs, shape=(1,)): - Functional.__init__( - self, - locals(), - prev=[func.restack(shape) for func in funcs], - operator=sequential, - ) - - -def sequential(other_funcs: Iterable[Functional], x, *args, **kwargs): - for f in other_funcs: - x = f(x, *args, **kwargs) - return x diff --git a/firelang/function/functional.py b/firelang/function/functional.py index 4a620a2..63fff4f 100644 --- a/firelang/function/functional.py +++ b/firelang/function/functional.py @@ -8,6 +8,7 @@ from .components import ( _Forward, _Divergence, + _DivergenceViaQuadform, _Jacdet, _Jaclogdet, ) @@ -16,6 +17,7 @@ "MLP", "MLPDiv", "MLPlanarDiv", + "MLPlanarDivFast", "MLPlanarJacdet", "MLPlanarJaclogdet", ] @@ -39,9 +41,13 @@ class MLPlanarDiv(_Divergence, _MLPlanarInit): pass +class MLPlanarDivFast(_DivergenceViaQuadform, _MLPlanarInit): + pass + + class MLPlanarJacdet(_Jacdet, _MLPlanarInit): pass class MLPlanarJaclogdet(_Jaclogdet, _MLPlanarInit): - pass \ No newline at end of file + pass diff --git a/firelang/function/operators.py b/firelang/function/operators.py new file mode 100644 index 0000000..ebb1bfd --- /dev/null +++ b/firelang/function/operators.py @@ -0,0 +1,210 @@ +from __future__ import annotations +from typing import List +from typing_extensions import Literal +import torch +from torch import Tensor +from torch.nn import ModuleList +from firelang.utils.shape import check_shape_consistency +from firelang.utils.index import normalize_index +from .base import Functional + + +class Add(Functional): + def __init__(self, f1: Functional, f2: Functional | float): + if isinstance(f2, Functional): + assert check_shape_consistency(f1.shape, f2.shape) + assert check_shape_consistency(f1.shape_out, f2.shape_out) + assert f1.dim == f2.dim + self.dim = f1.dim + Functional.__init__( + self, + locals_={"shape": f1.shape, "shape_out": f1.shape_out}, + is_fleaf=False, + ) + self.f1 = f1 + self.f2 = f2 + + def forward(self, x, *args, **kwargs): + fx1 = self.f1.forward(x, *args, **kwargs) + fx2 = ( + self.f2.forward(x, *args, **kwargs) + if isinstance(self.f2, Functional) + else self.f2 + ) + return fx1 + fx2 + + def restack(self, shape): + f1stack = self.f1.restack(shape) + f2stack = self.f2.restack(shape) if isinstance(self.f2, Functional) else self.f2 + return Add(f1stack, f2stack) + + +class Sub(Functional): + def __init__(self, f1: Functional, f2: Functional | float): + if isinstance(f2, Functional): + assert check_shape_consistency(f1.shape, f2.shape) + assert check_shape_consistency(f1.shape_out, f2.shape_out) + assert f1.dim == f2.dim + self.dim = f1.dim + Functional.__init__( + self, + locals_={"shape": f1.shape, "shape_out": f1.shape_out}, + is_fleaf=False, + ) + self.f1 = f1 + self.f2 = f2 + + def forward(self, x, *args, **kwargs): + fx1 = self.f1.forward(x, *args, **kwargs) + fx2 = ( + self.f2.forward(x, *args, **kwargs) + if isinstance(self.f2, Functional) + else self.f2 + ) + return fx1 - fx2 + + def restack(self, shape): + f1stack = self.f1.restack(shape) + f2stack = self.f2.restack(shape) if isinstance(self.f2, Functional) else self.f2 + return Sub(f1stack, f2stack) + + +class Mul(Functional): + def __init__(self, f1: Functional, f2: Functional | float): + if isinstance(f2, Functional): + assert check_shape_consistency(f1.shape, f2.shape) + assert check_shape_consistency(f1.shape_out, f2.shape_out) + assert f1.dim == f2.dim + self.dim = f1.dim + Functional.__init__( + self, + locals_={"shape": f1.shape, "shape_out": f1.shape_out}, + is_fleaf=False, + ) + self.f1 = f1 + self.f2 = f2 + + def forward(self, x, *args, **kwargs): + fx1 = self.f1.forward(x, *args, **kwargs) + fx2 = ( + self.f2.forward(x, *args, **kwargs) + if isinstance(self.f2, Functional) + else self.f2 + ) + return fx1 * fx2 + + def restack(self, shape): + f1stack = self.f1.restack(shape) + f2stack = self.f2.restack(shape) if isinstance(self.f2, Functional) else self.f2 + return Mul(f1stack, f2stack) + + +class TrueDiv(Functional): + def __init__(self, f1: Functional, f2: Functional | float): + if isinstance(f2, Functional): + assert check_shape_consistency(f1.shape, f2.shape) + assert check_shape_consistency(f1.shape_out, f2.shape_out) + assert f1.dim == f2.dim + self.dim = f1.dim + Functional.__init__( + self, + locals_={"shape": f1.shape, "shape_out": f1.shape_out}, + is_fleaf=False, + ) + self.f1 = f1 + self.f2 = f2 + + def forward(self, x, *args, **kwargs): + fx1 = self.f1.forward(x, *args, **kwargs) + fx2 = ( + self.f2.forward(x, *args, **kwargs) + if isinstance(self.f2, Functional) + else self.f2 + ) + return fx1 / fx2 + + def restack(self, shape): + f1stack = self.f1.restack(shape) + f2stack = self.f2.restack(shape) if isinstance(self.f2, Functional) else self.f2 + return TrueDiv(f1stack, f2stack) + + +class Pow(Functional): + def __init__(self, f: Functional, pow: float): + Functional.__init__( + self, + locals_={"shape": f.shape, "shape_out": f.shape_out}, + is_fleaf=False, + ) + self.f = f + self.pow = pow + self.dim = f.dim + + def forward(self, x, *args, **kwargs): + fx = self.f.forward(x, *args, **kwargs) + return fx**self.pow + + def restack(self, shape): + fstack = self.f.restack(shape) + return Pow(fstack, self.pow) + + +class Neg(Functional): + def __init__(self, f: Functional): + Functional.__init__( + self, + locals_={"shape": f.shape, "shape_out": f.shape_out}, + is_fleaf=False, + ) + self.f = f + self.dim = f.dim + + def forward(self, x, *args, **kwargs): + return -self.f.forward(x, *args, **kwargs) + + def restack(self, shape): + fstack = self.f.restack(shape) + return Neg(fstack) + + +class Abs(Functional): + def __init__(self, f: Functional): + Functional.__init__( + self, + locals_={"shape": f.shape, "shape_out": f.shape_out}, + is_fleaf=False, + ) + self.f = f + self.dim = f.dim + + def forward(self, x, *args, **kwargs): + return self.f.forward(x, *args, **kwargs).abs() + + def restack(self, shape): + fstack = self.f.restack(shape) + return Abs(fstack) + + +class Sequential(Functional): + def __init__(self, fs: List[Functional]): + for f1, f2 in zip(fs[:-1], fs[1:]): + assert f1.shape_out == f2.shape + assert f1.dim == f2.dim + f2.is_fleaf = False + self.dim = fs[0].dim + Functional.__init__( + self, + locals_={"shape": fs[0].shape, "shape_out": fs[-1].shape_out}, + is_fleaf=False, + ) + self.fs = ModuleList(fs) + + def forward(self, x, *args, **kwargs): + for f in self.fs: + x = f.forward(x, *args, **kwargs) + return x + + def restack(self, shape): + return Sequential([f.restack(shape) for f in self.fs]) + + diff --git a/firelang/measure/dirac.py b/firelang/measure/dirac.py index bf5f3b9..1cedbbb 100644 --- a/firelang/measure/dirac.py +++ b/firelang/measure/dirac.py @@ -1,11 +1,12 @@ from __future__ import annotations from typing import List, Tuple -import numpy as np import torch from torch import Tensor from torch.nn import Parameter from .base import Measure from firelang.utils.limits import parse_rect_limits +from firelang.utils.index import normalize_index +from firelang.utils.shape import check_shape_consistency from firelang.function import Functional __all__ = [ @@ -21,67 +22,78 @@ def __init__( self, dim: int, k: int, - limits: float | Tuple[float, float] | List[Tuple[float, float]] = None, + limits: float | Tuple[float, float] | List[Tuple[float, float]] | Tensor = None, mfix: bool = False, + signed: bool = False, shape: Tuple[int] = (1,), ): Measure.__init__(self, locals()) - size = int(np.prod(shape)) if limits is None: - self._x = Parameter(torch.randn(size, k, dim, dtype=torch.float32)) + self._x = Parameter(torch.randn(*shape, k, dim, dtype=torch.float32)) else: - limits = torch.tensor( - parse_rect_limits(limits, dim), dtype=torch.float32 - ) # (dim, 2) + if isinstance(limits, Tensor): + assert ( + limits.ndim == 2 and limits.shape[0] == dim and limits.shape[1] == 2 + ), f"Invalid limits: {limits}" + else: + limits = torch.tensor( + parse_rect_limits(limits, dim), dtype=torch.float32 + ) # (dim, 2) ranges = (limits[:, 1] - limits[:, 0])[None, None] # (1, 1, dim) starts = limits[:, 0][None, None] # (1, 1, dim) self._x = Parameter( - torch.rand(size, k, dim, dtype=torch.float32) * ranges + starts + torch.rand(*shape, k, dim, dtype=torch.float32) * ranges + starts ) - self._m = 1.0 if mfix else Parameter(torch.ones(size, k, dtype=torch.float32)) + self._m = 1.0 if mfix else Parameter(torch.ones(*shape, k, dtype=torch.float32)) self.dim = dim self.k = k self.limits = limits self.mfix = mfix self.shape = shape + self.signed = signed def integral( self, func: Functional, cross: bool = False, batch_size: int = 1000000, - sum: bool = True, ) -> Tensor: + + m, x = self.m, self.x if not cross: - m = self.m.view(*self.shape, self.k) # (...shape, k) - x = self.x.view(*self.shape, self.k, self.dim) # (...shape, k, dim) - func = func.view(*func.shape, 1) - fx = func(x) * m + func = func.vmap() # (1, ...shape) + x = x.permute(-2, *range(x.ndim - 2), -1) # (k, ...shape, dim) + if isinstance(m, Tensor): + m = m.permute(-1, *range(m.ndim - 1)) # (k, ...shape) + fx = func(x) * m # (k, ...shape) + fx = fx.sum(dim=0) else: - assert ( - self.shape[:-1] == func.shape[:-1] - ), f"Shape inconsistent: {func.shape[:-1]} ({func.shape}) != {self.shape[:-1]} ({self.shape})." - - measure_size = self.shape[-1] - func_size = func.shape[-1] + assert check_shape_consistency(self.shape[:-1], func.shape[:-1]) + batch_shape = self.shape[:-1] + msize = self.shape[-1] + fsize = func.shape[-1] - m = self.m.view(*self.shape[:-1], 1, measure_size, self.k) - x = self.x.view(*self.shape[:-1], 1, measure_size, self.k, self.dim) - func = func.view(*func.shape[:-1], func_size, 1, 1) + func = func.vmap(2) # (1, 1, ...batch_shape, fsize) + x = x.permute(-2, -3, *range(x.ndim - 3), -1).unsqueeze(-2) + # x: (k, msize, ...batch_shape, 1, dim) + if isinstance(m, Tensor): + m = m.permute(-1, -2, *range(m.ndim - 2)).unsqueeze(-1) + # m: (k, msize, ...batch_shape, 1) - size = func_size * self.k + size = fsize * self.k nrow_per_batch = (batch_size + size - 1) // size fx = [] - for i in range(0, measure_size, nrow_per_batch): - _x = x[..., i : i + nrow_per_batch, :, :] - _m = m[..., i : i + nrow_per_batch, :] - _fx = func(_x) * _m # (...shape[:-1], _nrow_per_batch, measure_size, k) + for i in range(0, msize, nrow_per_batch): + _x = x if x.shape[1] == 1 else x[:, i : i + nrow_per_batch, ...] + _m = m if m.shape[1] == 1 else m[:, i : i + nrow_per_batch, ...] + _fx = func(_x) * _m # (k, nrow_per_batch, ...batch_shape, fsize) fx.append(_fx) - fx = torch.cat(fx, dim=-2) - - if sum: - fx = fx.sum(-1) # (...shape[:-1], func_size, measure_size, k) + fx = torch.cat(fx, dim=1) # (k, msize, ...batch_shape, fsize) + fx = fx.sum(dim=0) # (msize, ...batch_shape, fsize) + fx = fx.permute( + *range(1, fx.ndim - 1), -1, 0 + ) # (...batch_shape, fsize, msize) return fx def get_x(self): @@ -102,7 +114,10 @@ def x(self): def get_m(self): if isinstance(self._m, Tensor): - return self._m.abs() + if self.signed: + return self._m + else: + return self._m.abs() else: return self._m @@ -128,4 +143,299 @@ def _parameter_shape_hash(self): hsh = Measure._parameter_shape_hash(self) hsh += hash(self.mfix) hsh += hash(self.limits) - return hash(hsh) \ No newline at end of file + return hash(hsh) + + def unsqueeze(self, dim: int) -> DiracMixture: + dim = normalize_index(dim, len(self.shape)) + new_shape = (*self.shape[:dim], 1, *self.shape[dim + 1 :]) + new = DiracMixture( + dim=self.dim, + k=self.k, + limits=self.limits, + mfix=self.mfix, + shape=new_shape, + signed=self.signed, + ) + + """ substitute _x """ + _x = self._x.unsqueeze(dim) + if new._x.shape != _x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(_x)) + new._x.requires_grad_(False) + new._x.copy_(_x) + + """ substitute _m """ + if isinstance(new._m, Tensor): + _m = self.m + _m = _m.reshape(*new_shape, self.k) + new._m.requires_grad_(False) + new._m.copy_(_m) + + return new + + def squeeze(self, dim: int) -> DiracMixture: + dim = normalize_index(dim, len(self.shape)) + assert ( + self.shape[dim] == 1 + ), f"Unable to squeeze a dimension of size {self.shape[dim]}." + new_shape = (*self.shape[: dim - 1], *self.shape[dim + 1 :]) + new = DiracMixture( + dim=self.dim, + k=self.k, + limits=self.limits, + mfix=self.mfix, + shape=new_shape, + signed=self.signed, + ) + + """ substitute _x """ + _x = self._x.squeeze(dim) + if new._x.shape != _x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(_x)) + new._x.requires_grad_(False) + new._x.copy_(_x) + + """ substitute _m """ + if isinstance(new._m, Tensor): + _m = self.m + _m = _m.reshape(*new_shape, self.k) + new._m.requires_grad_(False) + new._m.copy_(_m) + + return new + + def split(self, n_heads: int) -> DiracMixture: + assert ( + self.k % n_heads == 0 + ), f"K ({self.k}) must be multiple of the number of heads ({n_heads})." + new_k = self.k // n_heads + new_shape = (n_heads, *self.shape) + new = DiracMixture( + dim=self.dim, + k=new_k, + limits=self.limits, + mfix=self.mfix, + shape=new_shape, + signed=self.signed, + ) + + """ substitute _x """ + xshape = self._x.shape[:-2] + _x = self._x.reshape(*xshape, n_heads, new_k, self.dim) + _x = _x.permute( + -3, *range(len(self.shape)), -2, -1 + ).contiguous() # (n_heads, ...shape, new_k, dim) + if new._x.shape != _x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(_x)) + new._x.requires_grad_(False) + new._x.copy_(_x) + + """ substitute _m """ + if isinstance(new._m, Tensor): + _m = self.m + _m = _m.reshape(*self.shape, n_heads, self.k // n_heads) + _m = _m.permute(-2, *range(len(self.shape)), -1) # (n_heads, ...shape, k) + new._m.requires_grad_(False) + new._m.copy_(_m) + return new + + def concat(self, dim: int = 0) -> DiracMixture: + dim = normalize_index(dim, len(self.shape)) + n_heads = self.shape[dim] + new_k = self.k * n_heads + new_shape = (*self.shape[:dim], *self.shape[dim + 1 :]) + + new = DiracMixture( + dim=self.dim, + k=new_k, + limits=self.limits, + mfix=self.mfix, + shape=new_shape, + signed=self.signed, + ) + + """ substitute _x """ + _x = self._x # (...shape_before, n_heads, ...shape_after, k, field_dim) + _x = _x.permute( + *range(dim), *range(dim + 1, len(self.shape)), dim, -2, -1 + ) # (...shape_before, ...shape_after, n_heads, k, field_dim) + _x = _x.reshape(*_x.shape[:-3], new_k, self.dim) + if new._x.shape != _x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(_x)) + new._x.requires_grad_(False) + new._x.copy_(_x) + + """ substitute _m """ + if isinstance(new._m, Tensor): + _m = self.m # (...shape_before, n_heads, ...shape_after, k) + _m = _m.permute( + *range(dim), *range(dim + 1, len(self.shape)), dim, -1 + ) # (...shape_before, ...shape_after, n_heads, k) + _m = _m.reshape(*new_shape, new_k) + new._m.requires_grad_(False) + new._m.copy_(_m) + return new + + def share_supports(self, dim: int) -> DiracMixture: + """ + Warning: this function does not support automatic detection / reduction + of duplicated supports. So you should only call this function once. + + Args: + dim: the axis along which the Measure supports are shared. + + Returns: + DiracMixture: (...shape) of `n*k` mixtures, where `n` is the \ + size at dimension `dim`. + """ + dim = normalize_index(dim, len(self.shape)) + + n = self.shape[dim] + k = self.k + nk = n * k + + new = DiracMixture( + dim=self.dim, + k=nk, + limits=self.limits, + mfix=False, + shape=self.shape, + signed=self.signed, + ) + + """ substitute _x: (..., n, ..., k) -> (..., n, ..., nk) """ + _x: Tensor = self._x + xshape = _x.shape + xdims = len(xshape) + dimorder = ( + *range(dim), + *range(dim + 1, len(self.shape)), + dim, + *range(len(self.shape), xdims), + ) + _x = _x.permute(dimorder) # (...shape_before, ...shape_after, n, k, field_dim) + + common_x = _x.reshape( + *xshape[:dim], 1, *xshape[dim + 1 : -2], nk, self.dim + ) # (...shape_before, 1, ...shape_after, nk, field_dim) + + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(common_x)) + new._x.requires_grad_(False) + new._x.copy_(common_x) + + """ substitute _m: (..., n, ..., k) -> (..., n, ..., nk) """ + + if not isinstance(self._m, Tensor): + _m = torch.ones( + *self.shape, self.k, dtype=torch.float32, device=self._x.device + ) + else: + _m = self.m + ids = [slice(None)] * len(_m.shape) + ids[dim] = torch.arange(n).reshape(n, 1) + ids[-1] = torch.arange(nk).reshape(n, k) + + newm = torch.zeros(*self.shape, nk, dtype=_m.dtype, device=_m.device) + if dim + 1 == len(ids) - 1: # dim is the second to last dimension + newm[ids] = _m + else: + _m = _m.permute( + dim, len(_m.shape) - 1, *range(dim), *range(dim + 1, len(_m.shape) - 1) + ) + newm[ids] = _m + new._m.requires_grad_(False) + new._m.copy_(newm) + + return new + + def linear(self, A: Tensor) -> DiracMixture: + """Weighted linear combination of the measures along the last dimension, + which must have already shared the same set of `x`. + + Args: + A (Tensor): (..., out_dim, in_dim), where out_dim == in_dim + + Returns: + DiracMixture: (...shape[:-1], out_dim) + """ + assert A.shape[-1] == self.shape[-1] + new_shape = (*self.shape[:-1], A.shape[-2]) + + new = DiracMixture( + dim=self.dim, + k=self.k, + limits=self.limits, + mfix=False, + shape=new_shape, + signed=self.signed, + ) + + """ substitute _x """ + if new._x.shape != self._x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(self._x)) + new._x.requires_grad_(False) + new._x.copy_(self._x) + + """ substitute _m """ + if self.mfix: + _m = torch.ones( + *self.shape, self.k, dtype=torch.float32, device=self._x.device + ) + else: + _m = self.m # (...shape[:-1], shape[-1] == in_dim, k) + _m = ( + _m.unsqueeze(-3) # (...shape[:-1], 1, in_dim, k) + * A.unsqueeze(-1) # (...shape[:-1], out_dim, in_dim, 1) + ).sum( + dim=-2 + ) # (...shape[:-1], out_dim, k) + new._m.requires_grad_(False) + new._m.copy_(_m) + + return new + + def __add__(self, other: DiracMixture) -> DiracMixture: + assert self.shape == other.shape + assert (self._x == other._x).all() + + new = DiracMixture( + dim=self.dim, + k=self.k, + limits=self.limits, + mfix=False, + shape=self.shape, + signed=self.signed, + ) + + if self.mfix: + _m1 = torch.ones( + *self.shape, self.k, dtype=torch.float32, device=self._x.device + ) + else: + _m1 = self.m + + if other.mfix: + _m2 = torch.ones( + *other.shape, other.k, dtype=torch.float32, device=other._x.device + ) + else: + _m2 = other.m + + """ substitute _x """ + if new._x.shape != self._x.shape: + new._x.requires_grad_(False) + new._x = Parameter(torch.empty_like(self._x)) + new._x.requires_grad_(False) + new._x.copy_(self._x) + + """ substitute _m """ + new._m.requires_grad_(False) + new._m.copy_(_m1 + _m2) + + return new diff --git a/firelang/models/__init__.py b/firelang/models/__init__.py index d34cd71..3b98fba 100644 --- a/firelang/models/__init__.py +++ b/firelang/models/__init__.py @@ -1 +1,2 @@ +from .tensor import * from .word import * \ No newline at end of file diff --git a/firelang/models/_firetensor.py b/firelang/models/_firetensor.py new file mode 100644 index 0000000..679ee68 --- /dev/null +++ b/firelang/models/_firetensor.py @@ -0,0 +1,105 @@ +from __future__ import annotations +import torch +from torch import Tensor +from torch.nn import Module +from firelang.measure import Measure +from firelang.function import Functional +from firelang.stack import IndexLike + +__all__ = [ + "FireTensor", + "FIRETensor", +] + + +class FireTensor(Module): + def __init__(self, funcs: Functional, measures: Measure): + Module.__init__(self) + # check_shape_consistency(funcs.shape, measures.shape) + self.funcs: Functional = funcs + self.measures: Measure = measures + + def __getitem__(self, index: IndexLike) -> FireTensor: + return FireTensor(self.funcs[index], self.measures[index]) + + def view(self, *shape, inplace=False) -> FireTensor: + if inplace: + self.funcs.view(*shape, inplace=True) + return self + else: + return FireTensor( + funcs=self.funcs.view(*shape, inplace=False), + measures=self.measures.view(*shape, inplace=False), + ) + + def __add__(self, other: FireTensor | Functional) -> FireTensor: + if isinstance(other, FireTensor): + return FireTensor( + funcs=self.funcs + other.funcs, measures=self.measures + other.measures + ) + elif isinstance(other, Functional): + return FireTensor(funcs=self.funcs + other, measures=self.measures) + else: + raise TypeError(other) + + def __mul__(self, other: FIRETensor) -> Tensor: + if id(other) == id(self): + return self.measures.integral(self.funcs) * 2 + else: + return other.measures.integral(self.funcs) + self.measures.integral( + other.funcs + ) + + def __matmul__(self, other: FIRETensor) -> Tensor: + if id(other) == id(self): + mat = self.measures.integral(self.funcs, cross=True) + return mat + torch.transpose(mat, -2, -1) + else: + return other.measures.integral(self.funcs, cross=True) + torch.transpose( + self.measures.integral(other.funcs, cross=True), -2, -1 + ) + + def __repr__(self): + return ( + f"" + ) + + def split(self, n_heads: int): + return FireTensor( + funcs=self.funcs.vmap(), measures=self.measures.split(n_heads) + ) + + def size(self): + return self.funcs.shape + + @property + def shape(self): + return self.size() + + def detect_device(self): + return self.funcs.detect_device() + + def detect_dtype(self): + return self.funcs.detect_dtype() + + def flatten_parameter(self): + return torch.cat( + [self.funcs.flatten_parameter(), self.measures.flatten_parameter()], dim=-1 + ) + + def load_flatten_parameter(self, flattened: Tensor) -> int: + offset = self.funcs.load_flatten_parameter(flattened, offset=0) + offset = self.measures.load_flatten_parameter(flattened, offset=offset) + return offset + + def restack(self, shape): + return FireTensor( + funcs=self.funcs.restack(shape), measures=self.measures.restack(shape) + ) + + stack = restack + + +FIRETensor = FireTensor diff --git a/firelang/models/_fireword.py b/firelang/models/_fireword.py index bbece7d..25219ab 100644 --- a/firelang/models/_fireword.py +++ b/firelang/models/_fireword.py @@ -1,66 +1,55 @@ from __future__ import annotations from argparse import Namespace from typing import List, Union, Tuple +from collections import OrderedDict +import os +import json import numpy as np import torch from torch import Tensor from torch.nn import Module, functional as F -from firelang.measure import Measure -from firelang.function import Functional -from firelang.stack import StackingSlicing, IndexLike +from firelang.stack import StackingSlicing from firelang.measure import DiracMixture, metrics from firelang.utils.timer import Timer, elapsed from firelang.utils.optim import Loss +from firelang.utils.log import logger +from firelang.utils.parse import parse_func, parse_measure +from . import FireTensor from corpusit import Vocab __all__ = [ + "FireEmbedding", + "FireWord", + "FireWordConfig", "FIREWord", - "FIRETensor", ] -class FIREWord(Module): +class FireEmbedding(Module): funcs: StackingSlicing measures: StackingSlicing dim: int - vocab: Vocab def __init__( self, func_template: StackingSlicing, measure_template: StackingSlicing, - dim, - vocab: Vocab, + vocab_size, ): - super().__init__() - self.vocab_size = len(vocab) - self.funcs: StackingSlicing = func_template.restack(self.vocab_size) - self.measures: StackingSlicing = measure_template.restack(self.vocab_size) - - self.dim = dim - self.vocab = vocab - self.i2rank, self.rank2i = self._ranking() - - def _ranking(self): - ids = sorted(self.vocab.i2s_dict().keys()) - maxid = max(ids) - i2rank = -np.ones(maxid + 1, dtype=np.int64) - rank2i = -np.ones(len(ids), dtype=np.int64) - for rank, idx in enumerate(ids): - i2rank[idx] = rank - rank2i[rank] = idx - return i2rank, rank2i + Module.__init__(self) + assert func_template.dim == measure_template.dim + self.dim = func_template.dim + self.funcs: StackingSlicing = func_template.restack(vocab_size) + self.measures: StackingSlicing = measure_template.restack(vocab_size) + self.vocab_size = vocab_size def __len__(self): return self.vocab_size - def detect_device(self) -> torch.device: - return next(iter(self.parameters())).device - - def forward(self, ranks: Tensor) -> FIRETensor: + def forward(self, ranks: Tensor) -> FireTensor: """ Args: ranks (Tensor): (n, ) of word ranks @@ -68,17 +57,61 @@ def forward(self, ranks: Tensor) -> FIRETensor: Returns: (Functional, Measure) """ - if not isinstance(ranks, Tensor): - ranks = torch.tensor(ranks, device=self.detect_device(), dtype=torch.long) with Timer(elapsed, "slicing", sync_cuda=True), Timer( elapsed, "slicing", sync_cuda=True, relative=False ): - func = self.funcs[ranks] - measure = self.measures[ranks] - return FIRETensor(func, measure) + return FireTensor( + funcs=self.funcs[ranks], + measures=self.measures[ranks], + ) + + def detect_device(self) -> torch.device: + return next(iter(self.parameters())).device + - def __getitem__(self, words: Union[str, List[str]]) -> FIRETensor: +class FireWordConfig(dict): + dim: int + func: str + measure: str + + def __init__(self, **kwargs): + dict.__init__(self, **kwargs) + self.__dict__ = self + + +class FireWord(FireEmbedding): + + config: FireWordConfig + vocab: Vocab + + def __init__( + self, + config: FireWordConfig, + vocab: Vocab, + ): + FireEmbedding.__init__( + self, + func_template=parse_func(config.func, dim=config.dim), + measure_template=parse_measure(config.measure, dim=config.dim), + vocab_size=len(vocab), + ) + self.config = config + self.i2rank, self.rank2i = self._ranking(vocab) + self.vocab = vocab + self.vocab_size = len(vocab) + + def _ranking(self, vocab): + ids = sorted(vocab.i2s_dict().keys()) + maxid = max(ids) + i2rank = -np.ones(maxid + 1, dtype=np.int64) + rank2i = -np.ones(len(ids), dtype=np.int64) + for rank, idx in enumerate(ids): + i2rank[idx] = rank + rank2i[rank] = idx + return i2rank, rank2i + + def __getitem__(self, words: Union[str, List[str]]) -> FireTensor: """ Args: words (str or List[str]): a word or a list of words @@ -123,7 +156,7 @@ def field( assert m.shape == meshx.shape stack = int(np.prod(meshx.shape)) - func, _ = self[[word] * stack] + ft: FireTensor = self[[word] * stack] measure = DiracMixture(self.dim, 1, limits=None, shape=(stack,)).to( self.detect_device() ) @@ -135,7 +168,7 @@ def field( ) measure.m.copy_(torch.ones(stack, 1, dtype=torch.float32)) - outputs = measure.integral(func) # (stack, 1) + outputs = measure.integral(ft.funcs) # (stack, 1) outputs = outputs.view(*meshx.shape) return outputs @@ -187,8 +220,8 @@ def loss_skipgram( corresponding word pair is a positive or a negative sample. """ - x1: FIRETensor = self.forward(pairs[..., 0]) - x2: FIRETensor = self.forward(pairs[..., 1]) + x1: FireTensor = self.forward(pairs[..., 0]) + x2: FireTensor = self.forward(pairs[..., 1]) loss = Loss() @@ -213,46 +246,41 @@ def loss_skipgram( return loss + @staticmethod + def from_pretrained(dirpath) -> FireWord: + dirpath = os.path.abspath(dirpath) + if not os.path.exists(dirpath): + raise FileNotFoundError(f"Directory not found at {dirpath}") + + # config + with open(f"{dirpath}/config.json", "rt") as f: + config = FireWordConfig(**json.load(f)) + + # vocab + vocab = Vocab.from_json(f"{dirpath}/vocab.json") + + # state_dict + word = FireWord(config=config, vocab=vocab) + state_dict = torch.load(f"{dirpath}/pytorch_model.bin") + word.load_state_dict(state_dict) + return word + + def save(self, dirpath): + dirpath = os.path.abspath(dirpath) + if os.path.exists(dirpath): + logger.warn(f"Overwriting files in directory {dirpath}.") + else: + os.makedirs(dirpath, exist_ok=True) -class FIRETensor: - def __init__(self, funcs: Functional, measures: Measure): - assert funcs.shape == measures.shape - self.funcs: Functional = funcs - self.measures: Measure = measures + # config + with open(f"{dirpath}/config.json", "wt") as f: + json.dump(self.config, f) - def __getitem__(self, index: IndexLike) -> FIRETensor: - return FIRETensor(self.funcs[index], self.measures[index]) + # vocab + self.vocab.to_json(f"{dirpath}/vocab.json") - def view(self, *shape, inplace=False) -> FIRETensor: - if inplace: - self.funcs.view(*shape, inplace=True) - return self - else: - return FIRETensor( - funcs=self.funcs.view(*shape, inplace=False), - measures=self.measures.view(*shape, inplace=False), - ) + # state_dict + torch.save(self.state_dict(), f"{dirpath}/pytorch_model.bin") - def __mul__(self, other: FIRETensor): - if id(other) == id(self): - return self.measures.integral(self.funcs) * 2 - else: - return other.measures.integral(self.funcs) + self.measures.integral( - other.funcs - ) - - def __matmul__(self, other: FIRETensor): - if id(other) == id(self): - mat = self.measures.integral(self.funcs, cross=True) - return mat + torch.transpose(mat, -2, -1) - else: - return other.measures.integral(self.funcs, cross=True) + torch.transpose( - self.measures.integral(other.funcs, cross=True), -2, -1 - ) - def __repr__(self): - return ( - f"" - ) +FIREWord = FireWord diff --git a/firelang/models/tensor.py b/firelang/models/tensor.py new file mode 100644 index 0000000..89b49ce --- /dev/null +++ b/firelang/models/tensor.py @@ -0,0 +1 @@ +from ._firetensor import * \ No newline at end of file diff --git a/firelang/stack.py b/firelang/stack.py index 5b00e11..cba3a57 100644 --- a/firelang/stack.py +++ b/firelang/stack.py @@ -4,9 +4,10 @@ from collections import OrderedDict, defaultdict import inspect import numpy as np +import torch from torch import Tensor from torch.nn import Module, ModuleList, ModuleDict -from .utils.index import parse_index, IndexLike +from .utils.index import IndexLike from .utils.shape import parse_shape __all__ = [ @@ -60,36 +61,9 @@ def _sanity_check(self): "initialization argument." ) - def view(self, *shape, inplace: bool = False): - shape = parse_shape(shape, int(np.prod(self.shape))) - - if inplace: - self.shape = shape - return self - - else: - new = deepcopy(self) - new.shape = shape - - for module in new.children(): - if isinstance(module, StackingSlicing): - StackingSlicing.view(module, *shape, inplace=True) - elif isinstance(module, ModuleList): - for m in module: - if isinstance(m, StackingSlicing): - StackingSlicing.view(m, *shape, inplace=True) - elif isinstance(module, ModuleDict): - for _, m in module.items(): - if isinstance(m, StackingSlicing): - StackingSlicing.view(m, *shape, inplace=True) - return new - def __getitem__(self, index: IndexLike): - idtensor: Tensor = parse_index(index, self.shape) - ids = idtensor.reshape(-1) - shape = tuple(idtensor.shape) - - to: StackingSlicing = self.restack(shape) + new_shape = tuple(torch.empty(self.shape)[index].shape) + to: StackingSlicing = self.restack(new_shape) # A parameter not listed in `unsliceable_params` should be # sliced and copied. Otherwise, the whole parameter is copied. @@ -99,7 +73,7 @@ def __getitem__(self, index: IndexLike): if name in self.unsliceable_params: param_to.copy_(param) else: - param_to.copy_(param[ids]) + param_to.copy_(param[index]) # A submodule that is a `StackingSlicing` should be sliced # and copied. Otherwise, the whole submodule is copied. @@ -127,14 +101,35 @@ def __getitem__(self, index: IndexLike): ) else: submod_from: Module = module - submod_from.shape = shape + submod_from.shape = new_shape setattr(to, name, submod_from) - to.shape = shape + to.shape = new_shape return to def detect_device(self): - return next(iter(self.parameters())).device + for m in self.modules(): + try: + return next(m.parameters()).device + except StopIteration: + if hasattr(m, "_former_parameters"): + # `self` is an instance from torch.nn.parallel.replicate + fp = m._former_parameters + if len(fp): + return next(iter(fp.values())).device + raise ValueError("Failed to detect the device.") + + def detect_dtype(self): + for m in self.modules(): + try: + return next(m.parameters()).dtype + except StopIteration: + if hasattr(m, "_former_parameters"): + # `self` is an instance from torch.nn.parallel.replicate + fp = m._former_parameters + if len(fp): + return next(iter(fp.values())).dtype + raise ValueError("Failed to detect the dtype.") def _parameter_shape_hash(self): name_shapes = [(name, p.shape) for name, p in self.named_parameters()] @@ -142,12 +137,12 @@ def _parameter_shape_hash(self): def restack( self, - shape: int | Tuple[int], + *shape: int | Tuple[int], use_cached: bool = True, max_cached_copies: int = 100, ): - if not isinstance(shape, Tuple): - shape = (shape,) + if len(shape) == 1 and isinstance(shape[0], Iterable): + shape = tuple(shape[0]) tag = f"stacked/{self.__class__.__name__}-{self._parameter_shape_hash()}" if use_cached and shape in _cache[tag]: @@ -189,9 +184,15 @@ def _recover_args_from_locals( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, ]: - positional.append(value) + keywords[key] = value elif sign.kind == inspect.Parameter.VAR_POSITIONAL: positional.extend(value) elif sign.kind == inspect.Parameter.VAR_KEYWORD: keywords = {**keywords, **value} - return positional, keywords \ No newline at end of file + return positional, keywords + + @property + def ndim(self) -> int: + return len(self.shape) + + diff --git a/firelang/utils/index.py b/firelang/utils/index.py index f408abd..236c51c 100644 --- a/firelang/utils/index.py +++ b/firelang/utils/index.py @@ -1,16 +1,17 @@ from __future__ import annotations -from typing import List, Tuple, Union, Any +from typing import List, Tuple, Union, Any, Iterable import numpy as np import torch from torch import Tensor +from .shape import ShapeLike IndexLike = Union[ int, slice, List, Tensor, None, Tuple[Union[int, slice, List, Tensor, None]] ] -def parse_index(index: IndexLike, shape: Tuple[int]) -> Tensor: - if not isinstance(index, tuple): +def flatten_index(index: IndexLike, shape: Tuple[int]) -> Tensor: + if not isinstance(index, Iterable): index = (index,) index = _complete_ellipsis(index, ndim=len(shape)) @@ -73,3 +74,49 @@ def _complete_ellipsis(index: Tuple[Any | Ellipsis], ndim: int): list(index[:i]) + [slice(None)] * (ndim - len(index) + 1) + list(index[i + 1 :]) ) return tuple(completed) + + +def normalize_index(index: IndexLike, shape: ShapeLike): + if not isinstance(shape, Iterable): + shape = [shape] + if not isinstance(index, Iterable): + normalized = normalize_index([index], shape) + return normalized[0] + + index = _complete_ellipsis(index, ndim=len(shape)) + assert len(index) == len(shape) + normalized = [] + for dim, (index_at_dim, size_at_dim) in enumerate(zip(index, shape)): + + if isinstance(index_at_dim, int): + assert -size_at_dim <= index_at_dim < size_at_dim + if index_at_dim < 0: + index_at_dim += size_at_dim + normalized.append(index_at_dim) + elif isinstance(index_at_dim, slice): + index_at_dim = list(range(size_at_dim))[index_at_dim] + normalized.append(index_at_dim) + elif isinstance(index_at_dim, list): + nonneg = [] + for idx in index_at_dim: + assert -size_at_dim <= idx < size_at_dim + if idx < 0: + idx += size_at_dim + nonneg.append(idx) + normalized.append(nonneg) + elif isinstance(index_at_dim, Tensor): + assert index_at_dim.ndim == 1, ( + f"Index at dimension {dim} should be 1-dimensional, " + f"not {index_at_dim.ndim}-d." + ) + assert (-size_at_dim <= index_at_dim).all() and (index_at_dim < size_at_dim).all() + index_at_dim = index_at_dim.data.cpu().numpy() + index_at_dim[index_at_dim < 0] += size_at_dim + normalized.append(index_at_dim.tolist()) + else: + raise TypeError( + f"Index at dimension {dim} should be " + f"a `int`, a `slice`, or a `Tensor`, not {type(index_at_dim)}" + ) + + return normalized \ No newline at end of file diff --git a/firelang/utils/parse.py b/firelang/utils/parse.py index e76cf7a..ff4b668 100644 --- a/firelang/utils/parse.py +++ b/firelang/utils/parse.py @@ -1,4 +1,5 @@ from argparse import Namespace +import firelang from firelang import * __all__ = ["parse_func", "parse_measure"] @@ -13,7 +14,7 @@ def parse_func(name: str, args: Namespace = Namespace(), **kwargs): for nameseg in segs: seg = eval(nameseg) funcsegs.append(seg) - return Sequential(funcsegs) + return firelang.function.Sequential(funcsegs) def parse_measure(name, args: Namespace = Namespace(), **kwargs): diff --git a/firelang/utils/shape.py b/firelang/utils/shape.py index e774660..a57f7bd 100644 --- a/firelang/utils/shape.py +++ b/firelang/utils/shape.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Tuple, Iterable +from typing import Tuple, Iterable, Union import numpy as np +ShapeLike = Union[int, Tuple[int]] def check_shape_consistency(shape1: Tuple[int], shape2: Tuple[int]): assert len(shape1) == len( @@ -10,6 +11,7 @@ def check_shape_consistency(shape1: Tuple[int], shape2: Tuple[int]): for d1, d2 in zip(shape1, shape2): if d1 != 1 and d2 != 1 and d1 != d2: raise ValueError(f"Inconsistent shape: {shape1} and {shape2}") + return True def parse_shape(shape, num_elements): diff --git a/scripts/benchmark.py b/scripts/benchmark.py index f12ed4a..a422a76 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -15,7 +15,7 @@ from sklearn.cluster import DBSCAN from sklearn.metrics import accuracy_score -from firelang.models import FIREWord, FIRETensor +from firelang.models import FireWord, FireTensor from firelang.utils.log import logger from firelang.utils.timer import Timer, elapsed from scripts.sentsim import sentsim_as_weighted_wordsim_cuda @@ -212,7 +212,7 @@ def load_all_word_benchmarks(dirpath=DEFAULT_WORDSIM_DIR, lower=True): @torch.no_grad() @Timer(elapsed, "wordsim") def benchmark_word_similarity( - model: FIREWord, benchmarks: Mapping[str, SimilarityBenchmark] + model: FireWord, benchmarks: Mapping[str, SimilarityBenchmark] ): vocab: Vocab = model.vocab device = model.detect_device() @@ -407,7 +407,7 @@ def load_all_sentsim_benchmarks(dirpath=DEFAULT_SENTSIM_DIR, lower=True): @torch.no_grad() @Timer(elapsed, "sentsim") def benchmark_sentence_similarity( - model: FIREWord, + model: FireWord, benchmarks: Mapping[str, SimilarityBenchmark], sif_alpha=1e-3, ): @@ -627,7 +627,7 @@ def get_relwordpos(allmeasures, model, centerword, k=1000, pca=False): @torch.no_grad() @Timer(elapsed, "wordsense") def benchmark_wordsense_number( - model: FIREWord, + model: FireWord, w2nsense, num_workers=os.cpu_count(), eps=0.4, @@ -679,16 +679,16 @@ def benchmark_wordsense_number( "--checkpoints_for_similarity", nargs="+", default=[ - "checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy", - "checkpoints/wacky_mlplanardiv_d2_l4_k10", - "checkpoints/wacky_mlplanardiv_d2_l8_k20", + "checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy", + "checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10", + "checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20", ], ) parser.add_argument( "--checkpoints_for_polysemy", nargs="+", default=[ - "checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy", + "checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy", ], ) args = parser.parse_args() @@ -696,7 +696,7 @@ def benchmark_wordsense_number( for checkpoint in args.checkpoints_for_similarity: elapsed.clear() print(f"=============== Checkpoint `{checkpoint}` ================") - model = torch.load(checkpoint, map_location=device) + model = FireWord.from_pretrained(checkpoint).to(device) print("------------- word similarity -------------") benchmarks = load_all_word_benchmarks() @@ -728,7 +728,7 @@ def benchmark_wordsense_number( f"========= Word polysemy detection with checkpoint `{checkpoint}` =========" ) w2nsense = load_numwordsense("data/wordnet-542.txt") - model = torch.load(checkpoint, map_location=device) + model = FireWord.from_pretrained(checkpoint).to(device) acc, corr = benchmark_wordsense_number(model, w2nsense, eps=0.40) print( f"Accuracy = {acc*100:.3g}%, Pearson Correlation Coefficient = {corr:.3g}" diff --git a/scripts/benchmark/1_download_pretrained.sh b/scripts/benchmark/1_download_pretrained.sh index 14bf5db..bb9cc1e 100644 --- a/scripts/benchmark/1_download_pretrained.sh +++ b/scripts/benchmark/1_download_pretrained.sh @@ -1,11 +1,19 @@ -BASE_URL="https://www.cl.rcast.u-tokyo.ac.jp/~duxin/firelang/pretrained/word/v1.0/" -MODEL_23="wacky_mlplanardiv_d2_l4_k1_polysemy.gz" -MODEL_50="wacky_mlplanardiv_d2_l4_k10.gz" -MODEL_100="wacky_mlplanardiv_d2_l8_k20.gz" +BASE_URL="https://www.cl.rcast.u-tokyo.ac.jp/~duxin/firelang/pretrained/word/" +VERSION="v1.1/" +MODEL_23="wacky_mlplanardiv_d2_l4_k1_polysemy.tar.gz" +MODEL_50="wacky_mlplanardiv_d2_l4_k10.tar.gz" +MODEL_100="wacky_mlplanardiv_d2_l8_k20.tar.gz" -mkdir -p checkpoints -wget "$BASE_URL$MODEL_23" -O checkpoints/$MODEL_23 -wget "$BASE_URL$MODEL_50" -O checkpoints/$MODEL_50 -wget "$BASE_URL$MODEL_100" -O checkpoints/$MODEL_100 +mkdir -p checkpoints/$VERSION -gzip -d checkpoints/* \ No newline at end of file +wget "$BASE_URL$VERSION$MODEL_23" -O checkpoints/$VERSION$MODEL_23 +tar zxvf checkpoints/$VERSION$MODEL_23 -C checkpoints/$VERSION +rm checkpoints/$VERSION$MODEL_23 + +wget "$BASE_URL$VERSION$MODEL_50" -O checkpoints/$VERSION$MODEL_50 +tar zxvf checkpoints/$VERSION$MODEL_50 -C checkpoints/$VERSION +rm checkpoints/$VERSION$MODEL_50 + +wget "$BASE_URL$VERSION$MODEL_100" -O checkpoints/$VERSION$MODEL_100 +tar zxvf checkpoints/$VERSION$MODEL_100 -C checkpoints/$VERSION +rm checkpoints/$VERSION$MODEL_100 \ No newline at end of file diff --git a/scripts/benchmark/2_run_benchmark.sh b/scripts/benchmark/2_run_benchmark.sh index 2663d6d..d597354 100644 --- a/scripts/benchmark/2_run_benchmark.sh +++ b/scripts/benchmark/2_run_benchmark.sh @@ -1,7 +1,7 @@ python -m scripts.benchmark \ --checkpoints_for_similarity \ - checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy \ - checkpoints/wacky_mlplanardiv_d2_l4_k10 \ - checkpoints/wacky_mlplanardiv_d2_l8_k20 \ + checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy \ + checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10 \ + checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20 \ --checkpoints_for_polysemy \ - checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy \ No newline at end of file + checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy \ No newline at end of file diff --git a/scripts/text8/4_train.sh b/scripts/text8/4_train.sh index 84a1ba5..47eb9ff 100644 --- a/scripts/text8/4_train.sh +++ b/scripts/text8/4_train.sh @@ -11,8 +11,7 @@ python -m scripts.train \ --optimizer=adamw \ --seed=0 \ --accum_steps=10 \ - --func='MLPlanarDiv(args.dim, 4)' \ - --measure='DiracMixture(args.dim, 10)' \ + --func='MLPlanarDivFast(dim, 4)' \ + --measure='DiracMixture(dim, 10)' \ --weight_decay=1e-6 \ - --use_wandb \ - --amp \ No newline at end of file + --use_wandb \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 9db3812..3b77c3d 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -20,8 +20,7 @@ from torch.cuda.amp.grad_scaler import GradScaler from corpusit import Vocab, SkipGramDataset -from firelang import FIREWord, PFIREWord -from firelang.utils.parse import parse_func, parse_measure +from firelang import FireWord, FireWordConfig, FireTensor from firelang.utils.optim import DummyScheduler from firelang.utils.log import logger from firelang.utils.timer import elapsed, Timer @@ -30,7 +29,6 @@ ALL_WORDSIM_BENCHMARKS, load_word_benchmark, benchmark_word_similarity, - benchmark_word_similarity_pfire, ) @@ -38,8 +36,7 @@ try: import wandb - from wandb_config import config - from wandb_config import download_wandb_file + from wandb_config import download_wandb_files except Exception as e: logger.warn( "Unable to import wandb for experiment tracking. " @@ -51,7 +48,10 @@ @total_timer def train(args): + torch.set_num_threads(1) if args.use_wandb: + from wandb_config import config + config() exp_name = "".join(random.choices(ascii_uppercase + digits, k=8)) args.savedir = f"{args.savedir}/{exp_name}/" @@ -93,29 +93,16 @@ def train(args): set_seed(args.seed) if not (args.use_wandb and args.wandb_pretrained): - func_template = parse_func(args.func, args).to(device) - measure_template = parse_measure(args.measure, args).to(device) - if args.model == "FIREWord": - model = FIREWord( - func_template, - measure_template, - args.dim, - vocab, - ) - elif args.model == "PFIREWord": - model = PFIREWord( - func_template, - args.grid_limits, - args.grid_dim_sizes, - vocab, - ) + config = FireWordConfig(dim=args.dim, func=args.func, measure=args.measure) + if args.model.lower() == "fireword": + model = FireWord(config=config, vocab=vocab) else: raise ValueError(args.model) else: - model = torch.load( - download_wandb_file(args.wandb_pretrained, "best"), - map_location="cpu", - ) + if args.model.lower() == "fireword": + model = FireWord.from_pretrained( + download_wandb_files(args.wandb_pretrained, "best") + ) wandb.config.update({"continue": True, "previous_run": args.wandb_pretrained}) logger.info(model) num_parameters = count_parameters(model) // len(vocab) @@ -183,15 +170,8 @@ def train(args): """ ----------------- forward pass -------------------""" with prof, autocaster, Timer(elapsed, "forward", sync_cuda=True): if args.task == "skipgram": - if args.model == "FIREWord": - model: FIREWord - loss = model.loss_skipgram( - inputs, - labels, - args, - ) - elif args.model == "PFIREWord": - model: PFIREWord + if args.model.lower() == "fireword": + model: FireWord loss = model.loss_skipgram( inputs, labels, @@ -212,6 +192,7 @@ def train(args): scaler.scale(steploss).backward() else: steploss.backward() + grad_norm = ( torch.cat([p.grad.data.reshape(-1) for p in model.parameters()]) .norm() @@ -254,7 +235,7 @@ def train(args): """--------------- similarity benchmark ---------------""" with Timer(elapsed, "benchmark", sync_cuda=True): - if args.model == "FIREWord": + if args.model.lower() == "fireword": simscores = ( benchmark_word_similarity( model, @@ -262,14 +243,6 @@ def train(args): ) * 100 ) - elif args.model == "PFIREWord": - simscores = ( - benchmark_word_similarity_pfire( - model, - benchmarks, - ) - * 100 - ) else: raise ValueError(args.model) simscore = simscores.mean() @@ -277,7 +250,7 @@ def train(args): best_iter = i best_simscore = simscore best_loss = total_loss.item() - torch.save(model, best_savepath) + model.save(best_savepath) if args.task == "skipgram": n_pos = labels.sum() @@ -306,17 +279,15 @@ def train(args): } if args.dim == 2: """---------------- visualize ----------------""" - if args.model == "FIREWord": + if args.model.lower() == "fireword": fig = visualize_fire(model, args.plot_words) - elif args.model == "PFIREWord": - fig = visualize_pfire(model, args.plot_words) else: raise ValueError(args.model) img = wandb.Image(_fig2array(fig)) plt.close(fig) loginfo["wordfig"] = img wandb.log(loginfo) - wandb.save(best_savepath, args.savedir, policy="end") + wandb.save(f"{best_savepath}/**", args.savedir, policy="end") model.train() @@ -359,11 +330,14 @@ def _fig2array(fig): @torch.no_grad() -def visualize_fire(model: FIREWord, words: List[str], r: float = 4): - _, measure = model[words] +def visualize_fire(model: FireWord, words: List[str], r: float = 4): + ft: FireTensor = model[words] + measure = ft.measures positions = measure.get_x() weights = ( - torch.ones(x.shape[0], x.shape[1], dtype=x.dtype, device=x.device) + torch.ones( + positions.shape[0], positions.shape[1], dtype=x.dtype, device=x.device + ) if isinstance(measure.m, float) else measure.m.abs() ) @@ -412,33 +386,6 @@ def _sigmoid(x): return fig -@torch.no_grad() -def visualize_pfire(model: PFIREWord, words): - limits = model.limits - meshx, meshy = torch.meshgrid( - torch.linspace(*limits[0], 100), torch.linspace(*limits[1], 100) - ) - grids = model.grids(words, reshape=True, meshx=meshx, meshy=meshy) # (stack, n, n) - - grids = grids.data.cpu().numpy() # (stack, n, n) - meshx = meshx.data.cpu().numpy() - meshy = meshy.data.cpu().numpy() - - naxes = len(words) - ncols = 8 - fig = plt.figure(figsize=(4 * ncols / (ncols - 1), 4 * naxes)) - gs = fig.add_gridspec(naxes, ncols) - for i, (word, grid) in enumerate(zip(words, grids)): - ax = fig.add_subplot(gs[i, : ncols - 1]) - ax.set_title(word) - cont = ax.contourf(meshx, meshy, grid) - ax = fig.add_subplot(gs[i, -1]) - fig.colorbar(cont, cax=ax) - - fig.subplots_adjust(right=0.8) - return fig - - def parse_arguments(): parser = argparse.ArgumentParser() @@ -458,9 +405,7 @@ def boolean_string(s): type=str, default=None, ) - parser.add_argument( - "--model", type=str, default="FIREWord", choices=["FIREWord", "PFIREWord"] - ) + parser.add_argument("--model", type=str, default="FireWord", choices=["FireWord"]) parser.add_argument("--task", type=str, default="skipgram", choices=["skipgram"]) # ----- fire model settings ----- diff --git a/scripts/wacky/4_train.sh b/scripts/wacky/4_train.sh index ae944ff..1b37ee1 100644 --- a/scripts/wacky/4_train.sh +++ b/scripts/wacky/4_train.sh @@ -11,8 +11,7 @@ python -m scripts.train \ --optimizer=adamw \ --seed=0 \ --accum_steps=10 \ - --func='MLPlanarDiv(args.dim, 4).neg()' \ - --measure='DiracMixture(args.dim, 10)' \ + --func='MLPlanarDivFast(dim, 4).neg()' \ + --measure='DiracMixture(dim, 10)' \ --weight_decay=1e-6 \ - --use_wandb \ - --amp \ No newline at end of file + --use_wandb \ No newline at end of file diff --git a/wandb_config.template.py b/wandb_config.template.py index 9094d68..f228e3e 100644 --- a/wandb_config.template.py +++ b/wandb_config.template.py @@ -20,6 +20,7 @@ api = None + def config(): for k, v in settings.items(): os.environ[k] = v @@ -32,7 +33,19 @@ def download_wandb_file(runid: str, path) -> Union[str, Path]: if api is None: config() run = api.run(runid) - file = run.file('best').download(f'{wandb_dir}/download-{runid}', - replace=True) - logging.info(f'Downloaded file from wandb://{runid}/{path}') - return file.name \ No newline at end of file + file = run.file(path).download(f"{wandb_dir}/download-{runid}", replace=True) + print(f"Downloaded file at wandb://{runid}/{path}") + return file.name + + +def download_wandb_files(runid: str, path) -> Union[str, Path]: + if api is None: + config() + run = api.run(runid) + files = run.files() + save_dir = f"{wandb_dir}/download-{runid}" + for file in files: + if file.name.startswith(path): + file.download(save_dir, replace=True) + print(f"Downloaded file at wandb://{runid}/{file.name}") + return f"{save_dir}/{path}"