Merge branch 'hybrid/improved-prune-2' into varun/test-hybrid-estimation
commit
42e915f7d4
|
@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(
|
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
|
AlgebraicDecisionTree<Key> decisionTree;
|
||||||
|
|
||||||
|
// The canonical decision tree factor which will get the discrete conditionals
|
||||||
|
// added to it.
|
||||||
|
DecisionTreeFactor dtFactor;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
HybridConditional::shared_ptr conditional = this->at(i);
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||||
|
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
||||||
|
dtFactor = dtFactor * f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
|
// Get the decision tree of only the discrete keys
|
||||||
|
auto discreteConditionals = this->discreteConditionals();
|
||||||
|
const DecisionTreeFactor::shared_ptr discreteFactor =
|
||||||
|
boost::make_shared<DecisionTreeFactor>(
|
||||||
|
discreteConditionals->prune(maxNrLeaves));
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
|
|
|
@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
protected:
|
||||||
HybridBayesNet prune(
|
/**
|
||||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||||
|
*
|
||||||
|
* @return DecisionTreeFactor::shared_ptr
|
||||||
|
*/
|
||||||
|
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
|
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -34,8 +34,6 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class HybridGaussianFactorGraph;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hybrid Conditional Density
|
* Hybrid Conditional Density
|
||||||
*
|
*
|
||||||
|
|
|
@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) {
|
||||||
EXPECT(bayesNet.equals(other));
|
EXPECT(bayesNet.equals(other));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test choosing an assignment of conditionals
|
// Test choosing an assignment of conditionals
|
||||||
TEST(HybridBayesNet, Choose) {
|
TEST(HybridBayesNet, Choose) {
|
||||||
|
@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test bayes net pruning
|
||||||
|
TEST(HybridBayesNet, Prune) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
|
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||||
|
HybridValues pruned_delta = prunedBayesNet.optimize();
|
||||||
|
|
||||||
|
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
||||||
|
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test HybridBayesNet serialization.
|
// Test HybridBayesNet serialization.
|
||||||
TEST(HybridBayesNet, Serialization) {
|
TEST(HybridBayesNet, Serialization) {
|
||||||
|
|
|
@ -33,7 +33,6 @@ namespace gtsam {
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
template<class FACTOR> class FactorGraph;
|
template<class FACTOR> class FactorGraph;
|
||||||
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
|
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
|
||||||
class HybridBayesTreeClique;
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/** clique statistics */
|
/** clique statistics */
|
||||||
|
|
Loading…
Reference in New Issue