From 852e49f32aedad4a0b3c5ddb33684a59442a6b87 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sun, 1 Jan 2023 20:58:46 +0100 Subject: [PATCH] Use np.linalg.slogdet where possible. --- lib/hmmlearn/_kl_divergence.py | 14 ++++++++------ lib/hmmlearn/_utils.py | 11 +++++++++++ lib/hmmlearn/vhmm.py | 5 ++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/lib/hmmlearn/_kl_divergence.py b/lib/hmmlearn/_kl_divergence.py index fa78c703..5fbdb7cd 100644 --- a/lib/hmmlearn/_kl_divergence.py +++ b/lib/hmmlearn/_kl_divergence.py @@ -7,6 +7,8 @@ import numpy as np from scipy.special import gammaln, digamma +from . import _utils + def kl_dirichlet(q, p): """ @@ -54,10 +56,10 @@ def kl_multivariate_normal_distribution(mean_q, covar_q, mean_p, covar_p): D = mean_q.shape[0] # These correspond to the four terms in the ~wpenny paper documented above - return (0.5 * np.log(np.linalg.det(covar_p) / np.linalg.det(covar_q)) - + 0.5 * np.trace(precision_p @ covar_q) - + 0.5 * mean_diff @ precision_p @ mean_diff - - D/2) + return .5 * (_utils.logdet(covar_p) - _utils.logdet(covar_q) + + np.trace(precision_p @ covar_q) + + mean_diff @ precision_p @ mean_diff + - D) def kl_gamma_distribution(b_q, c_q, b_p, c_p): @@ -104,12 +106,12 @@ def _E(dof, scale): r""" $L(a, B) = \int \mathcal{Wishart}(\Gamma; a, B) \log |\Gamma| d\Gamma$ """ - return (-np.log(np.linalg.det(scale / 2)) + return (-_utils.logdet(scale / 2) + digamma((dof - np.arange(scale.shape[0])) / 2).sum()) def _logZ(dof, scale): D = scale.shape[0] return ((D * (D - 1) / 4) * np.log(np.pi) - - dof / 2 * np.log(np.linalg.det(scale / 2)) + - dof / 2 * _utils.logdet(scale / 2) + gammaln((dof - np.arange(scale.shape[0])) / 2).sum()) diff --git a/lib/hmmlearn/_utils.py b/lib/hmmlearn/_utils.py index 5b2ad8c6..f5456dd6 100644 --- a/lib/hmmlearn/_utils.py +++ b/lib/hmmlearn/_utils.py @@ -1,9 +1,20 @@ """Private utilities.""" +import warnings + import numpy as np from sklearn.utils.validation import NotFittedError +def logdet(a): + sign, logdet = np.linalg.slogdet(a) + if (sign < 0).any(): + warnings.warn("invalid value encountered in log", RuntimeWarning) + return np.where(sign < 0, np.nan, logdet) + else: + return logdet + + def split_X_lengths(X, lengths): if lengths is None: return [X] diff --git a/lib/hmmlearn/vhmm.py b/lib/hmmlearn/vhmm.py index 42377039..2b4bbe6b 100644 --- a/lib/hmmlearn/vhmm.py +++ b/lib/hmmlearn/vhmm.py @@ -648,12 +648,11 @@ def _compute_subnorm_log_likelihood(self, X): scale_posterior_ = fill_covars(self.scale_posterior_, self.covariance_type, self.n_components, self.n_features) W_k = np.linalg.inv(scale_posterior_) - term1 += self.n_features * np.log(2) - term1 += np.log(np.linalg.det(W_k)) + term1 += self.n_features * np.log(2) + _utils.logdet(W_k) term1 /= 2 # We ignore the constant that is typically excluded in the literature - # term2 = self.n_features * log(2 * M_PI ) / 2 + # term2 = self.n_features * log(2 * M_PI) / 2 term2 = 0 term3 = self.n_features / self.beta_posterior_