diff --git a/gtsam/linear/SubgraphBuilder.cpp b/gtsam/linear/SubgraphBuilder.cpp index 97d681547..083cd72c3 100644 --- a/gtsam/linear/SubgraphBuilder.cpp +++ b/gtsam/linear/SubgraphBuilder.cpp @@ -238,6 +238,101 @@ std::string SubgraphBuilderParameters::augmentationWeightTranslator( return "UNKNOWN"; } +/****************************************************************/ +std::vector utils::assignWeights(const GaussianFactorGraph &gfg, const SubgraphBuilderParameters::SkeletonWeight &skeletonWeight) +{ + using Weights = std::vector; + + const size_t m = gfg.size(); + Weights weights; + weights.reserve(m); + + for (const GaussianFactor::shared_ptr &gf : gfg) + { + switch (skeletonWeight) + { + case SubgraphBuilderParameters::EQUAL: + weights.push_back(1.0); + break; + case SubgraphBuilderParameters::RHS_2NORM: + { + if (JacobianFactor::shared_ptr jf = + std::dynamic_pointer_cast(gf)) + { + weights.push_back(jf->getb().norm()); + } + else if (HessianFactor::shared_ptr hf = + std::dynamic_pointer_cast(gf)) + { + weights.push_back(hf->linearTerm().norm()); + } + } + break; + case SubgraphBuilderParameters::LHS_FNORM: + { + if (JacobianFactor::shared_ptr jf = + std::dynamic_pointer_cast(gf)) + { + weights.push_back(std::sqrt(jf->getA().squaredNorm())); + } + else if (HessianFactor::shared_ptr hf = + std::dynamic_pointer_cast(gf)) + { + weights.push_back(std::sqrt(hf->information().squaredNorm())); + } + } + break; + + case SubgraphBuilderParameters::RANDOM: + weights.push_back(std::rand() % 100 + 1.0); + break; + + default: + throw std::invalid_argument( + "utils::assign_weights: undefined weight scheme "); + break; + } + } + return weights; +} + + +/****************************************************************/ +std::vector utils::kruskal(const GaussianFactorGraph &gfg, + const FastMap &ordering, + const std::vector &weights) +{ + const VariableIndex variableIndex(gfg); + const size_t n = variableIndex.size(); + const vector sortedIndices = sort_idx(weights); + + /* initialize buffer */ + vector treeIndices; + treeIndices.reserve(n - 1); + + // container for acsendingly sorted edges + DSFVector dsf(n); + + size_t count = 0; + for (const size_t index : sortedIndices) + { + const GaussianFactor &gf = *gfg[index]; + const auto keys = gf.keys(); + if (keys.size() != 2) + continue; + const size_t u = ordering.find(keys[0])->second, + v = ordering.find(keys[1])->second; + if (dsf.find(u) != dsf.find(v)) + { + dsf.merge(u, v); + treeIndices.push_back(index); + if (++count == n - 1) + break; + } + } + return treeIndices; +} + /****************************************************************/ vector SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg, const FastMap &ordering, @@ -329,31 +424,7 @@ vector SubgraphBuilder::bfs(const GaussianFactorGraph &gfg) const { vector SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg, const FastMap &ordering, const vector &weights) const { - const VariableIndex variableIndex(gfg); - const size_t n = variableIndex.size(); - const vector sortedIndices = sort_idx(weights); - - /* initialize buffer */ - vector treeIndices; - treeIndices.reserve(n - 1); - - // container for acsendingly sorted edges - DSFVector dsf(n); - - size_t count = 0; - for (const size_t index : sortedIndices) { - const GaussianFactor &gf = *gfg[index]; - const auto keys = gf.keys(); - if (keys.size() != 2) continue; - const size_t u = ordering.find(keys[0])->second, - v = ordering.find(keys[1])->second; - if (dsf.find(u) != dsf.find(v)) { - dsf.merge(u, v); - treeIndices.push_back(index); - if (++count == n - 1) break; - } - } - return treeIndices; + return utils::kruskal(gfg, ordering, weights); } /****************************************************************/ @@ -406,45 +477,7 @@ Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const { /****************************************************************/ SubgraphBuilder::Weights SubgraphBuilder::weights( const GaussianFactorGraph &gfg) const { - const size_t m = gfg.size(); - Weights weight; - weight.reserve(m); - - for (const GaussianFactor::shared_ptr &gf : gfg) { - switch (parameters_.skeletonWeight) { - case SubgraphBuilderParameters::EQUAL: - weight.push_back(1.0); - break; - case SubgraphBuilderParameters::RHS_2NORM: { - if (JacobianFactor::shared_ptr jf = - std::dynamic_pointer_cast(gf)) { - weight.push_back(jf->getb().norm()); - } else if (HessianFactor::shared_ptr hf = - std::dynamic_pointer_cast(gf)) { - weight.push_back(hf->linearTerm().norm()); - } - } break; - case SubgraphBuilderParameters::LHS_FNORM: { - if (JacobianFactor::shared_ptr jf = - std::dynamic_pointer_cast(gf)) { - weight.push_back(std::sqrt(jf->getA().squaredNorm())); - } else if (HessianFactor::shared_ptr hf = - std::dynamic_pointer_cast(gf)) { - weight.push_back(std::sqrt(hf->information().squaredNorm())); - } - } break; - - case SubgraphBuilderParameters::RANDOM: - weight.push_back(std::rand() % 100 + 1.0); - break; - - default: - throw std::invalid_argument( - "SubgraphBuilder::weights: undefined weight scheme "); - break; - } - } - return weight; + return utils::assignWeights(gfg, parameters_.skeletonWeight); } /*****************************************************************************/ diff --git a/gtsam/linear/SubgraphBuilder.h b/gtsam/linear/SubgraphBuilder.h index fe8f704dc..cd369fc5f 100644 --- a/gtsam/linear/SubgraphBuilder.h +++ b/gtsam/linear/SubgraphBuilder.h @@ -149,6 +149,16 @@ struct GTSAM_EXPORT SubgraphBuilderParameters { static std::string augmentationWeightTranslator(AugmentationWeight w); }; +namespace utils +{ + + std::vector assignWeights(const GaussianFactorGraph &gfg, + const SubgraphBuilderParameters::SkeletonWeight &skeleton_weight); + std::vector kruskal(const GaussianFactorGraph &gfg, + const FastMap &ordering, + const std::vector &weights); +} + /*****************************************************************************/ class GTSAM_EXPORT SubgraphBuilder { public: @@ -161,6 +171,8 @@ class GTSAM_EXPORT SubgraphBuilder { virtual ~SubgraphBuilder() {} virtual Subgraph operator()(const GaussianFactorGraph &jfg) const; +public: + private: std::vector buildTree(const GaussianFactorGraph &gfg, const FastMap &ordering, @@ -168,12 +180,13 @@ class GTSAM_EXPORT SubgraphBuilder { std::vector unary(const GaussianFactorGraph &gfg) const; std::vector natural_chain(const GaussianFactorGraph &gfg) const; std::vector bfs(const GaussianFactorGraph &gfg) const; + std::vector sample(const std::vector &weights, + const size_t t) const; std::vector kruskal(const GaussianFactorGraph &gfg, const FastMap &ordering, const std::vector &weights) const; - std::vector sample(const std::vector &weights, - const size_t t) const; - Weights weights(const GaussianFactorGraph &gfg) const; + Weights weights(const GaussianFactorGraph &gfg) const ; + SubgraphBuilderParameters parameters_; }; diff --git a/tests/testSubgraphSolver.cpp b/tests/testSubgraphSolver.cpp index 69b5fe5f9..245572896 100644 --- a/tests/testSubgraphSolver.cpp +++ b/tests/testSubgraphSolver.cpp @@ -129,6 +129,43 @@ TEST( SubgraphSolver, constructor3 ) DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5); } +/* ************************************************************************* */ +TEST(SubgraphBuilder, utilsKruskal) +{ + + const auto [g, _] = example::planarGraph(N); // A*x-b + + const FastMap forward_ordering = Ordering::Natural(g).invert(); + const auto weights = utils::assignWeights(g, gtsam::SubgraphBuilderParameters::SkeletonWeight::EQUAL); + + const auto mstEdgeIndices = utils::kruskal(g, forward_ordering, weights); + + // auto PrintMst = [](const auto &graph, const auto &mst_edge_indices) + // { + // std::cout << "MST Edge indices are: \n"; + // for (const auto &edge : mst_edge_indices) + // { + // std::cout << edge << " : "; + // for (const auto &key : graph[edge]->keys()) + // { + // std::cout << gtsam::DefaultKeyFormatter(gtsam::Symbol(key)) << ", "; + // } + // std::cout << "\n"; + // } + // }; + + // PrintMst(g, mstEdgeIndices); + + EXPECT(mstEdgeIndices[0] == 1); + EXPECT(mstEdgeIndices[1] == 2); + EXPECT(mstEdgeIndices[2] == 3); + EXPECT(mstEdgeIndices[3] == 4); + EXPECT(mstEdgeIndices[4] == 5); + EXPECT(mstEdgeIndices[5] == 6); + EXPECT(mstEdgeIndices[6] == 7); + EXPECT(mstEdgeIndices[7] == 8); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */