Merge pull request #1996 from borglab/hybrid-smoother
commit
3302ad46c8
|
@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||
bool removeDeadModes) const {
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||
const DiscreteBayesNet marginal = discreteMarginal();
|
||||
|
||||
|
@ -58,24 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
|||
joint = joint * (*conditional);
|
||||
}
|
||||
|
||||
// Create the result starting with the pruned joint.
|
||||
// Initialize the resulting HybridBayesNet.
|
||||
HybridBayesNet result;
|
||||
result.emplace_shared<DiscreteConditional>(joint);
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
result.back()->asDiscrete()->prune(maxNrLeaves);
|
||||
|
||||
// Get pruned discrete probabilities so
|
||||
// we can prune HybridGaussianConditionals.
|
||||
DiscreteConditional pruned = *result.back()->asDiscrete();
|
||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||
DiscreteConditional pruned = joint;
|
||||
pruned.prune(maxNrLeaves);
|
||||
|
||||
DiscreteValues deadModesValues;
|
||||
if (removeDeadModes) {
|
||||
if (deadModeThreshold.has_value()) {
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > 0.99);
|
||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
|
@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
|||
}
|
||||
|
||||
// Remove the modes (imperative)
|
||||
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
|
||||
pruned = *result.back()->asDiscrete();
|
||||
pruned.removeDiscreteModes(deadModesValues);
|
||||
|
||||
/*
|
||||
If the pruned discrete conditional has any keys left,
|
||||
we add it to the HybridBayesNet.
|
||||
If not, it means it is an orphan so we don't add this pruned joint,
|
||||
and instead add only the marginals below.
|
||||
*/
|
||||
if (pruned.keys().size() > 0) {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
// Add the marginals for future factors
|
||||
for (auto &&[key, _] : deadModesValues) {
|
||||
result.push_back(
|
||||
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||
}
|
||||
|
||||
} else {
|
||||
result.emplace_shared<DiscreteConditional>(pruned);
|
||||
}
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
|
@ -100,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
|||
*/
|
||||
|
||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||
// per pruned Discrete joint.
|
||||
// per pruned discrete joint.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
// Prune the hybrid Gaussian conditional!
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
|
||||
if (removeDeadModes) {
|
||||
if (deadModeThreshold.has_value()) {
|
||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||
for (const auto &kv : deadModesValues) {
|
||||
deadKeys.push_back(kv.first);
|
||||
|
@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return discrete_fg.optimize();
|
||||
}
|
||||
|
||||
|
|
|
@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
*
|
||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||
* @param removeDeadModes Flag to enable removal of modes which only have a
|
||||
* single possible assignment.
|
||||
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||
* If greater than this threshold, the mode gets assigned that value and is
|
||||
* considered "dead" for hybrid elimination.
|
||||
* The mode can then be removed since it only has a single possible
|
||||
* assignment.
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
|
||||
HybridBayesNet prune(
|
||||
size_t maxNrLeaves,
|
||||
const std::optional<double> &deadModeThreshold = {}) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
|
|
@ -24,17 +24,14 @@
|
|||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
Ordering HybridSmoother::getOrdering(
|
||||
const HybridGaussianFactorGraph &newFactors) {
|
||||
HybridGaussianFactorGraph factors(hybridBayesNet());
|
||||
factors.push_back(newFactors);
|
||||
|
||||
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
||||
const KeySet &newFactorKeys) {
|
||||
// Get all the discrete keys from the factors
|
||||
KeySet allDiscrete = factors.discreteKeySet();
|
||||
|
||||
// Create KeyVector with continuous keys followed by discrete keys.
|
||||
KeyVector newKeysDiscreteLast;
|
||||
const KeySet newFactorKeys = newFactors.keys();
|
||||
|
||||
// Insert continuous keys first.
|
||||
for (auto &k : newFactorKeys) {
|
||||
if (!allDiscrete.exists(k)) {
|
||||
|
@ -56,29 +53,35 @@ Ordering HybridSmoother::getOrdering(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||
std::optional<size_t> maxNrLeaves,
|
||||
const std::optional<Ordering> given_ordering) {
|
||||
HybridGaussianFactorGraph updatedGraph;
|
||||
// Add the necessary conditionals from the previous timestep(s).
|
||||
std::tie(updatedGraph, hybridBayesNet_) =
|
||||
addConditionals(graph, hybridBayesNet_);
|
||||
|
||||
Ordering ordering;
|
||||
// If no ordering provided, then we compute one
|
||||
if (!given_ordering.has_value()) {
|
||||
ordering = this->getOrdering(graph);
|
||||
// Get the keys from the new factors
|
||||
const KeySet newFactorKeys = graph.keys();
|
||||
|
||||
// Since updatedGraph now has all the connected conditionals,
|
||||
// we can get the correct ordering.
|
||||
ordering = this->getOrdering(updatedGraph, newFactorKeys);
|
||||
} else {
|
||||
ordering = *given_ordering;
|
||||
}
|
||||
|
||||
// Add the necessary conditionals from the previous timestep(s).
|
||||
std::tie(graph, hybridBayesNet_) =
|
||||
addConditionals(graph, hybridBayesNet_, ordering);
|
||||
|
||||
// Eliminate.
|
||||
HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering);
|
||||
HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering);
|
||||
|
||||
/// Prune
|
||||
if (maxNrLeaves) {
|
||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||
// all the conditionals with the same keys in bayesNetFragment.
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
||||
}
|
||||
|
||||
// Add the partial bayes net to the posterior bayes net.
|
||||
|
@ -88,10 +91,11 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
|||
/* ************************************************************************* */
|
||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
||||
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||
const HybridBayesNet &originalHybridBayesNet,
|
||||
const Ordering &ordering) const {
|
||||
const HybridBayesNet &hybridBayesNet) const {
|
||||
HybridGaussianFactorGraph graph(originalGraph);
|
||||
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
||||
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
||||
|
||||
KeySet factorKeys = graph.keys();
|
||||
|
||||
// If hybridBayesNet is not empty,
|
||||
// it means we have conditionals to add to the factor graph.
|
||||
|
@ -99,10 +103,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
|||
// We add all relevant hybrid conditionals on the last continuous variable
|
||||
// in the previous `hybridBayesNet` to the graph
|
||||
|
||||
// Conditionals to remove from the bayes net
|
||||
// since the conditional will be updated.
|
||||
std::vector<HybridConditional::shared_ptr> conditionals_to_erase;
|
||||
|
||||
// New conditionals to add to the graph
|
||||
gtsam::HybridBayesNet newConditionals;
|
||||
|
||||
|
@ -112,25 +112,32 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
|||
auto conditional = hybridBayesNet.at(i);
|
||||
|
||||
for (auto &key : conditional->frontals()) {
|
||||
if (std::find(ordering.begin(), ordering.end(), key) !=
|
||||
ordering.end()) {
|
||||
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
||||
factorKeys.end()) {
|
||||
newConditionals.push_back(conditional);
|
||||
conditionals_to_erase.push_back(conditional);
|
||||
|
||||
// Add the conditional parents to factorKeys
|
||||
// so we add those conditionals too.
|
||||
// NOTE: This assumes we have a structure where
|
||||
// variables depend on those in the future.
|
||||
for (auto &&parentKey : conditional->parents()) {
|
||||
factorKeys.insert(parentKey);
|
||||
}
|
||||
|
||||
// Remove the conditional from the updated Bayes net
|
||||
auto it = find(updatedHybridBayesNet.begin(),
|
||||
updatedHybridBayesNet.end(), conditional);
|
||||
updatedHybridBayesNet.erase(it);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove conditionals at the end so we don't affect the order in the
|
||||
// original bayes net.
|
||||
for (auto &&conditional : conditionals_to_erase) {
|
||||
auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional);
|
||||
hybridBayesNet.erase(it);
|
||||
}
|
||||
|
||||
graph.push_back(newConditionals);
|
||||
}
|
||||
return {graph, hybridBayesNet};
|
||||
|
||||
return {graph, updatedHybridBayesNet};
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
HybridBayesNet hybridBayesNet_;
|
||||
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||
|
||||
/// The threshold above which we make a decision about a mode.
|
||||
std::optional<double> deadModeThreshold_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor
|
||||
*
|
||||
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
||||
* value and is considered "dead". 0.99 is a good starting value.
|
||||
*/
|
||||
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
||||
: deadModeThreshold_(deadModeThreshold) {}
|
||||
|
||||
/**
|
||||
* Given new factors, perform an incremental update.
|
||||
* The relevant densities in the `hybridBayesNet` will be added to the input
|
||||
|
@ -49,11 +62,24 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
* @param given_ordering The (optional) ordering for elimination, only
|
||||
* continuous variables are allowed
|
||||
*/
|
||||
void update(HybridGaussianFactorGraph graph,
|
||||
void update(const HybridGaussianFactorGraph& graph,
|
||||
std::optional<size_t> maxNrLeaves = {},
|
||||
const std::optional<Ordering> given_ordering = {});
|
||||
|
||||
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
|
||||
/**
|
||||
* @brief Get an elimination ordering which eliminates continuous
|
||||
* and then discrete.
|
||||
*
|
||||
* Expects `factors` to already have the necessary conditionals
|
||||
* which were connected to the variables in the newly added factors.
|
||||
* Those variables should be in `newFactorKeys`.
|
||||
*
|
||||
* @param factors All the new factors and connected conditionals.
|
||||
* @param newFactorKeys The keys/variables in the newly added factors.
|
||||
* @return Ordering
|
||||
*/
|
||||
Ordering getOrdering(const HybridGaussianFactorGraph& factors,
|
||||
const KeySet& newFactorKeys);
|
||||
|
||||
/**
|
||||
* @brief Add conditionals from previous timestep as part of liquefication.
|
||||
|
@ -66,7 +92,7 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
*/
|
||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
|
||||
const HybridGaussianFactorGraph& graph,
|
||||
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;
|
||||
const HybridBayesNet& hybridBayesNet) const;
|
||||
|
||||
/**
|
||||
* @brief Get the hybrid Gaussian conditional from
|
||||
|
|
|
@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
|||
HybridValues delta = posterior->optimize();
|
||||
|
||||
// Prune the Bayes net
|
||||
const bool pruneDeadVariables = true;
|
||||
const double pruneDeadVariables = 0.99;
|
||||
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||
|
||||
// Check that discrete joint only has M0 and not (M0, M1)
|
||||
|
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
|||
// Check that hybrid conditionals that only depend on M1
|
||||
// are now Gaussian and not Hybrid
|
||||
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
||||
EXPECT(prunedBayesNet.at(1)->isHybrid());
|
||||
EXPECT(prunedBayesNet.at(1)->isDiscrete());
|
||||
EXPECT(prunedBayesNet.at(2)->isHybrid());
|
||||
// Only P(X2 | X1, M1) depends on M1,
|
||||
// so it gets convert to a Gaussian P(X2 | X1)
|
||||
EXPECT(prunedBayesNet.at(2)->isContinuous());
|
||||
EXPECT(prunedBayesNet.at(3)->isHybrid());
|
||||
EXPECT(prunedBayesNet.at(3)->isContinuous());
|
||||
EXPECT(prunedBayesNet.at(4)->isHybrid());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
|
@ -95,16 +95,15 @@ TEST(HybridSmoother, IncrementalSmoother) {
|
|||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||
|
||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||
Ordering ordering = smoother.getOrdering(linearized);
|
||||
|
||||
smoother.update(linearized, maxNrLeaves, ordering);
|
||||
smoother.update(linearized, maxNrLeaves);
|
||||
|
||||
// Clear all the factors from the graph
|
||||
graph.resize(0);
|
||||
}
|
||||
|
||||
EXPECT_LONGS_EQUAL(11,
|
||||
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
|
||||
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
|
||||
|
||||
// Get the continuous delta update as well as
|
||||
// the optimal discrete assignment.
|
||||
|
@ -150,16 +149,15 @@ TEST(HybridSmoother, ValidPruningError) {
|
|||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||
|
||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||
Ordering ordering = smoother.getOrdering(linearized);
|
||||
|
||||
smoother.update(linearized, maxNrLeaves, ordering);
|
||||
smoother.update(linearized, maxNrLeaves);
|
||||
|
||||
// Clear all the factors from the graph
|
||||
graph.resize(0);
|
||||
}
|
||||
|
||||
EXPECT_LONGS_EQUAL(14,
|
||||
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
|
||||
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
|
||||
|
||||
// Get the continuous delta update as well as
|
||||
// the optimal discrete assignment.
|
||||
|
@ -169,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
|
|||
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// Test if dead mode removal works.
|
||||
TEST(HybridSmoother, DeadModeRemoval) {
|
||||
using namespace estimation_fixture;
|
||||
|
||||
size_t K = 8;
|
||||
|
||||
// Switching example of robot moving in 1D
|
||||
// with given measurements and equal mode priors.
|
||||
HybridNonlinearFactorGraph graph;
|
||||
Values initial;
|
||||
Switching switching = InitializeEstimationProblem(
|
||||
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
|
||||
|
||||
// Smoother with dead mode removal enabled.
|
||||
HybridSmoother smoother(true);
|
||||
|
||||
constexpr size_t maxNrLeaves = 3;
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||
|
||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||
|
||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||
|
||||
smoother.update(linearized, maxNrLeaves);
|
||||
|
||||
// Clear all the factors from the graph
|
||||
graph.resize(0);
|
||||
}
|
||||
|
||||
// Get the continuous delta update as well as
|
||||
// the optimal discrete assignment.
|
||||
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||
|
||||
// Check discrete assignment
|
||||
DiscreteValues expected_discrete;
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
expected_discrete[M(k)] = discrete_seq[k];
|
||||
}
|
||||
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||
|
||||
// Update nonlinear solution and verify
|
||||
Values result = initial.retract(delta.continuous());
|
||||
Values expected_continuous;
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
expected_continuous.insert(X(k), measurements[k]);
|
||||
}
|
||||
EXPECT(assert_equal(expected_continuous, result));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue