common errorTree method and its use in HybridGaussianFactorGraph

release/4.3a0
Varun Agrawal 2024-09-19 21:16:56 -04:00
parent 245f3e042e
commit cd3c590f32
6 changed files with 42 additions and 33 deletions

View File

@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {

View File

@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;
/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

View File

@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
/// @}
private:

View File

@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);
/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const override;
/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.

View File

@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const override;
/**
* @brief Compute the log-likelihood, including the log-normalizing constant.

View File

@ -539,36 +539,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;
auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}
if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
error_tree =
error_tree + AlgebraicDecisionTree<Key>(f->error(continuousValues));
}
}
return error_tree;
}