discretePosterior for graphs

release/4.3a0
Frank Dellaert 2024-09-29 16:37:02 -07:00
parent 2abb410592
commit 64513eb6d9
3 changed files with 14 additions and 35 deletions

View File

@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each factor. // Iterate over each factor.
for (auto &factor : factors_) { for (auto &factor : factors_) {
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) { if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree // Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues); result = result + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) { } else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors // Skip discrete factors
continue; continue;
} else { } else {
// Everything else is a continuous only factor // Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues()); HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv)); result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
} }
} }
return error_tree; return result;
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor HybridGaussianFactorGraph::probPrime( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues); AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) { AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor // NOTE: The 0.5 term is handled by each factor
return exp(-error); return exp(-error);
}); });
return {GetDiscreteKeys(*this), prob_tree}; return p / p.sum();
}
/* ************************************************************************ */
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
const VectorValues &continuousValues) const {
auto p = probPrime(continuousValues);
return {p.size(), p};
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/ */
double probPrime(const HybridValues& values) const; double probPrime(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute probability.
* @return DecisionTreeFactor
*/
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
/** /**
* @brief Computer posterior P(M|X=x) when all continuous values X are given. * @brief Computer posterior P(M|X=x) when all continuous values X are given.
* This is very efficient as this simply probPrime normalized into a * This is efficient as this simply probPrime normalized.
* conditional. *
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
* which we would need, are hard to recover.
* *
* @param continuousValues Continuous values x to condition on. * @param continuousValues Continuous values x to condition on.
* @return DecisionTreeFactor * @return DecisionTreeFactor
*/ */
DiscreteConditional discretePosterior( AlgebraicDecisionTree<Key> discretePosterior(
const VectorValues& continuousValues) const; const VectorValues& continuousValues) const;
/** /**

View File

@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
const auto error_tree = graph.errorTree(delta.continuous()); const auto error_tree = graph.errorTree(delta.continuous());
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
// regression test for probPrime
const DecisionTreeFactor expectedFactor(
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
auto probabilities = graph.probPrime(delta.continuous());
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
// regression test for discretePosterior // regression test for discretePosterior
const DecisionTreeFactor normalized( const AlgebraicDecisionTree<Key> expectedPosterior(
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
DiscreteConditional expectedPosterior(2, normalized);
auto posterior = graph.discretePosterior(delta.continuous()); auto posterior = graph.discretePosterior(delta.continuous());
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
} }