Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big Re-"Factor" of Hybrid #1374

Merged
merged 16 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make HybridFactorGraph just a FactorGraph<Factor> with extra methods
  • Loading branch information
dellaert committed Jan 7, 2023
commit 1538452d5ae61fe429cd55be73870a0bc585f6fd
2 changes: 2 additions & 0 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionals)) {}

/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
dellaert marked this conversation as resolved.
Show resolved Hide resolved
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
Expand Down
23 changes: 13 additions & 10 deletions gtsam/hybrid/HybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

/**
* @file HybridFactorGraph.h
* @brief Hybrid factor graph base class that uses type erasure
* @brief Factor graph with utilities for hybrid factors.
* @author Varun Agrawal
* @author Frank Dellaert
* @date May 28, 2022
*/

Expand All @@ -31,13 +32,11 @@ using SharedFactor = boost::shared_ptr<Factor>;

/**
* Hybrid Factor Graph
* -----------------------
* This is the base hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
* Factor graph with utilities for hybrid factors.
*/
class HybridFactorGraph : public FactorGraph<HybridFactor> {
class HybridFactorGraph : public FactorGraph<Factor> {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
public:
using Base = FactorGraph<HybridFactor>;
using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This

Expand Down Expand Up @@ -140,8 +139,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet discreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& k : p->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
}
return discrete_keys;
Expand All @@ -151,8 +152,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
const KeySet continuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
}
}
return keys;
Expand Down
92 changes: 46 additions & 46 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,48 +79,47 @@ static GaussianFactorGraphTree addGaussian(
}

/* ************************************************************************ */
// TODO(dellaert): Implementation-wise, it's probably more efficient to first
// collect the discrete keys, and then loop over all assignments to populate a
// vector.
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
using boost::dynamic_pointer_cast;

gttic(assembleGraphTree);

GaussianFactorGraphTree result;

for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) {
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
continue;
}
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
result = gm->asMixture()->add(result);
}

} else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
result = addGaussian(result, cg->asGaussian());
}

} else if (f->isDiscrete()) {
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner());
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to eliminate continuous values only.
continue;

} else {
} else if (auto orphan = dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
if (!orphan) {
auto &fr = *f;
throw std::invalid_argument(
std::string("factor is discrete in continuous elimination ") +
demangle(typeid(fr).name()));
}
throw std::invalid_argument(
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
"yet.");
} else {
auto &fr = *f;
throw std::invalid_argument(
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
demangle(typeid(fr).name()));
}
}

Expand Down Expand Up @@ -377,8 +376,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Build a map from keys to DiscreteKeys
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
for (auto &&factor : factors) {
if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys()) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (auto &k : p->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
Expand Down Expand Up @@ -451,12 +450,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
Expand All @@ -466,25 +459,23 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
using boost::dynamic_pointer_cast;

AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
// Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues);

} else if (factors_.at(idx)->isContinuous()) {
} else if (auto hybridGaussianFactor =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();

// Compute the error of the gaussian factor.
Expand All @@ -493,9 +484,16 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
continue;
} else {
auto &fr = *f;
throw std::invalid_argument(
std::string(
"HybridGaussianFactorGraph::error: factor type not handled: ") +
demangle(typeid(fr).name()));
}
}

Expand All @@ -506,7 +504,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &factor : factors_) {
error += factor->error(values);
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
error += p->error(values);
}
}
return error;
}
Expand Down
8 changes: 5 additions & 3 deletions gtsam/hybrid/HybridJunctionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ struct HybridConstructorTraversalData {
parentData.junctionTreeNode->addChild(data.junctionTreeNode);

// Add all the discrete keys in the hybrid factors to the current data
for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys()) {
data.discreteKeys.insert(k.first);
for (const auto& f : node->factors) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(f)) {
for (auto& k : p->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
}
}

Expand Down
61 changes: 28 additions & 33 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,47 +50,42 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
/* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {
using boost::dynamic_pointer_cast;

// create an empty linear FG
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();

linearFG->reserve(size());

// linearize all hybrid factors
for (auto&& factor : factors_) {
for (auto& f : factors_) {
// First check if it is a valid factor
if (factor) {
// Check if the factor is a hybrid factor.
// It can be either a nonlinear MixtureFactor or a linear
// GaussianMixtureFactor.
if (factor->isHybrid()) {
// Check if it is a nonlinear mixture factor
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
linearFG->push_back(nlmf->linearize(continuousValues));
} else {
linearFG->push_back(factor);
}

// Now check if the factor is a continuous only factor.
} else if (factor->isContinuous()) {
// In this case, we check if factor->inner() is nonlinear since
// HybridFactors wrap over continuous factors.
auto nlhf = boost::dynamic_pointer_cast<HybridNonlinearFactor>(factor);
if (auto nlf =
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
auto hgf = boost::make_shared<HybridGaussianFactor>(
nlf->linearize(continuousValues));
linearFG->push_back(hgf);
} else {
linearFG->push_back(factor);
}
// Finally if nothing else, we are discrete-only which doesn't need
// lineariztion.
} else {
linearFG->push_back(factor);
}

} else {
if (!f) {
// TODO(dellaert): why?
dellaert marked this conversation as resolved.
Show resolved Hide resolved
linearFG->push_back(GaussianFactor::shared_ptr());
continue;
}
// Check if it is a nonlinear mixture factor
if (auto nlmf = dynamic_pointer_cast<MixtureFactor>(f)) {
const GaussianMixtureFactor::shared_ptr& gmf =
nlmf->linearize(continuousValues);
linearFG->push_back(gmf);
} else if (auto nlhf = dynamic_pointer_cast<HybridNonlinearFactor>(f)) {
// Nonlinear wrapper case:
const GaussianFactor::shared_ptr& gf =
nlhf->inner()->linearize(continuousValues);
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
linearFG->push_back(hgf);
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization.
linearFG->push_back(f);
} else {
auto& fr = *f;
throw std::invalid_argument(
std::string("HybridNonlinearFactorGraph::linearize: factor type "
"not handled: ") +
demangle(typeid(fr).name()));
}
}
return linearFG;
Expand Down
7 changes: 1 addition & 6 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
using HasDerivedValueType = typename std::enable_if<
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;

/// Check if T has a pointer type derived from FactorType.
template <typename T>
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
HybridFactor, typename T::value_type::element_type>::value>::type;

public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
Expand Down Expand Up @@ -124,7 +119,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* copied)
*/
template <typename CONTAINER>
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
void push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ TEST(HybridFactorGraph, Full_Elimination) {

DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (HybridFactor::shared_ptr& factor : (*remainingFactorGraph_partial)) {
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner());
}
Expand Down