discretePosterior in HNFG

release/4.3a0
Frank Dellaert 2024-10-01 21:41:58 -07:00
parent 14d1594bd1
commit 1bb5b9551b
3 changed files with 31 additions and 6 deletions

View File

@ -233,7 +233,7 @@ static double PotentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [this, &continuousValues](const auto &pair) { auto errorFunc = [&continuousValues](const auto &pair) {
return PotentiallyPrunedComponentError(pair.first, continuousValues); return PotentiallyPrunedComponentError(pair.first, continuousValues);
}; };
DecisionTree<Key, double> error_tree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);

View File

@ -181,19 +181,19 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree( AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
const Values& values) const { const Values& continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0); AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each factor. // Iterate over each factor.
for (auto& factor : factors_) { for (auto& factor : factors_) {
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) { if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
// Compute factor error and add it. // Compute factor error and add it.
result = result + hnf->errorTree(values); result = result + hnf->errorTree(continuousValues);
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) { } else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
// If continuous only, get the (double) error // If continuous only, get the (double) error
// and add it to every leaf of the result // and add it to every leaf of the result
result = result + nf->error(values); result = result + nf->error(continuousValues);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) { } else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// If discrete, just add its errorTree as well // If discrete, just add its errorTree as well
@ -210,4 +210,16 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
return result; return result;
} }
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
const Values& continuousValues) const {
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return p / p.sum();
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -98,10 +98,23 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* *
* @note: Gaussian and hybrid Gaussian factors are not considered! * @note: Gaussian and hybrid Gaussian factors are not considered!
* *
* @param values Manifold values at which to compute the error. * @param continuousValues Manifold values at which to compute the error.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> errorTree(const Values& values) const; AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const;
/**
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
* This is efficient as this simply takes -exp(.) of errorTree and normalizes.
*
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
* which we would need, are hard to recover.
*
* @param continuousValues Continuous values x to condition on.
* @return DecisionTreeFactor
*/
AlgebraicDecisionTree<Key> discretePosterior(
const Values& continuousValues) const;
/// @} /// @}
}; };