discretePosterior in HNFG
parent
14d1594bd1
commit
1bb5b9551b
|
@ -233,7 +233,7 @@ static double PotentiallyPrunedComponentError(
|
|||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [this, &continuousValues](const auto &pair) {
|
||||
auto errorFunc = [&continuousValues](const auto &pair) {
|
||||
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||
|
|
|
@ -181,19 +181,19 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
|||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
||||
const Values& values) const {
|
||||
const Values& continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (auto& factor : factors_) {
|
||||
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
||||
// Compute factor error and add it.
|
||||
result = result + hnf->errorTree(values);
|
||||
result = result + hnf->errorTree(continuousValues);
|
||||
|
||||
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to every leaf of the result
|
||||
result = result + nf->error(values);
|
||||
result = result + nf->error(continuousValues);
|
||||
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
// If discrete, just add its errorTree as well
|
||||
|
@ -210,4 +210,16 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
|
||||
const Values& continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return exp(-error);
|
||||
});
|
||||
return p / p.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -98,10 +98,23 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
|||
*
|
||||
* @note: Gaussian and hybrid Gaussian factors are not considered!
|
||||
*
|
||||
* @param values Manifold values at which to compute the error.
|
||||
* @param continuousValues Manifold values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> errorTree(const Values& values) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||
* This is efficient as this simply takes -exp(.) of errorTree and normalizes.
|
||||
*
|
||||
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||
* which we would need, are hard to recover.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition on.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> discretePosterior(
|
||||
const Values& continuousValues) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue