add pruned flag to avoid extra pruning

release/4.3a0
Varun Agrawal 2024-12-02 12:07:30 -05:00
parent 94e31c99df
commit a9c75d8ef4
3 changed files with 21 additions and 10 deletions

View File

@ -210,9 +210,11 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
auto hybridGaussianCond = conditional->asHybrid(); auto hybridGaussianCond = conditional->asHybrid();
// Imperative if (!hybridGaussianCond->pruned()) {
clique->conditional() = std::make_shared<HybridConditional>( // Imperative
hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
}
} }
return parentData; return parentData;
} }

View File

@ -120,7 +120,7 @@ struct HybridGaussianConditional::Helper {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, Helper &&helper) const DiscreteKeys &discreteParents, Helper &&helper, bool pruned)
: BaseFactor(discreteParents, : BaseFactor(discreteParents,
FactorValuePairs( FactorValuePairs(
[&](const GaussianFactorValuePair [&](const GaussianFactorValuePair
@ -130,7 +130,8 @@ HybridGaussianConditional::HybridGaussianConditional(
}, },
std::move(helper.pairs))), std::move(helper.pairs))),
BaseConditional(*helper.nrFrontals), BaseConditional(*helper.nrFrontals),
negLogConstant_(helper.minNegLogConstant) {} negLogConstant_(helper.minNegLogConstant),
pruned_(pruned) {}
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, const DiscreteKey &discreteParent,
@ -166,8 +167,9 @@ HybridGaussianConditional::HybridGaussianConditional(
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {} : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs) const DiscreteKeys &discreteParents, const FactorValuePairs &pairs,
: HybridGaussianConditional(discreteParents, Helper(pairs)) {} bool pruned)
: HybridGaussianConditional(discreteParents, Helper(pairs), pruned) {}
/* *******************************************************************************/ /* *******************************************************************************/
const HybridGaussianConditional::Conditionals const HybridGaussianConditional::Conditionals
@ -331,7 +333,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
FactorValuePairs prunedConditionals = factors().apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals); prunedConditionals, true);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -68,6 +68,9 @@ class GTSAM_EXPORT HybridGaussianConditional
///< Take advantage of the neg-log space so everything is a minimization ///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_; double negLogConstant_;
/// Flag to indicate if the conditional has been pruned.
bool pruned_ = false;
public: public:
/// @name Constructors /// @name Constructors
/// @{ /// @{
@ -150,9 +153,10 @@ class GTSAM_EXPORT HybridGaussianConditional
* *
* @param discreteParents the discrete parents. Will be placed last. * @param discreteParents the discrete parents. Will be placed last.
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs. * @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
* @param pruned Flag indicating if conditional has been pruned.
*/ */
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const FactorValuePairs &pairs); const FactorValuePairs &pairs, bool pruned = false);
/// @} /// @}
/// @name Testable /// @name Testable
@ -233,6 +237,9 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional::shared_ptr prune( HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const; const DecisionTreeFactor &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }
/// @} /// @}
private: private:
@ -241,7 +248,7 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Private constructor that uses helper struct above. /// Private constructor that uses helper struct above.
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
Helper &&helper); Helper &&helper, bool pruned = false);
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;