refactor HybridBayesTree::optimize
parent
e16460358f
commit
773af1ed44
|
@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
|
||||||
PrintForestVisitorPre visitor(keyFormatter);
|
PrintForestVisitorPre visitor(keyFormatter);
|
||||||
DepthFirstForest(forest, str, visitor);
|
DepthFirstForest(forest, str, visitor);
|
||||||
}
|
}
|
||||||
}
|
} // namespace treeTraversal
|
||||||
|
|
||||||
}
|
} // namespace gtsam
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/inference/BayesTree-inst.h>
|
#include <gtsam/inference/BayesTree-inst.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||||
|
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -39,95 +40,108 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesTree::optimize() const {
|
HybridValues HybridBayesTree::optimize() const {
|
||||||
HybridBayesNet hbn;
|
|
||||||
DiscreteBayesNet dbn;
|
DiscreteBayesNet dbn;
|
||||||
|
DiscreteValues mpe;
|
||||||
|
|
||||||
KeyVector added_keys;
|
auto root = roots_.at(0);
|
||||||
|
|
||||||
// Iterate over all the nodes in the BayesTree
|
|
||||||
for (auto&& node : nodes()) {
|
|
||||||
// Check if conditional being added is already in the Bayes net.
|
|
||||||
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
|
||||||
added_keys.end()) {
|
|
||||||
// Access the clique and get the underlying hybrid conditional
|
// Access the clique and get the underlying hybrid conditional
|
||||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
HybridConditional::shared_ptr root_conditional = root->conditional();
|
||||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
|
||||||
|
|
||||||
// Record the key being added
|
|
||||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
|
||||||
conditional->frontals().end());
|
|
||||||
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
// If discrete, we use it to compute the MPE
|
|
||||||
dbn.push_back(conditional->asDiscreteConditional());
|
|
||||||
|
|
||||||
|
// The root should be discrete only, we compute the MPE
|
||||||
|
if (root_conditional->isDiscrete()) {
|
||||||
|
dbn.push_back(root_conditional->asDiscreteConditional());
|
||||||
|
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||||
} else {
|
} else {
|
||||||
// Else conditional is hybrid or continuous-only,
|
throw std::runtime_error(
|
||||||
// so we directly add it to the Hybrid Bayes net.
|
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||||
hbn.push_back(conditional);
|
"ordering or use continuous factor graph.");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
// Get the MPE
|
|
||||||
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
|
||||||
GaussianBayesNet gbn = hbn.choose(mpe);
|
|
||||||
|
|
||||||
// If TBB is enabled, the bayes net order gets reversed,
|
VectorValues values = optimize(mpe);
|
||||||
// so we pre-reverse it
|
return HybridValues(mpe, values);
|
||||||
#ifdef GTSAM_USE_TBB
|
|
||||||
auto reversed = boost::adaptors::reverse(gbn);
|
|
||||||
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return HybridValues(mpe, gbn.optimize());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
/**
|
||||||
|
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
|
||||||
|
*
|
||||||
|
* When traversing the tree, the pre-order visitor will receive an instance of
|
||||||
|
* this class with the parent clique data.
|
||||||
|
*/
|
||||||
|
struct HybridAssignmentData {
|
||||||
|
const DiscreteValues assignment_;
|
||||||
|
GaussianBayesTree::sharedNode parentClique_;
|
||||||
|
// The gaussian bayes tree that will be recursively created.
|
||||||
|
GaussianBayesTree* gaussianbayesTree_;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Assignment Data object.
|
||||||
|
*
|
||||||
|
* @param assignment The MPE assignment for the optimal Gaussian cliques.
|
||||||
|
* @param parentClique The clique from the parent node of the current node.
|
||||||
|
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
|
||||||
|
*/
|
||||||
|
HybridAssignmentData(const DiscreteValues& assignment,
|
||||||
|
const GaussianBayesTree::sharedNode& parentClique,
|
||||||
|
GaussianBayesTree* gbt)
|
||||||
|
: assignment_(assignment),
|
||||||
|
parentClique_(parentClique),
|
||||||
|
gaussianbayesTree_(gbt) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A function used during tree traversal that operators on each node
|
||||||
|
* before visiting the node's children.
|
||||||
|
*
|
||||||
|
* @param node The current node being visited.
|
||||||
|
* @param parentData The HybridAssignmentData from the parent node.
|
||||||
|
* @return HybridAssignmentData
|
||||||
|
*/
|
||||||
|
static HybridAssignmentData AssignmentPreOrderVisitor(
|
||||||
|
const HybridBayesTree::sharedNode& node,
|
||||||
|
HybridAssignmentData& parentData) {
|
||||||
|
// Extract the gaussian conditional from the Hybrid clique
|
||||||
|
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||||
|
GaussianConditional::shared_ptr conditional;
|
||||||
|
if (hybrid_conditional->isHybrid()) {
|
||||||
|
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||||
|
} else if (hybrid_conditional->isContinuous()) {
|
||||||
|
conditional = hybrid_conditional->asGaussian();
|
||||||
|
} else {
|
||||||
|
// Discrete only conditional, so we set to empty gaussian conditional
|
||||||
|
conditional = boost::make_shared<GaussianConditional>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the GaussianClique for the current node
|
||||||
|
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||||
|
// Add the current clique to the GaussianBayesTree.
|
||||||
|
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
||||||
|
|
||||||
|
// Create new HybridAssignmentData where the current node is the parent
|
||||||
|
// This will be passed down to the children nodes
|
||||||
|
HybridAssignmentData data(parentData.assignment_, clique,
|
||||||
|
parentData.gaussianbayesTree_);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesTree gbt;
|
||||||
|
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||||
KeyVector added_keys;
|
{
|
||||||
|
treeTraversal::no_op visitorPost;
|
||||||
// Iterate over all the nodes in the BayesTree
|
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||||
for (auto&& node : nodes()) {
|
TbbOpenMPMixedScope threadLimiter;
|
||||||
// Check if conditional being added is already in the Bayes net.
|
treeTraversal::DepthFirstForestParallel(
|
||||||
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
|
||||||
added_keys.end()) {
|
visitorPost);
|
||||||
// Access the clique and get the underlying hybrid conditional
|
|
||||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
|
||||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
|
||||||
|
|
||||||
// Record the key being added
|
|
||||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
|
||||||
conditional->frontals().end());
|
|
||||||
|
|
||||||
// If conditional is hybrid (and not discrete-only), we get the Gaussian
|
|
||||||
// Conditional corresponding to the assignment and add it to the Gaussian
|
|
||||||
// Bayes Net.
|
|
||||||
if (conditional->isHybrid()) {
|
|
||||||
auto gm = conditional->asMixture();
|
|
||||||
GaussianConditional::shared_ptr gaussian_conditional =
|
|
||||||
(*gm)(assignment);
|
|
||||||
|
|
||||||
gbn.push_back(gaussian_conditional);
|
|
||||||
|
|
||||||
} else if (conditional->isContinuous()) {
|
|
||||||
// If conditional is Gaussian, we simply add it to the Bayes net.
|
|
||||||
gbn.push_back(conditional->asGaussian());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If TBB is enabled, the bayes net order gets reversed,
|
VectorValues result = gbt.optimize();
|
||||||
// so we pre-reverse it
|
|
||||||
#ifdef GTSAM_USE_TBB
|
|
||||||
auto reversed = boost::adaptors::reverse(gbn);
|
|
||||||
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Return the optimized bayes net.
|
// Return the optimized bayes net result.
|
||||||
return gbn.optimize();
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue