discretePosterior for graphs
parent
2abb410592
commit
64513eb6d9
|
|
@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (auto &factor : factors_) {
|
for (auto &factor : factors_) {
|
||||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
// Check for HybridFactor, and call errorTree
|
// Check for HybridFactor, and call errorTree
|
||||||
error_tree = error_tree + f->errorTree(continuousValues);
|
result = result + f->errorTree(continuousValues);
|
||||||
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
// Skip discrete factors
|
// Skip discrete factors
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
// Everything else is a continuous only factor
|
// Everything else is a continuous only factor
|
||||||
HybridValues hv(continuousValues, DiscreteValues());
|
HybridValues hv(continuousValues, DiscreteValues());
|
||||||
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
|
result = result + AlgebraicDecisionTree<Key>(factor->error(hv));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return error_tree;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||||
// NOTE: The 0.5 term is handled by each factor
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
return exp(-error);
|
return exp(-error);
|
||||||
});
|
});
|
||||||
return {GetDiscreteKeys(*this), prob_tree};
|
return p / p.sum();
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
|
|
||||||
const VectorValues &continuousValues) const {
|
|
||||||
auto p = probPrime(continuousValues);
|
|
||||||
return {p.size(), p};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
double probPrime(const HybridValues& values) const;
|
double probPrime(const HybridValues& values) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
|
||||||
* for each discrete assignment, and return as a tree.
|
|
||||||
*
|
|
||||||
* @param continuousValues Continuous values at which to compute probability.
|
|
||||||
* @return DecisionTreeFactor
|
|
||||||
*/
|
|
||||||
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||||
* This is very efficient as this simply probPrime normalized into a
|
* This is efficient as this simply probPrime normalized.
|
||||||
* conditional.
|
*
|
||||||
|
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||||
|
* which we would need, are hard to recover.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values x to condition on.
|
* @param continuousValues Continuous values x to condition on.
|
||||||
* @return DecisionTreeFactor
|
* @return DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
DiscreteConditional discretePosterior(
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
const VectorValues& continuousValues) const;
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
const auto error_tree = graph.errorTree(delta.continuous());
|
const auto error_tree = graph.errorTree(delta.continuous());
|
||||||
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||||
|
|
||||||
// regression test for probPrime
|
|
||||||
const DecisionTreeFactor expectedFactor(
|
|
||||||
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
|
|
||||||
auto probabilities = graph.probPrime(delta.continuous());
|
|
||||||
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
|
|
||||||
|
|
||||||
// regression test for discretePosterior
|
// regression test for discretePosterior
|
||||||
const DecisionTreeFactor normalized(
|
const AlgebraicDecisionTree<Key> expectedPosterior(
|
||||||
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
||||||
DiscreteConditional expectedPosterior(2, normalized);
|
|
||||||
auto posterior = graph.discretePosterior(delta.continuous());
|
auto posterior = graph.discretePosterior(delta.continuous());
|
||||||
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue