pass DiscreteConditional& for pruning instead of shared_ptr

release/4.3a0
Varun Agrawal 2025-01-04 05:48:50 -05:00
parent b7bddde82b
commit d6bc1e11a6
5 changed files with 15 additions and 16 deletions

View File

@ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
DiscreteConditional joint;
for (auto &&conditional : marginal) {
// The last discrete conditional may be a TableDistribution
if (auto dtc =
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc;
} else {
@ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
auto prunedHybridGaussianConditional = hgc->prune(*pruned);
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);

View File

@ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (!hybridGaussianCond->pruned()) {
// Imperative
clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
}
}
return parentData;

View File

@ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DiscreteConditional::shared_ptr &discreteProbs) const {
// Find keys in discreteProbs->keys() but not in this->keys():
const DiscreteConditional &discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs->keys().begin(),
discreteProbs->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs->max(keys);
auto max = discreteProbs.max(keys);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.

View File

@ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return Shared pointer to possibly a pruned HybridGaussianConditional
*/
HybridGaussianConditional::shared_ptr prune(
const DiscreteConditional::shared_ptr &discreteProbs) const;
const DiscreteConditional &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }

View File

@ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>(
keys.size(), decisionTreeFactor));
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
}
@ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
@ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());