discretePosterior

release/4.3a0
Frank Dellaert 2024-09-29 11:54:54 -07:00
parent 38ed609614
commit a898ad3661
3 changed files with 59 additions and 31 deletions

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -42,7 +43,6 @@
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <iterator>
#include <memory>
#include <stdexcept>
#include <utility>
@ -342,14 +342,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
}
/* *******************************************************************************/
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
static auto GetDiscreteKeys =
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
return {discreteKeySet.begin(), discreteKeySet.end()};
};
/* *******************************************************************************/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
keysForDiscreteVariables.end());
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
DecisionTreeFactor HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return prob_tree;
return {GetDiscreteKeys(*this), prob_tree};
}
/* ************************************************************************ */
DiscreteConditional HybridGaussianFactorGraph::discretePosterior(
const VectorValues &continuousValues) const {
auto p = probPrime(continuousValues);
return {p.size(), p};
}
/* ************************************************************************ */

View File

@ -18,6 +18,7 @@
#pragma once
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;
/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
@ -206,6 +196,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
double probPrime(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute probability.
* @return DecisionTreeFactor
*/
DecisionTreeFactor probPrime(const VectorValues& continuousValues) const;
/**
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
* This is very efficient as this simply probPrime normalized into a
* conditional.
*
* @param continuousValues Continuous values x to condition on.
* @return DecisionTreeFactor
*/
DiscreteConditional discretePosterior(
const VectorValues& continuousValues) const;
/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.

View File

@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
/* ****************************************************************************/
// Test hybrid gaussian factor graph error and unnormalized probabilities
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
// Create switching network with three continuous variables and two discrete:
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
Switching s(3);
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.errorTree(delta.continuous());
const HybridValues delta = hybridBayesNet->optimize();
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
// regression test for errorTree
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
const auto error_tree = graph.errorTree(delta.continuous());
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
// regression test for probPrime
const DecisionTreeFactor expectedFactor(
s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064});
auto probabilities = graph.probPrime(delta.continuous());
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
0.99029064};
AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
EXPECT(assert_equal(expectedFactor, probabilities, 1e-7));
// regression
EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
// regression test for discretePosterior
const DecisionTreeFactor normalized(
s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852});
DiscreteConditional expectedPosterior(2, normalized);
auto posterior = graph.discretePosterior(delta.continuous());
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
}
/* ****************************************************************************/