factor out class methods 'weights' and 'kruskal' into free-functions

release/4.3a0
Ankur Roy Chowdhury 2023-01-23 16:25:51 -08:00
parent 1f574c95ea
commit 981d8e9391
3 changed files with 150 additions and 67 deletions

View File

@ -238,6 +238,101 @@ std::string SubgraphBuilderParameters::augmentationWeightTranslator(
return "UNKNOWN"; return "UNKNOWN";
} }
/****************************************************************/
std::vector<double> utils::assignWeights(const GaussianFactorGraph &gfg, const SubgraphBuilderParameters::SkeletonWeight &skeletonWeight)
{
using Weights = std::vector<double>;
const size_t m = gfg.size();
Weights weights;
weights.reserve(m);
for (const GaussianFactor::shared_ptr &gf : gfg)
{
switch (skeletonWeight)
{
case SubgraphBuilderParameters::EQUAL:
weights.push_back(1.0);
break;
case SubgraphBuilderParameters::RHS_2NORM:
{
if (JacobianFactor::shared_ptr jf =
std::dynamic_pointer_cast<JacobianFactor>(gf))
{
weights.push_back(jf->getb().norm());
}
else if (HessianFactor::shared_ptr hf =
std::dynamic_pointer_cast<HessianFactor>(gf))
{
weights.push_back(hf->linearTerm().norm());
}
}
break;
case SubgraphBuilderParameters::LHS_FNORM:
{
if (JacobianFactor::shared_ptr jf =
std::dynamic_pointer_cast<JacobianFactor>(gf))
{
weights.push_back(std::sqrt(jf->getA().squaredNorm()));
}
else if (HessianFactor::shared_ptr hf =
std::dynamic_pointer_cast<HessianFactor>(gf))
{
weights.push_back(std::sqrt(hf->information().squaredNorm()));
}
}
break;
case SubgraphBuilderParameters::RANDOM:
weights.push_back(std::rand() % 100 + 1.0);
break;
default:
throw std::invalid_argument(
"utils::assign_weights: undefined weight scheme ");
break;
}
}
return weights;
}
/****************************************************************/
std::vector<size_t> utils::kruskal(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering,
const std::vector<double> &weights)
{
const VariableIndex variableIndex(gfg);
const size_t n = variableIndex.size();
const vector<size_t> sortedIndices = sort_idx(weights);
/* initialize buffer */
vector<size_t> treeIndices;
treeIndices.reserve(n - 1);
// container for acsendingly sorted edges
DSFVector dsf(n);
size_t count = 0;
for (const size_t index : sortedIndices)
{
const GaussianFactor &gf = *gfg[index];
const auto keys = gf.keys();
if (keys.size() != 2)
continue;
const size_t u = ordering.find(keys[0])->second,
v = ordering.find(keys[1])->second;
if (dsf.find(u) != dsf.find(v))
{
dsf.merge(u, v);
treeIndices.push_back(index);
if (++count == n - 1)
break;
}
}
return treeIndices;
}
/****************************************************************/ /****************************************************************/
vector<size_t> SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg, vector<size_t> SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering, const FastMap<Key, size_t> &ordering,
@ -329,31 +424,7 @@ vector<size_t> SubgraphBuilder::bfs(const GaussianFactorGraph &gfg) const {
vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg, vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering, const FastMap<Key, size_t> &ordering,
const vector<double> &weights) const { const vector<double> &weights) const {
const VariableIndex variableIndex(gfg); return utils::kruskal(gfg, ordering, weights);
const size_t n = variableIndex.size();
const vector<size_t> sortedIndices = sort_idx(weights);
/* initialize buffer */
vector<size_t> treeIndices;
treeIndices.reserve(n - 1);
// container for acsendingly sorted edges
DSFVector dsf(n);
size_t count = 0;
for (const size_t index : sortedIndices) {
const GaussianFactor &gf = *gfg[index];
const auto keys = gf.keys();
if (keys.size() != 2) continue;
const size_t u = ordering.find(keys[0])->second,
v = ordering.find(keys[1])->second;
if (dsf.find(u) != dsf.find(v)) {
dsf.merge(u, v);
treeIndices.push_back(index);
if (++count == n - 1) break;
}
}
return treeIndices;
} }
/****************************************************************/ /****************************************************************/
@ -406,45 +477,7 @@ Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const {
/****************************************************************/ /****************************************************************/
SubgraphBuilder::Weights SubgraphBuilder::weights( SubgraphBuilder::Weights SubgraphBuilder::weights(
const GaussianFactorGraph &gfg) const { const GaussianFactorGraph &gfg) const {
const size_t m = gfg.size(); return utils::assignWeights(gfg, parameters_.skeletonWeight);
Weights weight;
weight.reserve(m);
for (const GaussianFactor::shared_ptr &gf : gfg) {
switch (parameters_.skeletonWeight) {
case SubgraphBuilderParameters::EQUAL:
weight.push_back(1.0);
break;
case SubgraphBuilderParameters::RHS_2NORM: {
if (JacobianFactor::shared_ptr jf =
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
weight.push_back(jf->getb().norm());
} else if (HessianFactor::shared_ptr hf =
std::dynamic_pointer_cast<HessianFactor>(gf)) {
weight.push_back(hf->linearTerm().norm());
}
} break;
case SubgraphBuilderParameters::LHS_FNORM: {
if (JacobianFactor::shared_ptr jf =
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
weight.push_back(std::sqrt(jf->getA().squaredNorm()));
} else if (HessianFactor::shared_ptr hf =
std::dynamic_pointer_cast<HessianFactor>(gf)) {
weight.push_back(std::sqrt(hf->information().squaredNorm()));
}
} break;
case SubgraphBuilderParameters::RANDOM:
weight.push_back(std::rand() % 100 + 1.0);
break;
default:
throw std::invalid_argument(
"SubgraphBuilder::weights: undefined weight scheme ");
break;
}
}
return weight;
} }
/*****************************************************************************/ /*****************************************************************************/

View File

@ -149,6 +149,16 @@ struct GTSAM_EXPORT SubgraphBuilderParameters {
static std::string augmentationWeightTranslator(AugmentationWeight w); static std::string augmentationWeightTranslator(AugmentationWeight w);
}; };
namespace utils
{
std::vector<double> assignWeights(const GaussianFactorGraph &gfg,
const SubgraphBuilderParameters::SkeletonWeight &skeleton_weight);
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering,
const std::vector<double> &weights);
}
/*****************************************************************************/ /*****************************************************************************/
class GTSAM_EXPORT SubgraphBuilder { class GTSAM_EXPORT SubgraphBuilder {
public: public:
@ -161,6 +171,8 @@ class GTSAM_EXPORT SubgraphBuilder {
virtual ~SubgraphBuilder() {} virtual ~SubgraphBuilder() {}
virtual Subgraph operator()(const GaussianFactorGraph &jfg) const; virtual Subgraph operator()(const GaussianFactorGraph &jfg) const;
public:
private: private:
std::vector<size_t> buildTree(const GaussianFactorGraph &gfg, std::vector<size_t> buildTree(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering, const FastMap<Key, size_t> &ordering,
@ -168,12 +180,13 @@ class GTSAM_EXPORT SubgraphBuilder {
std::vector<size_t> unary(const GaussianFactorGraph &gfg) const; std::vector<size_t> unary(const GaussianFactorGraph &gfg) const;
std::vector<size_t> natural_chain(const GaussianFactorGraph &gfg) const; std::vector<size_t> natural_chain(const GaussianFactorGraph &gfg) const;
std::vector<size_t> bfs(const GaussianFactorGraph &gfg) const; std::vector<size_t> bfs(const GaussianFactorGraph &gfg) const;
std::vector<size_t> sample(const std::vector<double> &weights,
const size_t t) const;
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg, std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
const FastMap<Key, size_t> &ordering, const FastMap<Key, size_t> &ordering,
const std::vector<double> &weights) const; const std::vector<double> &weights) const;
std::vector<size_t> sample(const std::vector<double> &weights, Weights weights(const GaussianFactorGraph &gfg) const ;
const size_t t) const;
Weights weights(const GaussianFactorGraph &gfg) const;
SubgraphBuilderParameters parameters_; SubgraphBuilderParameters parameters_;
}; };

View File

@ -129,6 +129,43 @@ TEST( SubgraphSolver, constructor3 )
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5); DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
} }
/* ************************************************************************* */
TEST(SubgraphBuilder, utilsKruskal)
{
const auto [g, _] = example::planarGraph(N); // A*x-b
const FastMap<Key, size_t> forward_ordering = Ordering::Natural(g).invert();
const auto weights = utils::assignWeights(g, gtsam::SubgraphBuilderParameters::SkeletonWeight::EQUAL);
const auto mstEdgeIndices = utils::kruskal(g, forward_ordering, weights);
// auto PrintMst = [](const auto &graph, const auto &mst_edge_indices)
// {
// std::cout << "MST Edge indices are: \n";
// for (const auto &edge : mst_edge_indices)
// {
// std::cout << edge << " : ";
// for (const auto &key : graph[edge]->keys())
// {
// std::cout << gtsam::DefaultKeyFormatter(gtsam::Symbol(key)) << ", ";
// }
// std::cout << "\n";
// }
// };
// PrintMst(g, mstEdgeIndices);
EXPECT(mstEdgeIndices[0] == 1);
EXPECT(mstEdgeIndices[1] == 2);
EXPECT(mstEdgeIndices[2] == 3);
EXPECT(mstEdgeIndices[3] == 4);
EXPECT(mstEdgeIndices[4] == 5);
EXPECT(mstEdgeIndices[5] == 6);
EXPECT(mstEdgeIndices[6] == 7);
EXPECT(mstEdgeIndices[7] == 8);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); } int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */ /* ************************************************************************* */