use std::optional for specifying dead mode threshold
parent
d2b9eb5df6
commit
1764b58e8c
|
@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
double deadModeThreshold) const {
|
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
@ -66,13 +66,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
|
||||||
pruned.prune(maxNrLeaves);
|
pruned.prune(maxNrLeaves);
|
||||||
|
|
||||||
DiscreteValues deadModesValues;
|
DiscreteValues deadModesValues;
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
int index = -1;
|
int index = -1;
|
||||||
auto threshold = (probabilities.array() > deadModeThreshold);
|
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||||
// If atleast 1 value is non-zero, then we can find the index
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
// Else if all are zero, index would be set to 0 which is incorrect
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
if (!threshold.isZero()) {
|
if (!threshold.isZero()) {
|
||||||
|
@ -121,7 +121,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes,
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
|
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||||
for (const auto &kv : deadModesValues) {
|
for (const auto &kv : deadModesValues) {
|
||||||
deadKeys.push_back(kv.first);
|
deadKeys.push_back(kv.first);
|
||||||
|
|
|
@ -217,15 +217,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
* @param removeDeadModes Flag to enable removal of modes which only have a
|
|
||||||
* single possible assignment.
|
|
||||||
* @param deadModeThreshold The threshold to check the mode marginals against.
|
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||||
* If greater than this threshold, the mode gets assigned that value and is
|
* If greater than this threshold, the mode gets assigned that value and is
|
||||||
* considered "dead" for hybrid elimination.
|
* considered "dead" for hybrid elimination.
|
||||||
|
* The mode can then be removed since it only has a single possible
|
||||||
|
* assignment.
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false,
|
HybridBayesNet prune(size_t maxNrLeaves,
|
||||||
double deadModeThreshold = 0.99) const;
|
const std::optional<double> &deadModeThreshold) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
Loading…
Reference in New Issue