Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Serialization #1368

Merged
merged 9 commits into from
Jan 4, 2023
Prev Previous commit
Next Next commit
serialization tests for HybridBayesNet and HybridBayesTree
  • Loading branch information
varunagrawal committed Jan 3, 2023
commit 230f65dd4496cac516574c713e9f924132e20e88
13 changes: 13 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);

// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
}

}; // HybridConditional
Expand Down
16 changes: 0 additions & 16 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
* @date December 2021
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
Expand All @@ -31,7 +30,6 @@

using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;

using noiseModel::Isotropic;
using symbol_shorthand::M;
Expand Down Expand Up @@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
discrete_conditional_tree->apply(checker);
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));

// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesNet>(hbn));
// EXPECT(equalsXML<HybridBayesNet>(hbn));
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
}

/* ****************************************************************************/
// Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) {
Expand Down
16 changes: 0 additions & 16 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) {
EXPECT(assert_equal(expected_gbt, gbt));
}

/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));

using namespace gtsam::serializationTestHelpers;
// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesTree>(hbt));
// EXPECT(equalsXML<HybridBayesTree>(hbt));
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
38 changes: 36 additions & 2 deletions gtsam/hybrid/tests/testSerializationHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>

#include "Switching.h"

// Include for test suite
#include <CppUnitLite/TestHarness.h>

Expand All @@ -36,15 +41,17 @@ using symbol_shorthand::Z;

using namespace serializationTestHelpers;

BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");

BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")

BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
Expand All @@ -64,6 +71,8 @@ BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
// Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");

BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");

/* ****************************************************************************/
// Test HybridGaussianFactor serialization.
TEST(HybridSerialization, HybridGaussianFactor) {
Expand Down Expand Up @@ -137,6 +146,31 @@ TEST(HybridSerialization, GaussianMixture) {
EXPECT(equalsBinary<GaussianMixture>(gm));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));

EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}

/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));

EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down