diff --git a/.gitignore b/.gitignore index 66b6392026..a7893c30bc 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,5 @@ htmlcov # text editors .vscode/ + +venv/ diff --git a/src/pyhf/optimize/mixins.py b/src/pyhf/optimize/mixins.py index 19e150f069..f2bd7c4f82 100644 --- a/src/pyhf/optimize/mixins.py +++ b/src/pyhf/optimize/mixins.py @@ -8,8 +8,11 @@ from pyhf.optimize.common import shim from pyhf.tensor.manager import get_backend + log = logging.getLogger(__name__) +__all__ = ("OptimizerMixin",) + class OptimizerMixin: """Mixin Class to build optimizers.""" @@ -50,6 +53,7 @@ def _internal_minimize( do_grad=do_grad, par_names=par_names, ) + result = self._minimize( minimizer, func, @@ -67,7 +71,18 @@ def _internal_minimize( raise exceptions.FailedMinimization(result) return result - def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=False): + def _internal_postprocess( + self, + fitresult, + stitch_pars, + using_minuit, + return_uncertainties=False, + uncertainties=None, + hess_inv=None, + calc_correlations=False, + fixed_vals=None, + init_pars=None, + ): """ Post-process the fit result. @@ -80,17 +95,29 @@ def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=Fal fitted_pars = stitch_pars(tensorlib.astensor(fitresult.x)) # check if uncertainties were provided (and stitch just in case) - uncertainties = getattr(fitresult, 'unc', None) + uncertainties = getattr(fitresult, 'unc', None) or uncertainties if uncertainties is not None: # extract number of fixed parameters num_fixed_pars = len(fitted_pars) - len(fitresult.x) - # FIXME: Set uncertainties for fixed parameters to 0 manually - # https://github.com/scikit-hep/iminuit/issues/762 - # https://github.com/scikit-hep/pyhf/issues/1918 - # https://github.com/scikit-hep/cabinetry/pull/346 - uncertainties = np.where(fitresult.minuit.fixed, 0.0, uncertainties) - + # Set uncertainties for fixed parameters to 0 manually + if fixed_vals is not None: # check for fixed vals + if using_minuit: + # See related discussion here: + # https://github.com/scikit-hep/iminuit/issues/762 + # https://github.com/scikit-hep/pyhf/issues/1918 + # https://github.com/scikit-hep/cabinetry/pull/346 + uncertainties = np.where(fitresult.minuit.fixed, 0.0, uncertainties) + else: + # Not using minuit, so don't have `fitresult.minuit.fixed` here: do it manually + fixed_bools = [False] * len(init_pars) + for index, _ in fixed_vals: + fixed_bools[index] = True + uncertainties = tensorlib.where( + tensorlib.astensor(fixed_bools, dtype="bool"), + tensorlib.astensor(0.0), + uncertainties, + ) # stitch in zero-uncertainty for fixed values uncertainties = stitch_pars( tensorlib.astensor(uncertainties), @@ -99,24 +126,57 @@ def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=Fal if return_uncertainties: fitted_pars = tensorlib.stack([fitted_pars, uncertainties], axis=1) - correlations = getattr(fitresult, 'corr', None) - if correlations is not None: + cov = getattr(fitresult, 'hess_inv', None) + if cov is None and hess_inv is not None: + cov = hess_inv + + # we also need to edit the covariance matrix to zero-out uncertainties! + # NOTE: minuit already does this (https://github.com/scikit-hep/iminuit/issues/762#issuecomment-1207436406) + if fixed_vals is not None and not using_minuit: + fixed_bools = [False] * len(init_pars) + # Convert fixed_bools to a numpy array and reshape to make it a column vector + fixed_mask = tensorlib.reshape( + tensorlib.astensor(fixed_bools, dtype="bool"), (-1, 1) + ) + # Create 2D masks for rows and columns + row_mask = fixed_mask + col_mask = tensorlib.transpose(fixed_mask) + + # Use logical OR to combine the masks + final_mask = row_mask | col_mask + + # Use np.where to set elements of the covariance matrix to 0 where the mask is True + cov = tensorlib.where( + final_mask, tensorlib.astensor(0.0), tensorlib.astensor(cov) + ) + + correlations_from_fit = getattr(fitresult, 'corr', None) + if correlations_from_fit is None and calc_correlations: + correlations_from_fit = cov / tensorlib.outer(uncertainties, uncertainties) + correlations_from_fit = tensorlib.where( + tensorlib.isfinite(correlations_from_fit), + correlations_from_fit, + tensorlib.astensor(0.0), + ) + + if correlations_from_fit is not None and not using_minuit: _zeros = tensorlib.zeros(num_fixed_pars) # possibly a more elegant way to do this stitched_columns = [ stitch_pars(tensorlib.astensor(column), stitch_with=_zeros) - for column in zip(*correlations) + for column in zip(*correlations_from_fit) ] stitched_rows = [ stitch_pars(tensorlib.astensor(row), stitch_with=_zeros) for row in zip(*stitched_columns) ] - correlations = tensorlib.stack(stitched_rows, axis=1) + correlations_from_fit = tensorlib.stack(stitched_rows, axis=1) fitresult.x = fitted_pars fitresult.fun = tensorlib.astensor(fitresult.fun) fitresult.unc = uncertainties - fitresult.corr = correlations + fitresult.hess_inv = cov + fitresult.corr = correlations_from_fit return fitresult @@ -164,6 +224,10 @@ def minimize( - minimum (:obj:`float`): if ``return_fitted_val`` flagged, return minimized objective value - result (:class:`scipy.optimize.OptimizeResult`): if ``return_result_obj`` flagged """ + # literally just for the minimizer name to check if we're using minuit + # so we can check if valid for uncertainty calc later + using_minuit = hasattr(self, "name") and self.name == "minuit" + # Configure do_grad based on backend "automagically" if not set by user tensorlib, _ = get_backend() do_grad = tensorlib.default_do_grad if do_grad is None else do_grad @@ -194,8 +258,30 @@ def minimize( result = self._internal_minimize( **minimizer_kwargs, options=kwargs, par_names=par_names ) + + # compute uncertainties with automatic differentiation + if not using_minuit and tensorlib.name in ['tensorflow', 'jax', 'pytorch']: + # stitch in missing parameters (e.g. fixed parameters) + all_pars = stitch_pars(tensorlib.astensor(result.x)) + hess_inv = tensorlib.fisher_cov(pdf, all_pars, data) + uncertainties = tensorlib.sqrt(tensorlib.diagonal(hess_inv)) + calc_correlations = True + else: + hess_inv = None + uncertainties = None + calc_correlations = False + + # uncerts are set to 0 in here for fixed pars result = self._internal_postprocess( - result, stitch_pars, return_uncertainties=return_uncertainties + result, + stitch_pars, + using_minuit, + return_uncertainties=return_uncertainties, + uncertainties=uncertainties, + hess_inv=hess_inv, + calc_correlations=calc_correlations, + fixed_vals=fixed_vals, + init_pars=init_pars, ) _returns = [result.x] diff --git a/src/pyhf/optimize/opt_minuit.py b/src/pyhf/optimize/opt_minuit.py index 4f1edb3dca..17d11a16dc 100644 --- a/src/pyhf/optimize/opt_minuit.py +++ b/src/pyhf/optimize/opt_minuit.py @@ -5,6 +5,8 @@ import scipy import iminuit +__all__ = ("minuit_optimizer",) + class minuit_optimizer(OptimizerMixin): """ diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 438f3e7325..172c03d8a2 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -2,7 +2,7 @@ config.update('jax_enable_x64', True) -from jax import Array +from jax import Array, hessian import jax.numpy as jnp from jax.scipy.special import gammaln, xlogy from jax.scipy import special @@ -622,3 +622,41 @@ def transpose(self, tensor_in): .. versionadded:: 0.7.0 """ return tensor_in.transpose() + + def fisher_cov(self, model, pars, data): + """Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate. + See the Cramér-Rao bound for more details on the derivation of this. + + Args: + model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``. + pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty. + data (:obj:`tensor`): The observed data. + + Returns: + JAX ndarray: The covariance matrix of the maximum likelihood estimate. + """ + return jnp.linalg.inv( + -hessian(lambda pars, data: model.logpdf(pars, data)[0])(pars, data) + ) + + def diagonal(self, tensor_in): + """ + Return the diagonal elements of the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("jax") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]]) + >>> tensor + Array([[1., 0.], + [0., 1.]], dtype=float64) + >>> pyhf.tensorlib.diagonal(tensor) + Array([1., 1.], dtype=float64) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + JAX ndarray: The diagonal of the input tensor. + """ + return jnp.diag(tensor_in) diff --git a/src/pyhf/tensor/numpy_backend.py b/src/pyhf/tensor/numpy_backend.py index 8e019a7ec9..ec33392e3b 100644 --- a/src/pyhf/tensor/numpy_backend.py +++ b/src/pyhf/tensor/numpy_backend.py @@ -3,7 +3,16 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable, Generic, Mapping, Sequence, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Callable, + Generic, + Mapping, + Sequence, + TypeVar, + Union, + Any, +) import numpy as np @@ -648,3 +657,27 @@ def transpose(self, tensor_in: Tensor[T]) -> ArrayLike: .. versionadded:: 0.7.0 """ return tensor_in.transpose() + + def fisher_cov(self, model: Any, pars: Tensor[T], data: Tensor[T]) -> ArrayLike: + raise NotImplementedError + + def diagonal(self, tensor_in: Tensor[T]) -> ArrayLike: + """Return the diagonal elements of the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("numpy") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]]) + >>> tensor + array([[1., 0.], + [0., 1.]]) + >>> pyhf.tensorlib.diagonal(tensor) + array([1., 1.]) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + :class:`numpy.ndarray`: The diagonal of the input tensor. + """ + return np.diag(tensor_in) diff --git a/src/pyhf/tensor/pytorch_backend.py b/src/pyhf/tensor/pytorch_backend.py index c962c864f9..c4cc80032a 100644 --- a/src/pyhf/tensor/pytorch_backend.py +++ b/src/pyhf/tensor/pytorch_backend.py @@ -2,6 +2,7 @@ import torch import torch.autograd +from torch.func import hessian from torch.distributions.utils import broadcast_all import logging import math @@ -626,3 +627,41 @@ def transpose(self, tensor_in): .. versionadded:: 0.7.0 """ return tensor_in.transpose(0, 1) + + def fisher_cov(self, model, pars, data): + """Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate. + See the Cramér-Rao bound for more details on the derivation of this. + + Args: + model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``. + pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty. + data (:obj:`tensor`): The observed data. + + Returns: + PyTorch FloatTensor: The covariance matrix of the maximum likelihood estimate. + """ + return torch.linalg.inv( + -hessian(lambda pars, data: model.logpdf(pars, data)[0])(pars, data) + ) + + def diagonal(self, tensor_in): + """ + Return the diagonal elements of the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("pytorch") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]]) + >>> tensor + tensor([[1., 0.], + [0., 1.]]) + >>> pyhf.tensorlib.diagonal(tensor) + tensor([1., 1.]) + + Args: + tensor_in (:obj:`tensor`): The input tensor object. + + Returns: + PyTorch FloatTensor: The diagonal of the input tensor. + """ + return torch.diagonal(tensor_in) diff --git a/src/pyhf/tensor/tensorflow_backend.py b/src/pyhf/tensor/tensorflow_backend.py index a36fa9206f..72b7dc224e 100644 --- a/src/pyhf/tensor/tensorflow_backend.py +++ b/src/pyhf/tensor/tensorflow_backend.py @@ -723,3 +723,47 @@ def transpose(self, tensor_in): .. versionadded:: 0.7.0 """ return tf.transpose(tensor_in) + + def fisher_cov(self, model, pars, data): + """Calculates the inverse of the Fisher information matrix to estimate the covariance of the maximum likelihood estimate. + See the Cramér-Rao bound for more details on the derivation of this. + + Args: + model (:obj:`pyhf.pdf.Model`): The statistical model adhering to the schema ``model.json``. + pars (:obj:`tensor`): The (mle) model parameters at which to evaluate the uncertainty. + data (:obj:`tensor`): The observed data. + + Returns: + TensorFlow Tensor: The covariance matrix of the maximum likelihood estimate. + """ + with tf.GradientTape() as t2: + t2.watch(pars) + with tf.GradientTape() as t1: + t1.watch(pars) + lhood = model.logpdf(pars, data)[0] + g = t1.gradient(lhood, pars) + hess = t2.jacobian(g, pars) + return tf.linalg.inv(-hess) + + def diagonal(self, tensor_in): + """Return the diagonal elements of the tensor. + + Example: + >>> import pyhf + >>> pyhf.set_backend("tensorflow") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 0.0], [0.0, 1.0]]) + >>> tensor + + >>> pyhf.tensorlib.diagonal(tensor) +