discretePosterior
parent
38ed609614
commit
a898ad3661
|
@ -23,6 +23,7 @@
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
@ -42,7 +43,6 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -342,14 +342,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
||||||
|
static auto GetDiscreteKeys =
|
||||||
|
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
|
||||||
|
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
||||||
|
return {discreteKeySet.begin(), discreteKeySet.end()};
|
||||||
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
// Since we eliminate all continuous variables first,
|
// Since we eliminate all continuous variables first,
|
||||||
// the discrete separator will be *all* the discrete keys.
|
// the discrete separator will be *all* the discrete keys.
|
||||||
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||||
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
|
||||||
keysForDiscreteVariables.end());
|
|
||||||
|
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved.
|
// decision tree indexed by all discrete keys involved.
|
||||||
|
@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
||||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||||
// NOTE: The 0.5 term is handled by each factor
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
return exp(-error);
|
return exp(-error);
|
||||||
});
|
});
|
||||||
return prob_tree;
|
return {GetDiscreteKeys(*this), prob_tree};
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
auto p = probPrime(continuousValues);
|
||||||
|
return {p.size(), p};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
AlgebraicDecisionTree<Key> errorTree(
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
const VectorValues& continuousValues) const;
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
|
||||||
* for each discrete assignment, and return as a tree.
|
|
||||||
*
|
|
||||||
* @param continuousValues Continuous values at which to compute the
|
|
||||||
* probability.
|
|
||||||
* @return AlgebraicDecisionTree<Key>
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> probPrime(
|
|
||||||
const VectorValues& continuousValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the unnormalized posterior probability for a continuous
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
* vector values given a specific assignment.
|
* vector values given a specific assignment.
|
||||||
|
@ -206,6 +196,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
double probPrime(const HybridValues& values) const;
|
double probPrime(const HybridValues& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||||
|
* for each discrete assignment, and return as a tree.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values at which to compute probability.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||||
|
* This is very efficient as this simply probPrime normalized into a
|
||||||
|
* conditional.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x to condition on.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
DiscreteConditional discretePosterior(
|
||||||
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||||
* graph.
|
* graph.
|
||||||
|
|
|
@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
|
// Create switching network with three continuous variables and two discrete:
|
||||||
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||||
Switching s(3);
|
Switching s(3);
|
||||||
|
|
||||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
const HybridValues delta = hybridBayesNet->optimize();
|
||||||
auto error_tree = graph.errorTree(delta.continuous());
|
|
||||||
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
// regression test for errorTree
|
||||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
|
||||||
|
const auto error_tree = graph.errorTree(delta.continuous());
|
||||||
// regression
|
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
|
||||||
|
|
||||||
|
// regression test for probPrime
|
||||||
|
const DecisionTreeFactor expectedFactor(
|
||||||
|
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
|
||||||
auto probabilities = graph.probPrime(delta.continuous());
|
auto probabilities = graph.probPrime(delta.continuous());
|
||||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
|
||||||
0.99029064};
|
|
||||||
AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
|
|
||||||
|
|
||||||
// regression
|
// regression test for discretePosterior
|
||||||
EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
|
const DecisionTreeFactor normalized(
|
||||||
|
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
||||||
|
DiscreteConditional expectedPosterior(2, normalized);
|
||||||
|
auto posterior = graph.discretePosterior(delta.continuous());
|
||||||
|
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
Loading…
Reference in New Issue