pass DiscreteConditional& for pruning instead of shared_ptr

release/4.3a0
Varun Agrawal 2025-01-04 05:48:50 -05:00
parent b7bddde82b
commit d6bc1e11a6
5 changed files with 15 additions and 16 deletions

View File

@ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
DiscreteConditional joint; DiscreteConditional joint;
for (auto &&conditional : marginal) { for (auto &&conditional : marginal) {
// The last discrete conditional may be a TableDistribution // The last discrete conditional may be a TableDistribution
if (auto dtc = if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) {
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc; joint = joint * dc;
} else { } else {
@ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(*pruned);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional); result.push_back(prunedHybridGaussianConditional);

View File

@ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (!hybridGaussianCond->pruned()) { if (!hybridGaussianCond->pruned()) {
// Imperative // Imperative
clique->conditional() = std::make_shared<HybridConditional>( clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
} }
} }
return parentData; return parentData;

View File

@ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DiscreteConditional::shared_ptr &discreteProbs) const { const DiscreteConditional &discreteProbs) const {
// Find keys in discreteProbs->keys() but not in this->keys(): // Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end()); std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs->keys().begin(), std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs->keys().end()); discreteProbs.keys().end());
std::vector<Key> diff; std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff)); std::back_inserter(diff));
// Find maximum probability value for every combination of our keys. // Find maximum probability value for every combination of our keys.
Ordering keys(diff); Ordering keys(diff);
auto max = discreteProbs->max(keys); auto max = discreteProbs.max(keys);
// Check the max value for every combination of our keys. // Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional. // If the max value is 0.0, we can prune the corresponding conditional.

View File

@ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return Shared pointer to possibly a pruned HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional
*/ */
HybridGaussianConditional::shared_ptr prune( HybridGaussianConditional::shared_ptr prune(
const DiscreteConditional::shared_ptr &discreteProbs) const; const DiscreteConditional &discreteProbs) const;
/// Return true if the conditional has already been pruned. /// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; } bool pruned() const { return pruned_; }

View File

@ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
potentials[i] = 1; potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional // Prune the HybridGaussianConditional
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>( const auto pruned =
keys.size(), decisionTreeFactor)); hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional // Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
} }
@ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0}; 0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune( const auto pruned =
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals // Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
@ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0}; 0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune( const auto pruned =
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals // Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); EXPECT_LONGS_EQUAL(3, pruned->nrComponents());