refactor HybridBayesTree::optimize
parent
e16460358f
commit
773af1ed44
|
@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
|
|||
PrintForestVisitorPre visitor(keyFormatter);
|
||||
DepthFirstForest(forest, str, visitor);
|
||||
}
|
||||
}
|
||||
} // namespace treeTraversal
|
||||
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/inference/BayesTree-inst.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -39,95 +40,108 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesTree::optimize() const {
|
||||
HybridBayesNet hbn;
|
||||
DiscreteBayesNet dbn;
|
||||
DiscreteValues mpe;
|
||||
|
||||
KeyVector added_keys;
|
||||
auto root = roots_.at(0);
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridConditional::shared_ptr root_conditional = root->conditional();
|
||||
|
||||
// Iterate over all the nodes in the BayesTree
|
||||
for (auto&& node : nodes()) {
|
||||
// Check if conditional being added is already in the Bayes net.
|
||||
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
||||
added_keys.end()) {
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// Record the key being added
|
||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||
conditional->frontals().end());
|
||||
|
||||
if (conditional->isDiscrete()) {
|
||||
// If discrete, we use it to compute the MPE
|
||||
dbn.push_back(conditional->asDiscreteConditional());
|
||||
|
||||
} else {
|
||||
// Else conditional is hybrid or continuous-only,
|
||||
// so we directly add it to the Hybrid Bayes net.
|
||||
hbn.push_back(conditional);
|
||||
}
|
||||
}
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
dbn.push_back(root_conditional->asDiscreteConditional());
|
||||
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||
"ordering or use continuous factor graph.");
|
||||
}
|
||||
// Get the MPE
|
||||
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = hbn.choose(mpe);
|
||||
|
||||
// If TBB is enabled, the bayes net order gets reversed,
|
||||
// so we pre-reverse it
|
||||
#ifdef GTSAM_USE_TBB
|
||||
auto reversed = boost::adaptors::reverse(gbn);
|
||||
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
|
||||
#endif
|
||||
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
VectorValues values = optimize(mpe);
|
||||
return HybridValues(mpe, values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
/**
|
||||
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
|
||||
*
|
||||
* When traversing the tree, the pre-order visitor will receive an instance of
|
||||
* this class with the parent clique data.
|
||||
*/
|
||||
struct HybridAssignmentData {
|
||||
const DiscreteValues assignment_;
|
||||
GaussianBayesTree::sharedNode parentClique_;
|
||||
// The gaussian bayes tree that will be recursively created.
|
||||
GaussianBayesTree* gaussianbayesTree_;
|
||||
|
||||
KeyVector added_keys;
|
||||
/**
|
||||
* @brief Construct a new Hybrid Assignment Data object.
|
||||
*
|
||||
* @param assignment The MPE assignment for the optimal Gaussian cliques.
|
||||
* @param parentClique The clique from the parent node of the current node.
|
||||
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
|
||||
*/
|
||||
HybridAssignmentData(const DiscreteValues& assignment,
|
||||
const GaussianBayesTree::sharedNode& parentClique,
|
||||
GaussianBayesTree* gbt)
|
||||
: assignment_(assignment),
|
||||
parentClique_(parentClique),
|
||||
gaussianbayesTree_(gbt) {}
|
||||
|
||||
// Iterate over all the nodes in the BayesTree
|
||||
for (auto&& node : nodes()) {
|
||||
// Check if conditional being added is already in the Bayes net.
|
||||
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
||||
added_keys.end()) {
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// Record the key being added
|
||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||
conditional->frontals().end());
|
||||
|
||||
// If conditional is hybrid (and not discrete-only), we get the Gaussian
|
||||
// Conditional corresponding to the assignment and add it to the Gaussian
|
||||
// Bayes Net.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gm = conditional->asMixture();
|
||||
GaussianConditional::shared_ptr gaussian_conditional =
|
||||
(*gm)(assignment);
|
||||
|
||||
gbn.push_back(gaussian_conditional);
|
||||
|
||||
} else if (conditional->isContinuous()) {
|
||||
// If conditional is Gaussian, we simply add it to the Bayes net.
|
||||
gbn.push_back(conditional->asGaussian());
|
||||
}
|
||||
/**
|
||||
* @brief A function used during tree traversal that operators on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The HybridAssignmentData from the parent node.
|
||||
* @return HybridAssignmentData
|
||||
*/
|
||||
static HybridAssignmentData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& node,
|
||||
HybridAssignmentData& parentData) {
|
||||
// Extract the gaussian conditional from the Hybrid clique
|
||||
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||
GaussianConditional::shared_ptr conditional;
|
||||
if (hybrid_conditional->isHybrid()) {
|
||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||
} else if (hybrid_conditional->isContinuous()) {
|
||||
conditional = hybrid_conditional->asGaussian();
|
||||
} else {
|
||||
// Discrete only conditional, so we set to empty gaussian conditional
|
||||
conditional = boost::make_shared<GaussianConditional>();
|
||||
}
|
||||
|
||||
// Create the GaussianClique for the current node
|
||||
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||
// Add the current clique to the GaussianBayesTree.
|
||||
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
||||
|
||||
// Create new HybridAssignmentData where the current node is the parent
|
||||
// This will be passed down to the children nodes
|
||||
HybridAssignmentData data(parentData.assignment_, clique,
|
||||
parentData.gaussianbayesTree_);
|
||||
return data;
|
||||
}
|
||||
};
|
||||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesTree gbt;
|
||||
HybridAssignmentData rootData(assignment, 0, &gbt);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
|
||||
// If TBB is enabled, the bayes net order gets reversed,
|
||||
// so we pre-reverse it
|
||||
#ifdef GTSAM_USE_TBB
|
||||
auto reversed = boost::adaptors::reverse(gbn);
|
||||
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
|
||||
#endif
|
||||
VectorValues result = gbt.optimize();
|
||||
|
||||
// Return the optimized bayes net.
|
||||
return gbn.optimize();
|
||||
// Return the optimized bayes net result.
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
Loading…
Reference in New Issue