discretePosterior for graphs
parent
2abb410592
commit
64513eb6d9
|
|
@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
// Iterate over each factor.
|
||||
for (auto &factor : factors_) {
|
||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
// Check for HybridFactor, and call errorTree
|
||||
error_tree = error_tree + f->errorTree(continuousValues);
|
||||
result = result + f->errorTree(continuousValues);
|
||||
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
// Skip discrete factors
|
||||
continue;
|
||||
} else {
|
||||
// Everything else is a continuous only factor
|
||||
HybridValues hv(continuousValues, DiscreteValues());
|
||||
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
|
||||
result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
|
||||
}
|
||||
}
|
||||
return error_tree;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return exp(-error);
|
||||
});
|
||||
return {GetDiscreteKeys(*this), prob_tree};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto p = probPrime(continuousValues);
|
||||
return {p.size(), p};
|
||||
return p / p.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
double probPrime(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute probability.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||
* This is very efficient as this simply probPrime normalized into a
|
||||
* conditional.
|
||||
* This is efficient as this simply probPrime normalized.
|
||||
*
|
||||
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||
* which we would need, are hard to recover.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition on.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
DiscreteConditional discretePosterior(
|
||||
AlgebraicDecisionTree<Key> discretePosterior(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
|||
const auto error_tree = graph.errorTree(delta.continuous());
|
||||
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||
|
||||
// regression test for probPrime
|
||||
const DecisionTreeFactor expectedFactor(
|
||||
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
|
||||
auto probabilities = graph.probPrime(delta.continuous());
|
||||
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
|
||||
|
||||
// regression test for discretePosterior
|
||||
const DecisionTreeFactor normalized(
|
||||
const AlgebraicDecisionTree<Key> expectedPosterior(
|
||||
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
||||
DiscreteConditional expectedPosterior(2, normalized);
|
||||
auto posterior = graph.discretePosterior(delta.continuous());
|
||||
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue