Merge pull request #1857 from borglab/feature/posteriors
commit
b89e9c9a24
|
@ -70,6 +70,7 @@ namespace gtsam {
|
|||
return a / b;
|
||||
}
|
||||
static inline double id(const double& x) { return x; }
|
||||
static inline double negate(const double& x) { return -x; }
|
||||
};
|
||||
|
||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||
|
@ -186,6 +187,16 @@ namespace gtsam {
|
|||
return this->apply(g, &Ring::add);
|
||||
}
|
||||
|
||||
/** negation */
|
||||
AlgebraicDecisionTree operator-() const {
|
||||
return this->apply(&Ring::negate);
|
||||
}
|
||||
|
||||
/** subtract */
|
||||
AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const {
|
||||
return *this + (-g);
|
||||
}
|
||||
|
||||
/** product */
|
||||
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
|
||||
return this->apply(g, &Ring::mul);
|
||||
|
|
|
@ -131,7 +131,7 @@ namespace gtsam {
|
|||
|
||||
/// Calculate probability for given values `x`,
|
||||
/// is just look up in AlgebraicDecisionTree.
|
||||
double evaluate(const DiscreteValues& values) const {
|
||||
double evaluate(const Assignment<Key>& values) const {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
|
@ -155,7 +155,7 @@ namespace gtsam {
|
|||
return apply(f, safe_div);
|
||||
}
|
||||
|
||||
/// Convert into a decisiontree
|
||||
/// Convert into a decision tree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
|
|
|
@ -10,10 +10,10 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDecisionTree.cpp
|
||||
* @brief Develop DecisionTree
|
||||
* @author Frank Dellaert
|
||||
* @date Mar 6, 2011
|
||||
* @file testAlgebraicDecisionTree.cpp
|
||||
* @brief Unit tests for Algebraic decision tree
|
||||
* @author Frank Dellaert
|
||||
* @date Mar 6, 2011
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) {
|
|||
#endif
|
||||
}
|
||||
|
||||
/** I can't get this to work !
|
||||
class Mul: std::function<double(const double&, const double&)> {
|
||||
inline double operator()(const double& a, const double& b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
/* ************************************************************************** */
|
||||
// Test arithmetic:
|
||||
TEST(ADT, arithmetic) {
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
ADT zero{0}, one{1};
|
||||
ADT a(A, 1, 2);
|
||||
ADT b(B, 3, 4);
|
||||
|
||||
// If second argument of binary op is Leaf
|
||||
template<typename L>
|
||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||
Ptr h(new Choice(label(), cardinality()));
|
||||
for(const NodePtr& branch: branches_)
|
||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||
return Unique(cache, h);
|
||||
}
|
||||
*/
|
||||
// Addition
|
||||
CHECK(assert_equal(a, zero + a));
|
||||
|
||||
// Negate and subtraction
|
||||
CHECK(assert_equal(-a, zero - a));
|
||||
CHECK(assert_equal({zero}, a - a));
|
||||
CHECK(assert_equal(a + b, b + a));
|
||||
CHECK(assert_equal({A, 3, 4}, a + 2));
|
||||
CHECK(assert_equal({B, 1, 2}, b - 2));
|
||||
|
||||
// Multiplication
|
||||
CHECK(assert_equal(zero, zero * a));
|
||||
CHECK(assert_equal(zero, a * zero));
|
||||
CHECK(assert_equal(a, one * a));
|
||||
CHECK(assert_equal(a, a * one));
|
||||
CHECK(assert_equal(a * b, b * a));
|
||||
|
||||
// division
|
||||
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
|
||||
CHECK(assert_equal(b, (a * b) / a));
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// instrumented operators
|
||||
|
|
|
@ -17,10 +17,13 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
// In Wrappers we have no access to this so have a default ready
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
|
@ -38,135 +41,26 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
|
||||
* @param conditional Conditional to prune. Used to get full assignment.
|
||||
* @return std::function<double(const Assignment<Key> &, double)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &prunedDiscreteProbs,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the hybrid Gaussian conditional.
|
||||
std::set<DiscreteKey> discreteProbsKeySet =
|
||||
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
|
||||
std::set<DiscreteKey> conditionalKeySet =
|
||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
// The implementation is: build the entire joint into one factor and then prune.
|
||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||
const DiscreteBayesNet marginal = discreteMarginal();
|
||||
|
||||
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// This corresponds to 0 probability
|
||||
double pruned_prob = 0.0;
|
||||
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
// Case where the hybrid Gaussian conditional has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == discreteProbsKeySet) {
|
||||
if (prunedDiscreteProbs(values) == 0) {
|
||||
return pruned_prob;
|
||||
} else {
|
||||
return probability;
|
||||
}
|
||||
} else {
|
||||
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||
// get a `values` which doesn't have the full set of keys.
|
||||
std::set<Key> valuesKeys;
|
||||
for (auto kvp : values) {
|
||||
valuesKeys.insert(kvp.first);
|
||||
}
|
||||
std::set<Key> conditionalKeys;
|
||||
for (auto kvp : conditionalKeySet) {
|
||||
conditionalKeys.insert(kvp.first);
|
||||
}
|
||||
// If true, then values is missing some keys
|
||||
if (conditionalKeys != valuesKeys) {
|
||||
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||
std::vector<Key> missing_keys;
|
||||
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||
valuesKeys.begin(), valuesKeys.end(),
|
||||
std::back_inserter(missing_keys));
|
||||
// Insert missing keys with a default assignment.
|
||||
for (auto missing_key : missing_keys) {
|
||||
values[missing_key] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now we generate the full assignment by enumerating
|
||||
// over all keys in the prunedDiscreteProbs.
|
||||
// First we find the differing keys
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(discreteProbsKeySet.begin(),
|
||||
discreteProbsKeySet.end(), conditionalKeySet.begin(),
|
||||
conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
// Now enumerate over all assignments of the differing keys
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment);
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (prunedDiscreteProbs(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return pruned_prob;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||
size_t maxNrLeaves) {
|
||||
// Get the joint distribution of only the discrete keys
|
||||
// The joint discrete probability.
|
||||
DiscreteConditional discreteProbs;
|
||||
|
||||
std::vector<size_t> discrete_factor_idxs;
|
||||
// Record frontal keys so we can maintain ordering
|
||||
Ordering discrete_frontals;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
auto conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||
|
||||
Ordering conditional_keys(conditional->frontals());
|
||||
discrete_frontals += conditional_keys;
|
||||
discrete_factor_idxs.push_back(i);
|
||||
}
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
joint = joint * (*conditional);
|
||||
}
|
||||
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteProbs.prune(maxNrLeaves);
|
||||
// Prune the joint. NOTE: again, possibly quite expensive.
|
||||
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
|
||||
|
||||
// Eliminate joint probability back into conditionals
|
||||
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||
|
||||
// Assign pruned discrete conditionals back at the correct indices.
|
||||
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||
size_t idx = discrete_factor_idxs.at(i);
|
||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||
}
|
||||
|
||||
return prunedDiscreteProbs;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
DecisionTreeFactor prunedDiscreteProbs =
|
||||
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||
// Create a the result starting with the pruned joint.
|
||||
HybridBayesNet result;
|
||||
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
|
@ -175,28 +69,34 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
* We can later check the HybridGaussianConditional for just nullptrs.
|
||||
*/
|
||||
|
||||
HybridBayesNet prunedBayesNetFragment;
|
||||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||
// per pruned Discrete joint.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
// Make a copy of the hybrid Gaussian conditional and prune it!
|
||||
auto prunedHybridGaussianConditional =
|
||||
std::make_shared<HybridGaussianConditional>(*gm);
|
||||
prunedHybridGaussianConditional->prune(
|
||||
prunedDiscreteProbs); // imperative :-(
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);
|
||||
|
||||
} else {
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// Add the non-HybridGaussianConditional conditional
|
||||
prunedBayesNetFragment.push_back(conditional);
|
||||
result.push_back(gc);
|
||||
}
|
||||
// We ignore DiscreteConditional as they are already pruned and added.
|
||||
}
|
||||
|
||||
return prunedBayesNetFragment;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteBayesNet HybridBayesNet::discreteMarginal() const {
|
||||
DiscreteBayesNet result;
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto dc = conditional->asDiscrete()) {
|
||||
result.push_back(dc);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -291,66 +191,19 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
|||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
// If conditional is hybrid, compute error for all assignments.
|
||||
result = result + gm->errorTree(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the error and add it to the result
|
||||
double error = gc->error(continuousValues);
|
||||
// Add the computed error to every leaf of the result tree.
|
||||
result = result.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If discrete, add the discrete error in the right branch
|
||||
result = result.apply(
|
||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||
return leaf_value + dc->error(DiscreteValues(assignment));
|
||||
});
|
||||
}
|
||||
result = result + conditional->errorTree(continuousValues);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment and compute
|
||||
// logProbability.
|
||||
result = result + gm->logProbability(continuousValues);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the (double) logProbability and add it to the
|
||||
// result
|
||||
double logProbability = gc->logProbability(continuousValues);
|
||||
// Add the computed logProbability to every leaf of the logProbability
|
||||
// tree.
|
||||
result = result.apply([logProbability](double leaf_value) {
|
||||
return leaf_value + logProbability;
|
||||
});
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If discrete, add the discrete logProbability in the right branch
|
||||
result = result.apply(
|
||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
||||
return tree.apply([](double log) { return exp(log); });
|
||||
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> p =
|
||||
errors.apply([](double error) { return exp(-error); });
|
||||
return p / p.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
@ -77,16 +78,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
}
|
||||
|
||||
/**
|
||||
* Add a conditional using a shared_ptr, using implicit conversion to
|
||||
* a HybridConditional.
|
||||
*
|
||||
* This is useful when you create a conditional shared pointer as you need it
|
||||
* somewhere else.
|
||||
*
|
||||
* Move a HybridConditional into a shared pointer and add.
|
||||
|
||||
* Example:
|
||||
* auto shared_ptr_to_a_conditional =
|
||||
* std::make_shared<HybridGaussianConditional>(...);
|
||||
* hbn.push_back(shared_ptr_to_a_conditional);
|
||||
* HybridGaussianConditional conditional(...);
|
||||
* hbn.push_back(conditional); // loses the original conditional
|
||||
*/
|
||||
void push_back(HybridConditional &&conditional) {
|
||||
factors_.push_back(
|
||||
|
@ -124,13 +120,21 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
||||
* value assignment.
|
||||
* @brief Get the discrete Bayes Net P(M). As the hybrid Bayes net defines
|
||||
* P(X,M) = P(X|M) P(M), this method returns the marginal distribution on the
|
||||
* discrete variables.
|
||||
*
|
||||
* @note Any pure discrete factors are ignored.
|
||||
* @return discrete marginal as a DiscreteBayesNet.
|
||||
*/
|
||||
DiscreteBayesNet discreteMarginal() const;
|
||||
|
||||
/**
|
||||
* @brief Get the Gaussian Bayes net P(X|M=m) corresponding to a specific
|
||||
* assignment m for the discrete variables M. As the hybrid Bayes net defines
|
||||
* P(X,M) = P(X|M) P(M), this method returns the **posterior** p(X|M=m).
|
||||
*
|
||||
* @param assignment The discrete value assignment for the discrete keys.
|
||||
* @return GaussianBayesNet
|
||||
* @return Gaussian posterior P(X|M=m) as a GaussianBayesNet.
|
||||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
|
@ -201,18 +205,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
HybridValues sample() const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
@ -221,29 +220,33 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using Base::error;
|
||||
|
||||
/**
|
||||
* @brief Compute log probability for each discrete assignment,
|
||||
* and return as a tree.
|
||||
* @brief Compute the negative log posterior log P'(M|x) of all assignments up
|
||||
* to a constant, returning the result as an algebraic decision tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which
|
||||
* to compute the log probability.
|
||||
* @note The joint P(X,M) is p(X|M) P(M)
|
||||
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
|
||||
* Ideally we want log P(M|x) = log p(x|M) + log P(M) - log p(x), but
|
||||
* unfortunately log p(x) is expensive, so we compute the log of the
|
||||
* unnormalized posterior log P'(M|x) = log p(x|M) + log P(M)
|
||||
*
|
||||
* @param continuousValues Continuous values x at which to compute log P'(M|x)
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
using BayesNet::logProbability; // expose HybridValues version
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability q(μ|M),
|
||||
* for each discrete assignment, and return as a tree.
|
||||
* q(μ|M) is the unnormalized probability at the MLE point μ,
|
||||
* conditioned on the discrete variables.
|
||||
* @brief Compute normalized posterior P(M|X=x) and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||
* which we would need, are hard to recover.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition P(M|X=x) on.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> evaluate(
|
||||
AlgebraicDecisionTree<Key> discretePosterior(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
|
@ -255,13 +258,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// @}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Prune all the discrete conditionals.
|
||||
*
|
||||
* @param maxNrLeaves
|
||||
*/
|
||||
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -26,6 +26,10 @@
|
|||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gtsam/hybrid/HybridConditional.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
|
@ -207,7 +211,9 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
if (conditional->isHybrid()) {
|
||||
auto hybridGaussianCond = conditional->asHybrid();
|
||||
|
||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs);
|
||||
// Imperative
|
||||
clique->conditional() = std::make_shared<HybridConditional>(
|
||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
|
|
|
@ -64,7 +64,6 @@ void HybridConditional::print(const std::string &s,
|
|||
|
||||
if (inner_) {
|
||||
inner_->print("", formatter);
|
||||
|
||||
} else {
|
||||
if (isContinuous()) std::cout << "Continuous ";
|
||||
if (isDiscrete()) std::cout << "Discrete ";
|
||||
|
@ -100,79 +99,68 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
|||
if (auto gm = asHybrid()) {
|
||||
auto other = e->asHybrid();
|
||||
return other != nullptr && gm->equals(*other, tol);
|
||||
}
|
||||
if (auto gc = asGaussian()) {
|
||||
} else if (auto gc = asGaussian()) {
|
||||
auto other = e->asGaussian();
|
||||
return other != nullptr && gc->equals(*other, tol);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
auto other = e->asDiscrete();
|
||||
return other != nullptr && dc->equals(*other, tol);
|
||||
}
|
||||
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
} else
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::error(const HybridValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return gc->error(values.continuous());
|
||||
}
|
||||
if (auto gm = asHybrid()) {
|
||||
} else if (auto gm = asHybrid()) {
|
||||
return gm->error(values);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
return dc->error(values.discrete());
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
|
||||
const VectorValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return AlgebraicDecisionTree<Key>(gc->error(values));
|
||||
}
|
||||
if (auto gm = asHybrid()) {
|
||||
return {gc->error(values)}; // NOTE: a "constant" tree
|
||||
} else if (auto gm = asHybrid()) {
|
||||
return gm->errorTree(values);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
return AlgebraicDecisionTree<Key>(0.0);
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
return dc->errorTree();
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return gc->logProbability(values.continuous());
|
||||
}
|
||||
if (auto gm = asHybrid()) {
|
||||
} else if (auto gm = asHybrid()) {
|
||||
return gm->logProbability(values);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
return dc->logProbability(values.discrete());
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::logProbability: conditional type not handled");
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::logProbability: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::negLogConstant() const {
|
||||
if (auto gc = asGaussian()) {
|
||||
return gc->negLogConstant();
|
||||
}
|
||||
if (auto gm = asHybrid()) {
|
||||
return gm->negLogConstant(); // 0.0!
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
} else if (auto gm = asHybrid()) {
|
||||
return gm->negLogConstant();
|
||||
} else if (auto dc = asDiscrete()) {
|
||||
return dc->negLogConstant(); // 0.0!
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::negLogConstant: conditional type not handled");
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::negLogConstant: conditional type not handled");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -288,85 +288,32 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
|||
return s;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the hybrid gaussian conditional.
|
||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||
const DecisionTreeFactor &discreteProbs) const {
|
||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||
discreteProbs.keys().end());
|
||||
std::vector<Key> diff;
|
||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||
std::back_inserter(diff));
|
||||
|
||||
auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
// Find maximum probability value for every combination of our keys.
|
||||
Ordering keys(diff);
|
||||
auto max = discreteProbs.max(keys);
|
||||
|
||||
// Check the max value for every combination of our keys.
|
||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
const DiscreteValues values(choices);
|
||||
|
||||
// Case where the hybrid gaussian conditional has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
||||
if (discreteProbs(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
std::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return conditional;
|
||||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(
|
||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment);
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if (discreteProbs(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return nullptr;
|
||||
}
|
||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = prunerFunc(discreteProbs);
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to calculate (double) logProbability value from
|
||||
// GaussianConditional.
|
||||
auto probFunc =
|
||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||
if (conditional) {
|
||||
return conditional->logProbability(continuousValues);
|
||||
} else {
|
||||
// Return arbitrarily small logProbability if conditional is null
|
||||
// Conditional is null if it is pruned out.
|
||||
return -1e20;
|
||||
}
|
||||
};
|
||||
return DecisionTree<Key, double>(conditionals_, probFunc);
|
||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||
pruned_conditionals);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
||||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
* @author Frank Dellaert
|
||||
* @date Mar 12, 2022
|
||||
*/
|
||||
|
||||
|
@ -194,16 +195,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals() const;
|
||||
|
||||
/**
|
||||
* @brief Compute logProbability of the HybridGaussianConditional as a tree.
|
||||
*
|
||||
* @param continuousValues The continuous VectorValues.
|
||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the conditionals, and leaf values as the logProbability.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> logProbability(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||
*
|
||||
|
@ -225,8 +216,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* `discreteProbs`.
|
||||
*
|
||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||
*/
|
||||
void prune(const DecisionTreeFactor &discreteProbs);
|
||||
HybridGaussianConditional::shared_ptr prune(
|
||||
const DecisionTreeFactor &discreteProbs) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
@ -241,17 +234,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
/// Convert to a DecisionTree of Gaussian factor graphs.
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Get the pruner function from discrete probabilities.
|
||||
*
|
||||
* @param discreteProbs The probabilities of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
prunerFunc(const DecisionTreeFactor &prunedProbabilities);
|
||||
|
||||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
@ -42,7 +43,6 @@
|
|||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
@ -342,14 +342,20 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
||||
static auto GetDiscreteKeys =
|
||||
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
|
||||
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
||||
return {discreteKeySet.begin(), discreteKeySet.end()};
|
||||
};
|
||||
|
||||
/* *******************************************************************************/
|
||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||
// Since we eliminate all continuous variables first,
|
||||
// the discrete separator will be *all* the discrete keys.
|
||||
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
||||
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
||||
keysForDiscreteVariables.end());
|
||||
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved.
|
||||
|
@ -499,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
// Iterate over each factor.
|
||||
for (auto &factor : factors_) {
|
||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
// Check for HybridFactor, and call errorTree
|
||||
error_tree = error_tree + f->errorTree(continuousValues);
|
||||
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
// Skip discrete factors
|
||||
continue;
|
||||
if (auto hf = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
// Add errorTree for hybrid factors, includes HybridGaussianConditionals!
|
||||
result = result + hf->errorTree(continuousValues);
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
// If discrete, just add its errorTree as well
|
||||
result = result + df->errorTree();
|
||||
} else {
|
||||
// Everything else is a continuous only factor
|
||||
HybridValues hv(continuousValues, DiscreteValues());
|
||||
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
|
||||
result = result + factor->error(hv); // NOTE: yes, you can add constants
|
||||
}
|
||||
}
|
||||
return error_tree;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -525,18 +531,18 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return exp(-error);
|
||||
});
|
||||
return prob_tree;
|
||||
return p / p.sum();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianFactorGraph HybridGaussianFactorGraph::operator()(
|
||||
GaussianFactorGraph HybridGaussianFactorGraph::choose(
|
||||
const DiscreteValues &assignment) const {
|
||||
GaussianFactorGraph gfg;
|
||||
for (auto &&f : *this) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
* for each discrete assignment, and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the unnormalized posterior probability for a continuous
|
||||
* vector values given a specific assignment.
|
||||
|
@ -206,6 +196,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
double probPrime(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||
* This is efficient as this simply probPrime normalized.
|
||||
*
|
||||
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||
* which we would need, are hard to recover.
|
||||
*
|
||||
* @param continuousValues Continuous values x to condition on.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> discretePosterior(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||
* graph.
|
||||
|
@ -227,8 +230,23 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
eliminate(const Ordering& keys) const;
|
||||
/// @}
|
||||
|
||||
/// Get the GaussianFactorGraph at a given discrete assignment.
|
||||
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
|
||||
/**
|
||||
@brief Get the GaussianFactorGraph at a given discrete assignment. Note this
|
||||
* corresponds to the Gaussian posterior p(X|M=m, Z=z) of the continuous
|
||||
* variables X given the discrete assignment M=m and whatever measurements z
|
||||
* where assumed in the creation of the factor Graph.
|
||||
*
|
||||
* @note Be careful, as any factors not Gaussian are ignored.
|
||||
*
|
||||
* @param assignment The discrete value assignment for the discrete keys.
|
||||
* @return Gaussian factors as a GaussianFactorGraph
|
||||
*/
|
||||
GaussianFactorGraph choose(const DiscreteValues& assignment) const;
|
||||
|
||||
/// Syntactic sugar for choose
|
||||
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
|
||||
return choose(assignment);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
@ -72,21 +72,17 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
|||
addConditionals(graph, hybridBayesNet_, ordering);
|
||||
|
||||
// Eliminate.
|
||||
HybridBayesNet::shared_ptr bayesNetFragment =
|
||||
graph.eliminateSequential(ordering);
|
||||
HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering);
|
||||
|
||||
/// Prune
|
||||
if (maxNrLeaves) {
|
||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||
// all the conditionals with the same keys in bayesNetFragment.
|
||||
HybridBayesNet prunedBayesNetFragment =
|
||||
bayesNetFragment->prune(*maxNrLeaves);
|
||||
// Set the bayes net fragment to the pruned version
|
||||
bayesNetFragment = std::make_shared<HybridBayesNet>(prunedBayesNetFragment);
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
|
||||
}
|
||||
|
||||
// Add the partial bayes net to the posterior bayes net.
|
||||
hybridBayesNet_.add(*bayesNetFragment);
|
||||
hybridBayesNet_.add(bayesNetFragment);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -39,7 +39,7 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
* discrete factor on all discrete keys, plus all discrete factors in the
|
||||
* original graph.
|
||||
*
|
||||
* \note If maxComponents is given, we look at the discrete factor resulting
|
||||
* \note If maxNrLeaves is given, we look at the discrete factor resulting
|
||||
* from this elimination, and prune it and the Gaussian components
|
||||
* corresponding to the pruned choices.
|
||||
*
|
||||
|
|
|
@ -46,29 +46,29 @@ using symbol_shorthand::X;
|
|||
* @brief Create a switching system chain. A switching system is a continuous
|
||||
* system which depends on a discrete mode at each time step of the chain.
|
||||
*
|
||||
* @param n The number of chain elements.
|
||||
* @param K The number of chain elements.
|
||||
* @param x The functional to help specify the continuous key.
|
||||
* @param m The functional to help specify the discrete key.
|
||||
* @return HybridGaussianFactorGraph::shared_ptr
|
||||
*/
|
||||
inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
|
||||
size_t n, std::function<Key(int)> x = X, std::function<Key(int)> m = M) {
|
||||
size_t K, std::function<Key(int)> x = X, std::function<Key(int)> m = M) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
hfg.add(JacobianFactor(x(1), I_3x3, Z_3x1));
|
||||
|
||||
// x(1) to x(n+1)
|
||||
for (size_t t = 1; t < n; t++) {
|
||||
DiscreteKeys dKeys{{m(t), 2}};
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
DiscreteKeys dKeys{{m(k), 2}};
|
||||
std::vector<GaussianFactor::shared_ptr> components;
|
||||
components.emplace_back(
|
||||
new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Z_3x1));
|
||||
new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Z_3x1));
|
||||
components.emplace_back(
|
||||
new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Vector3::Ones()));
|
||||
hfg.add(HybridGaussianFactor({m(t), 2}, components));
|
||||
new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Vector3::Ones()));
|
||||
hfg.add(HybridGaussianFactor({m(k), 2}, components));
|
||||
|
||||
if (t > 1) {
|
||||
hfg.add(DecisionTreeFactor({{m(t - 1), 2}, {m(t), 2}}, "0 1 1 3"));
|
||||
if (k > 1) {
|
||||
hfg.add(DecisionTreeFactor({{m(k - 1), 2}, {m(k), 2}}, "0 1 1 3"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
|||
using MotionModel = BetweenFactor<double>;
|
||||
|
||||
// Test fixture with switching network.
|
||||
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1))
|
||||
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(K-3),M(K-2))
|
||||
struct Switching {
|
||||
size_t K;
|
||||
DiscreteKeys modes;
|
||||
|
@ -195,7 +195,7 @@ struct Switching {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1).
|
||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
||||
* E.g. if K=4, we want M0, M1 and M2.
|
||||
*
|
||||
* @param fg The factor graph to which the mode chain is added.
|
||||
|
|
|
@ -87,21 +87,29 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
|||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||
|
||||
// prune
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||
|
||||
// error
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||
|
||||
// errorTree
|
||||
AlgebraicDecisionTree<Key> expected(asiaKey, -log(0.4), -log(0.6));
|
||||
EXPECT(assert_equal(expected, bayesNet.errorTree({})));
|
||||
|
||||
// logProbability
|
||||
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// discretePosterior
|
||||
AlgebraicDecisionTree<Key> expectedPosterior(asiaKey, 0.4, 0.6);
|
||||
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
|
||||
|
||||
// toFactorGraph
|
||||
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||
EXPECT(assert_equal(expectedFG, fg));
|
||||
|
||||
// prune, imperative :-(
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -145,19 +153,38 @@ TEST(HybridBayesNet, Tiny) {
|
|||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||
// sample. Not deterministic !!! TODO(Frank): figure out why
|
||||
// std::mt19937_64 rng(42);
|
||||
// EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||
|
||||
// prune
|
||||
auto pruned = bayesNet.prune(1);
|
||||
CHECK(pruned.at(1)->asHybrid());
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
// error
|
||||
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
|
||||
px->negLogConstant() - log(0.4);
|
||||
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||
px->negLogConstant() - log(0.6);
|
||||
// print errors:
|
||||
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||
|
||||
// errorTree
|
||||
AlgebraicDecisionTree<Key> expected(M(0), error0, error1);
|
||||
EXPECT(assert_equal(expected, bayesNet.errorTree(vv)));
|
||||
|
||||
// discretePosterior
|
||||
// We have: P(z|x,mode)P(x)P(mode). When we condition on z and x, we get
|
||||
// P(mode|z,x) \propto P(z|x,mode)P(x)P(mode)
|
||||
// Normalizing this yields posterior P(mode|z,x) = {0.8, 0.2}
|
||||
double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
|
||||
AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
|
||||
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
|
||||
|
||||
// toFactorGraph
|
||||
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
@ -168,11 +195,15 @@ TEST(HybridBayesNet, Tiny) {
|
|||
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
|
||||
// prune, imperative :-(
|
||||
auto pruned = bayesNet.prune(1);
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
// Better and more general test:
|
||||
// Since ϕ(M, x) \propto P(M,x|z) the discretePosteriors should agree
|
||||
q0 = std::exp(-fg.error(zero));
|
||||
q1 = std::exp(-fg.error(one));
|
||||
sum = q0 + q1;
|
||||
EXPECT(assert_equal(expectedPosterior, {M(0), q0 / sum, q1 / sum}));
|
||||
VectorValues xv{{X(0), Vector1(5.0)}};
|
||||
auto fgPosterior = fg.discretePosterior(xv);
|
||||
EXPECT(assert_equal(expectedPosterior, fgPosterior));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -206,21 +237,6 @@ TEST(HybridBayesNet, evaluateHybrid) {
|
|||
bayesNet.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
|
||||
TEST(HybridBayesNet, Error) {
|
||||
using namespace different_sigmas;
|
||||
|
||||
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree(values.continuous());
|
||||
|
||||
// Regression.
|
||||
// Manually added all the error values from the 3 conditional types.
|
||||
AlgebraicDecisionTree<Key> expected(
|
||||
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});
|
||||
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test choosing an assignment of conditionals
|
||||
TEST(HybridBayesNet, Choose) {
|
||||
|
@ -318,22 +334,19 @@ TEST(HybridBayesNet, Pruning) {
|
|||
|
||||
// Optimize
|
||||
HybridValues delta = posterior->optimize();
|
||||
auto actualTree = posterior->evaluate(delta.continuous());
|
||||
|
||||
// Regression test on density tree.
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {6.1112424, 20.346113, 17.785849, 19.738098};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
EXPECT(assert_equal(expected, actualTree, 1e-6));
|
||||
// Verify discrete posterior at optimal value sums to 1.
|
||||
auto discretePosterior = posterior->discretePosterior(delta.continuous());
|
||||
EXPECT_DOUBLES_EQUAL(1.0, discretePosterior.sum(), 1e-9);
|
||||
|
||||
// Regression test on discrete posterior at optimal value.
|
||||
std::vector<double> leaves = {0.095516068, 0.31800092, 0.27798511, 0.3084979};
|
||||
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
|
||||
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||
|
||||
// Prune and get probabilities
|
||||
auto prunedBayesNet = posterior->prune(2);
|
||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
||||
|
||||
// Verify logProbability computation and check specific logProbability value
|
||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
|
@ -346,14 +359,21 @@ TEST(HybridBayesNet, Pruning) {
|
|||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||
|
||||
// Regression
|
||||
double density = exp(logProbability);
|
||||
EXPECT_DOUBLES_EQUAL(density,
|
||||
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||
1e-9);
|
||||
|
||||
// Check agreement with discrete posterior
|
||||
// double density = exp(logProbability);
|
||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
|
||||
// 1e-6);
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
|
||||
// Regression
|
||||
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -383,49 +403,47 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
s.linearizedFactorGraph.eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
DiscreteConditional discreteConditionals;
|
||||
for (auto&& conditional : *posterior) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discreteConditionals =
|
||||
discreteConditionals * (*conditional->asDiscrete());
|
||||
}
|
||||
DiscreteConditional joint;
|
||||
for (auto&& conditional : posterior->discreteMarginal()) {
|
||||
joint = joint * (*conditional);
|
||||
}
|
||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||
std::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals.prune(maxNrLeaves));
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto prunedDecisionTree = joint.prune(maxNrLeaves);
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
prunedDecisionTree.nrLeaves());
|
||||
#else
|
||||
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
|
||||
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree.nrLeaves());
|
||||
#endif
|
||||
|
||||
// regression
|
||||
// NOTE(Frank): I had to include *three* non-zeroes here now.
|
||||
DecisionTreeFactor::ADT potentials(
|
||||
s.modes, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
|
||||
s.modes,
|
||||
std::vector<double>{0, 0, 0, 0.28739288, 0, 0.43106901, 0, 0.2815381});
|
||||
DiscreteConditional expectedConditional(3, s.modes, potentials);
|
||||
|
||||
// Prune!
|
||||
posterior->prune(maxNrLeaves);
|
||||
auto pruned = posterior->prune(maxNrLeaves);
|
||||
|
||||
// Functor to verify values against the expected_discrete_conditionals
|
||||
// Functor to verify values against the expectedConditional
|
||||
auto checker = [&](const Assignment<Key>& assignment,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues choices(assignment);
|
||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||
if (prunedDecisionTree(choices) == 0) {
|
||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||
} else {
|
||||
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||
1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(expectedConditional(choices), probability, 1e-6);
|
||||
}
|
||||
return 0.0;
|
||||
};
|
||||
|
||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||
auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete();
|
||||
CHECK(pruned.at(0)->asDiscrete());
|
||||
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete();
|
||||
auto discrete_conditional_tree =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||
pruned_discrete_conditionals);
|
||||
|
@ -549,8 +567,8 @@ TEST(HybridBayesNet, ErrorTreeWithConditional) {
|
|||
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
|
||||
|
||||
// regression
|
||||
AlgebraicDecisionTree<Key> expected(m1, 59.335390372, 5050.125);
|
||||
EXPECT(assert_equal(expected, errorTree, 1e-9));
|
||||
AlgebraicDecisionTree<Key> expected(m1, 60.028538, 5050.8181);
|
||||
EXPECT(assert_equal(expected, errorTree, 1e-4));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -109,6 +109,7 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
|||
|
||||
HybridGaussianFactorGraph linearized;
|
||||
|
||||
constexpr size_t maxNrLeaves = 3;
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
// Motion Model
|
||||
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
||||
|
@ -120,8 +121,12 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
|||
linearized = *graph.linearize(initial);
|
||||
Ordering ordering = smoother.getOrdering(linearized);
|
||||
|
||||
smoother.update(linearized, 3, ordering);
|
||||
smoother.update(linearized, maxNrLeaves, ordering);
|
||||
graph.resize(0);
|
||||
|
||||
// Uncomment to print out pruned discrete marginal:
|
||||
// smoother.hybridBayesNet().at(0)->asDiscrete()->dot("smoother_" +
|
||||
// std::to_string(k));
|
||||
}
|
||||
|
||||
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||
|
|
|
@ -25,8 +25,12 @@
|
|||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "gtsam/discrete/DecisionTree.h"
|
||||
#include "gtsam/discrete/DiscreteKey.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
@ -74,17 +78,6 @@ TEST(HybridGaussianConditional, Invariants) {
|
|||
/// Check LogProbability.
|
||||
TEST(HybridGaussianConditional, LogProbability) {
|
||||
using namespace equal_constants;
|
||||
auto actual = hybrid_conditional.logProbability(vv);
|
||||
|
||||
// Check result.
|
||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
||||
std::vector<double> leaves = {conditionals[0]->logProbability(vv),
|
||||
conditionals[1]->logProbability(vv)};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
|
||||
EXPECT(assert_equal(expected, actual, 1e-6));
|
||||
|
||||
// Check for non-tree version.
|
||||
for (size_t mode : {0, 1}) {
|
||||
const HybridValues hv{vv, {{M(0), mode}}};
|
||||
EXPECT_DOUBLES_EQUAL(conditionals[mode]->logProbability(vv),
|
||||
|
@ -261,8 +254,60 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||
// DecisionTreeFactor with 3 keys:
|
||||
TEST(HybridGaussianConditional, Prune) {
|
||||
// Create a two key conditional:
|
||||
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
std::vector<GaussianConditional::shared_ptr> gcs;
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
gcs.push_back(
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
||||
}
|
||||
auto empty = std::make_shared<GaussianConditional>();
|
||||
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
HybridGaussianConditional hgc(modes, conditionals);
|
||||
|
||||
DiscreteKeys keys = modes;
|
||||
keys.push_back({M(3), 2});
|
||||
{
|
||||
for (size_t i = 0; i < 8; i++) {
|
||||
std::vector<double> potentials{0, 0, 0, 0, 0, 0, 0, 0};
|
||||
potentials[i] = 1;
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
}
|
||||
{
|
||||
const std::vector<double> potentials{0, 0, 0.5, 0, //
|
||||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
}
|
||||
{
|
||||
const std::vector<double> potentials{0.2, 0, 0.3, 0, //
|
||||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
}
|
||||
}
|
||||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
/* *************************************************************************
|
||||
*/
|
||||
|
|
|
@ -357,16 +357,9 @@ TEST(HybridGaussianFactor, DifferentCovariancesFG) {
|
|||
cv.insert(X(0), Vector1(0.0));
|
||||
cv.insert(X(1), Vector1(0.0));
|
||||
|
||||
// Check that the error values at the MLE point μ.
|
||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
||||
|
||||
DiscreteValues dv0{{M(1), 0}};
|
||||
DiscreteValues dv1{{M(1), 1}};
|
||||
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9);
|
||||
|
||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||
|
||||
|
|
|
@ -603,34 +603,31 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
|||
/* ****************************************************************************/
|
||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||
// Create switching network with three continuous variables and two discrete:
|
||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||
Switching s(3);
|
||||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.errorTree(delta.continuous());
|
||||
const HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
// regression test for errorTree
|
||||
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
|
||||
AlgebraicDecisionTree<Key> expectedErrors(s.modes, leaves);
|
||||
const auto error_tree = graph.errorTree(delta.continuous());
|
||||
EXPECT(assert_equal(expectedErrors, error_tree, 1e-7));
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||
|
||||
auto probabilities = graph.probPrime(delta.continuous());
|
||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
||||
0.99029064};
|
||||
AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
|
||||
// regression test for discretePosterior
|
||||
const AlgebraicDecisionTree<Key> expectedPosterior(
|
||||
s.modes, std::vector{0.095516068, 0.31800092, 0.27798511, 0.3084979});
|
||||
auto posterior = graph.discretePosterior(delta.continuous());
|
||||
EXPECT(assert_equal(expectedPosterior, posterior, 1e-7));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test hybrid gaussian factor graph errorTree during
|
||||
// incremental operation
|
||||
// Test hybrid gaussian factor graph errorTree during incremental operation
|
||||
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||
Switching s(4);
|
||||
|
||||
|
@ -650,8 +647,7 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
|||
auto error_tree = graph.errorTree(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
|
||||
0.0097568009};
|
||||
std::vector<double> leaves = {2.7916153, 1.5888555, 1.7233422, 1.6191947};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
|
@ -668,12 +664,10 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
|||
delta = hybridBayesNet->optimize();
|
||||
auto error_tree2 = graph.errorTree(delta.continuous());
|
||||
|
||||
discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
// regression
|
||||
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
|
||||
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
||||
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
AlgebraicDecisionTree<Key> expected_error2(s.modes, leaves);
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||
}
|
||||
|
||||
|
|
|
@ -1025,16 +1025,9 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
|
|||
cv.insert(X(0), Vector1(0.0));
|
||||
cv.insert(X(1), Vector1(0.0));
|
||||
|
||||
// Check that the error values at the MLE point μ.
|
||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
||||
|
||||
DiscreteValues dv0{{M(1), 0}};
|
||||
DiscreteValues dv1{{M(1), 1}};
|
||||
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9);
|
||||
|
||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||
|
||||
|
|
|
@ -140,6 +140,9 @@ namespace gtsam {
|
|||
/** Access the conditional */
|
||||
const sharedConditional& conditional() const { return conditional_; }
|
||||
|
||||
/** Write access to the conditional */
|
||||
sharedConditional& conditional() { return conditional_; }
|
||||
|
||||
/// Return true if this clique is the root of a Bayes tree.
|
||||
inline bool isRoot() const { return parent_.expired(); }
|
||||
|
||||
|
|
Loading…
Reference in New Issue