From 7e32a8739e02ebbdeebc0ec1ca6b5e7c146a2bc7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 6 Jan 2023 21:07:51 -0800 Subject: [PATCH] Removed HybridDiscreteFactor wrapper --- gtsam/hybrid/HybridDiscreteFactor.cpp | 61 ------------- gtsam/hybrid/HybridDiscreteFactor.h | 91 ------------------- gtsam/hybrid/HybridFactor.h | 5 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 97 ++++++++++----------- gtsam/hybrid/HybridGaussianFactorGraph.h | 11 ++- gtsam/hybrid/HybridGaussianISAM.cpp | 2 +- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 7 +- gtsam/hybrid/HybridNonlinearFactorGraph.h | 8 +- gtsam/hybrid/hybrid.i | 15 +--- gtsam/hybrid/tests/Switching.h | 16 ++-- 10 files changed, 69 insertions(+), 244 deletions(-) delete mode 100644 gtsam/hybrid/HybridDiscreteFactor.cpp delete mode 100644 gtsam/hybrid/HybridDiscreteFactor.h diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp deleted file mode 100644 index afdb6472a..000000000 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file HybridDiscreteFactor.cpp - * @brief Wrapper for a discrete factor - * @date Mar 11, 2022 - * @author Fan Jiang - */ - -#include -#include - -#include - -#include "gtsam/discrete/DecisionTreeFactor.h" - -namespace gtsam { - -/* ************************************************************************ */ -HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) - : Base(boost::dynamic_pointer_cast(other) - ->discreteKeys()), - inner_(other) {} - -/* ************************************************************************ */ -HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) - : Base(dtf.discreteKeys()), - inner_(boost::make_shared(std::move(dtf))) {} - -/* ************************************************************************ */ -bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { - const This *e = dynamic_cast(&lf); - if (e == nullptr) return false; - if (!Base::equals(*e, tol)) return false; - return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) - : !(e->inner_); -} - -/* ************************************************************************ */ -void HybridDiscreteFactor::print(const std::string &s, - const KeyFormatter &formatter) const { - HybridFactor::print(s, formatter); - inner_->print("\n", formatter); -}; - -/* ************************************************************************ */ -double HybridDiscreteFactor::error(const HybridValues &values) const { - return -log((*inner_)(values.discrete())); -} -/* ************************************************************************ */ - -} // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h deleted file mode 100644 index 7a43ab3a5..000000000 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ /dev/null @@ -1,91 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file HybridDiscreteFactor.h - * @date Mar 11, 2022 - * @author Fan Jiang - * @author Varun Agrawal - */ - -#pragma once - -#include -#include -#include - -namespace gtsam { - -class HybridValues; - -/** - * A HybridDiscreteFactor is a thin container for DiscreteFactor, which - * allows us to hide the implementation of DiscreteFactor and thus avoid - * diamond inheritance. - * - * @ingroup hybrid - */ -class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { - private: - DiscreteFactor::shared_ptr inner_; - - public: - using Base = HybridFactor; - using This = HybridDiscreteFactor; - using shared_ptr = boost::shared_ptr; - - /// @name Constructors - /// @{ - - /// Default constructor - for serialization. - HybridDiscreteFactor() = default; - - // Implicit conversion from a shared ptr of DF - HybridDiscreteFactor(DiscreteFactor::shared_ptr other); - - // Forwarding constructor from concrete DecisionTreeFactor - HybridDiscreteFactor(DecisionTreeFactor &&dtf); - - /// @} - /// @name Testable - /// @{ - virtual bool equals(const HybridFactor &lf, double tol) const override; - - void print( - const std::string &s = "HybridFactor\n", - const KeyFormatter &formatter = DefaultKeyFormatter) const override; - - /// @} - /// @name Standard Interface - /// @{ - - /// Return pointer to the internal discrete factor. - DiscreteFactor::shared_ptr inner() const { return inner_; } - - /// Return the error of the underlying Discrete Factor. - double error(const HybridValues &values) const override; - /// @} - - private: - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(ARCHIVE &ar, const unsigned int /*version*/) { - ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - ar &BOOST_SERIALIZATION_NVP(inner_); - } -}; - -// traits -template <> -struct traits : public Testable {}; - -} // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 8c1b0dad3..bab38aa07 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -67,11 +67,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2); /** - * Base class for hybrid probabilistic factors + * Base class for *truly* hybrid probabilistic factors * * Examples: - * - HybridGaussianFactor - * - HybridDiscreteFactor + * - MixtureFactor * - GaussianMixtureFactor * - GaussianMixture * diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index a2f420c3f..3896782b0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -47,7 +46,6 @@ #include #include #include -#include #include #include @@ -58,6 +56,15 @@ namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; +/* ************************************************************************ */ +// Throw a runtime exception for method specified in string s, and factor f: +static void throwRuntimeError(const std::string &s, + const boost::shared_ptr &f) { + auto &fr = *f; + throw std::runtime_error(s + " not implemented for factor type " + + demangle(typeid(fr).name()) + "."); +} + /* ************************************************************************ */ static GaussianFactorGraphTree addGaussian( const GaussianFactorGraphTree &gfgTree, @@ -67,7 +74,6 @@ static GaussianFactorGraphTree addGaussian( GaussianFactorGraph result; result.push_back(factor); return GaussianFactorGraphTree(GraphAndConstant(result, 0.0)); - } else { auto add = [&factor](const GraphAndConstant &graph_z) { auto result = graph_z.graph; @@ -103,8 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } } else if (auto gf = dynamic_pointer_cast(f)) { result = addGaussian(result, gf->inner()); - } else if (dynamic_pointer_cast(f) || - dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; @@ -116,10 +121,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { "gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented " "yet."); } else { - auto &fr = *f; - throw std::invalid_argument( - std::string("gtsam::assembleGraphTree: factor type not handled: ") + - demangle(typeid(fr).name())); + throwRuntimeError("gtsam::assembleGraphTree", f); } } @@ -129,16 +131,18 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } /* ************************************************************************ */ -static std::pair +static std::pair> continuousElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { + using boost::dynamic_pointer_cast; GaussianFactorGraph gfg; for (auto &fp : factors) { - if (auto ptr = boost::dynamic_pointer_cast(fp)) { - gfg.push_back(ptr->inner()); - } else if (auto ptr = boost::static_pointer_cast(fp)) { - gfg.push_back( - boost::static_pointer_cast(ptr->inner())); + if (auto hgf = dynamic_pointer_cast(fp)) { + gfg.push_back(hgf->inner()); + } else if (auto hc = dynamic_pointer_cast(fp)) { + auto gc = hc->asGaussian(); + assert(gc); + gfg.push_back(gc); } else { // It is an orphan wrapped conditional } @@ -150,18 +154,17 @@ continuousElimination(const HybridGaussianFactorGraph &factors, } /* ************************************************************************ */ -static std::pair +static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; for (auto &factor : factors) { - if (auto p = boost::dynamic_pointer_cast(factor)) { - dfg.push_back(p->inner()); - } else if (auto p = boost::static_pointer_cast(factor)) { - auto discrete_conditional = - boost::static_pointer_cast(p->inner()); - dfg.push_back(discrete_conditional); + if (auto dtf = boost::dynamic_pointer_cast(factor)) { + dfg.push_back(dtf); + } else if (auto hc = + boost::static_pointer_cast(factor)) { + dfg.push_back(hc->asDiscrete()); } else { // It is an orphan wrapper } @@ -170,8 +173,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // NOTE: This does sum-product. For max-product, use EliminateForMPE. auto result = EliminateDiscrete(dfg, frontalKeys); - return {boost::make_shared(result.first), - boost::make_shared(result.second)}; + return {boost::make_shared(result.first), result.second}; } /* ************************************************************************ */ @@ -189,7 +191,7 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { } /* ************************************************************************ */ -static std::pair +static std::pair> hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, const KeyVector &continuousSeparator, @@ -291,7 +293,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, boost::make_shared(discreteSeparator, fdt); return {boost::make_shared(gaussianMixture), - boost::make_shared(discreteFactor)}; + discreteFactor}; } else { // Create a resulting GaussianMixtureFactor on the separator. return {boost::make_shared(gaussianMixture), @@ -314,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, * eliminate a discrete variable (as specified in the ordering), the result will * be INCORRECT and there will be NO error raised. */ -std::pair // +std::pair> // EliminateHybrid(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { // NOTE: Because we are in the Conditional Gaussian regime there are only @@ -374,14 +376,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, } // Build a map from keys to DiscreteKeys - std::unordered_map mapFromKeyToDiscreteKey; - for (auto &&factor : factors) { - if (auto p = boost::dynamic_pointer_cast(factor)) { - for (auto &k : p->discreteKeys()) { - mapFromKeyToDiscreteKey[k.first] = k; - } - } - } + auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); // Fill in discrete frontals and continuous frontals. std::set discreteFrontals; @@ -433,23 +428,25 @@ void HybridGaussianFactorGraph::add(JacobianFactor &&factor) { } /* ************************************************************************ */ -void HybridGaussianFactorGraph::add(boost::shared_ptr &factor) { +void HybridGaussianFactorGraph::add( + const boost::shared_ptr &factor) { FactorGraph::add(boost::make_shared(factor)); } /* ************************************************************************ */ void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) { - FactorGraph::add(boost::make_shared(std::move(factor))); + FactorGraph::add(std::move(factor)); } /* ************************************************************************ */ -void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { - FactorGraph::add(boost::make_shared(factor)); +void HybridGaussianFactorGraph::add( + const DecisionTreeFactor::shared_ptr &factor) { + FactorGraph::add(factor); } /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = discreteKeys(); + const KeySet discrete_keys = discreteKeySet(); const VariableIndex index(factors_); Ordering ordering = Ordering::ColamdConstrainedLast( index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); @@ -484,16 +481,11 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f) || - dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If factor at `idx` is discrete-only, we skip. continue; } else { - auto &fr = *f; - throw std::invalid_argument( - std::string( - "HybridGaussianFactorGraph::error: factor type not handled: ") + - demangle(typeid(fr).name())); + throwRuntimeError("HybridGaussianFactorGraph::error", f); } } @@ -503,9 +495,14 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( /* ************************************************************************ */ double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; - for (auto &factor : factors_) { - if (auto p = boost::dynamic_pointer_cast(factor)) { - error += p->error(values); + for (auto &f : factors_) { + if (auto hf = boost::dynamic_pointer_cast(f)) { + // TODO(dellaert): needs to change when we discard other wrappers. + error += hf->error(values); + } else if (auto dtf = boost::dynamic_pointer_cast(f)) { + error -= log((*dtf)(values.discrete())); + } else { + throwRuntimeError("HybridGaussianFactorGraph::error", f); } } return error; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 144d144bb..c5fa27651 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -50,13 +50,13 @@ class HybridValues; * @ingroup hybrid */ GTSAM_EXPORT -std::pair, HybridFactor::shared_ptr> +std::pair, boost::shared_ptr> EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys); /* ************************************************************************* */ template <> struct EliminationTraits { - typedef HybridFactor FactorType; ///< Type of factors in factor graph + typedef Factor FactorType; ///< Type of factors in factor graph typedef HybridGaussianFactorGraph FactorGraphType; ///< Type of the factor graph (e.g. ///< HybridGaussianFactorGraph) @@ -70,7 +70,7 @@ struct EliminationTraits { typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree /// The default dense elimination function static std::pair, - boost::shared_ptr > + boost::shared_ptr> DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { return EliminateHybrid(factors, keys); } @@ -80,7 +80,6 @@ struct EliminationTraits { * Hybrid Gaussian Factor Graph * ----------------------- * This is the linearized version of a hybrid factor graph. - * Everything inside needs to be hybrid factor or hybrid conditional. * * @ingroup hybrid */ @@ -130,13 +129,13 @@ class GTSAM_EXPORT HybridGaussianFactorGraph void add(JacobianFactor&& factor); /// Add a Jacobian factor as a shared ptr. - void add(boost::shared_ptr& factor); + void add(const boost::shared_ptr& factor); /// Add a DecisionTreeFactor to the factor graph. void add(DecisionTreeFactor&& factor); /// Add a DecisionTreeFactor as a shared ptr. - void add(DecisionTreeFactor::shared_ptr factor); + void add(const boost::shared_ptr& factor); /** * Add a gaussian factor *pointer* to the internal gaussian factor graph diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index aa6b3f266..3f63cb089 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering( HybridGaussianFactorGraph& factors, const HybridGaussianFactorGraph& newFactors) { // Get all the discrete keys from the factors - KeySet allDiscrete = factors.discreteKeys(); + const KeySet allDiscrete = factors.discreteKeySet(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 380469b45..bc67bd0d7 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -16,7 +16,11 @@ * @date May 28, 2022 */ +#include +#include #include +#include +#include namespace gtsam { @@ -63,8 +67,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); const auto hgf = boost::make_shared(gf); linearFG->push_back(hgf); - } else if (dynamic_pointer_cast(f) || - dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); } else { diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 59921822e..60aee431b 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -18,16 +18,12 @@ #pragma once -#include #include -#include -#include -#include -#include -#include namespace gtsam { +class HybridGaussianFactorGraph; + /** * Nonlinear Hybrid Factor Graph * ----------------------- diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index e877e5ee7..012f707e4 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -68,17 +68,6 @@ virtual class HybridConditional { double error(const gtsam::HybridValues& values) const; }; -#include -virtual class HybridDiscreteFactor { - HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf); - void print(string s = "HybridDiscreteFactor\n", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const; - gtsam::Factor* inner(); - double error(const gtsam::HybridValues &values) const; -}; - #include class GaussianMixtureFactor : gtsam::HybridFactor { GaussianMixtureFactor( @@ -217,9 +206,7 @@ class HybridNonlinearFactorGraph { HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph); void push_back(gtsam::HybridFactor* factor); void push_back(gtsam::NonlinearFactor* factor); - void push_back(gtsam::HybridDiscreteFactor* factor); - void add(gtsam::NonlinearFactor* factor); - void add(gtsam::DiscreteFactor* factor); + void push_back(gtsam::DiscreteFactor* factor); gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const; bool empty() const; diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 385a7c3d5..46831c54e 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -206,13 +206,11 @@ struct Switching { */ void addModeChain(HybridNonlinearFactorGraph *fg, std::string discrete_transition_prob = "1/2 3/2") { - auto prior = boost::make_shared(modes[0], "1/1"); - fg->push_discrete(prior); + fg->emplace_shared(modes[0], "1/1"); for (size_t k = 0; k < K - 2; k++) { auto parents = {modes[k]}; - auto conditional = boost::make_shared( - modes[k + 1], parents, discrete_transition_prob); - fg->push_discrete(conditional); + fg->emplace_shared(modes[k + 1], parents, + discrete_transition_prob); } } @@ -224,13 +222,11 @@ struct Switching { */ void addModeChain(HybridGaussianFactorGraph *fg, std::string discrete_transition_prob = "1/2 3/2") { - auto prior = boost::make_shared(modes[0], "1/1"); - fg->push_discrete(prior); + fg->emplace_shared(modes[0], "1/1"); for (size_t k = 0; k < K - 2; k++) { auto parents = {modes[k]}; - auto conditional = boost::make_shared( - modes[k + 1], parents, discrete_transition_prob); - fg->push_discrete(conditional); + fg->emplace_shared(modes[k + 1], parents, + discrete_transition_prob); } } };