common errorTree method and its use in HybridGaussianFactorGraph
parent
245f3e042e
commit
cd3c590f32
|
|
@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
|
|||
"HybridConditional::error: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
|
||||
const VectorValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return AlgebraicDecisionTree<Key>(gc->error(values));
|
||||
}
|
||||
if (auto gm = asHybrid()) {
|
||||
return gm->errorTree(values);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
return AlgebraicDecisionTree<Key>(0.0);
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
|
|
|
|||
|
|
@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
|
|||
/// Return the error of the underlying conditional.
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the HybridConditional as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the conditionals involved, and leaf values as the error.
|
||||
*/
|
||||
virtual AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues& values) const override;
|
||||
|
||||
/// Return the log-probability (or density) of the underlying conditional.
|
||||
double logProbability(const HybridValues& values) const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
/// Return only the continuous keys for this factor.
|
||||
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
||||
|
||||
/// Virtual class to compute tree of linear errors.
|
||||
virtual AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &values) const = 0;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
const Conditionals &conditionals);
|
||||
|
||||
/**
|
||||
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
|
||||
* conditionals. The DecisionTree-based constructor is preferred over this
|
||||
* one.
|
||||
* @brief Make a Hybrid Gaussian Conditional from
|
||||
* a vector of Gaussian conditionals.
|
||||
* The DecisionTree-based constructor is preferred over this one.
|
||||
*
|
||||
* @param continuousFrontals The continuous frontal variables
|
||||
* @param continuousParents The continuous parent variables
|
||||
|
|
@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||
* only, with the leaf values as the error for each assignment.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
virtual AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const override;
|
||||
|
||||
/**
|
||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||
|
|
|
|||
|
|
@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the factors involved, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
virtual AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const override;
|
||||
|
||||
/**
|
||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||
|
|
|
|||
|
|
@ -539,36 +539,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (auto &factor : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
auto f = factor;
|
||||
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
f = hc->inner();
|
||||
}
|
||||
|
||||
if (auto hybridGaussianCond =
|
||||
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
// Compute factor error and add it.
|
||||
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
|
||||
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = gaussian->error(continuousValues);
|
||||
// Add the gaussian factor error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
|
||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
error_tree = error_tree + f->errorTree(continuousValues);
|
||||
} else if (auto f = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
error_tree =
|
||||
error_tree + AlgebraicDecisionTree<Key>(f->error(continuousValues));
|
||||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue