Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid/simplify #1388

Merged
merged 29 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
34a9aef
normalizationConstants returns all constants as a DecisionTreeFactor
dellaert Jan 12, 2023
1dcc6dd
All tests still work with zero constant!
dellaert Jan 12, 2023
03ad393
Removed FactorAndConstant, no longer needed
dellaert Jan 13, 2023
906330f
Add discrete contribution to logProbability
dellaert Jan 13, 2023
681c75c
Expose toFactorGraph to wrapper
dellaert Jan 13, 2023
dfef2c2
Simplify elimination
dellaert Jan 13, 2023
070cdb7
insert_or_assign
dellaert Jan 14, 2023
96e3eb7
Some test refactoring
dellaert Jan 14, 2023
c22b2ca
Improved docs
dellaert Jan 15, 2023
5b0408c
Check for error>0 and proper normalization constant
dellaert Jan 16, 2023
191e614
Fix print
dellaert Jan 16, 2023
57e59d1
Compute log-normalization constant as the max of the individual norma…
dellaert Jan 16, 2023
7a41180
Refactored tests and removed incorrect (R not upper-triangular) test.
dellaert Jan 16, 2023
207c9b7
Implemented the "hidden constant" scheme.
dellaert Jan 17, 2023
3a446d7
Explicitly implement logNormalizationConstant
dellaert Jan 17, 2023
202a5a3
Fixed toFactorGraph and added test to verify
dellaert Jan 17, 2023
a5951d8
Fixed test to work with "hidden constant" scheme
dellaert Jan 17, 2023
8357fc7
Fix python tests (and expose HybridBayesNet.error)
dellaert Jan 17, 2023
e31884c
Eradicated GraphAndConstant
dellaert Jan 17, 2023
9af7236
Added DEBUG_MARGINALS flag
dellaert Jan 17, 2023
519b2bb
Added comment
dellaert Jan 17, 2023
32d69a3
Trap if conditional==null.
dellaert Jan 17, 2023
f4859f0
Fix logProbability tests
dellaert Jan 17, 2023
4283925
Ratio test succeeds on fg, but not on posterior yet,
dellaert Jan 17, 2023
b494a61
Removed obsolete normalizationConstants method
dellaert Jan 17, 2023
892759e
Add math related to hybrid classes
dellaert Jan 17, 2023
c3ca31f
Added partial elimination test
dellaert Jan 17, 2023
e444962
Added correction with the normalization constant in the second elimin…
dellaert Jan 17, 2023
f714c4a
Merge branch 'develop' into hybrid/simplify
dellaert Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Compute log-normalization constant as the max of the individual norma…
…lization constants.
  • Loading branch information
dellaert committed Jan 16, 2023
commit 57e59d1237380e83b82e061885b1ab4291c4b6e4
28 changes: 24 additions & 4 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
this->logConstant_ = std::max(this->logConstant_,
conditional->logNormalizationConstant());
});
}

/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
Expand Down Expand Up @@ -203,8 +212,7 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::sharedFactor{
conditional->likelihood(given)};
return conditional->likelihood(given);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
Expand Down Expand Up @@ -307,11 +315,23 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, errorFunc);
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return logConstant_ + conditional->error(continuousValues) -
conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
return logConstant_ + conditional->error(values.continuous()) -
conditional->logNormalizationConstant();
}

/* *******************************************************************************/
Expand Down
28 changes: 22 additions & 6 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_;
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
Expand Down Expand Up @@ -155,6 +156,10 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;

/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }

/// Return a discrete factor with possibly varying normalization constants.
/// If there is no variation, return nullptr.
boost::shared_ptr<DecisionTreeFactor> normalizationConstants() const;
Expand Down Expand Up @@ -192,18 +197,29 @@ class GTSAM_EXPORT GaussianMixture
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
Expand Down