Merge pull request #1300 from borglab/hybrid/improved-prune-2
commit
cae787a175
|
@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
||||
// The canonical decision tree factor which will get the discrete conditionals
|
||||
// added to it.
|
||||
DecisionTreeFactor dtFactor;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
||||
dtFactor = dtFactor * f;
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
|
|
|
@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
protected:
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
*/
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||
|
||||
public:
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
|
@ -34,8 +34,6 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridGaussianFactorGraph;
|
||||
|
||||
/**
|
||||
* Hybrid Conditional Density
|
||||
*
|
||||
|
|
|
@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) {
|
|||
EXPECT(bayesNet.equals(other));
|
||||
}
|
||||
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test choosing an assignment of conditionals
|
||||
TEST(HybridBayesNet, Choose) {
|
||||
|
@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
HybridValues pruned_delta = prunedBayesNet.optimize();
|
||||
|
||||
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
|
|
|
@ -33,7 +33,6 @@ namespace gtsam {
|
|||
// Forward declarations
|
||||
template<class FACTOR> class FactorGraph;
|
||||
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
|
||||
class HybridBayesTreeClique;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/** clique statistics */
|
||||
|
|
Loading…
Reference in New Issue