remove dead modes in HybridBayesNet
parent
ff9a56c055
commit
22bf9df39a
|
@ -19,6 +19,7 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
#include <gtsam/discrete/TableDistribution.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
@ -46,7 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
|
bool removeDeadModes) const {
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
@ -66,6 +68,30 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// we can prune HybridGaussianConditionals.
|
// we can prune HybridGaussianConditionals.
|
||||||
DiscreteConditional pruned = *result.back()->asDiscrete();
|
DiscreteConditional pruned = *result.back()->asDiscrete();
|
||||||
|
|
||||||
|
DiscreteValues deadModesValues;
|
||||||
|
if (removeDeadModes) {
|
||||||
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
|
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
|
int index = -1;
|
||||||
|
auto threshold = (probabilities.array() > 0.99);
|
||||||
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
|
if (!threshold.isZero()) {
|
||||||
|
threshold.maxCoeff(&index);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (index >= 0) {
|
||||||
|
deadModesValues.insert(std::make_pair(dkey.first, index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the modes (imperative)
|
||||||
|
result.back()->removeModes(deadModesValues);
|
||||||
|
pruned = *result.back()->asDiscrete();
|
||||||
|
}
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
|
@ -80,8 +106,28 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
if (removeDeadModes) {
|
||||||
result.push_back(prunedHybridGaussianConditional);
|
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||||
|
for (const auto &kv : deadModesValues) {
|
||||||
|
deadKeys.push_back(kv.first);
|
||||||
|
}
|
||||||
|
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
|
||||||
|
conditionalDiscreteKeys.push_back(dkey.first);
|
||||||
|
}
|
||||||
|
// The discrete keys in the conditional are the same as the keys in the
|
||||||
|
// dead modes, then we just get the corresponding Gaussian conditional.
|
||||||
|
if (deadKeys == conditionalDiscreteKeys) {
|
||||||
|
result.push_back(
|
||||||
|
prunedHybridGaussianConditional->choose(deadModesValues));
|
||||||
|
} else {
|
||||||
|
// Add as-is
|
||||||
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
|
}
|
||||||
|
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// Add the non-HybridGaussianConditional conditional
|
// Add the non-HybridGaussianConditional conditional
|
||||||
result.push_back(gc);
|
result.push_back(gc);
|
||||||
|
|
|
@ -209,9 +209,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
|
* @param removeDeadModes
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
|
@ -407,7 +407,7 @@ TEST(HybridBayesNet, Prune) {
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr posterior =
|
HybridBayesNet::shared_ptr posterior =
|
||||||
s.linearizedFactorGraph().eliminateSequential();
|
s.linearizedFactorGraph().eliminateSequential();
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||||
|
|
||||||
// Call Max-Product to get MAP
|
// Call Max-Product to get MAP
|
||||||
HybridValues delta = posterior->optimize();
|
HybridValues delta = posterior->optimize();
|
||||||
|
@ -421,6 +421,35 @@ TEST(HybridBayesNet, Prune) {
|
||||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test Bayes net pruning and dead node removal
|
||||||
|
TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
HybridBayesNet::shared_ptr posterior =
|
||||||
|
s.linearizedFactorGraph().eliminateSequential();
|
||||||
|
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||||
|
|
||||||
|
// Call Max-Product to get MAP
|
||||||
|
HybridValues delta = posterior->optimize();
|
||||||
|
|
||||||
|
// Prune the Bayes net
|
||||||
|
const bool pruneDeadVariables = true;
|
||||||
|
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||||
|
|
||||||
|
// Check that discrete joint only has M0 and not (M0, M1)
|
||||||
|
// since M0 is removed
|
||||||
|
KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys();
|
||||||
|
EXPECT(KeyVector{M(0)} == actual_keys);
|
||||||
|
|
||||||
|
// Check that hybrid conditionals that only depend on M1 are no longer hybrid
|
||||||
|
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
||||||
|
EXPECT(prunedBayesNet.at(1)->isHybrid());
|
||||||
|
// Only P(X2 | X1, M1) depends on M1, so it is Gaussian
|
||||||
|
EXPECT(prunedBayesNet.at(2)->isContinuous());
|
||||||
|
EXPECT(prunedBayesNet.at(3)->isHybrid());
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net error and log-probability after pruning
|
// Test Bayes net error and log-probability after pruning
|
||||||
TEST(HybridBayesNet, ErrorAfterPruning) {
|
TEST(HybridBayesNet, ErrorAfterPruning) {
|
||||||
|
|
Loading…
Reference in New Issue