DiscreteNet::prune
parent
957c967d0c
commit
3c10913c70
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// The implementation is: build the entire joint into one factor and then prune.
|
||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
|
||||
DiscreteValues* fixedValues) const {
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
joint = joint * (*conditional);
|
||||
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
DiscreteConditional pruned = joint;
|
||||
pruned.prune(maxNrLeaves);
|
||||
|
||||
DiscreteValues deadModesValues;
|
||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||
// then we run dead mode removal.
|
||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
threshold.maxCoeff(&index);
|
||||
}
|
||||
|
||||
if (index >= 0) {
|
||||
deadModesValues.emplace(dkey.first, index);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the modes (imperative)
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
// Set the fixed values if requested.
|
||||
if (fixedValues) {
|
||||
*fixedValues = deadModesValues;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the resulting DiscreteBayesNet.
|
||||
DiscreteBayesNet result;
|
||||
if (pruned.keys().size() > 0) result.push_back(pruned);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
|||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the Bayes net
|
||||
*
|
||||
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||
* @param deadModeThreshold If given, threshold on marginals to prune variables.
|
||||
* @param fixedValues If given, return the fixed values removed.
|
||||
* @return A new DiscreteBayesNet with pruned conditionals.
|
||||
*/
|
||||
DiscreteBayesNet prune(size_t maxNrLeaves,
|
||||
const std::optional<double>& deadModeThreshold = {},
|
||||
DiscreteValues* fixedValues = nullptr) const;
|
||||
|
||||
///@}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// The implementation is: build the entire joint into one factor and then prune.
|
||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
|
@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||
const DiscreteBayesNet marginal = discreteMarginal();
|
||||
|
||||
// Prune discrete Bayes net
|
||||
DiscreteValues fixed;
|
||||
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
|
||||
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
joint = joint * (*conditional);
|
||||
DiscreteConditional pruned;
|
||||
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
||||
|
||||
// Set the fixed values if requested.
|
||||
if (deadModeThreshold && fixedValues) {
|
||||
*fixedValues = fixed;
|
||||
}
|
||||
|
||||
// Initialize the resulting HybridBayesNet.
|
||||
HybridBayesNet result;
|
||||
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
DiscreteConditional pruned = joint;
|
||||
pruned.prune(maxNrLeaves);
|
||||
|
||||
DiscreteValues deadModesValues;
|
||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||
// then we run dead mode removal.
|
||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DeadModeRemoval);
|
||||
#endif
|
||||
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
threshold.maxCoeff(&index);
|
||||
}
|
||||
|
||||
if (index >= 0) {
|
||||
deadModesValues.emplace(dkey.first, index);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the modes (imperative)
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DeadModeRemoval);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Go through all the Gaussian conditionals, restrict them according to
|
||||
// deadModesValues, and then prune further.
|
||||
for (auto &&conditional : *this) {
|
||||
// fixed values, and then prune further.
|
||||
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
|
||||
if (conditional->isDiscrete()) continue;
|
||||
|
||||
// Restrict conditional using deadModesValues.
|
||||
// No-op if not a HybridGaussianConditional or deadModesValues empty.
|
||||
auto restricted = conditional->restrict(deadModesValues);
|
||||
// No-op if not a HybridGaussianConditional.
|
||||
if (deadModeThreshold) conditional = conditional->restrict(fixed);
|
||||
|
||||
// Now decide on type what to do:
|
||||
if (auto hgc = restricted->asHybrid()) {
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
// Prune the hybrid Gaussian conditional!
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
if (!prunedHybridGaussianConditional) {
|
||||
|
@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
} else if (auto gc = restricted->asGaussian()) {
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// Add the non-HybridGaussianConditional conditional
|
||||
result.push_back(gc);
|
||||
} else
|
||||
|
@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||
}
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(HybridPruning);
|
||||
#endif
|
||||
|
||||
if (deadModeThreshold.has_value()) {
|
||||
/*
|
||||
If the pruned discrete conditional has any keys left, we add it to the
|
||||
HybridBayesNet. If not, it means it is an orphan so we don't add this
|
||||
pruned joint, and instead add only the marginals below.
|
||||
*/
|
||||
if (pruned.keys().size() > 0) {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
} else {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
// Add the pruned discrete conditionals to the result.
|
||||
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
|
||||
result.push_back(discrete);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue