DiscreteNet::prune
parent
957c967d0c
commit
3c10913c70
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// The implementation is: build the entire joint into one factor and then prune.
|
||||||
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
|
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
|
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
|
||||||
|
DiscreteValues* fixedValues) const {
|
||||||
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
|
DiscreteConditional joint;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
joint = joint * (*conditional);
|
||||||
|
|
||||||
|
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||||
|
DiscreteConditional pruned = joint;
|
||||||
|
pruned.prune(maxNrLeaves);
|
||||||
|
|
||||||
|
DiscreteValues deadModesValues;
|
||||||
|
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||||
|
// then we run dead mode removal.
|
||||||
|
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||||
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
|
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
|
int index = -1;
|
||||||
|
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||||
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
|
if (!threshold.isZero()) {
|
||||||
|
threshold.maxCoeff(&index);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (index >= 0) {
|
||||||
|
deadModesValues.emplace(dkey.first, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the modes (imperative)
|
||||||
|
pruned.removeDiscreteModes(deadModesValues);
|
||||||
|
|
||||||
|
// Set the fixed values if requested.
|
||||||
|
if (fixedValues) {
|
||||||
|
*fixedValues = deadModesValues;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the resulting DiscreteBayesNet.
|
||||||
|
DiscreteBayesNet result;
|
||||||
|
if (pruned.keys().size() > 0) result.push_back(pruned);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* *********************************************************************** */
|
/* *********************************************************************** */
|
||||||
std::string DiscreteBayesNet::markdown(
|
std::string DiscreteBayesNet::markdown(
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
*/
|
*/
|
||||||
DiscreteValues sample(DiscreteValues given) const;
|
DiscreteValues sample(DiscreteValues given) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prune the Bayes net
|
||||||
|
*
|
||||||
|
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||||
|
* @param deadModeThreshold If given, threshold on marginals to prune variables.
|
||||||
|
* @param fixedValues If given, return the fixed values removed.
|
||||||
|
* @return A new DiscreteBayesNet with pruned conditionals.
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet prune(size_t maxNrLeaves,
|
||||||
|
const std::optional<double>& deadModeThreshold = {},
|
||||||
|
DiscreteValues* fixedValues = nullptr) const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
|
||||||
#include <gtsam/discrete/TableDistribution.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// The implementation is: build the entire joint into one factor and then prune.
|
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
|
||||||
HybridBayesNet HybridBayesNet::prune(
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
// Prune discrete Bayes net
|
||||||
|
DiscreteValues fixed;
|
||||||
|
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
|
||||||
|
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional joint;
|
DiscreteConditional pruned;
|
||||||
for (auto &&conditional : marginal) {
|
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
||||||
joint = joint * (*conditional);
|
|
||||||
|
// Set the fixed values if requested.
|
||||||
|
if (deadModeThreshold && fixedValues) {
|
||||||
|
*fixedValues = fixed;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the resulting HybridBayesNet.
|
|
||||||
HybridBayesNet result;
|
HybridBayesNet result;
|
||||||
|
|
||||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
|
||||||
DiscreteConditional pruned = joint;
|
|
||||||
pruned.prune(maxNrLeaves);
|
|
||||||
|
|
||||||
DiscreteValues deadModesValues;
|
|
||||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
|
||||||
// then we run dead mode removal.
|
|
||||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(DeadModeRemoval);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
|
||||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
|
||||||
|
|
||||||
int index = -1;
|
|
||||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
|
||||||
// If atleast 1 value is non-zero, then we can find the index
|
|
||||||
// Else if all are zero, index would be set to 0 which is incorrect
|
|
||||||
if (!threshold.isZero()) {
|
|
||||||
threshold.maxCoeff(&index);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (index >= 0) {
|
|
||||||
deadModesValues.emplace(dkey.first, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the modes (imperative)
|
|
||||||
pruned.removeDiscreteModes(deadModesValues);
|
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(DeadModeRemoval);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Go through all the Gaussian conditionals, restrict them according to
|
// Go through all the Gaussian conditionals, restrict them according to
|
||||||
// deadModesValues, and then prune further.
|
// fixed values, and then prune further.
|
||||||
for (auto &&conditional : *this) {
|
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
|
||||||
if (conditional->isDiscrete()) continue;
|
if (conditional->isDiscrete()) continue;
|
||||||
|
|
||||||
// Restrict conditional using deadModesValues.
|
// No-op if not a HybridGaussianConditional.
|
||||||
// No-op if not a HybridGaussianConditional or deadModesValues empty.
|
if (deadModeThreshold) conditional = conditional->restrict(fixed);
|
||||||
auto restricted = conditional->restrict(deadModesValues);
|
|
||||||
|
|
||||||
// Now decide on type what to do:
|
// Now decide on type what to do:
|
||||||
if (auto hgc = restricted->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
if (!prunedHybridGaussianConditional) {
|
if (!prunedHybridGaussianConditional) {
|
||||||
|
@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
}
|
}
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
result.push_back(prunedHybridGaussianConditional);
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
} else if (auto gc = restricted->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// Add the non-HybridGaussianConditional conditional
|
// Add the non-HybridGaussianConditional conditional
|
||||||
result.push_back(gc);
|
result.push_back(gc);
|
||||||
} else
|
} else
|
||||||
|
@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
// Add the pruned discrete conditionals to the result.
|
||||||
gttoc_(HybridPruning);
|
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
|
||||||
#endif
|
result.push_back(discrete);
|
||||||
|
|
||||||
if (deadModeThreshold.has_value()) {
|
|
||||||
/*
|
|
||||||
If the pruned discrete conditional has any keys left, we add it to the
|
|
||||||
HybridBayesNet. If not, it means it is an orphan so we don't add this
|
|
||||||
pruned joint, and instead add only the marginals below.
|
|
||||||
*/
|
|
||||||
if (pruned.keys().size() > 0) {
|
|
||||||
result.emplace_shared<DiscreteConditional>(pruned);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
result.emplace_shared<DiscreteConditional>(pruned);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue