Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid/simplify #1388

Merged
merged 29 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
34a9aef
normalizationConstants returns all constants as a DecisionTreeFactor
dellaert Jan 12, 2023
1dcc6dd
All tests still work with zero constant!
dellaert Jan 12, 2023
03ad393
Removed FactorAndConstant, no longer needed
dellaert Jan 13, 2023
906330f
Add discrete contribution to logProbability
dellaert Jan 13, 2023
681c75c
Expose toFactorGraph to wrapper
dellaert Jan 13, 2023
dfef2c2
Simplify elimination
dellaert Jan 13, 2023
070cdb7
insert_or_assign
dellaert Jan 14, 2023
96e3eb7
Some test refactoring
dellaert Jan 14, 2023
c22b2ca
Improved docs
dellaert Jan 15, 2023
5b0408c
Check for error>0 and proper normalization constant
dellaert Jan 16, 2023
191e614
Fix print
dellaert Jan 16, 2023
57e59d1
Compute log-normalization constant as the max of the individual norma…
dellaert Jan 16, 2023
7a41180
Refactored tests and removed incorrect (R not upper-triangular) test.
dellaert Jan 16, 2023
207c9b7
Implemented the "hidden constant" scheme.
dellaert Jan 17, 2023
3a446d7
Explicitly implement logNormalizationConstant
dellaert Jan 17, 2023
202a5a3
Fixed toFactorGraph and added test to verify
dellaert Jan 17, 2023
a5951d8
Fixed test to work with "hidden constant" scheme
dellaert Jan 17, 2023
8357fc7
Fix python tests (and expose HybridBayesNet.error)
dellaert Jan 17, 2023
e31884c
Eradicated GraphAndConstant
dellaert Jan 17, 2023
9af7236
Added DEBUG_MARGINALS flag
dellaert Jan 17, 2023
519b2bb
Added comment
dellaert Jan 17, 2023
32d69a3
Trap if conditional==null.
dellaert Jan 17, 2023
f4859f0
Fix logProbability tests
dellaert Jan 17, 2023
4283925
Ratio test succeeds on fg, but not on posterior yet,
dellaert Jan 17, 2023
b494a61
Removed obsolete normalizationConstants method
dellaert Jan 17, 2023
892759e
Add math related to hybrid classes
dellaert Jan 17, 2023
c3ca31f
Added partial elimination test
dellaert Jan 17, 2023
e444962
Added correction with the normalization constant in the second elimin…
dellaert Jan 17, 2023
f714c4a
Merge branch 'develop' into hybrid/simplify
dellaert Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplify elimination
  • Loading branch information
dellaert committed Jan 16, 2023
commit dfef2c202ff2dd86d8a36086ef114f680aa3438f
16 changes: 6 additions & 10 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) {
auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
// This is the probability q(μ) at the MLE point.
// conditionalAndFactor.second is a factor without keys, just containing the residual.
auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) {
// This is the unnormalized probability q(μ) at the mean.
// The factor has no keys, just contains the residual.
static const VectorValues kEmpty;
// return exp(-conditionalAndFactor.first->logNormalizationConstant());
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
return exp( - conditionalAndFactor.second->error(kEmpty));
// return 1.0;
return factor? exp(-factor->error(kEmpty)) : 1.0;
};

const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
discreteSeparator, DecisionTree<Key, double>(newFactors, probPrime));

return {boost::make_shared<HybridConditional>(gaussianMixture),
discreteFactor};
Expand Down
26 changes: 25 additions & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
const HybridGaussianFactorGraph &fg, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
return bn.evaluate(*sample) / fg.probPrime(*sample);
Expand All @@ -670,6 +670,28 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
return true;
}

/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridBayesNet &posterior, size_t num_samples = 100) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
// return bn.evaluate(*sample) / fg.probPrime(*sample);
return bn.evaluate(*sample) / posterior.evaluate(*sample);
};

HybridValues sample = bn.sample(&kRng);
double expected_ratio = compute_ratio(&sample);

// Test ratios for a number of independent samples:
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&kRng);
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
}
return true;
}

/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
Expand All @@ -678,6 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements);
GTSAM_PRINT(bn);
EXPECT_LONGS_EQUAL(4, fg.size());

EXPECT(ratioTest(bn, measurements, fg));
Expand All @@ -701,6 +724,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
// Test elimination
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
GTSAM_PRINT(*posterior);

EXPECT(ratioTest(bn, measurements, *posterior));
}
Expand Down