From 05b2d3169f0a9e72e750c679b4dd92c75b7e0cda Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 3 Dec 2022 06:43:46 +0530 Subject: [PATCH 01/12] better printing --- gtsam/hybrid/GaussianMixture.cpp | 2 +- gtsam/hybrid/HybridFactor.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 4819eda657..0b6fcdff7c 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -105,7 +105,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { /* *******************************************************************************/ void GaussianMixture::print(const std::string &s, const KeyFormatter &formatter) const { - std::cout << s; + std::cout << (s.empty() ? "" : s + "\n"); if (isContinuous()) std::cout << "Continuous "; if (isDiscrete()) std::cout << "Discrete "; if (isHybrid()) std::cout << "Hybrid "; diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 1216fd9224..b25e97f051 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { /* ************************************************************************ */ void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { - std::cout << s; + std::cout << (s.empty() ? "" : s + "\n"); if (isContinuous_) std::cout << "Continuous "; if (isDiscrete_) std::cout << "Discrete "; if (isHybrid_) std::cout << "Hybrid "; From 3eaf4cc910fa0cdf34023c4ebd0bf4b37354499b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 3 Dec 2022 17:00:51 +0530 Subject: [PATCH 02/12] move multifrontal optimize test to testHybridBayesTree and update docstrings --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 3 +-- gtsam/hybrid/tests/testHybridBayesNet.cpp | 19 ------------------- gtsam/hybrid/tests/testHybridBayesTree.cpp | 19 +++++++++++++++++++ .../tests/testHybridNonlinearFactorGraph.cpp | 4 ++-- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 6218303381..1d52a24afe 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -257,7 +257,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // If there are no more continuous parents, then we should create here a // DiscreteFactor, with the error for each discrete choice. if (keysOfSeparator.empty()) { - // TODO(Varun) Use the math from the iMHS_Math-1-indexed document VectorValues empty_values; auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { if (!factor) { @@ -574,7 +573,7 @@ HybridGaussianFactorGraph::eliminateHybridSequential( bayesNet->at(bayesNet->size() - 1); DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - // If not discrete variables, return the eliminated bayes net. + // If no discrete variables, return the eliminated bayes net. if (discrete_keys.size() == 0) { return bayesNet; } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 8b8ca976b0..d2951ddaa7 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -164,25 +164,6 @@ TEST(HybridBayesNet, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } -/* ****************************************************************************/ -// Test bayes net multifrontal optimize -TEST(HybridBayesNet, OptimizeMultifrontal) { - Switching s(4); - - Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesTree::shared_ptr hybridBayesTree = - s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering); - HybridValues delta = hybridBayesTree->optimize(); - - VectorValues expectedValues; - expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); - expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); - expectedValues.insert(X(2), -1.00971 * Vector1::Ones()); - expectedValues.insert(X(3), -1.0001 * Vector1::Ones()); - - EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); -} - /* ****************************************************************************/ // Test bayes net error TEST(HybridBayesNet, Error) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 9bc12c31d7..3992aa023b 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -32,6 +32,25 @@ using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; +/* ****************************************************************************/ +// Test multifrontal optimize +TEST(HybridBayesTree, OptimizeMultifrontal) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree::shared_ptr hybridBayesTree = + s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering); + HybridValues delta = hybridBayesTree->optimize(); + + VectorValues expectedValues; + expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); + expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); + expectedValues.insert(X(2), -1.00971 * Vector1::Ones()); + expectedValues.insert(X(3), -1.0001 * Vector1::Ones()); + + EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); +} + /* ****************************************************************************/ // Test for optimizing a HybridBayesTree with a given assignment. TEST(HybridBayesTree, OptimizeAssignment) { diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index f8c61baf7c..e3b7f761a7 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -386,11 +386,11 @@ TEST(HybridFactorGraph, Partial_Elimination) { auto linearizedFactorGraph = self.linearizedFactorGraph; - // Create ordering. + // Create ordering of only continuous variables. Ordering ordering; for (size_t k = 0; k < self.K; k++) ordering += X(k); - // Eliminate partially. + // Eliminate partially i.e. only continuous part. HybridBayesNet::shared_ptr hybridBayesNet; HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; std::tie(hybridBayesNet, remainingFactorGraph) = From cd3cfa0faa5ebcacc05d7ccbdfc21bcdf505d0f9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 3 Dec 2022 17:14:11 +0530 Subject: [PATCH 03/12] moved sequential elimination code to HybridEliminationTree --- gtsam/hybrid/HybridEliminationTree.cpp | 12 ++- gtsam/hybrid/HybridEliminationTree.h | 107 +++++++++++++++++++++ gtsam/hybrid/HybridFactor.cpp | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 68 ------------- gtsam/hybrid/HybridGaussianFactorGraph.h | 26 ----- gtsam/hybrid/HybridSmoother.cpp | 2 +- 6 files changed, 119 insertions(+), 98 deletions(-) diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp index c2df2dd600..fe91905713 100644 --- a/gtsam/hybrid/HybridEliminationTree.cpp +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -27,12 +27,20 @@ template class EliminationTree; HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const VariableIndex& structure, const Ordering& order) - : Base(factorGraph, structure, order) {} + : Base(factorGraph, structure, order), + graph_(factorGraph), + variable_index_(structure) { + // Segregate the continuous and the discrete keys + std::tie(continuous_ordering_, discrete_ordering_) = + graph_.separateContinuousDiscreteOrdering(order); +} /* ************************************************************************* */ HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const Ordering& order) - : Base(factorGraph, order) {} + : Base(factorGraph, order), + graph_(factorGraph), + variable_index_(VariableIndex(factorGraph)) {} /* ************************************************************************* */ bool HybridEliminationTree::equals(const This& other, double tol) const { diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index b2dd1bd9c8..dfa88bf4e5 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -33,6 +33,12 @@ class GTSAM_EXPORT HybridEliminationTree private: friend class ::EliminationTreeTester; + Ordering continuous_ordering_, discrete_ordering_; + /// Used to store the original factor graph to eliminate + HybridGaussianFactorGraph graph_; + /// Store the provided variable index. + VariableIndex variable_index_; + public: typedef EliminationTree Base; ///< Base class @@ -66,6 +72,107 @@ class GTSAM_EXPORT HybridEliminationTree /** Test whether the tree is equal to another */ bool equals(const This& other, double tol = 1e-9) const; + + /** + * @brief Helper method to eliminate continuous variables. + * + * If no continuous variables exist, return an empty bayes net + * and the original graph. + * + * @param function Elimination function for hybrid elimination. + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminateContinuous(Eliminate function) const { + if (continuous_ordering_.size() > 0) { + This continuous_etree(graph_, variable_index_, continuous_ordering_); + return continuous_etree.Base::eliminate(function); + + } else { + BayesNetType::shared_ptr bayesNet = boost::make_shared(); + FactorGraphType::shared_ptr discreteGraph = + boost::make_shared(graph_); + return std::make_pair(bayesNet, discreteGraph); + } + } + + /** + * @brief Helper method to eliminate the discrete variables after the + * continuous variables have been eliminated. + * + * If there are no discrete variables, return an empty bayes net and the + * discreteGraph which is passed in. + * + * @param function Elimination function + * @param discreteGraph The factor graph with the factor ϕ(X | M, Z). + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminateDiscrete(Eliminate function, + const FactorGraphType::shared_ptr& discreteGraph) const { + BayesNetType::shared_ptr discreteBayesNet; + FactorGraphType::shared_ptr finalGraph; + if (discrete_ordering_.size() > 0) { + This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), + discrete_ordering_); + + std::tie(discreteBayesNet, finalGraph) = + discrete_etree.Base::eliminate(function); + + } else { + discreteBayesNet = boost::make_shared(); + finalGraph = discreteGraph; + } + + return std::make_pair(discreteBayesNet, finalGraph); + } + + /** + * @brief Override the EliminationTree eliminate method + * so we can perform hybrid elimination correctly. + * + * @param function + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminate(Eliminate function) const { + // Perform continuous elimination + BayesNetType::shared_ptr bayesNet; + FactorGraphType::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); + + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { + // Get the last continuous conditional + // which will have all the discrete keys + HybridConditional::shared_ptr last_conditional = + bayesNet->at(bayesNet->size() - 1); + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph_.continuousProbPrimes(discrete_keys, bayesNet); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + + // Perform discrete elimination + BayesNetType::shared_ptr discreteBayesNet; + FactorGraphType::shared_ptr finalGraph; + std::tie(discreteBayesNet, finalGraph) = + eliminateDiscrete(function, discreteGraph); + + // Add the discrete conditionals to the hybrid conditionals + bayesNet->add(*discreteBayesNet); + + return std::make_pair(bayesNet, finalGraph); + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index b25e97f051..1216fd9224 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { /* ************************************************************************ */ void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { - std::cout << (s.empty() ? "" : s + "\n"); + std::cout << s; if (isContinuous_) std::cout << "Continuous "; if (isDiscrete_) std::cout << "Discrete "; if (isHybrid_) std::cout << "Hybrid "; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 1d52a24afe..1afe4f38af 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -550,74 +550,6 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( return std::make_pair(continuous_ordering, discrete_ordering); } -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateHybridSequential( - const boost::optional continuous, - const boost::optional discrete, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - const Ordering continuous_ordering = - continuous ? *continuous : Ordering(this->continuousKeys()); - const Ordering discrete_ordering = - discrete ? *discrete : Ordering(this->discreteKeys()); - - // Eliminate continuous - HybridBayesNet::shared_ptr bayesNet; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = - BaseEliminateable::eliminatePartialSequential(continuous_ordering, - function, variableIndex); - - // Get the last continuous conditional which will have all the discrete keys - HybridConditional::shared_ptr last_conditional = - bayesNet->at(bayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // If no discrete variables, return the eliminated bayes net. - if (discrete_keys.size() == 0) { - return bayesNet; - } - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - this->continuousProbPrimes(discrete_keys, bayesNet); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - - // Perform discrete elimination - HybridBayesNet::shared_ptr discreteBayesNet = - discreteGraph->BaseEliminateable::eliminateSequential( - discrete_ordering, function, variableIndex); - - bayesNet->add(*discreteBayesNet); - - return bayesNet; -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateSequential( - OptionalOrderingType orderingType, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - return BaseEliminateable::eliminateSequential(orderingType, function, - variableIndex); -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateSequential( - const Ordering &ordering, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - // Segregate the continuous and the discrete keys - Ordering continuous_ordering, discrete_ordering; - std::tie(continuous_ordering, discrete_ordering) = - this->separateContinuousDiscreteOrdering(ordering); - - return this->eliminateHybridSequential(continuous_ordering, discrete_ordering, - function, variableIndex); -} - /* ************************************************************************ */ boost::shared_ptr HybridGaussianFactorGraph::eliminateHybridMultifrontal( diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 47b94070f7..a0d80b1547 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -302,32 +302,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph std::pair separateContinuousDiscreteOrdering( const Ordering& ordering) const; - /** - * @brief Custom elimination function which computes the correct - * continuous probabilities. Returns a bayes net. - * - * @param continuous Optional ordering for all continuous variables. - * @param discrete Optional ordering for all discrete variables. - * @return boost::shared_ptr - */ - boost::shared_ptr eliminateHybridSequential( - const boost::optional continuous = boost::none, - const boost::optional discrete = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Sequential elimination overload for hybrid - boost::shared_ptr eliminateSequential( - OptionalOrderingType orderingType = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Sequential elimination overload for hybrid - boost::shared_ptr eliminateSequential( - const Ordering& ordering, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - /** * @brief Custom elimination function which computes the correct * continuous probabilities. Returns a bayes tree. diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 12f6949abf..7400053bfc 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - auto bayesNetFragment = graph.eliminateHybridSequential(); + auto bayesNetFragment = graph.eliminateSequential(); /// Prune if (maxNrLeaves) { From 15fffeb18adf952424789cf6835123bc97b2eb1b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 14:30:01 +0530 Subject: [PATCH 04/12] add getters to HybridEliminationTree --- gtsam/hybrid/HybridEliminationTree.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index dfa88bf4e5..65d614ca3c 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -173,6 +173,13 @@ class GTSAM_EXPORT HybridEliminationTree return std::make_pair(bayesNet, finalGraph); } + + Ordering continuousOrdering() const { return continuous_ordering_; } + Ordering discreteOrdering() const { return discrete_ordering_; } + + /// Store the provided variable index. + VariableIndex variableIndex() const { return variable_index_; } + HybridGaussianFactorGraph graph() const { return graph_; } }; } // namespace gtsam From addbe2a5a57bfab3851a51a7193d92f3e2110bfa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 14:55:17 +0530 Subject: [PATCH 05/12] override eliminate in HybridJunctionTree --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 93 ------------ gtsam/hybrid/HybridGaussianFactorGraph.h | 26 +--- gtsam/hybrid/HybridJunctionTree.cpp | 139 +++++++++++++++++- gtsam/hybrid/HybridJunctionTree.h | 67 ++++++++- .../tests/testHybridGaussianFactorGraph.cpp | 7 +- 5 files changed, 209 insertions(+), 123 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 1afe4f38af..c430fac2c6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -550,98 +550,5 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( return std::make_pair(continuous_ordering, discrete_ordering); } -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateHybridMultifrontal( - const boost::optional continuous, - const boost::optional discrete, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - const Ordering continuous_ordering = - continuous ? *continuous : Ordering(this->continuousKeys()); - const Ordering discrete_ordering = - discrete ? *discrete : Ordering(this->discreteKeys()); - - // Eliminate continuous - HybridBayesTree::shared_ptr bayesTree; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesTree, discreteGraph) = - BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering, - function, variableIndex); - - // Get the last continuous conditional which will have all the discrete - const Key last_continuous_key = continuous_ordering.back(); - HybridConditional::shared_ptr last_conditional = - (*bayesTree)[last_continuous_key]->conditional(); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // If not discrete variables, return the eliminated bayes net. - if (discrete_keys.size() == 0) { - return bayesTree; - } - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - this->continuousProbPrimes(discrete_keys, bayesTree); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - - // Eliminate discrete variables to get the discrete bayes tree. - // This bayes tree will be updated with the - // continuous variables as the child nodes. - HybridBayesTree::shared_ptr updatedBayesTree = - discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering, - function); - - // Get the clique with all the discrete keys. - // There should only be 1 clique. - const HybridBayesTree::sharedClique discrete_clique = - (*updatedBayesTree)[discrete_ordering.at(0)]; - - std::set clique_set; - for (auto node : bayesTree->nodes()) { - clique_set.insert(node.second); - } - - // Set the root of the bayes tree as the discrete clique - for (auto clique : clique_set) { - if (clique->conditional()->parents() == - discrete_clique->conditional()->frontals()) { - updatedBayesTree->addClique(clique, discrete_clique); - - } else { - // Remove the clique from the children of the parents since it will get - // added again in addClique. - auto clique_it = std::find(clique->parent()->children.begin(), - clique->parent()->children.end(), clique); - clique->parent()->children.erase(clique_it); - updatedBayesTree->addClique(clique, clique->parent()); - } - } - return updatedBayesTree; -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateMultifrontal( - OptionalOrderingType orderingType, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - return BaseEliminateable::eliminateMultifrontal(orderingType, function, - variableIndex); -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateMultifrontal( - const Ordering &ordering, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - // Segregate the continuous and the discrete keys - Ordering continuous_ordering, discrete_ordering; - std::tie(continuous_ordering, discrete_ordering) = - this->separateContinuousDiscreteOrdering(ordering); - - return this->eliminateHybridMultifrontal( - continuous_ordering, discrete_ordering, function, variableIndex); -} } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index a0d80b1547..210ce50e93 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -302,31 +302,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph std::pair separateContinuousDiscreteOrdering( const Ordering& ordering) const; - /** - * @brief Custom elimination function which computes the correct - * continuous probabilities. Returns a bayes tree. - * - * @param continuous Optional ordering for all continuous variables. - * @param discrete Optional ordering for all discrete variables. - * @return boost::shared_ptr - */ - boost::shared_ptr eliminateHybridMultifrontal( - const boost::optional continuous = boost::none, - const boost::optional discrete = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Multifrontal elimination overload for hybrid - boost::shared_ptr eliminateMultifrontal( - OptionalOrderingType orderingType = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Multifrontal elimination overload for hybrid - boost::shared_ptr eliminateMultifrontal( - const Ordering& ordering, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 422c200a43..f233d4bef9 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -142,7 +142,8 @@ struct HybridConstructorTraversalData { /* ************************************************************************* */ HybridJunctionTree::HybridJunctionTree( - const HybridEliminationTree& eliminationTree) { + const HybridEliminationTree& eliminationTree) + : etree_(eliminationTree) { gttic(JunctionTree_FromEliminationTree); // Here we rely on the BayesNet having been produced by this elimination tree, // such that the conditionals are arranged in DFS post-order. We traverse the @@ -169,4 +170,140 @@ HybridJunctionTree::HybridJunctionTree( Base::remainingFactors_ = eliminationTree.remainingFactors(); } +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridJunctionTree::eliminateContinuous( + const Eliminate& function, const HybridGaussianFactorGraph& graph, + const Ordering& continuous_ordering) const { + HybridBayesTree::shared_ptr continuousBayesTree; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + + if (continuous_ordering.size() > 0) { + HybridEliminationTree continuous_etree(graph, etree_.variableIndex(), + continuous_ordering); + + This continuous_junction_tree(continuous_etree); + std::tie(continuousBayesTree, discreteGraph) = + continuous_junction_tree.Base::eliminate(function); + + } else { + continuousBayesTree = boost::make_shared(); + discreteGraph = boost::make_shared(graph); + } + + return std::make_pair(continuousBayesTree, discreteGraph); +} +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridJunctionTree::eliminateDiscrete( + const Eliminate& function, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& discrete_ordering) const { + HybridBayesTree::shared_ptr updatedBayesTree; + HybridGaussianFactorGraph::shared_ptr finalGraph; + if (discrete_ordering.size() > 0) { + HybridEliminationTree discrete_etree( + *discreteGraph, VariableIndex(*discreteGraph), discrete_ordering); + + This discrete_junction_tree(discrete_etree); + + std::tie(updatedBayesTree, finalGraph) = + discrete_junction_tree.Base::eliminate(function); + + // Get the clique with all the discrete keys. + // There should only be 1 clique. + const HybridBayesTree::sharedClique discrete_clique = + (*updatedBayesTree)[discrete_ordering.at(0)]; + + std::set clique_set; + for (auto node : continuousBayesTree->nodes()) { + clique_set.insert(node.second); + } + + // Set the root of the bayes tree as the discrete clique + for (auto clique : clique_set) { + if (clique->conditional()->parents() == + discrete_clique->conditional()->frontals()) { + updatedBayesTree->addClique(clique, discrete_clique); + + } else { + if (clique->parent()) { + // Remove the clique from the children of the parents since it will + // get added again in addClique. + auto clique_it = std::find(clique->parent()->children.begin(), + clique->parent()->children.end(), clique); + clique->parent()->children.erase(clique_it); + updatedBayesTree->addClique(clique, clique->parent()); + } else { + updatedBayesTree->addClique(clique); + } + } + } + } else { + updatedBayesTree = continuousBayesTree; + finalGraph = discreteGraph; + } + + return std::make_pair(updatedBayesTree, finalGraph); +} + +/* ************************************************************************* */ +boost::shared_ptr HybridJunctionTree::addProbPrimes( + const HybridGaussianFactorGraph& graph, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& continuous_ordering, + const Ordering& discrete_ordering) const { + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + if (continuous_ordering.size() > 0 && discrete_ordering.size() > 0) { + // Collect all the discrete keys + DiscreteKeys discrete_keys; + for (auto node : continuousBayesTree->nodes()) { + auto node_dkeys = node.second->conditional()->discreteKeys(); + discrete_keys.insert(discrete_keys.end(), node_dkeys.begin(), + node_dkeys.end()); + } + // Remove duplicates and convert back to DiscreteKeys + std::set dkeys_set(discrete_keys.begin(), discrete_keys.end()); + discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end()); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph.continuousProbPrimes(discrete_keys, continuousBayesTree); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + + return discreteGraph; +} + +/* ************************************************************************* */ +std::pair +HybridJunctionTree::eliminate(const Eliminate& function) const { + Ordering continuous_ordering = etree_.continuousOrdering(); + Ordering discrete_ordering = etree_.discreteOrdering(); + + FactorGraphType graph = etree_.graph(); + + // Eliminate continuous + BayesTreeType::shared_ptr continuousBayesTree; + FactorGraphType::shared_ptr discreteGraph; + std::tie(continuousBayesTree, discreteGraph) = + this->eliminateContinuous(function, graph, continuous_ordering); + + FactorGraphType::shared_ptr updatedDiscreteGraph = + this->addProbPrimes(graph, continuousBayesTree, discreteGraph, + continuous_ordering, discrete_ordering); + + // Eliminate discrete variables to get the discrete bayes tree. + return this->eliminateDiscrete(function, continuousBayesTree, + updatedDiscreteGraph, discrete_ordering); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h index 4b0c369a82..2dc13d5e38 100644 --- a/gtsam/hybrid/HybridJunctionTree.h +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -51,10 +51,15 @@ class HybridEliminationTree; */ class GTSAM_EXPORT HybridJunctionTree : public JunctionTree { + /// Record the elimination tree for use in hybrid elimination. + HybridEliminationTree etree_; + /// Store the provided variable index. + VariableIndex variable_index_; + public: typedef JunctionTree Base; ///< Base class - typedef HybridJunctionTree This; ///< This class + typedef HybridJunctionTree This; ///< This class typedef boost::shared_ptr shared_ptr; ///< Shared pointer to this class /** @@ -67,6 +72,66 @@ class GTSAM_EXPORT HybridJunctionTree * @return The elimination tree */ HybridJunctionTree(const HybridEliminationTree& eliminationTree); + + /** + * @brief + * + * @param function + * @param graph + * @param continuous_ordering + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminateContinuous(const Eliminate& function, + const HybridGaussianFactorGraph& graph, + const Ordering& continuous_ordering) const; + + /** + * @brief + * + * @param function + * @param continuousBayesTree + * @param discreteGraph + * @param discrete_ordering + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminateDiscrete(const Eliminate& function, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& discrete_ordering) const; + + /** + * @brief + * + * @param graph + * @param continuousBayesTree + * @param discreteGraph + * @param continuous_ordering + * @param discrete_ordering + * @return boost::shared_ptr + */ + boost::shared_ptr addProbPrimes( + const HybridGaussianFactorGraph& graph, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& continuous_ordering, + const Ordering& discrete_ordering) const; + + /** + * @brief + * + * @param function + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminate(const Eliminate& function) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 248d71d5fc..6288bcd93f 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -182,8 +182,9 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { boost::make_shared(X(1), I_3x3, Vector3::Ones())})); hfg.add(DecisionTreeFactor(m1, {2, 8})); - //TODO(Varun) Adding extra discrete variable not connected to continuous variable throws segfault - // hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); + // TODO(Varun) Adding extra discrete variable not connected to continuous + // variable throws segfault + // hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(hfg.getHybridOrdering()); @@ -276,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); // 9 cliques in the bayes tree and 0 remaining variables to eliminate. - EXPECT_LONGS_EQUAL(9, hbt->size()); + EXPECT_LONGS_EQUAL(7, hbt->size()); EXPECT_LONGS_EQUAL(0, remaining->size()); /* From ae0b3e3902d929bd47bec4579634c2825986d5ce Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 18:21:22 +0530 Subject: [PATCH 06/12] split up the eliminate method to constituent parts --- gtsam/hybrid/HybridEliminationTree.cpp | 88 ++++++++++++++++++++++++ gtsam/hybrid/HybridEliminationTree.h | 95 ++++++-------------------- 2 files changed, 107 insertions(+), 76 deletions(-) diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp index fe91905713..e541059197 100644 --- a/gtsam/hybrid/HybridEliminationTree.cpp +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -47,4 +47,92 @@ bool HybridEliminationTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminateContinuous(Eliminate function) const { + if (continuous_ordering_.size() > 0) { + This continuous_etree(graph_, variable_index_, continuous_ordering_); + return continuous_etree.Base::eliminate(function); + + } else { + HybridBayesNet::shared_ptr bayesNet = boost::make_shared(); + HybridGaussianFactorGraph::shared_ptr discreteGraph = + boost::make_shared(graph_); + return std::make_pair(bayesNet, discreteGraph); + } +} + +/* ************************************************************************* */ +boost::shared_ptr +HybridEliminationTree::addProbPrimes( + const HybridBayesNet::shared_ptr& continuousBayesNet, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { + // Get the last continuous conditional + // which will have all the discrete keys + HybridConditional::shared_ptr last_conditional = + continuousBayesNet->at(continuousBayesNet->size() - 1); + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph_.continuousProbPrimes(discrete_keys, continuousBayesNet); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + return discreteGraph; +} + +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminateDiscrete( + Eliminate function, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + HybridBayesNet::shared_ptr discreteBayesNet; + HybridGaussianFactorGraph::shared_ptr finalGraph; + if (discrete_ordering_.size() > 0) { + This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), + discrete_ordering_); + + std::tie(discreteBayesNet, finalGraph) = + discrete_etree.Base::eliminate(function); + + } else { + discreteBayesNet = boost::make_shared(); + finalGraph = discreteGraph; + } + + return std::make_pair(discreteBayesNet, finalGraph); +} + +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminate(Eliminate function) const { + // Perform continuous elimination + HybridBayesNet::shared_ptr bayesNet; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); + + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph = + addProbPrimes(bayesNet, discreteGraph); + + // Perform discrete elimination + HybridBayesNet::shared_ptr discreteBayesNet; + HybridGaussianFactorGraph::shared_ptr finalGraph; + std::tie(discreteBayesNet, finalGraph) = + eliminateDiscrete(function, updatedDiscreteGraph); + + // Add the discrete conditionals to the hybrid conditionals + bayesNet->add(*discreteBayesNet); + + return std::make_pair(bayesNet, finalGraph); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 65d614ca3c..9186e04a8a 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -80,22 +80,12 @@ class GTSAM_EXPORT HybridEliminationTree * and the original graph. * * @param function Elimination function for hybrid elimination. - * @return std::pair, - * boost::shared_ptr > + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminateContinuous(Eliminate function) const { - if (continuous_ordering_.size() > 0) { - This continuous_etree(graph_, variable_index_, continuous_ordering_); - return continuous_etree.Base::eliminate(function); - - } else { - BayesNetType::shared_ptr bayesNet = boost::make_shared(); - FactorGraphType::shared_ptr discreteGraph = - boost::make_shared(graph_); - return std::make_pair(bayesNet, discreteGraph); - } - } + std::pair, + boost::shared_ptr> + eliminateContinuous(Eliminate function) const; /** * @brief Helper method to eliminate the discrete variables after the @@ -104,75 +94,28 @@ class GTSAM_EXPORT HybridEliminationTree * If there are no discrete variables, return an empty bayes net and the * discreteGraph which is passed in. * - * @param function Elimination function + * @param function Hybrid elimination function * @param discreteGraph The factor graph with the factor ϕ(X | M, Z). - * @return std::pair, - * boost::shared_ptr > + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminateDiscrete(Eliminate function, - const FactorGraphType::shared_ptr& discreteGraph) const { - BayesNetType::shared_ptr discreteBayesNet; - FactorGraphType::shared_ptr finalGraph; - if (discrete_ordering_.size() > 0) { - This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), - discrete_ordering_); - - std::tie(discreteBayesNet, finalGraph) = - discrete_etree.Base::eliminate(function); - - } else { - discreteBayesNet = boost::make_shared(); - finalGraph = discreteGraph; - } - - return std::make_pair(discreteBayesNet, finalGraph); - } + std::pair, + boost::shared_ptr> + eliminateDiscrete( + Eliminate function, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; /** * @brief Override the EliminationTree eliminate method * so we can perform hybrid elimination correctly. * - * @param function - * @return std::pair, - * boost::shared_ptr > + * @param function Hybrid elimination function + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminate(Eliminate function) const { - // Perform continuous elimination - BayesNetType::shared_ptr bayesNet; - FactorGraphType::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); - - // If we have eliminated continuous variables - // and have discrete variables to eliminate, - // then compute P(X | M, Z) - if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { - // Get the last continuous conditional - // which will have all the discrete keys - HybridConditional::shared_ptr last_conditional = - bayesNet->at(bayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - graph_.continuousProbPrimes(discrete_keys, bayesNet); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - } - - // Perform discrete elimination - BayesNetType::shared_ptr discreteBayesNet; - FactorGraphType::shared_ptr finalGraph; - std::tie(discreteBayesNet, finalGraph) = - eliminateDiscrete(function, discreteGraph); - - // Add the discrete conditionals to the hybrid conditionals - bayesNet->add(*discreteBayesNet); - - return std::make_pair(bayesNet, finalGraph); - } + std::pair, + boost::shared_ptr> + eliminate(Eliminate function) const; Ordering continuousOrdering() const { return continuous_ordering_; } Ordering discreteOrdering() const { return discrete_ordering_; } From bed56e06932c8ab6b5875e6ae6b9026befdfe785 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 18:21:57 +0530 Subject: [PATCH 07/12] mark helper methods as protected and add docstrings --- gtsam/hybrid/HybridEliminationTree.h | 18 ++++++++ gtsam/hybrid/HybridJunctionTree.cpp | 13 +++--- gtsam/hybrid/HybridJunctionTree.h | 63 +++++++++++++++------------- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 9186e04a8a..9b68540266 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -73,6 +73,7 @@ class GTSAM_EXPORT HybridEliminationTree /** Test whether the tree is equal to another */ bool equals(const This& other, double tol = 1e-9) const; + protected: /** * @brief Helper method to eliminate continuous variables. * @@ -87,6 +88,22 @@ class GTSAM_EXPORT HybridEliminationTree boost::shared_ptr> eliminateContinuous(Eliminate function) const; + /** + * @brief Compute the unnormalized probability P'(X | M, Z) + * for the factor graph in each leaf of the discrete tree. + * The discrete decision tree formed as a result is added to the + * `discreteGraph` for discrete elimination. + * + * @param continuousBayesNet The bayes nets corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph consisting of factors + * on discrete variables only. + * @return boost::shared_ptr + */ + boost::shared_ptr addProbPrimes( + const HybridBayesNet::shared_ptr& continuousBayesNet, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + /** * @brief Helper method to eliminate the discrete variables after the * continuous variables have been eliminated. @@ -105,6 +122,7 @@ class GTSAM_EXPORT HybridEliminationTree Eliminate function, const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + public: /** * @brief Override the EliminationTree eliminate method * so we can perform hybrid elimination correctly. diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index f233d4bef9..43b043e304 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -252,11 +252,11 @@ HybridJunctionTree::eliminateDiscrete( /* ************************************************************************* */ boost::shared_ptr HybridJunctionTree::addProbPrimes( - const HybridGaussianFactorGraph& graph, const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& continuous_ordering, - const Ordering& discrete_ordering) const { + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + Ordering continuous_ordering = etree_.continuousOrdering(); + Ordering discrete_ordering = etree_.discreteOrdering(); + // If we have eliminated continuous variables // and have discrete variables to eliminate, // then compute P(X | M, Z) @@ -272,6 +272,8 @@ boost::shared_ptr HybridJunctionTree::addProbPrimes( std::set dkeys_set(discrete_keys.begin(), discrete_keys.end()); discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end()); + FactorGraphType graph = etree_.graph(); + // DecisionTree for P'(X|M, Z) for all mode sequences M const AlgebraicDecisionTree probPrimeTree = graph.continuousProbPrimes(discrete_keys, continuousBayesTree); @@ -298,8 +300,7 @@ HybridJunctionTree::eliminate(const Eliminate& function) const { this->eliminateContinuous(function, graph, continuous_ordering); FactorGraphType::shared_ptr updatedDiscreteGraph = - this->addProbPrimes(graph, continuousBayesTree, discreteGraph, - continuous_ordering, discrete_ordering); + this->addProbPrimes(continuousBayesTree, discreteGraph); // Eliminate discrete variables to get the discrete bayes tree. return this->eliminateDiscrete(function, continuousBayesTree, diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h index 2dc13d5e38..d0473c33d8 100644 --- a/gtsam/hybrid/HybridJunctionTree.h +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -73,14 +73,15 @@ class GTSAM_EXPORT HybridJunctionTree */ HybridJunctionTree(const HybridEliminationTree& eliminationTree); + protected: /** - * @brief - * - * @param function - * @param graph - * @param continuous_ordering + * @brief Eliminate all the continuous variables from the factor graph. + * + * @param function The hybrid elimination function. + * @param graph The factor graph to eliminate. + * @param continuous_ordering The ordering of continuous variables. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr> @@ -89,14 +90,17 @@ class GTSAM_EXPORT HybridJunctionTree const Ordering& continuous_ordering) const; /** - * @brief - * - * @param function - * @param continuousBayesTree - * @param discreteGraph - * @param discrete_ordering + * @brief Eliminate all the discrete variables in the hybrid factor graph. + * + * @param function The hybrid elimination function. + * @param continuousBayesTree The bayes tree corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph of factors containing + * only discrete variables. + * @param discrete_ordering The elimination ordering for + * the discrete variables. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr> @@ -106,28 +110,29 @@ class GTSAM_EXPORT HybridJunctionTree const Ordering& discrete_ordering) const; /** - * @brief - * - * @param graph - * @param continuousBayesTree - * @param discreteGraph - * @param continuous_ordering - * @param discrete_ordering - * @return boost::shared_ptr + * @brief Compute the unnormalized probability P'(X | M, Z) + * for the factor graph in each leaf of the discrete tree. + * The discrete decision tree formed as a result is added to the + * `discreteGraph` for discrete elimination. + * + * @param continuousBayesTree The bayes tree corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph consisting of factors + * on discrete variables only. + * @return boost::shared_ptr */ boost::shared_ptr addProbPrimes( - const HybridGaussianFactorGraph& graph, const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& continuous_ordering, - const Ordering& discrete_ordering) const; + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + public: /** - * @brief - * - * @param function + * @brief Override the eliminate method so we can + * perform hybrid elimination correctly. + * + * @param function The hybrid elimination function. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr> From 5fc114fad27ba308e5819e6c59dd4402200dd9bb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 18:24:16 +0530 Subject: [PATCH 08/12] more unit tests --- gtsam/hybrid/tests/testHybridEstimation.cpp | 25 +++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 431e5769f5..85be0ab31c 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -70,6 +70,28 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors, return ordering; } +TEST(HybridEstimation, Full) { + size_t K = 3; + std::vector measurements = {0, 1, 2}; + // Ground truth discrete seq + std::vector discrete_seq = {1, 1, 0}; + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); + HybridGaussianFactorGraph graph = switching.linearizedFactorGraph; + + Ordering hybridOrdering; + hybridOrdering += X(0); + hybridOrdering += X(1); + hybridOrdering += X(2); + hybridOrdering += M(0); + hybridOrdering += M(1); + HybridBayesNet::shared_ptr bayesNet = + graph.eliminateSequential(hybridOrdering); + + EXPECT_LONGS_EQUAL(5, bayesNet->size()); +} + /****************************************************************************/ // Test approximate inference with an additional pruning step. TEST(HybridEstimation, Incremental) { @@ -311,8 +333,7 @@ TEST(HybridEstimation, Probability) { discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); Ordering discrete(graph.discreteKeys()); - auto discreteBayesNet = - discreteGraph->BaseEliminateable::eliminateSequential(discrete); + auto discreteBayesNet = discreteGraph->eliminateSequential(); bayesNet->add(*discreteBayesNet); HybridValues hybrid_values = bayesNet->optimize(); From 22e4a733e0bf6dfa2b27947b679da27472753747 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 4 Dec 2022 18:36:36 +0530 Subject: [PATCH 09/12] Add details about the role of the HybridEliminationTree in hybrid elimination --- gtsam/hybrid/HybridEliminationTree.h | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 9b68540266..4802c2f89f 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -24,7 +24,22 @@ namespace gtsam { /** - * Elimination Tree type for Hybrid + * Elimination Tree type for Hybrid Factor Graphs. + * + * The elimination tree helps in elimination by specifying the parents to which + * the "combined factor" after elimination should be assigned to. + * + * In the case of hybrid elimination, we use the elimination tree to store the + * intermediate results. + * Concretely, we wish to estimate the unnormalized probability P'(X | M, Z) for + * each mode assignment M. + * P'(X | M, Z) can be computed by estimating X* which maximizes the continuous + * factor graph given the discrete modes, and then computing the unnormalized + * probability as P(X=X* | M, Z). + * + * This is done by eliminating the (continuous) factor graph present at each + * leaf of the discrete tree formed for the discrete sequence and using the + * inferred X* to compute the `probPrime`. * * @ingroup hybrid */ From 0596b2f543a7327d4e89d2999c578883756956ae Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 10 Dec 2022 09:46:26 +0530 Subject: [PATCH 10/12] remove unrequired code --- gtsam/hybrid/HybridEliminationTree.cpp | 100 +----------------- gtsam/hybrid/HybridEliminationTree.h | 90 ---------------- gtsam/hybrid/HybridJunctionTree.cpp | 140 +------------------------ gtsam/hybrid/HybridJunctionTree.h | 69 ------------ 4 files changed, 3 insertions(+), 396 deletions(-) diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp index e541059197..c2df2dd600 100644 --- a/gtsam/hybrid/HybridEliminationTree.cpp +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -27,112 +27,16 @@ template class EliminationTree; HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const VariableIndex& structure, const Ordering& order) - : Base(factorGraph, structure, order), - graph_(factorGraph), - variable_index_(structure) { - // Segregate the continuous and the discrete keys - std::tie(continuous_ordering_, discrete_ordering_) = - graph_.separateContinuousDiscreteOrdering(order); -} + : Base(factorGraph, structure, order) {} /* ************************************************************************* */ HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const Ordering& order) - : Base(factorGraph, order), - graph_(factorGraph), - variable_index_(VariableIndex(factorGraph)) {} + : Base(factorGraph, order) {} /* ************************************************************************* */ bool HybridEliminationTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } -/* ************************************************************************* */ -std::pair, - boost::shared_ptr> -HybridEliminationTree::eliminateContinuous(Eliminate function) const { - if (continuous_ordering_.size() > 0) { - This continuous_etree(graph_, variable_index_, continuous_ordering_); - return continuous_etree.Base::eliminate(function); - - } else { - HybridBayesNet::shared_ptr bayesNet = boost::make_shared(); - HybridGaussianFactorGraph::shared_ptr discreteGraph = - boost::make_shared(graph_); - return std::make_pair(bayesNet, discreteGraph); - } -} - -/* ************************************************************************* */ -boost::shared_ptr -HybridEliminationTree::addProbPrimes( - const HybridBayesNet::shared_ptr& continuousBayesNet, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { - if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { - // Get the last continuous conditional - // which will have all the discrete keys - HybridConditional::shared_ptr last_conditional = - continuousBayesNet->at(continuousBayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - graph_.continuousProbPrimes(discrete_keys, continuousBayesNet); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - } - return discreteGraph; -} - -/* ************************************************************************* */ -std::pair, - boost::shared_ptr> -HybridEliminationTree::eliminateDiscrete( - Eliminate function, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { - HybridBayesNet::shared_ptr discreteBayesNet; - HybridGaussianFactorGraph::shared_ptr finalGraph; - if (discrete_ordering_.size() > 0) { - This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), - discrete_ordering_); - - std::tie(discreteBayesNet, finalGraph) = - discrete_etree.Base::eliminate(function); - - } else { - discreteBayesNet = boost::make_shared(); - finalGraph = discreteGraph; - } - - return std::make_pair(discreteBayesNet, finalGraph); -} - -/* ************************************************************************* */ -std::pair, - boost::shared_ptr> -HybridEliminationTree::eliminate(Eliminate function) const { - // Perform continuous elimination - HybridBayesNet::shared_ptr bayesNet; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); - - // If we have eliminated continuous variables - // and have discrete variables to eliminate, - // then compute P(X | M, Z) - HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph = - addProbPrimes(bayesNet, discreteGraph); - - // Perform discrete elimination - HybridBayesNet::shared_ptr discreteBayesNet; - HybridGaussianFactorGraph::shared_ptr finalGraph; - std::tie(discreteBayesNet, finalGraph) = - eliminateDiscrete(function, updatedDiscreteGraph); - - // Add the discrete conditionals to the hybrid conditionals - bayesNet->add(*discreteBayesNet); - - return std::make_pair(bayesNet, finalGraph); -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 4802c2f89f..941fefa5a5 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -26,21 +26,6 @@ namespace gtsam { /** * Elimination Tree type for Hybrid Factor Graphs. * - * The elimination tree helps in elimination by specifying the parents to which - * the "combined factor" after elimination should be assigned to. - * - * In the case of hybrid elimination, we use the elimination tree to store the - * intermediate results. - * Concretely, we wish to estimate the unnormalized probability P'(X | M, Z) for - * each mode assignment M. - * P'(X | M, Z) can be computed by estimating X* which maximizes the continuous - * factor graph given the discrete modes, and then computing the unnormalized - * probability as P(X=X* | M, Z). - * - * This is done by eliminating the (continuous) factor graph present at each - * leaf of the discrete tree formed for the discrete sequence and using the - * inferred X* to compute the `probPrime`. - * * @ingroup hybrid */ class GTSAM_EXPORT HybridEliminationTree @@ -48,12 +33,6 @@ class GTSAM_EXPORT HybridEliminationTree private: friend class ::EliminationTreeTester; - Ordering continuous_ordering_, discrete_ordering_; - /// Used to store the original factor graph to eliminate - HybridGaussianFactorGraph graph_; - /// Store the provided variable index. - VariableIndex variable_index_; - public: typedef EliminationTree Base; ///< Base class @@ -87,75 +66,6 @@ class GTSAM_EXPORT HybridEliminationTree /** Test whether the tree is equal to another */ bool equals(const This& other, double tol = 1e-9) const; - - protected: - /** - * @brief Helper method to eliminate continuous variables. - * - * If no continuous variables exist, return an empty bayes net - * and the original graph. - * - * @param function Elimination function for hybrid elimination. - * @return std::pair, - * boost::shared_ptr > - */ - std::pair, - boost::shared_ptr> - eliminateContinuous(Eliminate function) const; - - /** - * @brief Compute the unnormalized probability P'(X | M, Z) - * for the factor graph in each leaf of the discrete tree. - * The discrete decision tree formed as a result is added to the - * `discreteGraph` for discrete elimination. - * - * @param continuousBayesNet The bayes nets corresponding to - * the eliminated continuous variables. - * @param discreteGraph Factor graph consisting of factors - * on discrete variables only. - * @return boost::shared_ptr - */ - boost::shared_ptr addProbPrimes( - const HybridBayesNet::shared_ptr& continuousBayesNet, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; - - /** - * @brief Helper method to eliminate the discrete variables after the - * continuous variables have been eliminated. - * - * If there are no discrete variables, return an empty bayes net and the - * discreteGraph which is passed in. - * - * @param function Hybrid elimination function - * @param discreteGraph The factor graph with the factor ϕ(X | M, Z). - * @return std::pair, - * boost::shared_ptr > - */ - std::pair, - boost::shared_ptr> - eliminateDiscrete( - Eliminate function, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; - - public: - /** - * @brief Override the EliminationTree eliminate method - * so we can perform hybrid elimination correctly. - * - * @param function Hybrid elimination function - * @return std::pair, - * boost::shared_ptr > - */ - std::pair, - boost::shared_ptr> - eliminate(Eliminate function) const; - - Ordering continuousOrdering() const { return continuous_ordering_; } - Ordering discreteOrdering() const { return discrete_ordering_; } - - /// Store the provided variable index. - VariableIndex variableIndex() const { return variable_index_; } - HybridGaussianFactorGraph graph() const { return graph_; } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 43b043e304..422c200a43 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -142,8 +142,7 @@ struct HybridConstructorTraversalData { /* ************************************************************************* */ HybridJunctionTree::HybridJunctionTree( - const HybridEliminationTree& eliminationTree) - : etree_(eliminationTree) { + const HybridEliminationTree& eliminationTree) { gttic(JunctionTree_FromEliminationTree); // Here we rely on the BayesNet having been produced by this elimination tree, // such that the conditionals are arranged in DFS post-order. We traverse the @@ -170,141 +169,4 @@ HybridJunctionTree::HybridJunctionTree( Base::remainingFactors_ = eliminationTree.remainingFactors(); } -/* ************************************************************************* */ -std::pair, - boost::shared_ptr> -HybridJunctionTree::eliminateContinuous( - const Eliminate& function, const HybridGaussianFactorGraph& graph, - const Ordering& continuous_ordering) const { - HybridBayesTree::shared_ptr continuousBayesTree; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - - if (continuous_ordering.size() > 0) { - HybridEliminationTree continuous_etree(graph, etree_.variableIndex(), - continuous_ordering); - - This continuous_junction_tree(continuous_etree); - std::tie(continuousBayesTree, discreteGraph) = - continuous_junction_tree.Base::eliminate(function); - - } else { - continuousBayesTree = boost::make_shared(); - discreteGraph = boost::make_shared(graph); - } - - return std::make_pair(continuousBayesTree, discreteGraph); -} -/* ************************************************************************* */ -std::pair, - boost::shared_ptr> -HybridJunctionTree::eliminateDiscrete( - const Eliminate& function, - const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& discrete_ordering) const { - HybridBayesTree::shared_ptr updatedBayesTree; - HybridGaussianFactorGraph::shared_ptr finalGraph; - if (discrete_ordering.size() > 0) { - HybridEliminationTree discrete_etree( - *discreteGraph, VariableIndex(*discreteGraph), discrete_ordering); - - This discrete_junction_tree(discrete_etree); - - std::tie(updatedBayesTree, finalGraph) = - discrete_junction_tree.Base::eliminate(function); - - // Get the clique with all the discrete keys. - // There should only be 1 clique. - const HybridBayesTree::sharedClique discrete_clique = - (*updatedBayesTree)[discrete_ordering.at(0)]; - - std::set clique_set; - for (auto node : continuousBayesTree->nodes()) { - clique_set.insert(node.second); - } - - // Set the root of the bayes tree as the discrete clique - for (auto clique : clique_set) { - if (clique->conditional()->parents() == - discrete_clique->conditional()->frontals()) { - updatedBayesTree->addClique(clique, discrete_clique); - - } else { - if (clique->parent()) { - // Remove the clique from the children of the parents since it will - // get added again in addClique. - auto clique_it = std::find(clique->parent()->children.begin(), - clique->parent()->children.end(), clique); - clique->parent()->children.erase(clique_it); - updatedBayesTree->addClique(clique, clique->parent()); - } else { - updatedBayesTree->addClique(clique); - } - } - } - } else { - updatedBayesTree = continuousBayesTree; - finalGraph = discreteGraph; - } - - return std::make_pair(updatedBayesTree, finalGraph); -} - -/* ************************************************************************* */ -boost::shared_ptr HybridJunctionTree::addProbPrimes( - const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { - Ordering continuous_ordering = etree_.continuousOrdering(); - Ordering discrete_ordering = etree_.discreteOrdering(); - - // If we have eliminated continuous variables - // and have discrete variables to eliminate, - // then compute P(X | M, Z) - if (continuous_ordering.size() > 0 && discrete_ordering.size() > 0) { - // Collect all the discrete keys - DiscreteKeys discrete_keys; - for (auto node : continuousBayesTree->nodes()) { - auto node_dkeys = node.second->conditional()->discreteKeys(); - discrete_keys.insert(discrete_keys.end(), node_dkeys.begin(), - node_dkeys.end()); - } - // Remove duplicates and convert back to DiscreteKeys - std::set dkeys_set(discrete_keys.begin(), discrete_keys.end()); - discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end()); - - FactorGraphType graph = etree_.graph(); - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - graph.continuousProbPrimes(discrete_keys, continuousBayesTree); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - } - - return discreteGraph; -} - -/* ************************************************************************* */ -std::pair -HybridJunctionTree::eliminate(const Eliminate& function) const { - Ordering continuous_ordering = etree_.continuousOrdering(); - Ordering discrete_ordering = etree_.discreteOrdering(); - - FactorGraphType graph = etree_.graph(); - - // Eliminate continuous - BayesTreeType::shared_ptr continuousBayesTree; - FactorGraphType::shared_ptr discreteGraph; - std::tie(continuousBayesTree, discreteGraph) = - this->eliminateContinuous(function, graph, continuous_ordering); - - FactorGraphType::shared_ptr updatedDiscreteGraph = - this->addProbPrimes(continuousBayesTree, discreteGraph); - - // Eliminate discrete variables to get the discrete bayes tree. - return this->eliminateDiscrete(function, continuousBayesTree, - updatedDiscreteGraph, discrete_ordering); -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h index d0473c33d8..29fa24809e 100644 --- a/gtsam/hybrid/HybridJunctionTree.h +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -51,10 +51,6 @@ class HybridEliminationTree; */ class GTSAM_EXPORT HybridJunctionTree : public JunctionTree { - /// Record the elimination tree for use in hybrid elimination. - HybridEliminationTree etree_; - /// Store the provided variable index. - VariableIndex variable_index_; public: typedef JunctionTree @@ -72,71 +68,6 @@ class GTSAM_EXPORT HybridJunctionTree * @return The elimination tree */ HybridJunctionTree(const HybridEliminationTree& eliminationTree); - - protected: - /** - * @brief Eliminate all the continuous variables from the factor graph. - * - * @param function The hybrid elimination function. - * @param graph The factor graph to eliminate. - * @param continuous_ordering The ordering of continuous variables. - * @return std::pair, - * boost::shared_ptr> - */ - std::pair, - boost::shared_ptr> - eliminateContinuous(const Eliminate& function, - const HybridGaussianFactorGraph& graph, - const Ordering& continuous_ordering) const; - - /** - * @brief Eliminate all the discrete variables in the hybrid factor graph. - * - * @param function The hybrid elimination function. - * @param continuousBayesTree The bayes tree corresponding to - * the eliminated continuous variables. - * @param discreteGraph Factor graph of factors containing - * only discrete variables. - * @param discrete_ordering The elimination ordering for - * the discrete variables. - * @return std::pair, - * boost::shared_ptr> - */ - std::pair, - boost::shared_ptr> - eliminateDiscrete(const Eliminate& function, - const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& discrete_ordering) const; - - /** - * @brief Compute the unnormalized probability P'(X | M, Z) - * for the factor graph in each leaf of the discrete tree. - * The discrete decision tree formed as a result is added to the - * `discreteGraph` for discrete elimination. - * - * @param continuousBayesTree The bayes tree corresponding to - * the eliminated continuous variables. - * @param discreteGraph Factor graph consisting of factors - * on discrete variables only. - * @return boost::shared_ptr - */ - boost::shared_ptr addProbPrimes( - const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; - - public: - /** - * @brief Override the eliminate method so we can - * perform hybrid elimination correctly. - * - * @param function The hybrid elimination function. - * @return std::pair, - * boost::shared_ptr> - */ - std::pair, - boost::shared_ptr> - eliminate(const Eliminate& function) const; }; } // namespace gtsam From 62bc9f20a3b36fbcb1840b9d600d68239f549063 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 10 Dec 2022 10:35:46 +0530 Subject: [PATCH 11/12] update hybrid elimination and corresponding tests --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 +++-- gtsam/hybrid/HybridSmoother.cpp | 2 +- gtsam/hybrid/tests/testHybridEstimation.cpp | 49 +++++-------------- .../tests/testHybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/tests/testHybridGaussianISAM.cpp | 2 +- .../tests/testHybridNonlinearFactorGraph.cpp | 11 +---- .../hybrid/tests/testHybridNonlinearISAM.cpp | 2 +- 7 files changed, 25 insertions(+), 55 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c430fac2c6..de237b0491 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -172,8 +172,13 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } } + // std::cout << "Eliminate For MPE" << std::endl; auto result = EliminateForMPE(dfg, frontalKeys); - + // std::cout << "discrete elimination done!" << std::endl; + // dfg.print(); + // std::cout << "\n\n\n" << std::endl; + // result.first->print(); + // result.second->print(); return {boost::make_shared(result.first), boost::make_shared(result.second)}; } @@ -262,7 +267,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, if (!factor) { return 0.0; // If nullptr, return 0.0 probability } else { - return 1.0; + double error = + 0.5 * std::abs(factor->augmentedInformation().determinant()); + return std::exp(-error); } }; DecisionTree fdt(separatorFactors, factorProb); @@ -550,5 +557,4 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( return std::make_pair(continuous_ordering, discrete_ordering); } - } // namespace gtsam diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 7400053bfc..07a7a4e77a 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - auto bayesNetFragment = graph.eliminateSequential(); + auto bayesNetFragment = graph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 85be0ab31c..39c5eea157 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -15,6 +15,7 @@ * @author Varun Agrawal */ +#include #include #include #include @@ -280,8 +281,10 @@ TEST(HybridEstimation, Probability) { VectorValues values = bayes_net->optimize(); - expected_errors.push_back(linear_graph->error(values)); - expected_prob_primes.push_back(linear_graph->probPrime(values)); + double error = linear_graph->error(values); + expected_errors.push_back(error); + double prob_prime = linear_graph->probPrime(values); + expected_prob_primes.push_back(prob_prime); } // Switching example of robot moving in 1D with given measurements and equal @@ -291,51 +294,21 @@ TEST(HybridEstimation, Probability) { auto graph = switching.linearizedFactorGraph; Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); - AlgebraicDecisionTree expected_probPrimeTree = probPrimeTree(graph); - - // Eliminate continuous - Ordering continuous_ordering(graph.continuousKeys()); - HybridBayesNet::shared_ptr bayesNet; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = - graph.eliminatePartialSequential(continuous_ordering); - - // Get the last continuous conditional which will have all the discrete keys - auto last_conditional = bayesNet->at(bayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - const std::vector assignments = - DiscreteValues::CartesianProduct(discrete_keys); - - // Reverse discrete keys order for correct tree construction - std::reverse(discrete_keys.begin(), discrete_keys.end()); - - // Create a decision tree of all the different VectorValues - DecisionTree delta_tree = - graph.continuousDelta(discrete_keys, bayesNet, assignments); - - AlgebraicDecisionTree probPrimeTree = - graph.continuousProbPrimes(discrete_keys, bayesNet); - - EXPECT(assert_equal(expected_probPrimeTree, probPrimeTree)); + HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering); + auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3); // Test if the probPrimeTree matches the probability of // the individual factor graphs for (size_t i = 0; i < pow(2, K - 1); i++) { - Assignment discrete_assignment; + DiscreteValues discrete_assignment; for (size_t v = 0; v < discrete_seq_map[i].size(); v++) { discrete_assignment[M(v)] = discrete_seq_map[i][v]; } - EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i), - probPrimeTree(discrete_assignment), 1e-8); + double discrete_transition_prob = 0.25; + EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i) * discrete_transition_prob, + (*discreteConditional)(discrete_assignment), 1e-8); } - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - - Ordering discrete(graph.discreteKeys()); - auto discreteBayesNet = discreteGraph->eliminateSequential(); - bayesNet->add(*discreteBayesNet); - HybridValues hybrid_values = bayesNet->optimize(); // This is the correct sequence as designed diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 6288bcd93f..8954f2503c 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -277,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); // 9 cliques in the bayes tree and 0 remaining variables to eliminate. - EXPECT_LONGS_EQUAL(7, hbt->size()); + EXPECT_LONGS_EQUAL(9, hbt->size()); EXPECT_LONGS_EQUAL(0, remaining->size()); /* diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 4ba6f88a50..8e79021327 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -178,7 +178,7 @@ TEST(HybridGaussianElimination, IncrementalInference) { // Test the probability values with regression tests. DiscreteValues assignment; - EXPECT(assert_equal(0.166667, m00_prob, 1e-5)); + EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 0; EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index e3b7f761a7..a0a6f7d955 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -372,8 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { dynamic_pointer_cast(hybridDiscreteFactor->inner()); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); - // All leaves should be probability 1 since this is not P*(X|M,Z) - EXPECT(discreteFactor->root_->isLeaf()); + EXPECT(discreteFactor->root_->isLeaf() == false); // TODO(Varun) Test emplace_discrete } @@ -441,14 +440,6 @@ TEST(HybridFactorGraph, Full_Elimination) { discrete_fg.push_back(df->inner()); } - // Get the probabilit P*(X | M, Z) - DiscreteKeys discrete_keys = - remainingFactorGraph_partial->at(2)->discreteKeys(); - AlgebraicDecisionTree probPrimeTree = - linearizedFactorGraph.continuousProbPrimes(discrete_keys, - hybridBayesNet_partial); - discrete_fg.add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - ordering.clear(); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); discreteBayesNet = diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 46d5c43245..2a1932ac76 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -197,7 +197,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) { // Test the probability values with regression tests. DiscreteValues assignment; - EXPECT(assert_equal(0.166667, m00_prob, 1e-5)); + EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 0; EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); From 6beffeb0c1d0eb4d098aef7ac88d198e9b3bd9c6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 10 Dec 2022 12:46:48 +0530 Subject: [PATCH 12/12] remove commented out code --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index de237b0491..5c18a94b55 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -172,13 +172,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } } - // std::cout << "Eliminate For MPE" << std::endl; auto result = EliminateForMPE(dfg, frontalKeys); - // std::cout << "discrete elimination done!" << std::endl; - // dfg.print(); - // std::cout << "\n\n\n" << std::endl; - // result.first->print(); - // result.second->print(); + return {boost::make_shared(result.first), boost::make_shared(result.second)}; }