Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional Nonlinear Hybrid #1277

Merged
merged 11 commits into from
Aug 22, 2022
Next Next commit
more tests running
  • Loading branch information
varunagrawal committed Aug 12, 2022
commit aa486586264e082c872f9277092590060fea30f8
95 changes: 46 additions & 49 deletions gtsam/hybrid/tests/testHybridIncremental.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ TEST(HybridGaussianElimination, IncrementalInference) {
graph1.push_back(switching.linearizedFactorGraph.at(3)); // P(X2)
graph1.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)

//TODO(Varun) we cannot enforce ordering
// // Create ordering.
// Ordering ordering1;
// ordering1 += X(1);
// ordering1 += X(2);

// Run update step
isam.update(graph1);

Expand All @@ -133,14 +127,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
graph2.push_back(switching.linearizedFactorGraph.at(4)); // P(X3)
graph2.push_back(switching.linearizedFactorGraph.at(6)); // P(M1, M2)

//TODO(Varun) we cannot enforce ordering
// // Create ordering.
// Ordering ordering2;
// ordering2 += X(2);
// ordering2 += X(3);

isam.update(graph2);
GTSAM_PRINT(isam);

/********************************************************/
// Run batch elimination so we can compare results.
Expand All @@ -150,68 +137,78 @@ TEST(HybridGaussianElimination, IncrementalInference) {
ordering += X(3);

// Now we calculate the actual factors using full elimination
HybridBayesNet::shared_ptr expectedHybridBayesNet;
HybridBayesTree::shared_ptr expectedHybridBayesTree;
HybridGaussianFactorGraph::shared_ptr expectedRemainingGraph;
std::tie(expectedHybridBayesNet, expectedRemainingGraph) =
switching.linearizedFactorGraph.eliminatePartialSequential(ordering);
std::tie(expectedHybridBayesTree, expectedRemainingGraph) =
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);

// The densities on X(1) should be the same
auto x1_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(1)]->conditional()->inner());
EXPECT(
assert_equal(*x1_conditional, *(expectedHybridBayesNet->atGaussian(0))));
auto actual_x1_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(1)]->conditional()->inner());
EXPECT(assert_equal(*x1_conditional, *actual_x1_conditional));

// The densities on X(2) should be the same
auto x2_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(2)]->conditional()->inner());
EXPECT(
assert_equal(*x2_conditional, *(expectedHybridBayesNet->atGaussian(1))));

// // The densities on X(3) should be the same
// auto x3_conditional =
// dynamic_pointer_cast<GaussianMixture>(isam[X(3)]->conditional()->inner());
// EXPECT(
// assert_equal(*x3_conditional, *(expectedHybridBayesNet->atGaussian(2))));

GTSAM_PRINT(*expectedHybridBayesNet);

// we only do the manual continuous elimination for 0,0
// the other discrete probabilities on M(2) are calculated the same way
auto actual_x2_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x2_conditional, *actual_x2_conditional));

// The densities on X(3) should be the same
auto x3_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(3)]->conditional()->inner());
auto actual_x3_conditional = dynamic_pointer_cast<GaussianMixture>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x3_conditional, *actual_x3_conditional));

// We only perform manual continuous elimination for 0,0.
// The other discrete probabilities on M(2) are calculated the same way
auto m00_prob = [&]() {
GaussianFactorGraph gf;
// gf.add(switching.linearizedFactorGraph.gaussianGraph().at(3));
auto x2_prior = boost::dynamic_pointer_cast<HybridGaussianFactor>(
switching.linearizedFactorGraph.at(3))->inner();
gf.add(x2_prior);

DiscreteValues m00;
m00[M(1)] = 0, m00[M(2)] = 0;
// auto dcMixture =
// dynamic_pointer_cast<DCGaussianMixtureFactor>(graph2.dcGraph().at(0));
// gf.add(dcMixture->factors()(m00));
// auto x2_mixed =
// boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet.at(1));
// gf.add(x2_mixed->factors()(m00));
// P(X2, X3 | M2)
auto dcMixture =
dynamic_pointer_cast<GaussianMixtureFactor>(graph2.at(0));
gf.add(dcMixture->factors()(m00));

auto x2_mixed =
boost::dynamic_pointer_cast<GaussianMixture>(isam[X(2)]->conditional()->inner());
// Perform explicit cast so we can add the conditional to `gf`.
auto x2_cond = boost::dynamic_pointer_cast<GaussianFactor>(
x2_mixed->conditionals()(m00));
gf.add(x2_cond);

auto result_gf = gf.eliminateSequential();
return gf.probPrime(result_gf->optimize());
}();

/// Test if the probability values are as expected with regression tests.
// DiscreteValues assignment;
// EXPECT(assert_equal(m00_prob, 0.60656, 1e-5));
// assignment[M(1)] = 0;
// assignment[M(2)] = 0;
// EXPECT(assert_equal(m00_prob, (*discreteFactor)(assignment), 1e-5));
auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional();
// Test if the probability values are as expected with regression tests.
// DiscreteValues assignment;
// EXPECT(assert_equal(m00_prob, 0.60656, 1e-5));
// assignment[M(1)] = 0;
// assignment[M(2)] = 0;
// EXPECT(assert_equal(m00_prob, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 1;
// assignment[M(2)] = 0;
// EXPECT(assert_equal(0.612477, (*discreteFactor)(assignment), 1e-5));
// EXPECT(assert_equal(0.612477, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 0;
// assignment[M(2)] = 1;
// EXPECT(assert_equal(0.999952, (*discreteFactor)(assignment), 1e-5));
// EXPECT(assert_equal(0.999952, (*discreteConditional)(assignment), 1e-5));
// assignment[M(1)] = 1;
// assignment[M(2)] = 1;
// EXPECT(assert_equal(1.0, (*discreteFactor)(assignment), 1e-5));
// EXPECT(assert_equal(1.0, (*discreteConditional)(assignment), 1e-5));

// DiscreteFactorGraph dfg;
// dfg.add(*discreteFactor);
// dfg.add(discreteFactor_m1);
// dfg.add(*discreteConditional);
// dfg.add(discreteConditional_m1);
// dfg.add_factors(switching.linearizedFactorGraph.discreteGraph());

// // Check if the chordal graph generated from incremental elimination
Expand Down