use DiscreteConditional shared_ptr for dynamic dispatch

release/4.3a0
Varun Agrawal 2024-12-31 00:26:20 -05:00
parent b7b273468c
commit 214043d60d
4 changed files with 14 additions and 12 deletions

View File

@ -56,11 +56,11 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
const DiscreteConditional::shared_ptr pruned = joint.prune(maxNrLeaves);
// Create a the result starting with the pruned joint.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
result.push_back(std::move(pruned));
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree

View File

@ -181,14 +181,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
discreteProbs->root_ = prunedDiscreteProbs.root_;
DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves);
discreteProbs->setData(prunedDiscreteProbs);
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
DiscreteConditional::shared_ptr prunedDiscreteProbs;
HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteProbs(prunedDiscreteProbs) {}

View File

@ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DecisionTreeFactor &discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
const DiscreteConditional::shared_ptr &discreteProbs) const {
// Find keys in discreteProbs->keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::set<Key> theirs(discreteProbs->keys().begin(),
discreteProbs->keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);
auto max = discreteProbs->max(keys);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return Shared pointer to possibly a pruned HybridGaussianConditional
*/
HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
const DiscreteConditional::shared_ptr &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }