Merge pull request #1343 from borglab/hybrid/model-selection
commit
f0cd78f2c9
|
@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
*
|
*
|
||||||
* @param decisionTree The probability decision tree of only discrete keys.
|
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
* @param conditional Conditional to prune. Used to get full assignment.
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
* @return std::function<double(const Assignment<Key> &, double)>
|
||||||
*/
|
*/
|
||||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
const DecisionTreeFactor &decisionTree,
|
const DecisionTreeFactor &prunedDecisionTree,
|
||||||
const HybridConditional &conditional) {
|
const HybridConditional &conditional) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the Gaussian mixture.
|
// and the Gaussian mixture.
|
||||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
std::set<DiscreteKey> decisionTreeKeySet =
|
||||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
||||||
|
std::set<DiscreteKey> conditionalKeySet =
|
||||||
|
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||||
|
|
||||||
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||||
const Assignment<Key> &choices,
|
const Assignment<Key> &choices,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
|
@ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
// Case where the Gaussian mixture has the same
|
// Case where the Gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (conditionalKeySet == decisionTreeKeySet) {
|
if (conditionalKeySet == decisionTreeKeySet) {
|
||||||
if (decisionTree(values) == 0) {
|
if (prunedDecisionTree(values) == 0) {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
} else {
|
} else {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||||
|
// get a `values` which doesn't have the full set of keys.
|
||||||
|
std::set<Key> valuesKeys;
|
||||||
|
for (auto kvp : values) {
|
||||||
|
valuesKeys.insert(kvp.first);
|
||||||
|
}
|
||||||
|
std::set<Key> conditionalKeys;
|
||||||
|
for (auto kvp : conditionalKeySet) {
|
||||||
|
conditionalKeys.insert(kvp.first);
|
||||||
|
}
|
||||||
|
// If true, then values is missing some keys
|
||||||
|
if (conditionalKeys != valuesKeys) {
|
||||||
|
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||||
|
std::vector<Key> missing_keys;
|
||||||
|
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||||
|
valuesKeys.begin(), valuesKeys.end(),
|
||||||
|
std::back_inserter(missing_keys));
|
||||||
|
// Insert missing keys with a default assignment.
|
||||||
|
for (auto missing_key : missing_keys) {
|
||||||
|
values[missing_key] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we generate the full assignment by enumerating
|
||||||
|
// over all keys in the prunedDecisionTree.
|
||||||
|
// First we find the differing keys
|
||||||
std::vector<DiscreteKey> set_diff;
|
std::vector<DiscreteKey> set_diff;
|
||||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||||
std::back_inserter(set_diff));
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
|
// Now enumerate over all assignments of the differing keys
|
||||||
const std::vector<DiscreteValues> assignments =
|
const std::vector<DiscreteValues> assignments =
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
|
@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
// If any one of the sub-branches are non-zero,
|
||||||
// we need this probability.
|
// we need this probability.
|
||||||
if (decisionTree(augmented_values) > 0.0) {
|
if (prunedDecisionTree(augmented_values) > 0.0) {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,7 +128,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// TODO(dellaert): what is this non-const method used for? Abolish it?
|
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||||
|
@ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
HybridConditional::shared_ptr conditional = this->at(i);
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
auto discrete = conditional->asDiscrete();
|
auto discrete = conditional->asDiscrete();
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
|
||||||
discrete->frontals().end());
|
|
||||||
|
|
||||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||||
auto discreteTree =
|
auto discreteTree =
|
||||||
|
@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||||
|
|
||||||
// Create the new (hybrid) conditional
|
// Create the new (hybrid) conditional
|
||||||
|
KeyVector frontals(discrete->frontals().begin(),
|
||||||
|
discrete->frontals().end());
|
||||||
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||||
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||||
|
@ -206,7 +234,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Solve for the MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
DiscreteBayesNet discrete_bn;
|
DiscreteBayesNet discrete_bn;
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
|
@ -214,6 +242,7 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Solve for the MPE
|
||||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
// Given the MPE, compute the optimal continuous values.
|
||||||
|
|
|
@ -138,7 +138,8 @@ struct HybridAssignmentData {
|
||||||
|
|
||||||
/* *************************************************************************
|
/* *************************************************************************
|
||||||
*/
|
*/
|
||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
GaussianBayesTree HybridBayesTree::choose(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
GaussianBayesTree gbt;
|
GaussianBayesTree gbt;
|
||||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||||
{
|
{
|
||||||
|
@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rootData.isValid()) {
|
if (!rootData.isValid()) {
|
||||||
|
return GaussianBayesTree();
|
||||||
|
}
|
||||||
|
return gbt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
|
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
GaussianBayesTree gbt = this->choose(assignment);
|
||||||
|
// If empty GaussianBayesTree, means a clique is pruned hence invalid
|
||||||
|
if (gbt.size() == 0) {
|
||||||
return VectorValues();
|
return VectorValues();
|
||||||
}
|
}
|
||||||
VectorValues result = gbt.optimize();
|
VectorValues result = gbt.optimize();
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
|
||||||
|
* value assignment.
|
||||||
|
*
|
||||||
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
|
* @return GaussianBayesTree
|
||||||
|
*/
|
||||||
|
GaussianBayesTree choose(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||||
* set of discrete variables and using it to compute the best continuous
|
* set of discrete variables and using it to compute the best continuous
|
||||||
|
|
|
@ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
if (!factor) {
|
if (!factor) {
|
||||||
return 0.0; // If nullptr, return 0.0 probability
|
return 0.0; // If nullptr, return 0.0 probability
|
||||||
} else {
|
} else {
|
||||||
|
// This is the probability q(μ) at the MLE point.
|
||||||
double error =
|
double error =
|
||||||
0.5 * std::abs(factor->augmentedInformation().determinant());
|
0.5 * std::abs(factor->augmentedInformation().determinant());
|
||||||
return std::exp(-error);
|
return std::exp(-error);
|
||||||
|
@ -396,18 +397,16 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
if (discrete_only) {
|
if (discrete_only) {
|
||||||
// Case 1: we are only dealing with discrete
|
// Case 1: we are only dealing with discrete
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, frontalKeys);
|
||||||
} else {
|
} else if (mapFromKeyToDiscreteKey.empty()) {
|
||||||
// Case 2: we are only dealing with continuous
|
// Case 2: we are only dealing with continuous
|
||||||
if (mapFromKeyToDiscreteKey.empty()) {
|
return continuousElimination(factors, frontalKeys);
|
||||||
return continuousElimination(factors, frontalKeys);
|
} else {
|
||||||
} else {
|
// Case 3: We are now in the hybrid land!
|
||||||
// Case 3: We are now in the hybrid land!
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
tictoc_reset_();
|
tictoc_reset_();
|
||||||
#endif
|
#endif
|
||||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
discreteSeparatorSet);
|
discreteSeparatorSet);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
/**
|
/**
|
||||||
* @file HybridGaussianFactorGraph.h
|
* @file HybridGaussianFactorGraph.h
|
||||||
* @brief Linearized Hybrid factor graph that uses type erasure
|
* @brief Linearized Hybrid factor graph that uses type erasure
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang, Varun Agrawal
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
|
@ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
|
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
|
||||||
size_t index) const {
|
size_t index) const {
|
||||||
return boost::dynamic_pointer_cast<GaussianMixture>(
|
return hybridBayesNet_.atMixture(index);
|
||||||
hybridBayesNet_.at(index));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues {
|
||||||
* @param j The index with which the value will be associated. */
|
* @param j The index with which the value will be associated. */
|
||||||
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
|
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
|
||||||
|
|
||||||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
// TODO(Shangjie)- insert_or_assign() , similar to Values.h
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read/write access to the discrete value with key \c j, throws
|
* Read/write access to the discrete value with key \c j, throws
|
||||||
|
|
|
@ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
|
//TODO(Varun) The expectedAssignment should be 111, not 101
|
||||||
DiscreteValues expectedAssignment;
|
DiscreteValues expectedAssignment;
|
||||||
expectedAssignment[M(0)] = 1;
|
expectedAssignment[M(0)] = 1;
|
||||||
expectedAssignment[M(1)] = 0;
|
expectedAssignment[M(1)] = 0;
|
||||||
expectedAssignment[M(2)] = 1;
|
expectedAssignment[M(2)] = 1;
|
||||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||||
|
|
||||||
|
//TODO(Varun) This should be all -Vector1::Ones()
|
||||||
VectorValues expectedValues;
|
VectorValues expectedValues;
|
||||||
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
||||||
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
||||||
|
|
|
@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test for choosing a GaussianBayesTree from a HybridBayesTree.
|
||||||
|
TEST(HybridBayesTree, Choose) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
HybridGaussianISAM isam;
|
||||||
|
HybridGaussianFactorGraph graph1;
|
||||||
|
|
||||||
|
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||||
|
for (size_t i = 1; i < 4; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the Gaussian factors, 1 prior on X(0),
|
||||||
|
// 3 measurements on X(2), X(3), X(4)
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||||
|
for (size_t i = 4; i <= 6; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the discrete factors
|
||||||
|
for (size_t i = 7; i <= 9; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
isam.update(graph1);
|
||||||
|
|
||||||
|
DiscreteValues assignment;
|
||||||
|
assignment[M(0)] = 1;
|
||||||
|
assignment[M(1)] = 1;
|
||||||
|
assignment[M(2)] = 1;
|
||||||
|
|
||||||
|
GaussianBayesTree gbt = isam.choose(assignment);
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
ordering += X(0);
|
||||||
|
ordering += X(1);
|
||||||
|
ordering += X(2);
|
||||||
|
ordering += X(3);
|
||||||
|
ordering += M(0);
|
||||||
|
ordering += M(1);
|
||||||
|
ordering += M(2);
|
||||||
|
|
||||||
|
//TODO(Varun) get segfault if ordering not provided
|
||||||
|
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
|
||||||
|
|
||||||
|
auto expected_gbt = bayesTree->choose(assignment);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_gbt, gbt));
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test HybridBayesTree serialization.
|
// Test HybridBayesTree serialization.
|
||||||
TEST(HybridBayesTree, Serialization) {
|
TEST(HybridBayesTree, Serialization) {
|
||||||
|
|
|
@ -72,25 +72,44 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridEstimation, Full) {
|
TEST(HybridEstimation, Full) {
|
||||||
size_t K = 3;
|
size_t K = 6;
|
||||||
std::vector<double> measurements = {0, 1, 2};
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
|
||||||
// Ground truth discrete seq
|
// Ground truth discrete seq
|
||||||
std::vector<size_t> discrete_seq = {1, 1, 0};
|
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 1};
|
||||||
// Switching example of robot moving in 1D
|
// Switching example of robot moving in 1D
|
||||||
// with given measurements and equal mode priors.
|
// with given measurements and equal mode priors.
|
||||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||||
HybridGaussianFactorGraph graph = switching.linearizedFactorGraph;
|
HybridGaussianFactorGraph graph = switching.linearizedFactorGraph;
|
||||||
|
|
||||||
Ordering hybridOrdering;
|
Ordering hybridOrdering;
|
||||||
hybridOrdering += X(0);
|
for (size_t k = 0; k < K; k++) {
|
||||||
hybridOrdering += X(1);
|
hybridOrdering += X(k);
|
||||||
hybridOrdering += X(2);
|
}
|
||||||
hybridOrdering += M(0);
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
hybridOrdering += M(1);
|
hybridOrdering += M(k);
|
||||||
|
}
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr bayesNet =
|
HybridBayesNet::shared_ptr bayesNet =
|
||||||
graph.eliminateSequential(hybridOrdering);
|
graph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(5, bayesNet->size());
|
EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());
|
||||||
|
|
||||||
|
HybridValues delta = bayesNet->optimize();
|
||||||
|
|
||||||
|
Values initial = switching.linearizationPoint;
|
||||||
|
Values result = initial.retract(delta.continuous());
|
||||||
|
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||||
|
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -102,8 +121,8 @@ TEST(HybridEstimation, Incremental) {
|
||||||
// Ground truth discrete seq
|
// Ground truth discrete seq
|
||||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||||
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||||
// Switching example of robot moving in 1D with given measurements and equal
|
// Switching example of robot moving in 1D
|
||||||
// mode priors.
|
// with given measurements and equal mode priors.
|
||||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||||
HybridSmoother smoother;
|
HybridSmoother smoother;
|
||||||
HybridNonlinearFactorGraph graph;
|
HybridNonlinearFactorGraph graph;
|
||||||
|
@ -209,13 +228,16 @@ std::vector<size_t> getDiscreteSequence(size_t x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper method to get the tree of unnormalized probabilities
|
* @brief Helper method to get the tree of
|
||||||
* as per the new elimination scheme.
|
* unnormalized probabilities as per the elimination scheme.
|
||||||
|
*
|
||||||
|
* Used as a helper to compute q(\mu | M, Z) which is used by
|
||||||
|
* both P(X | M, Z) and P(M | Z).
|
||||||
*
|
*
|
||||||
* @param graph The HybridGaussianFactorGraph to eliminate.
|
* @param graph The HybridGaussianFactorGraph to eliminate.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> probPrimeTree(
|
AlgebraicDecisionTree<Key> getProbPrimeTree(
|
||||||
const HybridGaussianFactorGraph& graph) {
|
const HybridGaussianFactorGraph& graph) {
|
||||||
HybridBayesNet::shared_ptr bayesNet;
|
HybridBayesNet::shared_ptr bayesNet;
|
||||||
HybridGaussianFactorGraph::shared_ptr remainingGraph;
|
HybridGaussianFactorGraph::shared_ptr remainingGraph;
|
||||||
|
@ -239,20 +261,19 @@ AlgebraicDecisionTree<Key> probPrimeTree(
|
||||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||||
vector_values);
|
vector_values);
|
||||||
|
|
||||||
|
// Get the probPrime tree with the correct leaf probabilities
|
||||||
std::vector<double> probPrimes;
|
std::vector<double> probPrimes;
|
||||||
for (const DiscreteValues& assignment : assignments) {
|
for (const DiscreteValues& assignment : assignments) {
|
||||||
double error = 0.0;
|
|
||||||
VectorValues delta = *delta_tree(assignment);
|
VectorValues delta = *delta_tree(assignment);
|
||||||
for (auto factor : graph) {
|
|
||||||
if (factor->isHybrid()) {
|
|
||||||
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
|
|
||||||
error += f->error(delta, assignment);
|
|
||||||
|
|
||||||
} else if (factor->isContinuous()) {
|
// If VectorValues is empty, it means this is a pruned branch.
|
||||||
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
|
// Set the probPrime to 0.0.
|
||||||
error += f->inner()->error(delta);
|
if (delta.size() == 0) {
|
||||||
}
|
probPrimes.push_back(0.0);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double error = graph.error(delta, assignment);
|
||||||
probPrimes.push_back(exp(-error));
|
probPrimes.push_back(exp(-error));
|
||||||
}
|
}
|
||||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||||
|
@ -274,10 +295,23 @@ TEST(HybridEstimation, Probability) {
|
||||||
Switching switching(K, between_sigma, measurement_sigma, measurements,
|
Switching switching(K, between_sigma, measurement_sigma, measurements,
|
||||||
"1/1 1/1");
|
"1/1 1/1");
|
||||||
auto graph = switching.linearizedFactorGraph;
|
auto graph = switching.linearizedFactorGraph;
|
||||||
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering);
|
// Continuous elimination
|
||||||
auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3);
|
Ordering continuous_ordering(graph.continuousKeys());
|
||||||
|
HybridBayesNet::shared_ptr bayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||||
|
std::tie(bayesNet, discreteGraph) =
|
||||||
|
graph.eliminatePartialSequential(continuous_ordering);
|
||||||
|
|
||||||
|
// Discrete elimination
|
||||||
|
Ordering discrete_ordering(graph.discreteKeys());
|
||||||
|
auto discreteBayesNet = discreteGraph->eliminateSequential(discrete_ordering);
|
||||||
|
|
||||||
|
// Add the discrete conditionals to make it a full bayes net.
|
||||||
|
for (auto discrete_conditional : *discreteBayesNet) {
|
||||||
|
bayesNet->add(discrete_conditional);
|
||||||
|
}
|
||||||
|
auto discreteConditional = discreteBayesNet->atDiscrete(0);
|
||||||
|
|
||||||
HybridValues hybrid_values = bayesNet->optimize();
|
HybridValues hybrid_values = bayesNet->optimize();
|
||||||
|
|
||||||
|
@ -310,7 +344,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
|
||||||
|
|
||||||
// Get the tree of unnormalized probabilities for each mode sequence.
|
// Get the tree of unnormalized probabilities for each mode sequence.
|
||||||
AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph);
|
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
|
||||||
|
|
||||||
// Eliminate continuous
|
// Eliminate continuous
|
||||||
Ordering continuous_ordering(graph.continuousKeys());
|
Ordering continuous_ordering(graph.continuousKeys());
|
||||||
|
@ -326,8 +360,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||||
|
|
||||||
Ordering discrete(graph.discreteKeys());
|
Ordering discrete(graph.discreteKeys());
|
||||||
auto discreteBayesTree =
|
auto discreteBayesTree = discreteGraph->eliminateMultifrontal(discrete);
|
||||||
discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete);
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());
|
EXPECT_LONGS_EQUAL(1, discreteBayesTree->size());
|
||||||
// DiscreteBayesTree should have only 1 clique
|
// DiscreteBayesTree should have only 1 clique
|
||||||
|
@ -345,8 +378,8 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
discreteBayesTree->addClique(clique, discrete_clique);
|
discreteBayesTree->addClique(clique, discrete_clique);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Remove the clique from the children of the parents since it will get
|
// Remove the clique from the children of the parents since
|
||||||
// added again in addClique.
|
// it will get added again in addClique.
|
||||||
auto clique_it = std::find(clique->parent()->children.begin(),
|
auto clique_it = std::find(clique->parent()->children.begin(),
|
||||||
clique->parent()->children.end(), clique);
|
clique->parent()->children.end(), clique);
|
||||||
clique->parent()->children.erase(clique_it);
|
clique->parent()->children.erase(clique_it);
|
||||||
|
@ -392,7 +425,7 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************
|
/*********************************************************************************
|
||||||
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
// Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
||||||
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();
|
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();
|
||||||
|
|
|
@ -81,14 +81,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tiny(num_measurements: int = 1):
|
def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet:
|
||||||
"""Create a tiny two variable hybrid model."""
|
"""
|
||||||
|
Create a tiny two variable hybrid model which represents
|
||||||
|
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
||||||
|
"""
|
||||||
# Create hybrid Bayes net.
|
# Create hybrid Bayes net.
|
||||||
bayesNet = gtsam.HybridBayesNet()
|
bayesNet = gtsam.HybridBayesNet()
|
||||||
|
|
||||||
# Create mode key: 0 is low-noise, 1 is high-noise.
|
# Create mode key: 0 is low-noise, 1 is high-noise.
|
||||||
modeKey = M(0)
|
mode = (M(0), 2)
|
||||||
mode = (modeKey, 2)
|
|
||||||
|
|
||||||
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
|
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
|
||||||
I = np.eye(1)
|
I = np.eye(1)
|
||||||
|
@ -141,14 +143,22 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
return bayesNet.evaluate(sample) / fg.probPrime(
|
return bayesNet.evaluate(sample) / fg.probPrime(
|
||||||
continuous, sample.discrete())
|
continuous, sample.discrete())
|
||||||
|
|
||||||
def test_tiny2(self):
|
def test_ratio(self):
|
||||||
"""Test a tiny two variable hybrid model, with 2 measurements."""
|
"""
|
||||||
# Create the Bayes net and sample from it.
|
Given a tiny two variable hybrid model, with 2 measurements,
|
||||||
|
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||||
|
and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
|
||||||
|
both of which represent the same posterior.
|
||||||
|
"""
|
||||||
|
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||||
bayesNet = self.tiny(num_measurements=2)
|
bayesNet = self.tiny(num_measurements=2)
|
||||||
sample = bayesNet.sample()
|
# Sample from the Bayes net.
|
||||||
|
sample: gtsam.HybridValues = bayesNet.sample()
|
||||||
# print(sample)
|
# print(sample)
|
||||||
|
|
||||||
# Create a factor graph from the Bayes net with sampled measurements.
|
# Create a factor graph from the Bayes net with sampled measurements.
|
||||||
|
# The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)`
|
||||||
|
# and thus represents the same joint probability as the Bayes net.
|
||||||
fg = HybridGaussianFactorGraph()
|
fg = HybridGaussianFactorGraph()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
conditional = bayesNet.atMixture(i)
|
conditional = bayesNet.atMixture(i)
|
||||||
|
|
Loading…
Reference in New Issue