Merge pull request #1343 from borglab/hybrid/model-selection
commit
f0cd78f2c9
|
@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
||||
* @param conditional Conditional to prune. Used to get full assignment.
|
||||
* @return std::function<double(const Assignment<Key> &, double)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &decisionTree,
|
||||
const DecisionTreeFactor &prunedDecisionTree,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the Gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
std::set<DiscreteKey> decisionTreeKeySet =
|
||||
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
||||
std::set<DiscreteKey> conditionalKeySet =
|
||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
|
@ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
// Case where the Gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == decisionTreeKeySet) {
|
||||
if (decisionTree(values) == 0) {
|
||||
if (prunedDecisionTree(values) == 0) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return probability;
|
||||
}
|
||||
} else {
|
||||
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||
// get a `values` which doesn't have the full set of keys.
|
||||
std::set<Key> valuesKeys;
|
||||
for (auto kvp : values) {
|
||||
valuesKeys.insert(kvp.first);
|
||||
}
|
||||
std::set<Key> conditionalKeys;
|
||||
for (auto kvp : conditionalKeySet) {
|
||||
conditionalKeys.insert(kvp.first);
|
||||
}
|
||||
// If true, then values is missing some keys
|
||||
if (conditionalKeys != valuesKeys) {
|
||||
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||
std::vector<Key> missing_keys;
|
||||
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||
valuesKeys.begin(), valuesKeys.end(),
|
||||
std::back_inserter(missing_keys));
|
||||
// Insert missing keys with a default assignment.
|
||||
for (auto missing_key : missing_keys) {
|
||||
values[missing_key] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now we generate the full assignment by enumerating
|
||||
// over all keys in the prunedDecisionTree.
|
||||
// First we find the differing keys
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
// Now enumerate over all assignments of the differing keys
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
|
@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (decisionTree(augmented_values) > 0.0) {
|
||||
if (prunedDecisionTree(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
|
@ -99,7 +128,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// TODO(dellaert): what is this non-const method used for? Abolish it?
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||
|
@ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
auto discrete = conditional->asDiscrete();
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
|
||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||
auto discreteTree =
|
||||
|
@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||
|
@ -206,7 +234,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Solve for the MPE
|
||||
// Collect all the discrete factors to compute MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
|
@ -214,6 +242,7 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
}
|
||||
}
|
||||
|
||||
// Solve for the MPE
|
||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
|
|
|
@ -138,7 +138,8 @@ struct HybridAssignmentData {
|
|||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesTree HybridBayesTree::choose(
|
||||
const DiscreteValues& assignment) const {
|
||||
GaussianBayesTree gbt;
|
||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||
{
|
||||
|
@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
}
|
||||
|
||||
if (!rootData.isValid()) {
|
||||
return GaussianBayesTree();
|
||||
}
|
||||
return gbt;
|
||||
}
|
||||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesTree gbt = this->choose(assignment);
|
||||
// If empty GaussianBayesTree, means a clique is pruned hence invalid
|
||||
if (gbt.size() == 0) {
|
||||
return VectorValues();
|
||||
}
|
||||
VectorValues result = gbt.optimize();
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/inference/BayesTree.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/linear/GaussianBayesTree.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
|
@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
/** Check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
/**
|
||||
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
|
||||
* value assignment.
|
||||
*
|
||||
* @param assignment The discrete value assignment for the discrete keys.
|
||||
* @return GaussianBayesTree
|
||||
*/
|
||||
GaussianBayesTree choose(const DiscreteValues& assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||
* set of discrete variables and using it to compute the best continuous
|
||||
|
|
|
@ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant());
|
||||
return std::exp(-error);
|
||||
|
@ -396,9 +397,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
if (discrete_only) {
|
||||
// Case 1: we are only dealing with discrete
|
||||
return discreteElimination(factors, frontalKeys);
|
||||
} else {
|
||||
} else if (mapFromKeyToDiscreteKey.empty()) {
|
||||
// Case 2: we are only dealing with continuous
|
||||
if (mapFromKeyToDiscreteKey.empty()) {
|
||||
return continuousElimination(factors, frontalKeys);
|
||||
} else {
|
||||
// Case 3: We are now in the hybrid land!
|
||||
|
@ -408,7 +408,6 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||
discreteSeparatorSet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
/**
|
||||
* @file HybridGaussianFactorGraph.h
|
||||
* @brief Linearized Hybrid factor graph that uses type erasure
|
||||
* @author Fan Jiang
|
||||
* @author Fan Jiang, Varun Agrawal
|
||||
* @date Mar 11, 2022
|
||||
*/
|
||||
|
||||
|
|
|
@ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
|||
/* ************************************************************************* */
|
||||
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
|
||||
size_t index) const {
|
||||
return boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet_.at(index));
|
||||
return hybridBayesNet_.atMixture(index);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues {
|
|||
* @param j The index with which the value will be associated. */
|
||||
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
|
||||
|
||||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
||||
// TODO(Shangjie)- insert_or_assign() , similar to Values.h
|
||||
|
||||
/**
|
||||
* Read/write access to the discrete value with key \c j, throws
|
||||
|
|
|
@ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) {
|
|||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
//TODO(Varun) The expectedAssignment should be 111, not 101
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(0)] = 1;
|
||||
expectedAssignment[M(1)] = 0;
|
||||
expectedAssignment[M(2)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
//TODO(Varun) This should be all -Vector1::Ones()
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
||||
|
|
|
@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test for choosing a GaussianBayesTree from a HybridBayesTree.
|
||||
TEST(HybridBayesTree, Choose) {
|
||||
Switching s(4);
|
||||
|
||||
HybridGaussianISAM isam;
|
||||
HybridGaussianFactorGraph graph1;
|
||||
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(0),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||
for (size_t i = 4; i <= 6; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the discrete factors
|
||||
for (size_t i = 7; i <= 9; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
isam.update(graph1);
|
||||
|
||||
DiscreteValues assignment;
|
||||
assignment[M(0)] = 1;
|
||||
assignment[M(1)] = 1;
|
||||
assignment[M(2)] = 1;
|
||||
|
||||
GaussianBayesTree gbt = isam.choose(assignment);
|
||||
|
||||
Ordering ordering;
|
||||
ordering += X(0);
|
||||
ordering += X(1);
|
||||
ordering += X(2);
|
||||
ordering += X(3);
|
||||
ordering += M(0);
|
||||
ordering += M(1);
|
||||
ordering += M(2);
|
||||
|
||||
//TODO(Varun) get segfault if ordering not provided
|
||||
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
|
||||
|
||||
auto expected_gbt = bayesTree->choose(assignment);
|
||||
|
||||
EXPECT(assert_equal(expected_gbt, gbt));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridBayesTree, Serialization) {
|
||||
|
|
|
@ -72,25 +72,44 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
|
|||
}
|
||||
|
||||
TEST(HybridEstimation, Full) {
|
||||
size_t K = 3;
|
||||
std::vector<double> measurements = {0, 1, 2};
|
||||
size_t K = 6;
|
||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
|
||||
// Ground truth discrete seq
|
||||
std::vector<size_t> discrete_seq = {1, 1, 0};
|
||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 1};
|
||||
// Switching example of robot moving in 1D
|
||||
// with given measurements and equal mode priors.
|
||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||
HybridGaussianFactorGraph graph = switching.linearizedFactorGraph;
|
||||
|
||||
Ordering hybridOrdering;
|
||||
hybridOrdering += X(0);
|
||||
hybridOrdering += X(1);
|
||||
hybridOrdering += X(2);
|
||||
hybridOrdering += M(0);
|
||||
hybridOrdering += M(1);
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
hybridOrdering += X(k);
|
||||
}
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
hybridOrdering += M(k);
|
||||
}
|
||||
|
||||
HybridBayesNet::shared_ptr bayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
|
||||
EXPECT_LONGS_EQUAL(5, bayesNet->size());
|
||||
EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());
|
||||
|
||||
HybridValues delta = bayesNet->optimize();
|
||||
|
||||
Values initial = switching.linearizationPoint;
|
||||
Values result = initial.retract(delta.continuous());
|
||||
|
||||
DiscreteValues expected_discrete;
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
expected_discrete[M(k)] = discrete_seq[k];
|
||||
}
|
||||
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||
|
||||
Values expected_continuous;
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
expected_continuous.insert(X(k), measurements[k]);
|
||||
}
|
||||
EXPECT(assert_equal(expected_continuous, result));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
@ -102,8 +121,8 @@ TEST(HybridEstimation, Incremental) {
|
|||
// Ground truth discrete seq
|
||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||
// Switching example of robot moving in 1D with given measurements and equal
|
||||
// mode priors.
|
||||
// Switching example of robot moving in 1D
|
||||
// with given measurements and equal mode priors.
|
||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||
HybridSmoother smoother;
|
||||
HybridNonlinearFactorGraph graph;
|
||||
|
@ -209,13 +228,16 @@ std::vector<size_t> getDiscreteSequence(size_t x) {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Helper method to get the tree of unnormalized probabilities
|
||||
* as per the new elimination scheme.
|
||||
* @brief Helper method to get the tree of
|
||||
* unnormalized probabilities as per the elimination scheme.
|
||||
*
|
||||
* Used as a helper to compute q(\mu | M, Z) which is used by
|
||||
* both P(X | M, Z) and P(M | Z).
|
||||
*
|
||||
* @param graph The HybridGaussianFactorGraph to eliminate.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(
|
||||
AlgebraicDecisionTree<Key> getProbPrimeTree(
|
||||
const HybridGaussianFactorGraph& graph) {
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr remainingGraph;
|
||||
|
@ -239,20 +261,19 @@ AlgebraicDecisionTree<Key> probPrimeTree(
|
|||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||
vector_values);
|
||||
|
||||
// Get the probPrime tree with the correct leaf probabilities
|
||||
std::vector<double> probPrimes;
|
||||
for (const DiscreteValues& assignment : assignments) {
|
||||
double error = 0.0;
|
||||
VectorValues delta = *delta_tree(assignment);
|
||||
for (auto factor : graph) {
|
||||
if (factor->isHybrid()) {
|
||||
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
|
||||
error += f->error(delta, assignment);
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
|
||||
error += f->inner()->error(delta);
|
||||
}
|
||||
// If VectorValues is empty, it means this is a pruned branch.
|
||||
// Set the probPrime to 0.0.
|
||||
if (delta.size() == 0) {
|
||||
probPrimes.push_back(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
double error = graph.error(delta, assignment);
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
|
@ -274,10 +295,23 @@ TEST(HybridEstimation, Probability) {
|
|||
Switching switching(K, between_sigma, measurement_sigma, measurements,
|
||||
"1/1 1/1");
|
||||
auto graph = switching.linearizedFactorGraph;
|
||||
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
||||
|
||||
HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering);
|
||||
auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3);
|
||||
// Continuous elimination
|
||||
Ordering continuous_ordering(graph.continuousKeys());
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) =
|
||||
graph.eliminatePartialSequential(continuous_ordering);
|
||||
|
||||
// Discrete elimination
|
||||
Ordering discrete_ordering(graph.discreteKeys());
|
||||
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);
|
||||
|
||||
// Add the discrete conditionals to make it a full bayes net.
|
||||
for (auto discrete_conditional : *discreteBayesNet) {
|
||||
bayesNet->add(discrete_conditional);
|
||||
}
|
||||
auto discreteConditional = discreteBayesNet->atDiscrete(0);
|
||||
|
||||
HybridValues hybrid_values = bayesNet->optimize();
|
||||
|
||||
|
@ -310,7 +344,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
|||
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
||||
|
||||
// Get the tree of unnormalized probabilities for each mode sequence.
|
||||
AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph);
|
||||
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
|
||||
|
||||
// Eliminate continuous
|
||||
Ordering continuous_ordering(graph.continuousKeys());
|
||||
|
@ -326,8 +360,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
|||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
|
||||
Ordering discrete(graph.discreteKeys());
|
||||
auto discreteBayesTree =
|
||||
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete);
|
||||
auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
|
||||
|
||||
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());
|
||||
// DiscreteBayesTree should have only 1 clique
|
||||
|
@ -345,8 +378,8 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
|||
discreteBayesTree->addClique(clique, discrete_clique);
|
||||
|
||||
} else {
|
||||
// Remove the clique from the children of the parents since it will get
|
||||
// added again in addClique.
|
||||
// Remove the clique from the children of the parents since
|
||||
// it will get added again in addClique.
|
||||
auto clique_it = std::find(clique->parent()->children.begin(),
|
||||
clique->parent()->children.end(), clique);
|
||||
clique->parent()->children.erase(clique_it);
|
||||
|
@ -392,7 +425,7 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
|
|||
}
|
||||
|
||||
/*********************************************************************************
|
||||
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
||||
// Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
|
||||
********************************************************************************/
|
||||
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
||||
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();
|
||||
|
|
|
@ -81,14 +81,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||
|
||||
@staticmethod
|
||||
def tiny(num_measurements: int = 1):
|
||||
"""Create a tiny two variable hybrid model."""
|
||||
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
|
||||
"""
|
||||
Create a tiny two variable hybrid model which represents
|
||||
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
||||
"""
|
||||
# Create hybrid Bayes net.
|
||||
bayesNet = gtsam.HybridBayesNet()
|
||||
|
||||
# Create mode key: 0 is low-noise, 1 is high-noise.
|
||||
modeKey = M(0)
|
||||
mode = (modeKey, 2)
|
||||
mode = (M(0), 2)
|
||||
|
||||
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
|
||||
I = np.eye(1)
|
||||
|
@ -141,14 +143,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return bayesNet.evaluate(sample) / fg.probPrime(
|
||||
continuous, sample.discrete())
|
||||
|
||||
def test_tiny2(self):
|
||||
"""Test a tiny two variable hybrid model, with 2 measurements."""
|
||||
# Create the Bayes net and sample from it.
|
||||
def test_ratio(self):
|
||||
"""
|
||||
Given a tiny two variable hybrid model, with 2 measurements,
|
||||
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
|
||||
both of which represent the same posterior.
|
||||
"""
|
||||
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
bayesNet = self.tiny(num_measurements=2)
|
||||
sample = bayesNet.sample()
|
||||
# Sample from the Bayes net.
|
||||
sample: gtsam.HybridValues = bayesNet.sample()
|
||||
# print(sample)
|
||||
|
||||
# Create a factor graph from the Bayes net with sampled measurements.
|
||||
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
|
||||
# and thus represents the same joint probability as the Bayes net.
|
||||
fg = HybridGaussianFactorGraph()
|
||||
for i in range(2):
|
||||
conditional = bayesNet.atMixture(i)
|
||||
|
|
Loading…
Reference in New Issue