Merge pull request #1343 from borglab/hybrid/model-selection

release/4.3a0
Varun Agrawal 2022-12-30 10:03:03 -05:00 committed by GitHub
commit f0cd78f2c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 212 additions and 67 deletions

View File

@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param decisionTree The probability decision tree of only discrete keys. * @param prunedDecisionTree The prob. decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr( * @param conditional Conditional to prune. Used to get full assignment.
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree, const DecisionTreeFactor &prunedDecisionTree,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the Gaussian mixture. // and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); std::set<DiscreteKey> decisionTreeKeySet =
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
@ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// Case where the Gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) { if (prunedDecisionTree(values) == 0) {
return 0.0; return 0.0;
} else { } else {
return probability; return probability;
} }
} else { } else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree.
// First we find the differing keys
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(), conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) { for (const DiscreteValues &assignment : assignments) {
@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this probability. // we need this probability.
if (decisionTree(augmented_values) > 0.0) { if (prunedDecisionTree(augmented_values) > 0.0) {
return probability; return probability;
} }
} }
@ -99,7 +128,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
/* ************************************************************************* */ /* ************************************************************************* */
// TODO(dellaert): what is this non-const method used for? Abolish it?
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys(); KeyVector prunedTreeKeys = prunedDecisionTree->keys();
@ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
// Apply prunerFunc to the underlying AlgebraicDecisionTree // Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree = auto discreteTree =
@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
// Create the new (hybrid) conditional // Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>( auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete); conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
@ -206,7 +234,7 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
@ -214,6 +242,7 @@ HybridValues HybridBayesNet::optimize() const {
} }
} }
// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values. // Given the MPE, compute the optimal continuous values.

View File

