common errorTree method and its use in HybridGaussianFactorGraph

release/4.3a0
Varun Agrawal 2024-09-19 21:16:56 -04:00
parent 245f3e042e
commit cd3c590f32
6 changed files with 42 additions and 33 deletions

View File

@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled"); "HybridConditional::error: conditional type not handled");
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */ /* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const { double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) { if (auto gc = asGaussian()) {

View File

@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional. /// Return the error of the underlying conditional.
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;
/// Return the log-probability (or density) of the underlying conditional. /// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override; double logProbability(const HybridValues& values) const override;

View File

@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor. /// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; } const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
/// @} /// @}
private: private:

View File

@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals); const Conditionals &conditionals);
/** /**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian * @brief Make a Hybrid Gaussian Conditional from
* conditionals. The DecisionTree-based constructor is preferred over this * a vector of Gaussian conditionals.
* one. * The DecisionTree-based constructor is preferred over this one.
* *
* @param continuousFrontals The continuous frontal variables * @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables * @param continuousParents The continuous parent variables
@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys * @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment. * only, with the leaf values as the error for each assignment.
*/ */
AlgebraicDecisionTree<Key> errorTree( virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const override;
/** /**
* @brief Compute the logProbability of this hybrid Gaussian conditional. * @brief Compute the logProbability of this hybrid Gaussian conditional.

View File

@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error. * as the factors involved, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> errorTree( virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const override;
/** /**
* @brief Compute the log-likelihood, including the log-normalizing constant. * @brief Compute the log-likelihood, including the log-normalizing constant.

View File

@ -539,36 +539,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor. // Iterate over each factor.
for (auto &factor : factors_) { for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
AlgebraicDecisionTree<Key> factor_error; error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
auto f = factor; error_tree =
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) { error_tree + AlgebraicDecisionTree<Key>(f->error(continuousValues));
f = hc->inner();
}
if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
} }
} }
return error_tree; return error_tree;
} }