gtsam/gtsam/hybrid/HybridBayesTree.cpp

226 lines
7.9 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridBayesTree.cpp
* @brief Hybrid Bayes Tree, the result of eliminating a
* HybridJunctionTree
* @date Mar 11, 2022
* @author Fan Jiang, Varun Agrawal
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h>
namespace gtsam {
// Instantiate base class
template class BayesTreeCliqueBase<HybridBayesTreeClique,
HybridGaussianFactorGraph>;
template class BayesTree<HybridBayesTreeClique>;
/* ************************************************************************* */
bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
DiscreteBayesNet dbn;
DiscreteValues mpe;
auto root = roots_.at(0);
// Access the clique and get the underlying hybrid conditional
HybridConditional::shared_ptr root_conditional = root->conditional();
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination "
"ordering or use continuous factor graph.");
}
VectorValues values = optimize(mpe);
return HybridValues(values, mpe);
}
/* ************************************************************************* */
/**
* @brief Helper class for Depth First Forest traversal on the HybridBayesTree.
*
* When traversing the tree, the pre-order visitor will receive an instance of
* this class with the parent clique data.
*/
struct HybridAssignmentData {
const DiscreteValues assignment_;
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;
/**
* @brief Construct a new Hybrid Assignment Data object.
*
* @param assignment The MPE assignment for the optimal Gaussian cliques.
* @param parentClique The clique from the parent node of the current node.
* @param gbt The Gaussian Bayes Tree being generated during tree traversal.
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt),
valid_(valid) {}
bool isValid() const { return valid_; }
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The HybridAssignmentData from the parent node.
* @return HybridAssignmentData which is passed to the children.
*/
static HybridAssignmentData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& node,
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
} else if (hybrid_conditional->isContinuous()) {
conditional = hybrid_conditional->asGaussian();
} else {
// Discrete only conditional, so we set to empty gaussian conditional
conditional = std::make_shared<GaussianConditional>();
}
GaussianBayesTree::sharedNode clique;
if (conditional) {
// Create the GaussianClique for the current node
clique = std::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}
// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_, parentData.valid_);
return data;
}
};
/* *************************************************************************
*/
GaussianBayesTree HybridBayesTree::choose(
const DiscreteValues& assignment) const {
GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridAssignmentData::AssignmentPreOrderVisitor,
visitorPost);
}
if (!rootData.isValid()) {
return GaussianBayesTree();
}
return gbt;
}
/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt = this->choose(assignment);
// If empty GaussianBayesTree, means a clique is pruned hence invalid
if (gbt.size() == 0) {
return VectorValues();
}
VectorValues result = gbt.optimize();
// Return the optimized bayes net result.
return result;
}
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {}
/**
* @brief A function used during tree traversal that operates on each node
* before visiting the node's children.
*
* @param node The current node being visited.
* @param parentData The data from the parent node.
* @return HybridPrunerData which is passed to the children.
*/
static HybridPrunerData AssignmentPreOrderVisitor(
const HybridBayesTree::sharedNode& clique,
HybridPrunerData& parentData) {
// Get the conditional
HybridConditional::shared_ptr conditional = clique->conditional();
// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree);
}
return parentData;
}
};
HybridPrunerData rootData(prunedDecisionTree, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
visitorPost);
}
}
} // namespace gtsam