discretePosterior
parent
38ed609614
commit
a898ad3661
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
@ -42,7 +43,6 @@
|
|||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
@ -342,14 +342,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
||||
static auto GetDiscreteKeys =
|
||||
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
|
||||
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
||||
return {discreteKeySet.begin(), discreteKeySet.end()};
|
||||
};
|
||||
|
||||
/* *******************************************************************************/
|
||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||
// Since we eliminate all continuous variables first,
|
||||
// the discrete separator will be *all* the discrete keys.
|
||||
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
||||
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
||||
keysForDiscreteVariables.end());
|
||||
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved.
|
||||
|
@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return exp(-error);
|
||||
});
|
||||
return prob_tree;
|
||||
return {GetDiscreteKeys(*this), prob_tree};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto p = probPrime(continuousValues);
|
||||
return {p.size(), p};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the unnormalized posterior probability for a continuous
|
||||
* vector values given a specific assignment.
|
||||
|
@ -206,6 +196,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
double probPrime(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute probability.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||
* This is very efficient as this simply probPrime normalized into a
|
||||
* conditional.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition on.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
DiscreteConditional discretePosterior(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||
* graph.
|
||||
|
|
|
@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
|||
/* ****************************************************************************/
|
||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||
// Create switching network with three continuous variables and two discrete:
|
||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||
Switching s(3);
|
||||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.errorTree(delta.continuous());
|
||||
const HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
// regression test for errorTree
|
||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
|
||||
const auto error_tree = graph.errorTree(delta.continuous());
|
||||
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||
|
||||
// regression test for probPrime
|
||||
const DecisionTreeFactor expectedFactor(
|
||||
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
|
||||
auto probabilities = graph.probPrime(delta.continuous());
|
||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
||||
0.99029064};
|
||||
AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
|
||||
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
|
||||
// regression test for discretePosterior
|
||||
const DecisionTreeFactor normalized(
|
||||
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
|
||||
DiscreteConditional expectedPosterior(2, normalized);
|
||||
auto posterior = graph.discretePosterior(delta.continuous());
|
||||
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
Loading…
Reference in New Issue