Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big Re-"Factor" of Hybrid #1374

Merged
merged 16 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionals)) {}

/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
dellaert marked this conversation as resolved.
Show resolved Hide resolved
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
Expand Down
12 changes: 5 additions & 7 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
const DecisionTreeFactor &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys();

// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
Expand All @@ -154,7 +154,7 @@ void HybridBayesNet::updateDiscreteConditionals(
auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));

// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
Expand All @@ -173,9 +173,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr decisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);

this->updateDiscreteConditionals(decisionTree);

Expand All @@ -194,7 +192,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
prunedGaussianMixture->prune(decisionTree); // imperative :-(

// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture);
Expand Down
49 changes: 33 additions & 16 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,51 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @{

/// GTSAM-style printing
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
void print(const std::string &s = "", const KeyFormatter &formatter =
DefaultKeyFormatter) const override;

/// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;
bool equals(const This &fg, double tol = 1e-9) const;

/// @}
/// @name Standard Interface
/// @{

/// Add HybridConditional to Bayes Net
using Base::emplace_shared;
/**
* @brief Add a hybrid conditional using a shared_ptr.
*
* This is the "native" push back, as this class stores hybrid conditionals.
*/
void push_back(boost::shared_ptr<HybridConditional> conditional) {
factors_.push_back(conditional);
}

/// Add a conditional directly using a pointer.
/**
* Preferred: add a conditional directly using a pointer.
*
* Examples:
* hbn.emplace_back(new GaussianMixture(...)));
* hbn.emplace_back(new GaussianConditional(...)));
* hbn.emplace_back(new DiscreteConditional(...)));
*/
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(boost::make_shared<HybridConditional>(
boost::shared_ptr<Conditional>(conditional)));
}

/// Add a conditional directly using a shared_ptr.
void push_back(boost::shared_ptr<HybridConditional> conditional) {
factors_.push_back(conditional);
}

/// Add a conditional directly using implicit conversion.
/**
* Add a conditional using a shared_ptr, using implicit conversion to
* a HybridConditional.
*
* This is useful when you create a conditional shared pointer as you need it
* somewhere else.
*
* Example:
* auto shared_ptr_to_a_conditional =
* boost::make_shared<GaussianMixture>(...);
* hbn.push_back(shared_ptr_to_a_conditional);
*/
void push_back(HybridConditional &&conditional) {
factors_.push_back(
boost::make_shared<HybridConditional>(std::move(conditional)));
Expand Down Expand Up @@ -214,8 +232,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*
* @param prunedDecisionTree
*/
void updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);

/** Serialization function */
friend class boost::serialization::access;
Expand Down
15 changes: 6 additions & 9 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,15 @@ struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
* This object stores parent keys in our base type factor so that
* eliminating those parent keys will pull this subtree into the
* elimination.
* This does special stuff for the hybrid case.
*
* @tparam CLIQUE
* This is a template instantiation for hybrid Bayes tree cliques, storing both
* the regular keys *and* discrete keys in the HybridConditional.
*/
template <class CLIQUE>
class BayesTreeOrphanWrapper<
CLIQUE, typename std::enable_if<
boost::is_same<CLIQUE, HybridBayesTreeClique>::value> >
: public CLIQUE::ConditionalType {
template <>
class BayesTreeOrphanWrapper<HybridBayesTreeClique> : public HybridConditional {
public:
typedef CLIQUE CliqueType;
typedef typename CLIQUE::ConditionalType Base;
typedef HybridBayesTreeClique CliqueType;
typedef HybridConditional Base;

boost::shared_ptr<CliqueType> clique;

Expand Down
61 changes: 0 additions & 61 deletions gtsam/hybrid/HybridDiscreteFactor.cpp

This file was deleted.

91 changes: 0 additions & 91 deletions gtsam/hybrid/HybridDiscreteFactor.h

This file was deleted.

5 changes: 2 additions & 3 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
const DiscreteKeys &key2);

/**
* Base class for hybrid probabilistic factors
* Base class for *truly* hybrid probabilistic factors
*
* Examples:
* - HybridGaussianFactor
* - HybridDiscreteFactor
* - MixtureFactor
* - GaussianMixtureFactor
* - GaussianMixture
*
Expand Down
78 changes: 78 additions & 0 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file HybridFactorGraph.cpp
* @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal
* @author Frank Dellaert
* @date January, 2023
*/

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>

#include <boost/format.hpp>

namespace gtsam {

/* ************************************************************************* */
DiscreteKeys HybridFactorGraph::discreteKeys() const {
DiscreteKeys keys;
for (auto& factor : factors_) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.push_back(key);
}
}
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.push_back(key);
}
}
}
return keys;
}

/* ************************************************************************* */
KeySet HybridFactorGraph::discreteKeySet() const {
KeySet keys;
for (const DiscreteKey& k : discreteKeys()) {
keys.insert(k.first);
}
return keys;
}

/* ************************************************************************* */
std::unordered_map<Key, DiscreteKey> HybridFactorGraph::discreteKeyMap() const {
std::unordered_map<Key, DiscreteKey> result;
for (const DiscreteKey& k : discreteKeys()) {
result[k.first] = k;
}
return result;
}

/* ************************************************************************* */
const KeySet HybridFactorGraph::continuousKeySet() const {
KeySet keys;
for (auto& factor : factors_) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
}
}
return keys;
}

/* ************************************************************************* */

} // namespace gtsam
Loading