add HybridBayesTree optimize method
parent
0edcfd4ff8
commit
5169b2ec30
|
@ -18,6 +18,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/treeTraversal-inst.h>
|
#include <gtsam/base/treeTraversal-inst.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/inference/BayesTree-inst.h>
|
#include <gtsam/inference/BayesTree-inst.h>
|
||||||
|
@ -35,6 +37,42 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
||||||
return Base::equals(other, tol);
|
return Base::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesTree::optimize() const {
|
||||||
|
HybridBayesNet hbn;
|
||||||
|
DiscreteBayesNet dbn;
|
||||||
|
|
||||||
|
KeyVector added_keys;
|
||||||
|
|
||||||
|
// Iterate over all the nodes in the BayesTree
|
||||||
|
for (auto&& node : nodes()) {
|
||||||
|
// Check if conditional being added is already in the Bayes net.
|
||||||
|
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
||||||
|
added_keys.end()) {
|
||||||
|
// Access the clique and get the underlying hybrid conditional
|
||||||
|
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||||
|
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||||
|
|
||||||
|
// Record the key being added
|
||||||
|
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||||
|
conditional->frontals().end());
|
||||||
|
|
||||||
|
if (conditional->isHybrid()) {
|
||||||
|
// If conditional is hybrid, add it to a Hybrid Bayes net.
|
||||||
|
hbn.push_back(conditional);
|
||||||
|
} else if (conditional->isDiscrete()) {
|
||||||
|
// Else if discrete, we use it to compute the MPE
|
||||||
|
dbn.push_back(conditional->asDiscreteConditional());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get the MPE
|
||||||
|
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
|
||||||
|
// Given the MPE, compute the optimal continuous values.
|
||||||
|
GaussianBayesNet gbn = hbn.choose(mpe);
|
||||||
|
return HybridValues(mpe, gbn.optimize());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
|
@ -50,11 +88,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||||
|
|
||||||
KeyVector frontals(conditional->frontals().begin(),
|
|
||||||
conditional->frontals().end());
|
|
||||||
|
|
||||||
// Record the key being added
|
// Record the key being added
|
||||||
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
|
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||||
|
conditional->frontals().end());
|
||||||
|
|
||||||
// If conditional is hybrid (and not discrete-only), we get the Gaussian
|
// If conditional is hybrid (and not discrete-only), we get the Gaussian
|
||||||
// Conditional corresponding to the assignment and add it to the Gaussian
|
// Conditional corresponding to the assignment and add it to the Gaussian
|
||||||
|
|
|
@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& other, double tol = 1e-9) const;
|
bool equals(const This& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||||
|
* set of discrete variables and using it to compute the best continuous
|
||||||
|
* update delta.
|
||||||
|
*
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
|
HybridValues optimize() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Recursively optimize the BayesTree to produce a vector solution.
|
* @brief Recursively optimize the BayesTree to produce a vector solution.
|
||||||
*
|
*
|
||||||
|
|
|
@ -125,11 +125,6 @@ TEST(HybridBayesNet, OptimizeAssignment) {
|
||||||
TEST(HybridBayesNet, Optimize) {
|
TEST(HybridBayesNet, Optimize) {
|
||||||
Switching s(4);
|
Switching s(4);
|
||||||
|
|
||||||
Ordering ordering;
|
|
||||||
for (auto&& kvp : s.linearizationPoint) {
|
|
||||||
ordering += kvp.key;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
* @date August 2022
|
* @date August 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
|
|
||||||
|
@ -31,8 +32,8 @@ using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test for optimizing a HybridBayesTree.
|
// Test for optimizing a HybridBayesTree with a given assignment.
|
||||||
TEST(HybridBayesTree, Optimize) {
|
TEST(HybridBayesTree, OptimizeAssignment) {
|
||||||
Switching s(4);
|
Switching s(4);
|
||||||
|
|
||||||
HybridGaussianISAM isam;
|
HybridGaussianISAM isam;
|
||||||
|
@ -85,6 +86,58 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
EXPECT(assert_equal(expected, delta));
|
EXPECT(assert_equal(expected, delta));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test for optimizing a HybridBayesTree.
|
||||||
|
TEST(HybridBayesTree, Optimize) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
HybridGaussianISAM isam;
|
||||||
|
HybridGaussianFactorGraph graph1;
|
||||||
|
|
||||||
|
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||||
|
for (size_t i = 1; i < 4; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the Gaussian factors, 1 prior on X(1),
|
||||||
|
// 3 measurements on X(2), X(3), X(4)
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||||
|
for (size_t i = 4; i <= 6; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the discrete factors
|
||||||
|
for (size_t i = 7; i <= 9; i++) {
|
||||||
|
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
isam.update(graph1);
|
||||||
|
|
||||||
|
HybridValues delta = isam.optimize();
|
||||||
|
|
||||||
|
// Create ordering.
|
||||||
|
Ordering ordering;
|
||||||
|
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
|
||||||
|
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||||
|
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||||
|
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
|
DiscreteFactorGraph dfg;
|
||||||
|
for (auto&& f : *remainingFactorGraph) {
|
||||||
|
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
|
||||||
|
dfg.push_back(
|
||||||
|
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues expectedMPE = dfg.optimize();
|
||||||
|
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expectedMPE, delta.discrete()));
|
||||||
|
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue