diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 074534b8d..3cb3bba65 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const { "HybridConditional::error: conditional type not handled"); } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridConditional::errorTree( + const VectorValues &values) const { + if (auto gc = asGaussian()) { + return AlgebraicDecisionTree(gc->error(values)); + } + if (auto gm = asHybrid()) { + return gm->errorTree(values); + } + if (auto dc = asDiscrete()) { + return AlgebraicDecisionTree(0.0); + } + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); +} + /* ************************************************************************ */ double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index f44ee2bf9..0009d6bd8 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional /// Return the error of the underlying conditional. double error(const HybridValues& values) const override; + /** + * @brief Compute error of the HybridConditional as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the conditionals involved, and leaf values as the error. + */ + virtual AlgebraicDecisionTree errorTree( + const VectorValues& values) const override; + /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index ad29dfdca..fc91e0838 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// Return only the continuous keys for this factor. const KeyVector &continuousKeys() const { return continuousKeys_; } + /// Virtual class to compute tree of linear errors. + virtual AlgebraicDecisionTree errorTree( + const VectorValues &values) const = 0; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 72a999472..5e585acef 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional const Conditionals &conditionals); /** - * @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian - * conditionals. The DecisionTree-based constructor is preferred over this - * one. + * @brief Make a Hybrid Gaussian Conditional from + * a vector of Gaussian conditionals. + * The DecisionTree-based constructor is preferred over this one. * * @param continuousFrontals The continuous frontal variables * @param continuousParents The continuous parent variables @@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional * @return AlgebraicDecisionTree A decision tree on the discrete keys * only, with the leaf values as the error for each assignment. */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; + virtual AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const override; /** * @brief Compute the logProbability of this hybrid Gaussian conditional. diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index a86714863..8d57ad7da 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; + virtual AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const override; /** * @brief Compute the log-likelihood, including the log-normalizing constant. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 28a0c446f..0d4e534e1 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -539,36 +539,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); - // Iterate over each factor. for (auto &factor : factors_) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - AlgebraicDecisionTree factor_error; - - auto f = factor; - if (auto hc = dynamic_pointer_cast(factor)) { - f = hc->inner(); - } - - if (auto hybridGaussianCond = - dynamic_pointer_cast(f)) { - // Compute factor error and add it. - error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues); - } else if (auto gaussian = dynamic_pointer_cast(f)) { - // If continuous only, get the (double) error - // and add it to the error_tree - double error = gaussian->error(continuousValues); - // Add the gaussian factor error to every leaf of the error tree. - error_tree = error_tree.apply( - [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { - // If factor at `idx` is discrete-only, we skip. - continue; - } else { - throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f); + if (auto f = std::dynamic_pointer_cast(factor)) { + error_tree = error_tree + f->errorTree(continuousValues); + } else if (auto f = std::dynamic_pointer_cast(factor)) { + error_tree = + error_tree + AlgebraicDecisionTree(f->error(continuousValues)); } } - return error_tree; }