discretePosterior for graphs

release/4.3a0
Frank Dellaert 2024-09-29 16:37:02 -07:00
parent 2abb410592
commit 64513eb6d9
3 changed files with 14 additions and 35 deletions

View File

@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each factor.
for (auto &factor : factors_) {
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues);
result = result + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
} else {
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
}
}
return error_tree;
return result;
}
/* ************************************************************************ */
@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
}
/* ************************************************************************ */
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return {GetDiscreteKeys(*this), prob_tree};
}
/* ************************************************************************ */
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
const VectorValues &continuousValues) const {
auto p = probPrime(continuousValues);
return {p.size(), p};
return p / p.sum();
}
/* ************************************************************************ */

View File

@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
double probPrime(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute probability.
* @return DecisionTreeFactor
*/
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
/**
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
* This is very efficient as this simply probPrime normalized into a
* conditional.
* This is efficient as this simply probPrime normalized.
*
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
* which we would need, are hard to recover.
*
* @param continuousValues Continuous values x to condition on.
* @return DecisionTreeFactor
*/
DiscreteConditional discretePosterior(
AlgebraicDecisionTree<Key> discretePosterior(
const VectorValues& continuousValues) const;
/**

View File

@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
const auto error_tree = graph.errorTree(delta.continuous());
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
// regression test for probPrime
const DecisionTreeFactor expectedFactor(
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
auto probabilities = graph.probPrime(delta.continuous());
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
// regression test for discretePosterior
const DecisionTreeFactor normalized(
const AlgebraicDecisionTree<Key> expectedPosterior(
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
DiscreteConditional expectedPosterior(2, normalized);
auto posterior = graph.discretePosterior(delta.continuous());
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
}