add pruned flag to avoid extra pruning
parent
94e31c99df
commit
a9c75d8ef4
|
@ -210,10 +210,12 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
auto hybridGaussianCond = conditional->asHybrid();
|
auto hybridGaussianCond = conditional->asHybrid();
|
||||||
|
|
||||||
|
if (!hybridGaussianCond->pruned()) {
|
||||||
// Imperative
|
// Imperative
|
||||||
clique->conditional() = std::make_shared<HybridConditional>(
|
clique->conditional() = std::make_shared<HybridConditional>(
|
||||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return parentData;
|
return parentData;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -120,7 +120,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys &discreteParents, Helper &&helper)
|
const DiscreteKeys &discreteParents, Helper &&helper, bool pruned)
|
||||||
: BaseFactor(discreteParents,
|
: BaseFactor(discreteParents,
|
||||||
FactorValuePairs(
|
FactorValuePairs(
|
||||||
[&](const GaussianFactorValuePair
|
[&](const GaussianFactorValuePair
|
||||||
|
@ -130,7 +130,8 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
},
|
},
|
||||||
std::move(helper.pairs))),
|
std::move(helper.pairs))),
|
||||||
BaseConditional(*helper.nrFrontals),
|
BaseConditional(*helper.nrFrontals),
|
||||||
negLogConstant_(helper.minNegLogConstant) {}
|
negLogConstant_(helper.minNegLogConstant),
|
||||||
|
pruned_(pruned) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKey &discreteParent,
|
const DiscreteKey &discreteParent,
|
||||||
|
@ -166,8 +167,9 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs,
|
||||||
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
bool pruned)
|
||||||
|
: HybridGaussianConditional(discreteParents, Helper(pairs), pruned) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const HybridGaussianConditional::Conditionals
|
const HybridGaussianConditional::Conditionals
|
||||||
|
@ -331,7 +333,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
|
||||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
prunedConditionals);
|
prunedConditionals, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -68,6 +68,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
///< Take advantage of the neg-log space so everything is a minimization
|
///< Take advantage of the neg-log space so everything is a minimization
|
||||||
double negLogConstant_;
|
double negLogConstant_;
|
||||||
|
|
||||||
|
/// Flag to indicate if the conditional has been pruned.
|
||||||
|
bool pruned_ = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -150,9 +153,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
*
|
*
|
||||||
* @param discreteParents the discrete parents. Will be placed last.
|
* @param discreteParents the discrete parents. Will be placed last.
|
||||||
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
||||||
|
* @param pruned Flag indicating if conditional has been pruned.
|
||||||
*/
|
*/
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const FactorValuePairs &pairs);
|
const FactorValuePairs &pairs, bool pruned = false);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -233,6 +237,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional::shared_ptr prune(
|
HybridGaussianConditional::shared_ptr prune(
|
||||||
const DecisionTreeFactor &discreteProbs) const;
|
const DecisionTreeFactor &discreteProbs) const;
|
||||||
|
|
||||||
|
/// Return true if the conditional has already been pruned.
|
||||||
|
bool pruned() const { return pruned_; }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -241,7 +248,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
/// Private constructor that uses helper struct above.
|
/// Private constructor that uses helper struct above.
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
Helper &&helper);
|
Helper &&helper, bool pruned = false);
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
Loading…
Reference in New Issue