refactor HybridBayesTree::optimize

release/4.3a0
Varun Agrawal 2022-09-02 15:04:28 -04:00
parent e16460358f
commit 773af1ed44
2 changed files with 91 additions and 77 deletions

View File

@ -221,6 +221,6 @@ void PrintForest(const FOREST& forest, std::string str,
PrintForestVisitorPre visitor(keyFormatter); PrintForestVisitorPre visitor(keyFormatter);
DepthFirstForest(forest, str, visitor); DepthFirstForest(forest, str, visitor);
} }
} } // namespace treeTraversal
} } // namespace gtsam

View File

@ -24,6 +24,7 @@
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h> #include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h> #include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h>
namespace gtsam { namespace gtsam {
@ -39,95 +40,108 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const { HybridValues HybridBayesTree::optimize() const {
HybridBayesNet hbn;
DiscreteBayesNet dbn; DiscreteBayesNet dbn;
DiscreteValues mpe;
KeyVector added_keys; auto root = roots_.at(0);
// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional // Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second; HybridConditional::shared_ptr root_conditional = root->conditional();
HybridConditional::shared_ptr conditional = clique->conditional();
// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());
if (conditional->isDiscrete()) {
// If discrete, we use it to compute the MPE
dbn.push_back(conditional->asDiscreteConditional());
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
mpe = DiscreteFactorGraph(dbn).optimize();
} else { } else {
// Else conditional is hybrid or continuous-only, throw std::runtime_error(
// so we directly add it to the Hybrid Bayes net. "HybridBayesTree root is not discrete-only. Please check elimination "
hbn.push_back(conditional); "ordering or use continuous factor graph.");
} }
}
}
// Get the MPE
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hbn.choose(mpe);
// If TBB is enabled, the bayes net order gets reversed, VectorValues values = optimize(mpe);
// so we pre-reverse it return HybridValues(mpe, values);
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif
return HybridValues(mpe, gbn.optimize());
} }
/* ************************************************************************* */ /* ************************************************************************* */
/**
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
*
* When traversing the tree, the pre-order visitor will receive an instance of
* this class with the parent clique data.
*/
struct HybridAssignmentData {
const DiscreteValues assignment_;
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
/**
* @brief Construct a new Hybrid Assignment Data object.
*
* @param assignment The MPE assignment for the optimal Gaussian cliques.
* @param parentClique The clique from the parent node of the current node.
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
/**
* @brief A function used during tree traversal that operators on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
} else if (hybrid_conditional->isContinuous()) {
conditional = hybrid_conditional->asGaussian();
} else {
// Discrete only conditional, so we set to empty gaussian conditional
conditional = boost::make_shared<GaussianConditional>();
}
// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
return data;
}
};
/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn; GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt);
KeyVector added_keys; {
treeTraversal::no_op visitorPost;
// Iterate over all the nodes in the BayesTree // Limits OpenMP threads since we're mixing TBB and OpenMP
for (auto&& node : nodes()) { TbbOpenMPMixedScope threadLimiter;
// Check if conditional being added is already in the Bayes net. treeTraversal::DepthFirstForestParallel(
if (std::find(added_keys.begin(), added_keys.end(), node.first) == *this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
added_keys.end()) { visitorPost);
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();
// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());
// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
// Bayes Net.
if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
GaussianConditional::shared_ptr gaussian_conditional =
(*gm)(assignment);
gbn.push_back(gaussian_conditional);
} else if (conditional->isContinuous()) {
// If conditional is Gaussian, we simply add it to the Bayes net.
gbn.push_back(conditional->asGaussian());
}
}
} }
// If TBB is enabled, the bayes net order gets reversed, VectorValues result = gbt.optimize();
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif
// Return the optimized bayes net. // Return the optimized bayes net result.
return gbn.optimize(); return result;
} }
} // namespace gtsam } // namespace gtsam