Skip to content

Commit

Permalink
Use np.linalg.slogdet where possible.
Browse files Browse the repository at this point in the history
  • Loading branch information
anntzer committed Jan 2, 2023
1 parent b0991bd commit 852e49f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
14 changes: 8 additions & 6 deletions lib/hmmlearn/_kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from scipy.special import gammaln, digamma

from . import _utils


def kl_dirichlet(q, p):
"""
Expand Down Expand Up @@ -54,10 +56,10 @@ def kl_multivariate_normal_distribution(mean_q, covar_q, mean_p, covar_p):
D = mean_q.shape[0]

# These correspond to the four terms in the ~wpenny paper documented above
return (0.5 * np.log(np.linalg.det(covar_p) / np.linalg.det(covar_q))
+ 0.5 * np.trace(precision_p @ covar_q)
+ 0.5 * mean_diff @ precision_p @ mean_diff
- D/2)
return .5 * (_utils.logdet(covar_p) - _utils.logdet(covar_q)
+ np.trace(precision_p @ covar_q)
+ mean_diff @ precision_p @ mean_diff
- D)


def kl_gamma_distribution(b_q, c_q, b_p, c_p):
Expand Down Expand Up @@ -104,12 +106,12 @@ def _E(dof, scale):
r"""
$L(a, B) = \int \mathcal{Wishart}(\Gamma; a, B) \log |\Gamma| d\Gamma$
"""
return (-np.log(np.linalg.det(scale / 2))
return (-_utils.logdet(scale / 2)
+ digamma((dof - np.arange(scale.shape[0])) / 2).sum())


def _logZ(dof, scale):
D = scale.shape[0]
return ((D * (D - 1) / 4) * np.log(np.pi)
- dof / 2 * np.log(np.linalg.det(scale / 2))
- dof / 2 * _utils.logdet(scale / 2)
+ gammaln((dof - np.arange(scale.shape[0])) / 2).sum())
11 changes: 11 additions & 0 deletions lib/hmmlearn/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
"""Private utilities."""

import warnings

import numpy as np
from sklearn.utils.validation import NotFittedError


def logdet(a):
sign, logdet = np.linalg.slogdet(a)
if (sign < 0).any():
warnings.warn("invalid value encountered in log", RuntimeWarning)
return np.where(sign < 0, np.nan, logdet)
else:
return logdet


def split_X_lengths(X, lengths):
if lengths is None:
return [X]
Expand Down
5 changes: 2 additions & 3 deletions lib/hmmlearn/vhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,11 @@ def _compute_subnorm_log_likelihood(self, X):
scale_posterior_ = fill_covars(self.scale_posterior_,
self.covariance_type, self.n_components, self.n_features)
W_k = np.linalg.inv(scale_posterior_)
term1 += self.n_features * np.log(2)
term1 += np.log(np.linalg.det(W_k))
term1 += self.n_features * np.log(2) + _utils.logdet(W_k)
term1 /= 2

# We ignore the constant that is typically excluded in the literature
# term2 = self.n_features * log(2 * M_PI ) / 2
# term2 = self.n_features * log(2 * M_PI) / 2
term2 = 0
term3 = self.n_features / self.beta_posterior_

Expand Down

0 comments on commit 852e49f

Please sign in to comment.