Merge pull request #1300 from borglab/hybrid/improved-prune-2

release/4.3a0
Varun Agrawal 2022-10-04 10:01:13 -04:00 committed by GitHub
commit cae787a175
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 9 deletions

View File

@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
// The canonical decision tree factor which will get the discrete conditionals
// added to it.
DecisionTreeFactor dtFactor;
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
dtFactor = dtFactor * f;
}
}
return boost::make_shared<DecisionTreeFactor>(dtFactor);
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.

View File

@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;
/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;
public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;
/// @}

View File

@ -34,8 +34,6 @@
namespace gtsam {
class HybridGaussianFactorGraph;
/**
* Hybrid Conditional Density
*

View File

@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) {
EXPECT(bayesNet.equals(other));
}
/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ****************************************************************************/
// Test bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto prunedBayesNet = hybridBayesNet->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {

View File

@ -33,7 +33,6 @@ namespace gtsam {
// Forward declarations
template<class FACTOR> class FactorGraph;
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
class HybridBayesTreeClique;
/* ************************************************************************* */
/** clique statistics */