diff --git a/gtsam/linear/HessianFactor.cpp b/gtsam/linear/HessianFactor.cpp index 83e30f881..2fb03481b 100644 --- a/gtsam/linear/HessianFactor.cpp +++ b/gtsam/linear/HessianFactor.cpp @@ -45,6 +45,7 @@ #include #include +#include using namespace std; using namespace boost::assign; @@ -60,21 +61,37 @@ string SlotEntry::toString() const { } /* ************************************************************************* */ -Scatter::Scatter(const GaussianFactorGraph& gfg) { +Scatter::Scatter(const GaussianFactorGraph& gfg, boost::optional ordering) +{ + static const size_t none = std::numeric_limits::max(); + // First do the set union. BOOST_FOREACH(const GaussianFactor::shared_ptr& factor, gfg) { if(factor) { for(GaussianFactor::const_iterator variable = factor->begin(); variable != factor->end(); ++variable) { - this->insert(make_pair(*variable, SlotEntry(0, factor->getDim(variable)))); + this->insert(make_pair(*variable, SlotEntry(none, factor->getDim(variable)))); } } } + // If we have an ordering, pre-fill the ordered variables first + size_t slot = 0; + if(ordering) { + BOOST_FOREACH(Key key, *ordering) { + const_iterator entry = find(key); + if(entry == end()) + throw std::invalid_argument( + "The ordering provided to the HessianFactor Scatter constructor\n" + "contained extra variables that did not appear in the factors to combine."); + at(key).slot = (slot ++); + } + } + // Next fill in the slot indices (we can only get these after doing the set // union. - size_t slot = 0; BOOST_FOREACH(value_type& var_slot, *this) { - var_slot.second.slot = (slot ++); + if(var_slot.second.slot == none) + var_slot.second.slot = (slot ++); } } @@ -247,7 +264,6 @@ namespace { /* ************************************************************************* */ HessianFactor::HessianFactor(const GaussianFactorGraph& factors, - boost::optional ordering, boost::optional scatter) { boost::optional computedScatter; @@ -427,7 +443,7 @@ std::pair, boost::shared_ptr(factors, keys); + jointFactor = boost::make_shared(factors, Scatter(factors, keys)); } catch(std::invalid_argument&) { throw InvalidDenseElimination( "EliminateCholesky was called with a request to eliminate variables that are not\n" diff --git a/gtsam/linear/HessianFactor.h b/gtsam/linear/HessianFactor.h index fcff217e8..2f0a2e1c4 100644 --- a/gtsam/linear/HessianFactor.h +++ b/gtsam/linear/HessianFactor.h @@ -47,7 +47,7 @@ namespace gtsam { class Scatter : public FastMap { public: Scatter() {} - Scatter(const GaussianFactorGraph& gfg); + Scatter(const GaussianFactorGraph& gfg, boost::optional ordering = boost::none); }; /** @@ -190,7 +190,6 @@ namespace gtsam { /** Combine a set of factors into a single dense HessianFactor */ explicit HessianFactor(const GaussianFactorGraph& factors, - boost::optional ordering = boost::none, boost::optional scatter = boost::none); /** Destructor */ diff --git a/tests/testGaussianBayesTree.cpp b/tests/testGaussianBayesTree.cpp index 5bfbf251a..bb8a2da62 100644 --- a/tests/testGaussianBayesTree.cpp +++ b/tests/testGaussianBayesTree.cpp @@ -37,6 +37,8 @@ using namespace example; using symbol_shorthand::X; using symbol_shorthand::L; +#define TEST TEST_UNSAFE + /* ************************************************************************* */ // Some numbers that should be consistent among all smoother tests