discretePosterior in HNFG
parent
14d1594bd1
commit
1bb5b9551b
|
@ -233,7 +233,7 @@ static double PotentiallyPrunedComponentError(
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [this, &continuousValues](const auto &pair) {
|
auto errorFunc = [&continuousValues](const auto &pair) {
|
||||||
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||||
|
|
|
@ -181,19 +181,19 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
||||||
const Values& values) const {
|
const Values& continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> result(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (auto& factor : factors_) {
|
for (auto& factor : factors_) {
|
||||||
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
||||||
// Compute factor error and add it.
|
// Compute factor error and add it.
|
||||||
result = result + hnf->errorTree(values);
|
result = result + hnf->errorTree(continuousValues);
|
||||||
|
|
||||||
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||||
// If continuous only, get the (double) error
|
// If continuous only, get the (double) error
|
||||||
// and add it to every leaf of the result
|
// and add it to every leaf of the result
|
||||||
result = result + nf->error(values);
|
result = result + nf->error(continuousValues);
|
||||||
|
|
||||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
// If discrete, just add its errorTree as well
|
// If discrete, just add its errorTree as well
|
||||||
|
@ -210,4 +210,16 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
|
||||||
|
const Values& continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
|
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||||
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return exp(-error);
|
||||||
|
});
|
||||||
|
return p / p.sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -98,10 +98,23 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
||||||
*
|
*
|
||||||
* @note: Gaussian and hybrid Gaussian factors are not considered!
|
* @note: Gaussian and hybrid Gaussian factors are not considered!
|
||||||
*
|
*
|
||||||
* @param values Manifold values at which to compute the error.
|
* @param continuousValues Manifold values at which to compute the error.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> errorTree(const Values& values) const;
|
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||||
|
* This is efficient as this simply takes -exp(.) of errorTree and normalizes.
|
||||||
|
*
|
||||||
|
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||||
|
* which we would need, are hard to recover.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x to condition on.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
|
const Values& continuousValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue