Merge pull request #1996 from borglab/hybrid-smoother
commit
3302ad46c8
|
@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
bool removeDeadModes) const {
|
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
@ -58,24 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
joint = joint * (*conditional);
|
joint = joint * (*conditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the result starting with the pruned joint.
|
// Initialize the resulting HybridBayesNet.
|
||||||
HybridBayesNet result;
|
HybridBayesNet result;
|
||||||
result.emplace_shared<DiscreteConditional>(joint);
|
|
||||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
|
||||||
result.back()->asDiscrete()->prune(maxNrLeaves);
|
|
||||||
|
|
||||||
// Get pruned discrete probabilities so
|
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||||
// we can prune HybridGaussianConditionals.
|
DiscreteConditional pruned = joint;
|
||||||
DiscreteConditional pruned = *result.back()->asDiscrete();
|
pruned.prune(maxNrLeaves);
|
||||||
|
|
||||||
DiscreteValues deadModesValues;
|
DiscreteValues deadModesValues;
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
int index = -1;
|
int index = -1;
|
||||||
auto threshold = (probabilities.array() > 0.99);
|
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||||
// If atleast 1 value is non-zero, then we can find the index
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
// Else if all are zero, index would be set to 0 which is incorrect
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
if (!threshold.isZero()) {
|
if (!threshold.isZero()) {
|
||||||
|
@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the modes (imperative)
|
// Remove the modes (imperative)
|
||||||
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
|
pruned.removeDiscreteModes(deadModesValues);
|
||||||
pruned = *result.back()->asDiscrete();
|
|
||||||
|
/*
|
||||||
|
If the pruned discrete conditional has any keys left,
|
||||||
|
we add it to the HybridBayesNet.
|
||||||
|
If not, it means it is an orphan so we don't add this pruned joint,
|
||||||
|
and instead add only the marginals below.
|
||||||
|
*/
|
||||||
|
if (pruned.keys().size() > 0) {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the marginals for future factors
|
||||||
|
for (auto &&[key, _] : deadModesValues) {
|
||||||
|
result.push_back(
|
||||||
|
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
|
@ -100,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||||
// per pruned Discrete joint.
|
// per pruned discrete joint.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
|
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||||
for (const auto &kv : deadModesValues) {
|
for (const auto &kv : deadModesValues) {
|
||||||
deadKeys.push_back(kv.first);
|
deadKeys.push_back(kv.first);
|
||||||
|
@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return discrete_fg.optimize();
|
return discrete_fg.optimize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
* @param removeDeadModes Flag to enable removal of modes which only have a
|
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||||
* single possible assignment.
|
* If greater than this threshold, the mode gets assigned that value and is
|
||||||
|
* considered "dead" for hybrid elimination.
|
||||||
|
* The mode can then be removed since it only has a single possible
|
||||||
|
* assignment.
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
|
HybridBayesNet prune(
|
||||||
|
size_t maxNrLeaves,
|
||||||
|
const std::optional<double> &deadModeThreshold = {}) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
|
@ -24,17 +24,14 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Ordering HybridSmoother::getOrdering(
|
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
||||||
const HybridGaussianFactorGraph &newFactors) {
|
const KeySet &newFactorKeys) {
|
||||||
HybridGaussianFactorGraph factors(hybridBayesNet());
|
|
||||||
factors.push_back(newFactors);
|
|
||||||
|
|
||||||
// Get all the discrete keys from the factors
|
// Get all the discrete keys from the factors
|
||||||
KeySet allDiscrete = factors.discreteKeySet();
|
KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
// Create KeyVector with continuous keys followed by discrete keys.
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector newKeysDiscreteLast;
|
||||||
const KeySet newFactorKeys = newFactors.keys();
|
|
||||||
// Insert continuous keys first.
|
// Insert continuous keys first.
|
||||||
for (auto &k : newFactorKeys) {
|
for (auto &k : newFactorKeys) {
|
||||||
if (!allDiscrete.exists(k)) {
|
if (!allDiscrete.exists(k)) {
|
||||||
|
@ -56,29 +53,35 @@ Ordering HybridSmoother::getOrdering(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
std::optional<size_t> maxNrLeaves,
|
std::optional<size_t> maxNrLeaves,
|
||||||
const std::optional<Ordering> given_ordering) {
|
const std::optional<Ordering> given_ordering) {
|
||||||
|
HybridGaussianFactorGraph updatedGraph;
|
||||||
|
// Add the necessary conditionals from the previous timestep(s).
|
||||||
|
std::tie(updatedGraph, hybridBayesNet_) =
|
||||||
|
addConditionals(graph, hybridBayesNet_);
|
||||||
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
// If no ordering provided, then we compute one
|
// If no ordering provided, then we compute one
|
||||||
if (!given_ordering.has_value()) {
|
if (!given_ordering.has_value()) {
|
||||||
ordering = this->getOrdering(graph);
|
// Get the keys from the new factors
|
||||||
|
const KeySet newFactorKeys = graph.keys();
|
||||||
|
|
||||||
|
// Since updatedGraph now has all the connected conditionals,
|
||||||
|
// we can get the correct ordering.
|
||||||
|
ordering = this->getOrdering(updatedGraph, newFactorKeys);
|
||||||
} else {
|
} else {
|
||||||
ordering = *given_ordering;
|
ordering = *given_ordering;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the necessary conditionals from the previous timestep(s).
|
|
||||||
std::tie(graph, hybridBayesNet_) =
|
|
||||||
addConditionals(graph, hybridBayesNet_, ordering);
|
|
||||||
|
|
||||||
// Eliminate.
|
// Eliminate.
|
||||||
HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering);
|
HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering);
|
||||||
|
|
||||||
/// Prune
|
/// Prune
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||||
// all the conditionals with the same keys in bayesNetFragment.
|
// all the conditionals with the same keys in bayesNetFragment.
|
||||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
|
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the partial bayes net to the posterior bayes net.
|
// Add the partial bayes net to the posterior bayes net.
|
||||||
|
@ -88,10 +91,11 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
||||||
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
const HybridBayesNet &originalHybridBayesNet,
|
const HybridBayesNet &hybridBayesNet) const {
|
||||||
const Ordering &ordering) const {
|
|
||||||
HybridGaussianFactorGraph graph(originalGraph);
|
HybridGaussianFactorGraph graph(originalGraph);
|
||||||
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
||||||
|
|
||||||
|
KeySet factorKeys = graph.keys();
|
||||||
|
|
||||||
// If hybridBayesNet is not empty,
|
// If hybridBayesNet is not empty,
|
||||||
// it means we have conditionals to add to the factor graph.
|
// it means we have conditionals to add to the factor graph.
|
||||||
|
@ -99,10 +103,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
// We add all relevant hybrid conditionals on the last continuous variable
|
// We add all relevant hybrid conditionals on the last continuous variable
|
||||||
// in the previous `hybridBayesNet` to the graph
|
// in the previous `hybridBayesNet` to the graph
|
||||||
|
|
||||||
// Conditionals to remove from the bayes net
|
|
||||||
// since the conditional will be updated.
|
|
||||||
std::vector<HybridConditional::shared_ptr> conditionals_to_erase;
|
|
||||||
|
|
||||||
// New conditionals to add to the graph
|
// New conditionals to add to the graph
|
||||||
gtsam::HybridBayesNet newConditionals;
|
gtsam::HybridBayesNet newConditionals;
|
||||||
|
|
||||||
|
@ -112,25 +112,32 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
auto conditional = hybridBayesNet.at(i);
|
auto conditional = hybridBayesNet.at(i);
|
||||||
|
|
||||||
for (auto &key : conditional->frontals()) {
|
for (auto &key : conditional->frontals()) {
|
||||||
if (std::find(ordering.begin(), ordering.end(), key) !=
|
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
||||||
ordering.end()) {
|
factorKeys.end()) {
|
||||||
newConditionals.push_back(conditional);
|
newConditionals.push_back(conditional);
|
||||||
conditionals_to_erase.push_back(conditional);
|
|
||||||
|
// Add the conditional parents to factorKeys
|
||||||
|
// so we add those conditionals too.
|
||||||
|
// NOTE: This assumes we have a structure where
|
||||||
|
// variables depend on those in the future.
|
||||||
|
for (auto &&parentKey : conditional->parents()) {
|
||||||
|
factorKeys.insert(parentKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the conditional from the updated Bayes net
|
||||||
|
auto it = find(updatedHybridBayesNet.begin(),
|
||||||
|
updatedHybridBayesNet.end(), conditional);
|
||||||
|
updatedHybridBayesNet.erase(it);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove conditionals at the end so we don't affect the order in the
|
|
||||||
// original bayes net.
|
|
||||||
for (auto &&conditional : conditionals_to_erase) {
|
|
||||||
auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional);
|
|
||||||
hybridBayesNet.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.push_back(newConditionals);
|
graph.push_back(newConditionals);
|
||||||
}
|
}
|
||||||
return {graph, hybridBayesNet};
|
|
||||||
|
return {graph, updatedHybridBayesNet};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
HybridBayesNet hybridBayesNet_;
|
HybridBayesNet hybridBayesNet_;
|
||||||
HybridGaussianFactorGraph remainingFactorGraph_;
|
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||||
|
|
||||||
|
/// The threshold above which we make a decision about a mode.
|
||||||
|
std::optional<double> deadModeThreshold_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Constructor
|
||||||
|
*
|
||||||
|
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||||
|
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
||||||
|
* value and is considered "dead". 0.99 is a good starting value.
|
||||||
|
*/
|
||||||
|
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
||||||
|
: deadModeThreshold_(deadModeThreshold) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given new factors, perform an incremental update.
|
* Given new factors, perform an incremental update.
|
||||||
* The relevant densities in the `hybridBayesNet` will be added to the input
|
* The relevant densities in the `hybridBayesNet` will be added to the input
|
||||||
|
@ -49,11 +62,24 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
* @param given_ordering The (optional) ordering for elimination, only
|
* @param given_ordering The (optional) ordering for elimination, only
|
||||||
* continuous variables are allowed
|
* continuous variables are allowed
|
||||||
*/
|
*/
|
||||||
void update(HybridGaussianFactorGraph graph,
|
void update(const HybridGaussianFactorGraph& graph,
|
||||||
std::optional<size_t> maxNrLeaves = {},
|
std::optional<size_t> maxNrLeaves = {},
|
||||||
const std::optional<Ordering> given_ordering = {});
|
const std::optional<Ordering> given_ordering = {});
|
||||||
|
|
||||||
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
|
/**
|
||||||
|
* @brief Get an elimination ordering which eliminates continuous
|
||||||
|
* and then discrete.
|
||||||
|
*
|
||||||
|
* Expects `factors` to already have the necessary conditionals
|
||||||
|
* which were connected to the variables in the newly added factors.
|
||||||
|
* Those variables should be in `newFactorKeys`.
|
||||||
|
*
|
||||||
|
* @param factors All the new factors and connected conditionals.
|
||||||
|
* @param newFactorKeys The keys/variables in the newly added factors.
|
||||||
|
* @return Ordering
|
||||||
|
*/
|
||||||
|
Ordering getOrdering(const HybridGaussianFactorGraph& factors,
|
||||||
|
const KeySet& newFactorKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add conditionals from previous timestep as part of liquefication.
|
* @brief Add conditionals from previous timestep as part of liquefication.
|
||||||
|
@ -66,7 +92,7 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
*/
|
*/
|
||||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
|
std::pair<HybridGaussianFactorGraph, HybridBayesNet> addConditionals(
|
||||||
const HybridGaussianFactorGraph& graph,
|
const HybridGaussianFactorGraph& graph,
|
||||||
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;
|
const HybridBayesNet& hybridBayesNet) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the hybrid Gaussian conditional from
|
* @brief Get the hybrid Gaussian conditional from
|
||||||
|
|
|
@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
HybridValues delta = posterior->optimize();
|
HybridValues delta = posterior->optimize();
|
||||||
|
|
||||||
// Prune the Bayes net
|
// Prune the Bayes net
|
||||||
const bool pruneDeadVariables = true;
|
const double pruneDeadVariables = 0.99;
|
||||||
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||||
|
|
||||||
// Check that discrete joint only has M0 and not (M0, M1)
|
// Check that discrete joint only has M0 and not (M0, M1)
|
||||||
|
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
// Check that hybrid conditionals that only depend on M1
|
// Check that hybrid conditionals that only depend on M1
|
||||||
// are now Gaussian and not Hybrid
|
// are now Gaussian and not Hybrid
|
||||||
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
||||||
EXPECT(prunedBayesNet.at(1)->isHybrid());
|
EXPECT(prunedBayesNet.at(1)->isDiscrete());
|
||||||
|
EXPECT(prunedBayesNet.at(2)->isHybrid());
|
||||||
// Only P(X2 | X1, M1) depends on M1,
|
// Only P(X2 | X1, M1) depends on M1,
|
||||||
// so it gets convert to a Gaussian P(X2 | X1)
|
// so it gets convert to a Gaussian P(X2 | X1)
|
||||||
EXPECT(prunedBayesNet.at(2)->isContinuous());
|
EXPECT(prunedBayesNet.at(3)->isContinuous());
|
||||||
EXPECT(prunedBayesNet.at(3)->isHybrid());
|
EXPECT(prunedBayesNet.at(4)->isHybrid());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
|
@ -95,16 +95,15 @@ TEST(HybridSmoother, IncrementalSmoother) {
|
||||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||||
Ordering ordering = smoother.getOrdering(linearized);
|
|
||||||
|
|
||||||
smoother.update(linearized, maxNrLeaves, ordering);
|
smoother.update(linearized, maxNrLeaves);
|
||||||
|
|
||||||
// Clear all the factors from the graph
|
// Clear all the factors from the graph
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(11,
|
EXPECT_LONGS_EQUAL(11,
|
||||||
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
|
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
|
@ -150,16 +149,15 @@ TEST(HybridSmoother, ValidPruningError) {
|
||||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||||
Ordering ordering = smoother.getOrdering(linearized);
|
|
||||||
|
|
||||||
smoother.update(linearized, maxNrLeaves, ordering);
|
smoother.update(linearized, maxNrLeaves);
|
||||||
|
|
||||||
// Clear all the factors from the graph
|
// Clear all the factors from the graph
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(14,
|
EXPECT_LONGS_EQUAL(14,
|
||||||
smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues());
|
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
|
@ -169,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
|
||||||
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
|
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Test if dead mode removal works.
|
||||||
|
TEST(HybridSmoother, DeadModeRemoval) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
|
size_t K = 8;
|
||||||
|
|
||||||
|
// Switching example of robot moving in 1D
|
||||||
|
// with given measurements and equal mode priors.
|
||||||
|
HybridNonlinearFactorGraph graph;
|
||||||
|
Values initial;
|
||||||
|
Switching switching = InitializeEstimationProblem(
|
||||||
|
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
|
||||||
|
|
||||||
|
// Smoother with dead mode removal enabled.
|
||||||
|
HybridSmoother smoother(true);
|
||||||
|
|
||||||
|
constexpr size_t maxNrLeaves = 3;
|
||||||
|
for (size_t k = 1; k < K; k++) {
|
||||||
|
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||||
|
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||||
|
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||||
|
|
||||||
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||||
|
|
||||||
|
smoother.update(linearized, maxNrLeaves);
|
||||||
|
|
||||||
|
// Clear all the factors from the graph
|
||||||
|
graph.resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the continuous delta update as well as
|
||||||
|
// the optimal discrete assignment.
|
||||||
|
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||||
|
|
||||||
|
// Check discrete assignment
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||||
|
|
||||||
|
// Update nonlinear solution and verify
|
||||||
|
Values result = initial.retract(delta.continuous());
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue