Skip to content

Commit

Permalink
Support more kinds of marginalization via dim analysis
Browse files Browse the repository at this point in the history
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
  • Loading branch information
ricardoV94 committed Oct 4, 2024
1 parent d965959 commit e96d07f
Show file tree
Hide file tree
Showing 9 changed files with 926 additions and 216 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.16.1 # CI was failing to resolve
- pymc>=5.17.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.16.1 # CI was failing to resolve
- pymc>=5.17.0 # CI was failing to resolve
- blackjax
- scikit-learn
179 changes: 132 additions & 47 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,55 @@
from collections.abc import Sequence

import numpy as np
import pytensor.tensor as pt

from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp
from pymc.logprob import conditional_logp
from pymc.logprob.abstract import _logprob
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Mode, clone_replace, graph_replace, scan
from pytensor import map as scan_map
from pytensor import tensor as pt
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorType, TensorVariable
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import Op, vectorize_graph
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorVariable

from pymc_experimental.distributions import DiscreteMarkovChain


class MarginalRV(SymbolicRandomVariable):
class MarginalRV(OpFromGraph, MeasurableOp):
"""Base class for Marginalized RVs"""

def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
self.dims_connections = dims_connections
super().__init__(*args, **kwargs)

class FiniteDiscreteMarginalRV(MarginalRV):
"""Base class for Finite Discrete Marginalized RVs"""
@property
def support_axes(self) -> tuple[tuple[int]]:
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
support_axes_vars = []
for dims_connection in self.dims_connections:
ndim = len(dims_connection)
marginalized_supp_axes = ndim - marginalized_ndim_supp
support_axes_vars.append(
tuple(
-i
for i, dim in enumerate(reversed(dims_connection), start=1)
if (dim is None or dim > marginalized_supp_axes)
)
)
return tuple(support_axes_vars)


class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""
class MarginalFiniteDiscreteRV(MarginalRV):
"""Base class for Marginalized Finite Discrete RVs"""


class MarginalDiscreteMarkovChainRV(MarginalRV):
"""Base class for Marginalized Discrete Markov Chain RVs"""


def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
Expand All @@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
return (0, 1)
elif isinstance(op, Categorical):
[p_param] = dist_params
return tuple(range(pt.get_vector_length(p_param)))
[p_param_length] = constant_fold([p_param.shape[-1]])
return tuple(range(p_param_length))
elif isinstance(op, DiscreteUniform):
lower, upper = constant_fold(dist_params)
return tuple(np.arange(lower, upper + 1))
Expand All @@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
raise NotImplementedError(f"Cannot compute domain for op {op}")


def _add_reduce_batch_dependent_logps(
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
):
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
def reduce_batch_dependent_logps(
dependent_dims_connections: Sequence[tuple[int | None, ...]],
dependent_ops: Sequence[Op],
dependent_logps: Sequence[TensorVariable],
) -> TensorVariable:
"""Combine the logps of dependent RVs and align them with the marginalized logp.
This requires reducing extra batch dims and transposing when they are not aligned.
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
marginalize(idx)
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
"""
from pymc_experimental.model.marginal.graph_analysis import get_support_axes

mbcast = marginalized_type.broadcastable
reduced_logps = []
for dependent_logp in dependent_logps:
dbcast = dependent_logp.type.broadcastable
dim_diff = len(dbcast) - len(mbcast)
mbcast_aligned = (True,) * dim_diff + mbcast
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
reduced_logps.append(dependent_logp.sum(vbcast_axis))
return pt.add(*reduced_logps)
for dependent_op, dependent_logp, dependent_dims_connection in zip(
dependent_ops, dependent_logps, dependent_dims_connections
):
if dependent_logp.type.ndim > 0:
# Find which support axis implied by the MarginalRV need to be reduced
# Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs)
dep_supp_axes = get_support_axes(dependent_op)[0]

# Dependent RV support axes are already collapsed in the logp, so we ignore them
supp_axes = [
-i
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
if (dim is None and -i not in dep_supp_axes)
]
dependent_logp = dependent_logp.sum(supp_axes)

@_logprob.register(FiniteDiscreteMarginalRV)
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
marginalized_rv, *inner_rvs = clone_replace(
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
dependent_logp = dependent_logp.transpose(*dims_alignment)

reduced_logps.append(dependent_logp)

reduced_logp = pt.add(*reduced_logps)
return reduced_logp


def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
"""Align the logp with the order specified in dims."""
dims_alignment = [dim for dim in dims if dim is not None]
return logp.transpose(*dims_alignment)


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
replace=tuple(zip(op.inner_inputs, inputs)),
)


DUMMY_ZERO = pt.constant(0, name="dummy_zero")


@_logprob.register(MarginalFiniteDiscreteRV)
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)

# Obtain the joint_logp graph of the inner RV graph
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
Expand All @@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):

# Reduce logp dimensions corresponding to broadcasted variables
marginalized_logp = logps_dict.pop(marginalized_vv)
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
marginalized_rv.type, logps_dict.values()
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
dependent_dims_connections=op.dims_connections,
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
dependent_logps=[logps_dict[value] for value in values],
)

# Compute the joint_logp for all possible n values of the marginalized RV. We assume
Expand Down Expand Up @@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
mode=Mode().including("local_remove_check_parameter"),
)

joint_logps = pt.logsumexp(joint_logps, axis=0)
joint_logp = pt.logsumexp(joint_logps, axis=0)

# Align logp with non-collapsed batch dimensions of first RV
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)

# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps


@_logprob.register(DiscreteMarginalMarkovChainRV)
@_logprob.register(MarginalDiscreteMarkovChainRV)
def marginal_hmm_logp(op, values, *inputs, **kwargs):
marginalized_rvs_node = op.make_node(*inputs)
inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)

chain_rv, *dependent_rvs = inner_rvs
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
domain = pt.arange(P.shape[-1], dtype="int32")

Expand All @@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))

# Reduce and add the batch dims beyond the chain dimension
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
chain_rv.type, logp_emissions_dict.values()
reduced_logp_emissions = reduce_batch_dependent_logps(
dependent_dims_connections=op.dims_connections,
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
dependent_logps=[logp_emissions_dict[value] for value in values],
)

# Add a batch dimension for the domain of the chain
Expand Down Expand Up @@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
# Final logp is just the sum of the last scan state
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)

# Align logp with non-collapsed batch dimensions of first RV
remaining_dims_first_emission = list(op.dims_connections[0])
# The last dim of chain_rv was removed when computing the logp
remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)

# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
return joint_logp, *dummy_logps
Loading

0 comments on commit e96d07f

Please sign in to comment.