@ -138,7 +138,8 @@ struct HybridAssignmentData {
/* ************************************************************************* /* *************************************************************************
*/ */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { GaussianBayesTree HybridBayesTree::choose(
const DiscreteValues& assignment) const {
GaussianBayesTree gbt; GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt); HybridAssignmentData rootData(assignment, 0, &gbt);
{ {
@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
} }
if (!rootData.isValid()) { if (!rootData.isValid()) {
return GaussianBayesTree();
}
return gbt;
}
/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt = this->choose(assignment);
// If empty GaussianBayesTree, means a clique is pruned hence invalid
if (gbt.size() == 0) {
return VectorValues(); return VectorValues();
} }
VectorValues result = gbt.optimize(); VectorValues result = gbt.optimize();

View File

@ -24,6 +24,7 @@
#include <gtsam/inference/BayesTree.h> #include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianBayesTree.h>
#include <string> #include <string>
@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */ /** Check equality */
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
* value assignment.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesTree
*/
GaussianBayesTree choose(const DiscreteValues& assignment) const;
/** /**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current * @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous * set of discrete variables and using it to compute the best continuous

View File

@ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (!factor) { if (!factor) {
return 0.0; // If nullptr, return 0.0 probability return 0.0; // If nullptr, return 0.0 probability
} else { } else {
// This is the probability q(μ) at the MLE point.
double error = double error =
0.5 * std::abs(factor->augmentedInformation().determinant()); 0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error); return std::exp(-error);
@ -396,18 +397,16 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
if (discrete_only) { if (discrete_only) {
// Case 1: we are only dealing with discrete // Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys); return discreteElimination(factors, frontalKeys);
} else { } else if (mapFromKeyToDiscreteKey.empty()) {
// Case 2: we are only dealing with continuous // Case 2: we are only dealing with continuous
if (mapFromKeyToDiscreteKey.empty()) { return continuousElimination(factors, frontalKeys);
return continuousElimination(factors, frontalKeys); } else {
} else { // Case 3: We are now in the hybrid land!
// Case 3: We are now in the hybrid land!
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
tictoc_reset_(); tictoc_reset_();
#endif #endif
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet); discreteSeparatorSet);
}
} }
} }

View File

@ -12,7 +12,7 @@
/** /**
* @file HybridGaussianFactorGraph.h * @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure * @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang * @author Fan Jiang, Varun Agrawal
* @date Mar 11, 2022 * @date Mar 11, 2022
*/ */

View File

@ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture( GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const { size_t index) const {
return boost::dynamic_pointer_cast<GaussianMixture>( return hybridBayesNet_.atMixture(index);
hybridBayesNet_.at(index));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues {
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous_.insert(j, value); } void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h // TODO(Shangjie)- insert_or_assign() , similar to Values.h
/** /**
* Read/write access to the discrete value with key \c j, throws * Read/write access to the discrete value with key \c j, throws

View File

@ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
//TODO(Varun) The expectedAssignment should be 111, not 101
DiscreteValues expectedAssignment; DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1; expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 0; expectedAssignment[M(1)] = 0;
expectedAssignment[M(2)] = 1; expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete())); EXPECT(assert_equal(expectedAssignment, delta.discrete()));
//TODO(Varun) This should be all -Vector1::Ones()
VectorValues expectedValues; VectorValues expectedValues;
expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); expectedValues.insert(X(1), -0.99029 * Vector1::Ones());

View File

@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous())); EXPECT(assert_equal(expectedValues, delta.continuous()));
} }
/* ****************************************************************************/
// Test for choosing a GaussianBayesTree from a HybridBayesTree.
TEST(HybridBayesTree, Choose) {
Switching s(4);
HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(0),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 6; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
DiscreteValues assignment;
assignment[M(0)] = 1;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
GaussianBayesTree gbt = isam.choose(assignment);
Ordering ordering;
ordering += X(0);
ordering += X(1);
ordering += X(2);
ordering += X(3);
ordering += M(0);
ordering += M(1);
ordering += M(2);
//TODO(Varun) get segfault if ordering not provided
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
auto expected_gbt = bayesTree->choose(assignment);
EXPECT(assert_equal(expected_gbt, gbt));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test HybridBayesTree serialization. // Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) { TEST(HybridBayesTree, Serialization) {

View File

@ -72,25 +72,44 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
} }
TEST(HybridEstimation, Full) { TEST(HybridEstimation, Full) {
size_t K = 3; size_t K = 6;
std::vector<double> measurements = {0, 1, 2}; std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
// Ground truth discrete seq // Ground truth discrete seq
std::vector<size_t> discrete_seq = {1, 1, 0}; std::vector<size_t> discrete_seq = {1, 1, 0, 0, 1};
// Switching example of robot moving in 1D // Switching example of robot moving in 1D
// with given measurements and equal mode priors. // with given measurements and equal mode priors.
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
HybridGaussianFactorGraph graph = switching.linearizedFactorGraph; HybridGaussianFactorGraph graph = switching.linearizedFactorGraph;
Ordering hybridOrdering; Ordering hybridOrdering;
hybridOrdering += X(0); for (size_t k = 0; k < K; k++) {
hybridOrdering += X(1); hybridOrdering += X(k);
hybridOrdering += X(2); }
hybridOrdering += M(0); for (size_t k = 0; k < K - 1; k++) {
hybridOrdering += M(1); hybridOrdering += M(k);
}
HybridBayesNet::shared_ptr bayesNet = HybridBayesNet::shared_ptr bayesNet =
graph.eliminateSequential(hybridOrdering); graph.eliminateSequential(hybridOrdering);
EXPECT_LONGS_EQUAL(5, bayesNet->size()); EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());
HybridValues delta = bayesNet->optimize();
Values initial = switching.linearizationPoint;
Values result = initial.retract(delta.continuous());
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, delta.discrete()));
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
} }
/****************************************************************************/ /****************************************************************************/
@ -102,8 +121,8 @@ TEST(HybridEstimation, Incremental) {
// Ground truth discrete seq // Ground truth discrete seq
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; 1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
// Switching example of robot moving in 1D with given measurements and equal // Switching example of robot moving in 1D
// mode priors. // with given measurements and equal mode priors.
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
HybridSmoother smoother; HybridSmoother smoother;
HybridNonlinearFactorGraph graph; HybridNonlinearFactorGraph graph;
@ -209,13 +228,16 @@ std::vector<size_t> getDiscreteSequence(size_t x) {
} }
/** /**
* @brief Helper method to get the tree of unnormalized probabilities * @brief Helper method to get the tree of
* as per the new elimination scheme. * unnormalized probabilities as per the elimination scheme.
*
* Used as a helper to compute q(\mu | M, Z) which is used by
* both P(X | M, Z) and P(M | Z).
* *
* @param graph The HybridGaussianFactorGraph to eliminate. * @param graph The HybridGaussianFactorGraph to eliminate.
* @return AlgebraicDecisionTree<Key> * @return AlgebraicDecisionTree<Key>
*/ */
AlgebraicDecisionTree<Key> probPrimeTree( AlgebraicDecisionTree<Key> getProbPrimeTree(
const HybridGaussianFactorGraph& graph) { const HybridGaussianFactorGraph& graph) {
HybridBayesNet::shared_ptr bayesNet; HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr remainingGraph; HybridGaussianFactorGraph::shared_ptr remainingGraph;
@ -239,20 +261,19 @@ AlgebraicDecisionTree<Key> probPrimeTree(
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys, DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values); vector_values);
// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes; std::vector<double> probPrimes;
for (const DiscreteValues& assignment : assignments) { for (const DiscreteValues& assignment : assignments) {
double error = 0.0;
VectorValues delta = *delta_tree(assignment); VectorValues delta = *delta_tree(assignment);
for (auto factor : graph) {
if (factor->isHybrid()) {
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
error += f->error(delta, assignment);
} else if (factor->isContinuous()) { // If VectorValues is empty, it means this is a pruned branch.
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor); // Set the probPrime to 0.0.
error += f->inner()->error(delta); if (delta.size() == 0) {
} probPrimes.push_back(0.0);
continue;
} }
double error = graph.error(delta, assignment);
probPrimes.push_back(exp(-error)); probPrimes.push_back(exp(-error));
} }
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes); AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
@ -274,10 +295,23 @@ TEST(HybridEstimation, Probability) {
Switching switching(K, between_sigma, measurement_sigma, measurements, Switching switching(K, between_sigma, measurement_sigma, measurements,
"1/1 1/1"); "1/1 1/1");
auto graph = switching.linearizedFactorGraph; auto graph = switching.linearizedFactorGraph;
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering); // Continuous elimination
auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3); Ordering continuous_ordering(graph.continuousKeys());
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering);
// Discrete elimination
Ordering discrete_ordering(graph.discreteKeys());
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);
// Add the discrete conditionals to make it a full bayes net.
for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional);
}
auto discreteConditional = discreteBayesNet->atDiscrete(0);
HybridValues hybrid_values = bayesNet->optimize(); HybridValues hybrid_values = bayesNet->optimize();
@ -310,7 +344,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
// Get the tree of unnormalized probabilities for each mode sequence. // Get the tree of unnormalized probabilities for each mode sequence.
AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph); AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
// Eliminate continuous // Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys()); Ordering continuous_ordering(graph.continuousKeys());
@ -326,8 +360,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
DiscreteKeys discrete_keys = last_conditional->discreteKeys(); DiscreteKeys discrete_keys = last_conditional->discreteKeys();
Ordering discrete(graph.discreteKeys()); Ordering discrete(graph.discreteKeys());
auto discreteBayesTree = auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete);
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size()); EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());
// DiscreteBayesTree should have only 1 clique // DiscreteBayesTree should have only 1 clique
@ -345,8 +378,8 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
discreteBayesTree->addClique(clique, discrete_clique); discreteBayesTree->addClique(clique, discrete_clique);
} else { } else {
// Remove the clique from the children of the parents since it will get // Remove the clique from the children of the parents since
// added again in addClique. // it will get added again in addClique.
auto clique_it = std::find(clique->parent()->children.begin(), auto clique_it = std::find(clique->parent()->children.begin(),
clique->parent()->children.end(), clique); clique->parent()->children.end(), clique);
clique->parent()->children.erase(clique_it); clique->parent()->children.erase(clique_it);
@ -392,7 +425,7 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
} }
/********************************************************************************* /*********************************************************************************
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1) // Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
********************************************************************************/ ********************************************************************************/
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() { static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph(); HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();

View File

@ -81,14 +81,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1) self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod @staticmethod
def tiny(num_measurements: int = 1): def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
"""Create a tiny two variable hybrid model.""" """
Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
"""
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = gtsam.HybridBayesNet() bayesNet = gtsam.HybridBayesNet()
# Create mode key: 0 is low-noise, 1 is high-noise. # Create mode key: 0 is low-noise, 1 is high-noise.
modeKey = M(0) mode = (M(0), 2)
mode = (modeKey, 2)
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement. # Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
I = np.eye(1) I = np.eye(1)
@ -141,14 +143,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet.evaluate(sample) / fg.probPrime( return bayesNet.evaluate(sample) / fg.probPrime(
continuous, sample.discrete()) continuous, sample.discrete())
def test_tiny2(self): def test_ratio(self):
"""Test a tiny two variable hybrid model, with 2 measurements.""" """
# Create the Bayes net and sample from it. Given a tiny two variable hybrid model, with 2 measurements,
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n)
and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
both of which represent the same posterior.
"""
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2) bayesNet = self.tiny(num_measurements=2)
sample = bayesNet.sample() # Sample from the Bayes net.
sample: gtsam.HybridValues = bayesNet.sample()
# print(sample) # print(sample)
# Create a factor graph from the Bayes net with sampled measurements. # Create a factor graph from the Bayes net with sampled measurements.
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
# and thus represents the same joint probability as the Bayes net.
fg = HybridGaussianFactorGraph() fg = HybridGaussianFactorGraph()
for i in range(2): for i in range(2):
conditional = bayesNet.atMixture(i) conditional = bayesNet.atMixture(i)