Merge pull request #1996 from borglab/hybrid-smoother

release/4.3a0
Varun Agrawal 2025-01-25 11:19:54 -05:00 committed by GitHub
commit 3302ad46c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 168 additions and 62 deletions

View File

@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
bool removeDeadModes) const {
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();
@ -58,24 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
joint = joint * (*conditional);
}
// Create the result starting with the pruned joint.
// Initialize the resulting HybridBayesNet.
HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(joint);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
result.back()->asDiscrete()->prune(maxNrLeaves);
// Get pruned discrete probabilities so
// we can prune HybridGaussianConditionals.
DiscreteConditional pruned = *result.back()->asDiscrete();
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
if (removeDeadModes) {
if (deadModeThreshold.has_value()) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > 0.99);
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
}
// Remove the modes (imperative)
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
pruned = *result.back()->asDiscrete();
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
@ -100,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
*/
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned Discrete joint.
// per pruned discrete joint.
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (removeDeadModes) {
if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);
@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const {
}
}
}
return discrete_fg.optimize();
}

View File

@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
*
* @param maxNrLeaves Continuous values at which to compute the error.
* @param removeDeadModes Flag to enable removal of modes which only have a
* single possible assignment.
* @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
HybridBayesNet prune(
size_t maxNrLeaves,
const std::optional<double> &deadModeThreshold = {}) const;
/**
* @brief Error method using HybridValues which returns specific error for

View File

@ -24,17 +24,14 @@
namespace gtsam {
/* ************************************************************************* */
Ordering HybridSmoother::getOrdering(
const HybridGaussianFactorGraph &newFactors) {
HybridGaussianFactorGraph factors(hybridBayesNet());
factors.push_back(newFactors);
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
const KeySet &newFactorKeys) {
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
const KeySet newFactorKeys = newFactors.keys();
// Insert continuous keys first.
for (auto &k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
@ -56,29 +53,35 @@ Ordering HybridSmoother::getOrdering(
}
/* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph,
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
std::optional<size_t> maxNrLeaves,
const std::optional<Ordering> given_ordering) {
HybridGaussianFactorGraph updatedGraph;
// Add the necessary conditionals from the previous timestep(s).
std::tie(updatedGraph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_);
Ordering ordering;
// If no ordering provided, then we compute one
if (!given_ordering.has_value()) {
ordering = this->getOrdering(graph);
// Get the keys from the new factors
const KeySet newFactorKeys = graph.keys();
// Since updatedGraph now has all the connected conditionals,
// we can get the correct ordering.
ordering = this->getOrdering(updatedGraph, newFactorKeys);
} else {
ordering = *given_ordering;
}
// Add the necessary conditionals from the previous timestep(s).
std::tie(graph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_, ordering);
// Eliminate.
HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering);
HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering);
/// Prune
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
}
// Add the partial bayes net to the posterior bayes net.
@ -88,10 +91,11 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
/* ************************************************************************* */
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
const HybridBayesNet &originalHybridBayesNet,
const Ordering &ordering) const {
const HybridBayesNet &hybridBayesNet) const {
HybridGaussianFactorGraph graph(originalGraph);
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
KeySet factorKeys = graph.keys();
// If hybridBayesNet is not empty,
// it means we have conditionals to add to the factor graph.
@ -99,10 +103,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
// We add all relevant hybrid conditionals on the last continuous variable
// in the previous `hybridBayesNet` to the graph
// Conditionals to remove from the bayes net
// since the conditional will be updated.
std::vector<HybridConditional::shared_ptr> conditionals_to_erase;
// New conditionals to add to the graph
gtsam::HybridBayesNet newConditionals;
@ -112,25 +112,32 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
auto conditional = hybridBayesNet.at(i);
for (auto &key : conditional->frontals()) {
if (std::find(ordering.begin(), ordering.end(), key) !=
ordering.end()) {
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
factorKeys.end()) {
newConditionals.push_back(conditional);
conditionals_to_erase.push_back(conditional);
// Add the conditional parents to factorKeys
// so we add those conditionals too.
// NOTE: This assumes we have a structure where
// variables depend on those in the future.
for (auto &&parentKey : conditional->parents()) {
factorKeys.insert(parentKey);
}
// Remove the conditional from the updated Bayes net
auto it = find(updatedHybridBayesNet.begin(),
updatedHybridBayesNet.end(), conditional);
updatedHybridBayesNet.erase(it);
break;
}
}
}
// Remove conditionals at the end so we don't affect the order in the
// original bayes net.
for (auto &&conditional : conditionals_to_erase) {
auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional);
hybridBayesNet.erase(it);
}
graph.push_back(newConditionals);
}
return {graph, hybridBayesNet};
return {graph, updatedHybridBayesNet};
}
/* ************************************************************************* */

