Change threshold name
parent
3c10913c70
commit
9bae03a6fa
|
|
@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
DiscreteBayesNet DiscreteBayesNet::prune(
|
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
|
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
|
||||||
DiscreteValues* fixedValues) const {
|
DiscreteValues* fixedValues) const {
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint;
|
||||||
|
|
@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
DiscreteValues deadModesValues;
|
DiscreteValues deadModesValues;
|
||||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||||
// then we run dead mode removal.
|
// then we run dead mode removal.
|
||||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
int index = -1;
|
int index = -1;
|
||||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
auto threshold = (probabilities.array() > *marginalThreshold);
|
||||||
// If atleast 1 value is non-zero, then we can find the index
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
// Else if all are zero, index would be set to 0 which is incorrect
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
if (!threshold.isZero()) {
|
if (!threshold.isZero()) {
|
||||||
|
|
|
||||||
|
|
@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
* @brief Prune the Bayes net
|
* @brief Prune the Bayes net
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves The maximum number of leaves to keep.
|
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||||
* @param deadModeThreshold If given, threshold on marginals to prune variables.
|
* @param marginalThreshold If given, threshold on marginals to prune variables.
|
||||||
* @param fixedValues If given, return the fixed values removed.
|
* @param fixedValues If given, return the fixed values removed.
|
||||||
* @return A new DiscreteBayesNet with pruned conditionals.
|
* @return A new DiscreteBayesNet with pruned conditionals.
|
||||||
*/
|
*/
|
||||||
DiscreteBayesNet prune(size_t maxNrLeaves,
|
DiscreteBayesNet prune(size_t maxNrLeaves,
|
||||||
const std::optional<double>& deadModeThreshold = {},
|
const std::optional<double>& marginalThreshold = {},
|
||||||
DiscreteValues* fixedValues = nullptr) const;
|
DiscreteValues* fixedValues = nullptr) const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
|
||||||
|
DiscreteValues *fixedValues) const {
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(HybridPruning);
|
gttic_(HybridPruning);
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
|
|
||||||
// Prune discrete Bayes net
|
// Prune discrete Bayes net
|
||||||
DiscreteValues fixed;
|
DiscreteValues fixed;
|
||||||
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
|
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
|
||||||
|
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional pruned;
|
DiscreteConditional pruned;
|
||||||
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
||||||
|
|
||||||
// Set the fixed values if requested.
|
// Set the fixed values if requested.
|
||||||
if (deadModeThreshold && fixedValues) {
|
if (marginalThreshold && fixedValues) {
|
||||||
*fixedValues = fixed;
|
*fixedValues = fixed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
if (conditional->isDiscrete()) continue;
|
if (conditional->isDiscrete()) continue;
|
||||||
|
|
||||||
// No-op if not a HybridGaussianConditional.
|
// No-op if not a HybridGaussianConditional.
|
||||||
if (deadModeThreshold) conditional = conditional->restrict(fixed);
|
if (marginalThreshold) conditional = conditional->restrict(fixed);
|
||||||
|
|
||||||
// Now decide on type what to do:
|
// Now decide on type what to do:
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
|
|
|
||||||
|
|
@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
* @param deadModeThreshold The threshold to check the mode marginals against.
|
* @param marginalThreshold The threshold to check the mode marginals against.
|
||||||
* If greater than this threshold, the mode gets assigned that value and is
|
* @param fixedValues The fixed values resulting from dead mode removal.
|
||||||
* considered "dead" for hybrid elimination.
|
*
|
||||||
* The mode can then be removed since it only has a single possible
|
* @note If marginal greater than this threshold, the mode gets assigned that
|
||||||
* assignment.
|
* value and is considered "dead" for hybrid elimination. The mode can then be
|
||||||
|
* removed since it only has a single possible assignment.
|
||||||
|
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(
|
HybridBayesNet prune(size_t maxNrLeaves,
|
||||||
size_t maxNrLeaves,
|
const std::optional<double> &marginalThreshold = {},
|
||||||
const std::optional<double> &deadModeThreshold = {}) const;
|
DiscreteValues *fixedValues = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||||
// all the conditionals with the same keys in bayesNetFragment.
|
// all the conditionals with the same keys in bayesNetFragment.
|
||||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the partial bayes net to the posterior bayes net.
|
// Add the partial bayes net to the posterior bayes net.
|
||||||
|
|
|
||||||
|
|
@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
HybridGaussianFactorGraph remainingFactorGraph_;
|
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||||
|
|
||||||
/// The threshold above which we make a decision about a mode.
|
/// The threshold above which we make a decision about a mode.
|
||||||
std::optional<double> deadModeThreshold_;
|
std::optional<double> marginalThreshold_;
|
||||||
|
DiscreteValues fixedValues_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Constructor
|
* @brief Constructor
|
||||||
*
|
*
|
||||||
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||||
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
* @param marginalThreshold The threshold above which a mode gets assigned a
|
||||||
* value and is considered "dead". 0.99 is a good starting value.
|
* value and is considered "dead". 0.99 is a good starting value.
|
||||||
*/
|
*/
|
||||||
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
HybridSmoother(const std::optional<double> marginalThreshold = {})
|
||||||
: deadModeThreshold_(deadModeThreshold) {}
|
: marginalThreshold_(marginalThreshold) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given new factors, perform an incremental update.
|
* Given new factors, perform an incremental update.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue