Change threshold name

release/4.3a0
Frank Dellaert 2025-01-30 08:57:04 -05:00
parent 3c10913c70
commit 9bae03a6fa
6 changed files with 26 additions and 22 deletions

View File

@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// been pruned before. Another, possibly faster approach is branch and bound // been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune( DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold, size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const { DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint; DiscreteConditional joint;
@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune(
DiscreteValues deadModesValues; DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning, // If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal. // then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) { for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey); const Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1; int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold); auto threshold = (probabilities.array() > *marginalThreshold);
// If atleast 1 value is non-zero, then we can find the index // If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect // Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) { if (!threshold.isZero()) {

View File

@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* @brief Prune the Bayes net * @brief Prune the Bayes net
* *
* @param maxNrLeaves The maximum number of leaves to keep. * @param maxNrLeaves The maximum number of leaves to keep.
* @param deadModeThreshold If given, threshold on marginals to prune variables. * @param marginalThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed. * @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals. * @return A new DiscreteBayesNet with pruned conditionals.
*/ */
DiscreteBayesNet prune(size_t maxNrLeaves, DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& deadModeThreshold = {}, const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const; DiscreteValues* fixedValues = nullptr) const;
///@} ///@}

View File

@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune( HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(HybridPruning); gttic_(HybridPruning);
#endif #endif
@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune(
// Prune discrete Bayes net // Prune discrete Bayes net
DiscreteValues fixed; DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional pruned; DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
// Set the fixed values if requested. // Set the fixed values if requested.
if (deadModeThreshold && fixedValues) { if (marginalThreshold && fixedValues) {
*fixedValues = fixed; *fixedValues = fixed;
} }
@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune(
if (conditional->isDiscrete()) continue; if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional. // No-op if not a HybridGaussianConditional.
if (deadModeThreshold) conditional = conditional->restrict(fixed); if (marginalThreshold) conditional = conditional->restrict(fixed);
// Now decide on type what to do: // Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {

View File

@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
* *
* @param maxNrLeaves Continuous values at which to compute the error. * @param maxNrLeaves Continuous values at which to compute the error.
* @param deadModeThreshold The threshold to check the mode marginals against. * @param marginalThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is * @param fixedValues The fixed values resulting from dead mode removal.
* considered "dead" for hybrid elimination. *
* The mode can then be removed since it only has a single possible * @note If marginal greater than this threshold, the mode gets assigned that
* assignment. * value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.
* @return A pruned HybridBayesNet * @return A pruned HybridBayesNet
*/ */
HybridBayesNet prune( HybridBayesNet prune(size_t maxNrLeaves,
size_t maxNrLeaves, const std::optional<double> &marginalThreshold = {},
const std::optional<double> &deadModeThreshold = {}) const; DiscreteValues *fixedValues = nullptr) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) { if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment. // all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
} }
// Add the partial bayes net to the posterior bayes net. // Add the partial bayes net to the posterior bayes net.

View File

@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
HybridGaussianFactorGraph remainingFactorGraph_; HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode. /// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_; std::optional<double> marginalThreshold_;
DiscreteValues fixedValues_;
public: public:
/** /**
* @brief Constructor * @brief Constructor
* *
* @param removeDeadModes Flag indicating whether to remove dead modes. * @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a * @param marginalThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value. * value and is considered "dead". 0.99 is a good starting value.
*/ */
HybridSmoother(const std::optional<double> deadModeThreshold = {}) HybridSmoother(const std::optional<double> marginalThreshold = {})
: deadModeThreshold_(deadModeThreshold) {} : marginalThreshold_(marginalThreshold) {}
/** /**
* Given new factors, perform an incremental update. * Given new factors, perform an incremental update.