Merge pull request #1999 from borglab/fix-deadmoderemoval

release/4.3a0
Varun Agrawal 2025-01-25 00:53:43 -05:00 committed by GitHub
commit 5a3005dcc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 30 deletions

View File

@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// TODO(Frank): This can be quite expensive *unless* the factors have already // TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound // been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional. // search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, HybridBayesNet HybridBayesNet::prune(
bool removeDeadModes) const { size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();
@ -58,24 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
joint = joint * (*conditional); joint = joint * (*conditional);
} }
// Create the result starting with the pruned joint. // Initialize the resulting HybridBayesNet.
HybridBayesNet result; HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(joint);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
result.back()->asDiscrete()->prune(maxNrLeaves);
// Get pruned discrete probabilities so // Prune the joint. NOTE: imperative and, again, possibly quite expensive.
// we can prune HybridGaussianConditionals. DiscreteConditional pruned = joint;
DiscreteConditional pruned = *result.back()->asDiscrete(); pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues; DiscreteValues deadModesValues;
if (removeDeadModes) { if (deadModeThreshold.has_value()) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) { for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey); Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1; int index = -1;
auto threshold = (probabilities.array() > 0.99); auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index // If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect // Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) { if (!threshold.isZero()) {
@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
} }
// Remove the modes (imperative) // Remove the modes (imperative)
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); pruned.removeDiscreteModes(deadModesValues);
pruned = *result.back()->asDiscrete();
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
} }
/* To prune, we visitWith every leaf in the HybridGaussianConditional. /* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -100,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
*/ */
// Go through all the Gaussian conditionals in the Bayes Net and prune them as // Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned Discrete joint. // per pruned discrete joint.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (removeDeadModes) { if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys; KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) { for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first); deadKeys.push_back(kv.first);
@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const {
} }
} }
} }
return discrete_fg.optimize(); return discrete_fg.optimize();
} }

View File

@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
* *
* @param maxNrLeaves Continuous values at which to compute the error. * @param maxNrLeaves Continuous values at which to compute the error.
* @param removeDeadModes Flag to enable removal of modes which only have a * @param deadModeThreshold The threshold to check the mode marginals against.
* single possible assignment. * If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @return A pruned HybridBayesNet * @return A pruned HybridBayesNet
*/ */
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; HybridBayesNet prune(
size_t maxNrLeaves,
const std::optional<double> &deadModeThreshold = {}) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) { if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment. // all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_); bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
} }
// Add the partial bayes net to the posterior bayes net. // Add the partial bayes net to the posterior bayes net.

View File

@ -29,10 +29,8 @@ class GTSAM_EXPORT HybridSmoother {
HybridBayesNet hybridBayesNet_; HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_; HybridGaussianFactorGraph remainingFactorGraph_;
/// Flag indicating that we should remove dead discrete modes.
bool removeDeadModes_;
/// The threshold above which we make a decision about a mode. /// The threshold above which we make a decision about a mode.
double deadModeThreshold_; std::optional<double> deadModeThreshold_;
public: public:
/** /**
@ -40,11 +38,10 @@ class GTSAM_EXPORT HybridSmoother {
* *
* @param removeDeadModes Flag indicating whether to remove dead modes. * @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a * @param deadModeThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". * value and is considered "dead". 0.99 is a good starting value.
*/ */
HybridSmoother(bool removeDeadModes = false, double deadModeThreshold = 0.99) HybridSmoother(const std::optional<double> deadModeThreshold = {})
: removeDeadModes_(removeDeadModes), : deadModeThreshold_(deadModeThreshold) {}
deadModeThreshold_(deadModeThreshold) {}
/** /**
* Given new factors, perform an incremental update. * Given new factors, perform an incremental update.

View File

@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
HybridValues delta = posterior->optimize(); HybridValues delta = posterior->optimize();
// Prune the Bayes net // Prune the Bayes net
const bool pruneDeadVariables = true; const double pruneDeadVariables = 0.99;
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
// Check that discrete joint only has M0 and not (M0, M1) // Check that discrete joint only has M0 and not (M0, M1)
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
// Check that hybrid conditionals that only depend on M1 // Check that hybrid conditionals that only depend on M1
// are now Gaussian and not Hybrid // are now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(0)->isDiscrete()); EXPECT(prunedBayesNet.at(0)->isDiscrete());
EXPECT(prunedBayesNet.at(1)->isHybrid()); EXPECT(prunedBayesNet.at(1)->isDiscrete());
EXPECT(prunedBayesNet.at(2)->isHybrid());
// Only P(X2 | X1, M1) depends on M1, // Only P(X2 | X1, M1) depends on M1,
// so it gets convert to a Gaussian P(X2 | X1) // so it gets convert to a Gaussian P(X2 | X1)
EXPECT(prunedBayesNet.at(2)->isContinuous()); EXPECT(prunedBayesNet.at(3)->isContinuous());
EXPECT(prunedBayesNet.at(3)->isHybrid()); EXPECT(prunedBayesNet.at(4)->isHybrid());
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -167,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
} }
/****************************************************************************/
// Test if dead mode removal works.
TEST(HybridSmoother, DeadModeRemoval) {
using namespace estimation_fixture;
size_t K = 8;
// Switching example of robot moving in 1D
// with given measurements and equal mode priors.
HybridNonlinearFactorGraph graph;
Values initial;
Switching switching = InitializeEstimationProblem(
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
// Smoother with dead mode removal enabled.
HybridSmoother smoother(true);
constexpr size_t maxNrLeaves = 3;
for (size_t k = 1; k < K; k++) {
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
graph.push_back(switching.unaryFactors.at(k)); // Measurement
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph
graph.resize(0);
}
// Get the continuous delta update as well as
// the optimal discrete assignment.
HybridValues delta = smoother.hybridBayesNet().optimize();
// Check discrete assignment
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, delta.discrete()));
// Update nonlinear solution and verify
Values result = initial.retract(delta.continuous());
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;