Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big Re-"Factor" of Hybrid #1374

Merged
merged 16 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Made tests compile after purging HybridDiscreteFactor
  • Loading branch information
dellaert committed Jan 7, 2023
commit 18d4bdf4f46140bc709150f41c4f3cdc07441bdd
3 changes: 1 addition & 2 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ TEST(HybridBayesNet, Sampling) {
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);

DiscreteKey mode(M(0), 2);
auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1");
nfg.push_discrete(discrete_prior);
nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");

Values initial;
double z0 = 0.0, z1 = 1.0;
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ TEST(HybridBayesTree, Optimize) {

DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}

// Add the probabilities for each branch
Expand Down
12 changes: 6 additions & 6 deletions gtsam/hybrid/tests/testHybridEstimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) {
factors += newFactors;
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys();
KeySet allDiscrete = factors.discreteKeySet();

// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
Expand Down Expand Up @@ -241,7 +241,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
const HybridGaussianFactorGraph& graph) {
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr remainingGraph;
Ordering continuous(graph.continuousKeys());
Ordering continuous(graph.continuousKeySet());
std::tie(bayesNet, remainingGraph) =
graph.eliminatePartialSequential(continuous);

Expand Down Expand Up @@ -296,14 +296,14 @@ TEST(HybridEstimation, Probability) {
auto graph = switching.linearizedFactorGraph;

// Continuous elimination
Ordering continuous_ordering(graph.continuousKeys());
Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering);

// Discrete elimination
Ordering discrete_ordering(graph.discreteKeys());
Ordering discrete_ordering(graph.discreteKeySet());
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);

// Add the discrete conditionals to make it a full bayes net.
Expand Down Expand Up @@ -346,7 +346,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);

// Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys());
Ordering continuous_ordering(graph.continuousKeySet());
HybridBayesTree::shared_ptr bayesTree;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesTree, discreteGraph) =
Expand All @@ -358,7 +358,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
auto last_conditional = (*bayesTree)[last_continuous_key]->conditional();
DiscreteKeys discrete_keys = last_conditional->discreteKeys();

Ordering discrete(graph.discreteKeys());
Ordering discrete(graph.discreteKeySet());
auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);

EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());
Expand Down
19 changes: 9 additions & 10 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
Expand Down Expand Up @@ -102,7 +101,7 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {

// Add priors on x0 and c1
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
hfg.add(DecisionTreeFactor(m, {2, 8}));

Ordering ordering;
ordering.push_back(X(0));
Expand All @@ -116,24 +115,25 @@ TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
HybridGaussianFactorGraph hfg;

DiscreteKey m1(M(1), 2);

// Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));

// Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));

// Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));

hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));

auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
// Do elimination.
Ordering ordering = Ordering::ColamdConstrainedLast(hfg, {M(1)});
auto result = hfg.eliminateSequential(ordering);

auto dc = result->at(2)->asDiscrete();
CHECK(dc);
DiscreteValues dv;
dv[M(1)] = 0;
// Regression test
Expand Down Expand Up @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
// Hybrid factor P(x1|c1)
hfg.add(GaussianMixtureFactor({X(1)}, {m}, dt));
// Prior factor on c1
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
hfg.add(DecisionTreeFactor(m, {2, 8}));

// Get a constrained ordering keeping c1 last
auto ordering_full = hfg.getHybridOrdering();
Expand Down Expand Up @@ -250,8 +250,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(GaussianMixtureFactor({X(2)}, {{M(1), 2}}, dt1));
}

hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
Expand Down
18 changes: 7 additions & 11 deletions gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ TEST(HybridsGaussianElimination, Eliminate_x1) {
Ordering ordering;
ordering += X(1);

std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> result =
EliminateHybrid(factors, ordering);
auto result = EliminateHybrid(factors, ordering);
CHECK(result.first);
EXPECT_LONGS_EQUAL(1, result.first->nrFrontals());
CHECK(result.second);
Expand Down Expand Up @@ -350,7 +349,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
ordering += X(1);

HybridConditional::shared_ptr hybridConditionalMixture;
HybridFactor::shared_ptr factorOnModes;
boost::shared_ptr<Factor> factorOnModes;

std::tie(hybridConditionalMixture, factorOnModes) =
EliminateHybrid(factors, ordering);
Expand All @@ -364,12 +363,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
// 1 parent, which is the mode
EXPECT_LONGS_EQUAL(1, gaussianConditionalMixture->nrParents());

// This is now a HybridDiscreteFactor
auto hybridDiscreteFactor =
dynamic_pointer_cast<HybridDiscreteFactor>(factorOnModes);
// Access the type-erased inner object and convert to DecisionTreeFactor
auto discreteFactor =
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
Expand Down Expand Up @@ -436,8 +431,9 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner());
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}

ordering.clear();
Expand Down
13 changes: 0 additions & 13 deletions gtsam/hybrid/tests/testSerializationHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
Expand Down Expand Up @@ -83,18 +82,6 @@ TEST(HybridSerialization, HybridGaussianFactor) {
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
}

/* ****************************************************************************/
// Test HybridDiscreteFactor serialization.
TEST(HybridSerialization, HybridDiscreteFactor) {
DiscreteKeys discreteKeys{{M(0), 2}};
const HybridDiscreteFactor factor(
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));

EXPECT(equalsObj<HybridDiscreteFactor>(factor));
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
}

/* ****************************************************************************/
// Test GaussianMixtureFactor serialization.
TEST(HybridSerialization, GaussianMixtureFactor) {
Expand Down