add pruned flag to avoid extra pruning

release/4.3a0
Varun Agrawal 2024-12-02 12:07:30 -05:00
parent 94e31c99df
commit a9c75d8ef4
3 changed files with 21 additions and 10 deletions

View File

@ -210,9 +210,11 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) {
auto hybridGaussianCond = conditional->asHybrid();
// Imperative
clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
if (!hybridGaussianCond->pruned()) {
// Imperative
clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
}
}
return parentData;
}

View File

@ -120,7 +120,7 @@ struct HybridGaussianConditional::Helper {
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, Helper &&helper)
const DiscreteKeys &discreteParents, Helper &&helper, bool pruned)
: BaseFactor(discreteParents,
FactorValuePairs(
[&](const GaussianFactorValuePair
@ -130,7 +130,8 @@ HybridGaussianConditional::HybridGaussianConditional(
},
std::move(helper.pairs))),
BaseConditional(*helper.nrFrontals),
negLogConstant_(helper.minNegLogConstant) {}
negLogConstant_(helper.minNegLogConstant),
pruned_(pruned) {}
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
@ -166,8 +167,9 @@ HybridGaussianConditional::HybridGaussianConditional(
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs,
bool pruned)
: HybridGaussianConditional(discreteParents, Helper(pairs), pruned) {}
/* *******************************************************************************/
const HybridGaussianConditional::Conditionals
@ -331,7 +333,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals);
prunedConditionals, true);
}
/* *******************************************************************************/

View File

@ -68,6 +68,9 @@ class GTSAM_EXPORT HybridGaussianConditional
///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_;
/// Flag to indicate if the conditional has been pruned.
bool pruned_ = false;
public:
/// @name Constructors
/// @{
@ -150,9 +153,10 @@ class GTSAM_EXPORT HybridGaussianConditional
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
* @param pruned Flag indicating if conditional has been pruned.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const FactorValuePairs &pairs);
const FactorValuePairs &pairs, bool pruned = false);
/// @}
/// @name Testable
@ -233,6 +237,9 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }
/// @}
private:
@ -241,7 +248,7 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Private constructor that uses helper struct above.
HybridGaussianConditional(const DiscreteKeys &discreteParents,
Helper &&helper);
Helper &&helper, bool pruned = false);
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;