From a9b2c326693b6c087b190628a3fd3780c671a094 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 2 Jan 2022 23:45:01 -0500 Subject: [PATCH 1/4] Move DefaultFormatter to base class and add defaults. Also replace Super with Base and add using. --- gtsam/discrete/AlgebraicDecisionTree.h | 48 +++++++++--------- gtsam/discrete/DecisionTree.h | 69 +++++++++++++++++++------- 2 files changed, 75 insertions(+), 42 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 17a38f7cf3..0b13f408e4 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -29,16 +29,9 @@ namespace gtsam { */ template class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { - /// Default method used by `formatter` when printing. - static std::string DefaultFormatter(const L& x) { - std::stringstream ss; - ss << x; - return ss.str(); - } - public: - typedef DecisionTree Super; + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { @@ -66,33 +59,33 @@ namespace gtsam { }; AlgebraicDecisionTree() : - Super(1.0) { + Base(1.0) { } - AlgebraicDecisionTree(const Super& add) : - Super(add) { + AlgebraicDecisionTree(const Base& add) : + Base(add) { } /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { + Base(label, y1, y2) { } /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : + Base(labelC, y1, y2) { } /** Create from keys and vector table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + (const std::vector& labelCs, const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); @@ -100,18 +93,23 @@ namespace gtsam { std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { @@ -143,18 +141,18 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } /// print method customized to node type `double`. void print(const std::string& s, - const typename Super::LabelFormatter& labelFormatter = - &DefaultFormatter) const { + const typename Base::LabelFormatter& labelFormatter = + &Base::DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.2g") % v).str(); }; - Super::print(s, labelFormatter, valueFormatter); + Base::print(s, labelFormatter, valueFormatter); } /// Equality method customized to node type `double`. @@ -163,7 +161,7 @@ namespace gtsam { auto compare = [tol](double a, double b) { return std::abs(a - b) < tol; }; - return Super::equals(other, compare); + return Base::equals(other, compare); } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ecc3d17dce..b02c2b3023 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,11 +39,24 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { + protected: /// Default method for comparison of two objects of type Y. static bool DefaultCompare(const Y& a, const Y& b) { return a == b; } + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: using LabelFormatter = std::function; @@ -88,12 +101,14 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s, - const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter) const = 0; - virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, - bool showZero) const = 0; + virtual void print( + const std::string& s, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter) const = 0; + virtual void dot(std::ostream& os, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; virtual bool equals(const Node& other, const CompareFunc& compare = @@ -111,7 +126,7 @@ namespace gtsam { public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; /// a DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; @@ -164,7 +179,16 @@ namespace gtsam { DecisionTree(const DecisionTree& other, std::function Y_of_X); - /** Convert from a different type, also transate labels via map. */ + /** + * @brief Convert from a different node type X to node type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous node type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ template DecisionTree(const DecisionTree& other, const std::map& L_of_M, std::function Y_of_X); @@ -173,9 +197,16 @@ namespace gtsam { /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter) const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter) const; // Testable bool equals(const DecisionTree& other, @@ -220,16 +251,20 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, bool showZero = true) const; + void dot(std::ostream& os, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, bool showZero = true) const; + void dot(const std::string& name, + const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, + bool showZero = true) const; /** output to graphviz format string */ - std::string dot(const LabelFormatter& labelFormatter, - const ValueFormatter& valueFormatter, + std::string dot(const LabelFormatter& labelFormatter = &DefaultFormatter, + const ValueFormatter& valueFormatter = &DefaultFormatter, bool showZero = true) const; /// @name Advanced Interface From 174490eb510dc39b5cc2b9f2c50764081f99f092 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 2 Jan 2022 23:49:47 -0500 Subject: [PATCH 2/4] kill commented out code --- gtsam/discrete/tests/testDecisionTree.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index cc61a382f9..53f3c43797 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -45,15 +45,6 @@ struct Crazy { double b; }; -// bool equals(const Crazy& other, double tol = 1e-12) const { -// return a == other.a && std::abs(b - other.b) < tol; -// } - -// bool operator==(const Crazy& other) const { -// return this->equals(other); -// } -// }; - struct CrazyDecisionTree : public DecisionTree { /// print to stdout void print(const std::string& s = "") const { @@ -261,8 +252,6 @@ TEST(DT, conversion) return y != 0; }; BDT f2(f1, ordering, bool_of_int); - // f1.print("f1"); - // f2.print("f2"); // create a value Assignment