Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid/simplify #1388

Merged
merged 29 commits into from
Jan 17, 2023
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
34a9aef
normalizationConstants returns all constants as a DecisionTreeFactor
dellaert Jan 12, 2023
1dcc6dd
All tests still work with zero constant!
dellaert Jan 12, 2023
03ad393
Removed FactorAndConstant, no longer needed
dellaert Jan 13, 2023
906330f
Add discrete contribution to logProbability
dellaert Jan 13, 2023
681c75c
Expose toFactorGraph to wrapper
dellaert Jan 13, 2023
dfef2c2
Simplify elimination
dellaert Jan 13, 2023
070cdb7
insert_or_assign
dellaert Jan 14, 2023
96e3eb7
Some test refactoring
dellaert Jan 14, 2023
c22b2ca
Improved docs
dellaert Jan 15, 2023
5b0408c
Check for error>0 and proper normalization constant
dellaert Jan 16, 2023
191e614
Fix print
dellaert Jan 16, 2023
57e59d1
Compute log-normalization constant as the max of the individual norma…
dellaert Jan 16, 2023
7a41180
Refactored tests and removed incorrect (R not upper-triangular) test.
dellaert Jan 16, 2023
207c9b7
Implemented the "hidden constant" scheme.
dellaert Jan 17, 2023
3a446d7
Explicitly implement logNormalizationConstant
dellaert Jan 17, 2023
202a5a3
Fixed toFactorGraph and added test to verify
dellaert Jan 17, 2023
a5951d8
Fixed test to work with "hidden constant" scheme
dellaert Jan 17, 2023
8357fc7
Fix python tests (and expose HybridBayesNet.error)
dellaert Jan 17, 2023
e31884c
Eradicated GraphAndConstant
dellaert Jan 17, 2023
9af7236
Added DEBUG_MARGINALS flag
dellaert Jan 17, 2023
519b2bb
Added comment
dellaert Jan 17, 2023
32d69a3
Trap if conditional==null.
dellaert Jan 17, 2023
f4859f0
Fix logProbability tests
dellaert Jan 17, 2023
4283925
Ratio test succeeds on fg, but not on posterior yet,
dellaert Jan 17, 2023
b494a61
Removed obsolete normalizationConstants method
dellaert Jan 17, 2023
892759e
Add math related to hybrid classes
dellaert Jan 17, 2023
c3ca31f
Added partial elimination test
dellaert Jan 17, 2023
e444962
Added correction with the normalization constant in the second elimin…
dellaert Jan 17, 2023
f714c4a
Merge branch 'develop' into hybrid/simplify
dellaert Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added correction with the normalization constant in the second elimin…
…ation path.
  • Loading branch information
dellaert committed Jan 17, 2023
commit e444962aad06985d6b75517bc7a3fac8d98b7b4b
67 changes: 39 additions & 28 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,11 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// FG has a nullptr as we're looping over the factors.
factorGraphTree = removeEmpty(factorGraphTree);

using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;
using Result = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;

// This is the elimination method on the leaf nodes
auto eliminateFunc =
[&](const GaussianFactorGraph &graph) -> EliminationPair {
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
if (graph.empty()) {
return {nullptr, nullptr};
}
Expand All @@ -234,21 +233,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttic_(hybrid_eliminate);
#endif

boost::shared_ptr<GaussianConditional> conditional;
boost::shared_ptr<GaussianFactor> newFactor;
boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph, frontalKeys);
auto result = EliminatePreferCholesky(graph, frontalKeys);

#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif

return {conditional, newFactor};
return result;
};

// Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(factorGraphTree,
eliminateFunc);
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);

#ifdef HYBRID_TIMING
tictoc_print_();
Expand All @@ -264,30 +259,46 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
auto gaussianMixture = boost::make_shared<GaussianMixture>(
frontalKeys, continuousSeparator, discreteSeparator, conditionals);

// If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) {
auto probPrime = [&](const EliminationPair &pair) {
// This is the unnormalized probability q(μ;m) at the mean.
// q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
// The factor has no keys, just contains the residual.
// If there are no more continuous parents, then we create a
// DiscreteFactor here, with the error for each discrete choice.

// Integrate the probability mass in the last continuous conditional using
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
auto probability = [&](const Result &pair) -> double {
static const VectorValues kEmpty;
return pair.second ? exp(-pair.second->error(kEmpty)) /
pair.first->normalizationConstant()
: 1.0;
// If the factor is not null, it has no keys, just contains the residual.
const auto &factor = pair.second;
if (!factor) return 1.0; // TODO(dellaert): not loving this.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this branch be taken?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, whenever we do BayesTree there is some way (that I don't yet understand) that nullptrs are used, and we are dealing with all over the place. Can you explain? Or @varunagrawal ?

return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
dellaert marked this conversation as resolved.
Show resolved Hide resolved
};

const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
discreteSeparator,
DecisionTree<Key, double>(eliminationResults, probPrime));

DecisionTree<Key, double> probabilities(eliminationResults, probability);
return {boost::make_shared<HybridConditional>(gaussianMixture),
discreteFactor};
boost::make_shared<DecisionTreeFactor>(discreteSeparator,
probabilities)};
} else {
// Create a resulting GaussianMixtureFactor on the separator.
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
// taking care to correct for conditional constant.

// Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
const auto &factor = pair.second;
if (!factor) return factor; // TODO(dellaert): not loving this.
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
dellaert marked this conversation as resolved.
Show resolved Hide resolved
return hf;
};

GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
correct);
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors);

return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors)};
mixtureFactor};
}
}

Expand Down