View File

@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother {
HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_;
public:
/**
* @brief Constructor
*
* @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value.
*/
HybridSmoother(const std::optional<double> deadModeThreshold = {})
: deadModeThreshold_(deadModeThreshold) {}
/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input
@ -49,11 +62,24 @@ class GTSAM_EXPORT HybridSmoother {
* @param given_ordering The (optional) ordering for elimination, only
* continuous variables are allowed
*/
void update(HybridGaussianFactorGraph graph,
void update(const HybridGaussianFactorGraph& graph,
std::optional<size_t> maxNrLeaves = {},
const std::optional<Ordering> given_ordering = {});
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
/**
* @brief Get an elimination ordering which eliminates continuous
* and then discrete.
*
* Expects `factors` to already have the necessary conditionals
* which were connected to the variables in the newly added factors.
* Those variables should be in `newFactorKeys`.
*
* @param factors All the new factors and connected conditionals.
* @param newFactorKeys The keys/variables in the newly added factors.
* @return Ordering
*/
Ordering getOrdering(const HybridGaussianFactorGraph& factors,
const KeySet& newFactorKeys);
/**
* @brief Add conditionals from previous timestep as part of liquefication.
@ -66,7 +92,7 @@ class GTSAM_EXPORT HybridSmoother {
*/
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
const HybridGaussianFactorGraph& graph,
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;
const HybridBayesNet& hybridBayesNet) const;
/**
* @brief Get the hybrid Gaussian conditional from

View File

@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
HybridValues delta = posterior->optimize();
// Prune the Bayes net
const bool pruneDeadVariables = true;
const double pruneDeadVariables = 0.99;
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
// Check that discrete joint only has M0 and not (M0, M1)
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
// Check that hybrid conditionals that only depend on M1
// are now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(0)->isDiscrete());
EXPECT(prunedBayesNet.at(1)->isHybrid());
EXPECT(prunedBayesNet.at(1)->isDiscrete());
EXPECT(prunedBayesNet.at(2)->isHybrid());
// Only P(X2 | X1, M1) depends on M1,
// so it gets convert to a Gaussian P(X2 | X1)
EXPECT(prunedBayesNet.at(2)->isContinuous());
EXPECT(prunedBayesNet.at(3)->isHybrid());
EXPECT(prunedBayesNet.at(3)->isContinuous());
EXPECT(prunedBayesNet.at(4)->isHybrid());
}
/* ****************************************************************************/

View File

@ -95,16 +95,15 @@ TEST(HybridSmoother, IncrementalSmoother) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
Ordering ordering = smoother.getOrdering(linearized);
smoother.update(linearized, maxNrLeaves, ordering);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph
graph.resize(0);
}
EXPECT_LONGS_EQUAL(11,
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
// Get the continuous delta update as well as
// the optimal discrete assignment.
@ -150,16 +149,15 @@ TEST(HybridSmoother, ValidPruningError) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
Ordering ordering = smoother.getOrdering(linearized);
smoother.update(linearized, maxNrLeaves, ordering);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph
graph.resize(0);
}
EXPECT_LONGS_EQUAL(14,
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
// Get the continuous delta update as well as
// the optimal discrete assignment.
@ -169,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
}
/****************************************************************************/
// Test if dead mode removal works.
TEST(HybridSmoother, DeadModeRemoval) {
using namespace estimation_fixture;
size_t K = 8;
// Switching example of robot moving in 1D
// with given measurements and equal mode priors.
HybridNonlinearFactorGraph graph;
Values initial;
Switching switching = InitializeEstimationProblem(
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
// Smoother with dead mode removal enabled.
HybridSmoother smoother(true);
constexpr size_t maxNrLeaves = 3;
for (size_t k = 1; k < K; k++) {
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
graph.push_back(switching.unaryFactors.at(k)); // Measurement
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph
graph.resize(0);
}
// Get the continuous delta update as well as
// the optimal discrete assignment.
HybridValues delta = smoother.hybridBayesNet().optimize();
// Check discrete assignment
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, delta.discrete()));
// Update nonlinear solution and verify
Values result = initial.retract(delta.continuous());
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/* ************************************************************************* */
int main() {
TestResult tr;