Merge branch 'hybrid-smoother' into city10000
commit
5cee0a2d34
|
@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
// been pruned before. Another, possibly faster approach is branch and bound
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
// search to find the K-best leaves and then create a single pruned conditional.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
HybridBayesNet HybridBayesNet::prune(
|
||||||
bool removeDeadModes) const {
|
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||||
// Collect all the discrete conditionals. Could be small if already pruned.
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
const DiscreteBayesNet marginal = discreteMarginal();
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
|
|
||||||
|
@ -58,25 +58,21 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
joint = joint * (*conditional);
|
joint = joint * (*conditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
// Initialize the resulting HybridBayesNet.
|
||||||
joint.prune(maxNrLeaves);
|
|
||||||
|
|
||||||
// Create the result starting with the pruned joint.
|
|
||||||
HybridBayesNet result;
|
HybridBayesNet result;
|
||||||
result.emplace_shared<DiscreteConditional>(joint);
|
|
||||||
|
|
||||||
// Get pruned discrete probabilities so
|
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
|
||||||
// we can prune HybridGaussianConditionals.
|
DiscreteConditional pruned = joint;
|
||||||
DiscreteConditional pruned = *result.back()->asDiscrete();
|
pruned.prune(maxNrLeaves);
|
||||||
|
|
||||||
DiscreteValues deadModesValues;
|
DiscreteValues deadModesValues;
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
Vector probabilities = marginals.marginalProbabilities(dkey);
|
Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
||||||
int index = -1;
|
int index = -1;
|
||||||
auto threshold = (probabilities.array() > 0.99);
|
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||||
// If atleast 1 value is non-zero, then we can find the index
|
// If atleast 1 value is non-zero, then we can find the index
|
||||||
// Else if all are zero, index would be set to 0 which is incorrect
|
// Else if all are zero, index would be set to 0 which is incorrect
|
||||||
if (!threshold.isZero()) {
|
if (!threshold.isZero()) {
|
||||||
|
@ -89,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the modes (imperative)
|
// Remove the modes (imperative)
|
||||||
result.back()->asDiscrete()->removeDiscreteModes(deadModesValues);
|
pruned.removeDiscreteModes(deadModesValues);
|
||||||
pruned = *result.back()->asDiscrete();
|
|
||||||
|
/*
|
||||||
|
If the pruned discrete conditional has any keys left,
|
||||||
|
we add it to the HybridBayesNet.
|
||||||
|
If not, it means it is an orphan so we don't add this pruned joint,
|
||||||
|
and instead add only the marginals below.
|
||||||
|
*/
|
||||||
|
if (pruned.keys().size() > 0) {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the marginals for future factors
|
||||||
|
for (auto &&[key, _] : deadModesValues) {
|
||||||
|
result.push_back(
|
||||||
|
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
result.emplace_shared<DiscreteConditional>(pruned);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
|
@ -101,13 +115,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves,
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||||
// per pruned Discrete joint.
|
// per pruned discrete joint.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
|
|
||||||
if (removeDeadModes) {
|
if (deadModeThreshold.has_value()) {
|
||||||
KeyVector deadKeys, conditionalDiscreteKeys;
|
KeyVector deadKeys, conditionalDiscreteKeys;
|
||||||
for (const auto &kv : deadModesValues) {
|
for (const auto &kv : deadModesValues) {
|
||||||
deadKeys.push_back(kv.first);
|
deadKeys.push_back(kv.first);
|
||||||
|
|
|
@ -217,11 +217,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
*
|
*
|
||||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
* @param removeDeadModes Flag to enable removal of modes which only have a
|
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||||
* single possible assignment.
|
* If greater than this threshold, the mode gets assigned that value and is
|
||||||
|
* considered "dead" for hybrid elimination.
|
||||||
|
* The mode can then be removed since it only has a single possible
|
||||||
|
* assignment.
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const;
|
HybridBayesNet prune(
|
||||||
|
size_t maxNrLeaves,
|
||||||
|
const std::optional<double> &deadModeThreshold = {}) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
|
@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||||
// all the conditionals with the same keys in bayesNetFragment.
|
// all the conditionals with the same keys in bayesNetFragment.
|
||||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
|
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the partial bayes net to the posterior bayes net.
|
// Add the partial bayes net to the posterior bayes net.
|
||||||
|
@ -116,6 +116,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
factorKeys.end()) {
|
factorKeys.end()) {
|
||||||
newConditionals.push_back(conditional);
|
newConditionals.push_back(conditional);
|
||||||
|
|
||||||
|
// Add the conditional parents to factorKeys
|
||||||
|
// so we add those conditionals too.
|
||||||
|
// NOTE: This assumes we have a structure where
|
||||||
|
// variables depend on those in the future.
|
||||||
|
for (auto &&parentKey : conditional->parents()) {
|
||||||
|
factorKeys.insert(parentKey);
|
||||||
|
}
|
||||||
|
|
||||||
// Remove the conditional from the updated Bayes net
|
// Remove the conditional from the updated Bayes net
|
||||||
auto it = find(updatedHybridBayesNet.begin(),
|
auto it = find(updatedHybridBayesNet.begin(),
|
||||||
updatedHybridBayesNet.end(), conditional);
|
updatedHybridBayesNet.end(), conditional);
|
||||||
|
|
|
@ -29,7 +29,20 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
HybridBayesNet hybridBayesNet_;
|
HybridBayesNet hybridBayesNet_;
|
||||||
HybridGaussianFactorGraph remainingFactorGraph_;
|
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||||
|
|
||||||
|
/// The threshold above which we make a decision about a mode.
|
||||||
|
std::optional<double> deadModeThreshold_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Constructor
|
||||||
|
*
|
||||||
|
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||||
|
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
||||||
|
* value and is considered "dead". 0.99 is a good starting value.
|
||||||
|
*/
|
||||||
|
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
||||||
|
: deadModeThreshold_(deadModeThreshold) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given new factors, perform an incremental update.
|
* Given new factors, perform an incremental update.
|
||||||
* The relevant densities in the `hybridBayesNet` will be added to the input
|
* The relevant densities in the `hybridBayesNet` will be added to the input
|
||||||
|
@ -53,17 +66,6 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
std::optional<size_t> maxNrLeaves = {},
|
std::optional<size_t> maxNrLeaves = {},
|
||||||
const std::optional<Ordering> given_ordering = {});
|
const std::optional<Ordering> given_ordering = {});
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get an elimination ordering which eliminates continuous and then
|
|
||||||
* discrete.
|
|
||||||
*
|
|
||||||
* Expects `newFactors` to already have the necessary conditionals connected
|
|
||||||
* to the
|
|
||||||
*
|
|
||||||
* @param factors
|
|
||||||
* @return Ordering
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get an elimination ordering which eliminates continuous
|
* @brief Get an elimination ordering which eliminates continuous
|
||||||
* and then discrete.
|
* and then discrete.
|
||||||
|
|
|
@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
HybridValues delta = posterior->optimize();
|
HybridValues delta = posterior->optimize();
|
||||||
|
|
||||||
// Prune the Bayes net
|
// Prune the Bayes net
|
||||||
const bool pruneDeadVariables = true;
|
const double pruneDeadVariables = 0.99;
|
||||||
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
|
||||||
|
|
||||||
// Check that discrete joint only has M0 and not (M0, M1)
|
// Check that discrete joint only has M0 and not (M0, M1)
|
||||||
|
@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
|
||||||
// Check that hybrid conditionals that only depend on M1
|
// Check that hybrid conditionals that only depend on M1
|
||||||
// are now Gaussian and not Hybrid
|
// are now Gaussian and not Hybrid
|
||||||
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
EXPECT(prunedBayesNet.at(0)->isDiscrete());
|
||||||
EXPECT(prunedBayesNet.at(1)->isHybrid());
|
EXPECT(prunedBayesNet.at(1)->isDiscrete());
|
||||||
|
EXPECT(prunedBayesNet.at(2)->isHybrid());
|
||||||
// Only P(X2 | X1, M1) depends on M1,
|
// Only P(X2 | X1, M1) depends on M1,
|
||||||
// so it gets convert to a Gaussian P(X2 | X1)
|
// so it gets convert to a Gaussian P(X2 | X1)
|
||||||
EXPECT(prunedBayesNet.at(2)->isContinuous());
|
EXPECT(prunedBayesNet.at(3)->isContinuous());
|
||||||
EXPECT(prunedBayesNet.at(3)->isHybrid());
|
EXPECT(prunedBayesNet.at(4)->isHybrid());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
|
@ -167,6 +167,59 @@ TEST(HybridSmoother, ValidPruningError) {
|
||||||
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
|
EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Test if dead mode removal works.
|
||||||
|
TEST(HybridSmoother, DeadModeRemoval) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
|
size_t K = 8;
|
||||||
|
|
||||||
|
// Switching example of robot moving in 1D
|
||||||
|
// with given measurements and equal mode priors.
|
||||||
|
HybridNonlinearFactorGraph graph;
|
||||||
|
Values initial;
|
||||||
|
Switching switching = InitializeEstimationProblem(
|
||||||
|
K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial);
|
||||||
|
|
||||||
|
// Smoother with dead mode removal enabled.
|
||||||
|
HybridSmoother smoother(true);
|
||||||
|
|
||||||
|
constexpr size_t maxNrLeaves = 3;
|
||||||
|
for (size_t k = 1; k < K; k++) {
|
||||||
|
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||||
|
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||||
|
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||||
|
|
||||||
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||||
|
|
||||||
|
smoother.update(linearized, maxNrLeaves);
|
||||||
|
|
||||||
|
// Clear all the factors from the graph
|
||||||
|
graph.resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the continuous delta update as well as
|
||||||
|
// the optimal discrete assignment.
|
||||||
|
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||||
|
|
||||||
|
// Check discrete assignment
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||||
|
|
||||||
|
// Update nonlinear solution and verify
|
||||||
|
Values result = initial.retract(delta.continuous());
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue