DiscreteNet::prune

release/4.3a0
Frank Dellaert 2025-01-30 07:57:42 -05:00
parent 957c967d0c
commit 3c10913c70
3 changed files with 85 additions and 70 deletions

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
namespace gtsam { namespace gtsam {
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
return result; return result;
} }
/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
// Set the fixed values if requested.
if (fixedValues) {
*fixedValues = deadModesValues;
}
}
// Return the resulting DiscreteBayesNet.
DiscreteBayesNet result;
if (pruned.keys().size() > 0) result.push_back(pruned);
return result;
}
/* *********************************************************************** */ /* *********************************************************************** */
std::string DiscreteBayesNet::markdown( std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,

View File

@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*/ */
DiscreteValues sample(DiscreteValues given) const; DiscreteValues sample(DiscreteValues given) const;
/**
* @brief Prune the Bayes net
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param deadModeThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals.
*/
DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& deadModeThreshold = {},
DiscreteValues* fixedValues = nullptr) const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -19,7 +19,6 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/TableDistribution.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune( HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune(
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();
// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint; DiscreteConditional pruned;
for (auto &&conditional : marginal) { for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
joint = joint * (*conditional);
// Set the fixed values if requested.
if (deadModeThreshold && fixedValues) {
*fixedValues = fixed;
} }
// Initialize the resulting HybridBayesNet.
HybridBayesNet result; HybridBayesNet result;
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
#if GTSAM_HYBRID_TIMING
gttic_(DeadModeRemoval);
#endif
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
#if GTSAM_HYBRID_TIMING
gttoc_(DeadModeRemoval);
#endif
}
// Go through all the Gaussian conditionals, restrict them according to // Go through all the Gaussian conditionals, restrict them according to
// deadModesValues, and then prune further. // fixed values, and then prune further.
for (auto &&conditional : *this) { for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue; if (conditional->isDiscrete()) continue;
// Restrict conditional using deadModesValues. // No-op if not a HybridGaussianConditional.
// No-op if not a HybridGaussianConditional or deadModesValues empty. if (deadModeThreshold) conditional = conditional->restrict(fixed);
auto restricted = conditional->restrict(deadModesValues);
// Now decide on type what to do: // Now decide on type what to do:
if (auto hgc = restricted->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) { if (!prunedHybridGaussianConditional) {
@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune(
} }
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional); result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = restricted->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional // Add the non-HybridGaussianConditional conditional
result.push_back(gc); result.push_back(gc);
} else } else
@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune(
"HybrdiBayesNet::prune: Unknown HybridConditional type."); "HybrdiBayesNet::prune: Unknown HybridConditional type.");
} }
#if GTSAM_HYBRID_TIMING // Add the pruned discrete conditionals to the result.
gttoc_(HybridPruning); for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
#endif result.push_back(discrete);
if (deadModeThreshold.has_value()) {
/*
If the pruned discrete conditional has any keys left, we add it to the
HybridBayesNet. If not, it means it is an orphan so we don't add this
pruned joint, and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
return result; return result;
} }