Merge branch 'hybrid-smoother' into city10000

release/4.3a0
Varun Agrawal 2025-01-25 00:55:04 -05:00
commit 5cee0a2d34
6 changed files with 118 additions and 35 deletions

View File

@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
bool removeDeadModes) const {
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();
@ -58,25 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
joint = joint * (*conditional);
}
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
joint.prune(maxNrLeaves);
// Create the result starting with the pruned joint.
// Initialize the resulting HybridBayesNet.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(joint);
// Get pruned discrete probabilities so
// we can prune HybridGaussianConditionals.
DiscreteConditional pruned = *result.back()->asDiscrete();
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
if (removeDeadModes) {
if (deadModeThreshold.has_value()) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > 0.99);
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
@ -89,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
}
// Remove the modes (imperative)
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
pruned = *result.back()->asDiscrete();
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -101,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
*/
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned Discrete joint.
// per pruned discrete joint.
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (removeDeadModes) {
if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);

View File

@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
*
* @param maxNrLeaves Continuous values at which to compute the error.
* @param removeDeadModes Flag to enable removal of modes which only have a
* single possible assignment.
* @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
HybridBayesNet prune(
size_t maxNrLeaves,
const std::optional<double> &deadModeThreshold = {}) const;
/**
* @brief Error method using HybridValues which returns specific error for

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
}
// Add the partial bayes net to the posterior bayes net.
@ -116,6 +116,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
factorKeys.end()) {
newConditionals.push_back(conditional);
// Add the conditional parents to factorKeys
// so we add those conditionals too.
// NOTE: This assumes we have a structure where
// variables depend on those in the future.
for (auto &&parentKey : conditional->parents()) {
factorKeys.insert(parentKey);
}
// Remove the conditional from the updated Bayes net
auto it = find(updatedHybridBayesNet.begin(),
updatedHybridBayesNet.end(), conditional);

View File

@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother {
HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_;
public:
/**
* @brief Constructor
*
* @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value.
*/
HybridSmoother(const std::optional<double> deadModeThreshold = {})
: deadModeThreshold_(deadModeThreshold) {}
/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input
@ -53,17 +66,6 @@ class GTSAM_EXPORT HybridSmoother {
std::optional<size_t> maxNrLeaves = {},
const std::optional<Ordering> given_ordering = {});
/**
* @brief Get an elimination ordering which eliminates continuous and then
* discrete.
*
* Expects `newFactors` to already have the necessary conditionals connected
* to the
*
* @param factors
* @return Ordering
*/
/**
* @brief Get an elimination ordering which eliminates continuous
* and then discrete.

View File

@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
HybridValues delta = posterior->optimize();
// Prune the Bayes net
const bool pruneDeadVariables = true;
const double pruneDeadVariables = 0.99;
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
// Check that discrete joint only has M0 and not (M0, M1)
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
// Check that hybrid conditionals that only depend on M1
// are now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(0)->isDiscrete());
EXPECT(prunedBayesNet.at(1)->isHybrid());
EXPECT(prunedBayesNet.at(1)->isDiscrete());
EXPECT(prunedBayesNet.at(2)->isHybrid());
// Only P(X2 | X1, M1) depends on M1,
// so it gets convert to a Gaussian P(X2 | X1)
EXPECT(prunedBayesNet.at(2)->isContinuous());
EXPECT(prunedBayesNet.at(3)->isHybrid());
EXPECT(prunedBayesNet.at(3)->isContinuous());
EXPECT(prunedBayesNet.at(4)->isHybrid());
}
/* ****************************************************************************/

View File

@ -167,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
}
/****************************************************************************/
// Test if dead mode removal works.
TEST(HybridSmoother, DeadModeRemoval) {
using namespace estimation_fixture;
size_t K = 8;
// Switching example of robot moving in 1D
// with given measurements and equal mode priors.
HybridNonlinearFactorGraph graph;
Values initial;
Switching switching = InitializeEstimationProblem(
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
// Smoother with dead mode removal enabled.
HybridSmoother smoother(true);
constexpr size_t maxNrLeaves = 3;
for (size_t k = 1; k < K; k++) {
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
graph.push_back(switching.unaryFactors.at(k)); // Measurement
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph
graph.resize(0);
}
// Get the continuous delta update as well as
// the optimal discrete assignment.
HybridValues delta = smoother.hybridBayesNet().optimize();
// Check discrete assignment
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, delta.discrete()));
// Update nonlinear solution and verify
Values result = initial.retract(delta.continuous());
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/* ************************************************************************* */
int main() {
TestResult tr;