add choose method to HybridBayesTree
parent
1e17dd3655
commit
5da56c1393
|
@ -138,7 +138,8 @@ struct HybridAssignmentData {
|
||||||
|
|
||||||
/* *************************************************************************
|
/* *************************************************************************
|
||||||
*/
|
*/
|
||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
GaussianBayesTree HybridBayesTree::choose(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
GaussianBayesTree gbt;
|
GaussianBayesTree gbt;
|
||||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||||
{
|
{
|
||||||
|
@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rootData.isValid()) {
|
if (!rootData.isValid()) {
|
||||||
|
return GaussianBayesTree();
|
||||||
|
}
|
||||||
|
return gbt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
|
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
GaussianBayesTree gbt = this->choose(assignment);
|
||||||
|
// If empty GaussianBayesTree, means a clique is pruned hence invalid
|
||||||
|
if (gbt.size() == 0) {
|
||||||
return VectorValues();
|
return VectorValues();
|
||||||
}
|
}
|
||||||
VectorValues result = gbt.optimize();
|
VectorValues result = gbt.optimize();
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
|
||||||
|
* value assignment.
|
||||||
|
*
|
||||||
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
|
* @return GaussianBayesTree
|
||||||
|
*/
|
||||||
|
GaussianBayesTree choose(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||||
* set of discrete variables and using it to compute the best continuous
|
* set of discrete variables and using it to compute the best continuous
|
||||||
|
|
|
@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test for choosing a GaussianBayesTree from a HybridBayesTree.
|
||||||
|
TEST(HybridBayesTree, Choose) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
HybridGaussianISAM isam;
|
||||||
|
HybridGaussianFactorGraph graph1;
|
||||||
|
|
||||||
|
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||||
|
for (size_t i = 1; i < 4; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the Gaussian factors, 1 prior on X(0),
|
||||||
|
// 3 measurements on X(2), X(3), X(4)
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||||
|
for (size_t i = 4; i <= 6; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the discrete factors
|
||||||
|
for (size_t i = 7; i <= 9; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
isam.update(graph1);
|
||||||
|
|
||||||
|
DiscreteValues assignment;
|
||||||
|
assignment[M(0)] = 1;
|
||||||
|
assignment[M(1)] = 1;
|
||||||
|
assignment[M(2)] = 1;
|
||||||
|
|
||||||
|
GaussianBayesTree gbt = isam.choose(assignment);
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
ordering += X(0);
|
||||||
|
ordering += X(1);
|
||||||
|
ordering += X(2);
|
||||||
|
ordering += X(3);
|
||||||
|
ordering += M(0);
|
||||||
|
ordering += M(1);
|
||||||
|
ordering += M(2);
|
||||||
|
|
||||||
|
//TODO(Varun) get segfault if ordering not provided
|
||||||
|
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
|
||||||
|
|
||||||
|
auto expected_gbt = bayesTree->choose(assignment);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_gbt, gbt));
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test HybridBayesTree serialization.
|
// Test HybridBayesTree serialization.
|
||||||
TEST(HybridBayesTree, Serialization) {
|
TEST(HybridBayesTree, Serialization) {
|
||||||
|
|
Loading…
Reference in New Issue