factor out class methods 'weights' and 'kruskal' into free-functions
parent
1f574c95ea
commit
981d8e9391
|
|
@ -238,6 +238,101 @@ std::string SubgraphBuilderParameters::augmentationWeightTranslator(
|
||||||
return "UNKNOWN";
|
return "UNKNOWN";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************/
|
||||||
|
std::vector<double> utils::assignWeights(const GaussianFactorGraph &gfg, const SubgraphBuilderParameters::SkeletonWeight &skeletonWeight)
|
||||||
|
{
|
||||||
|
using Weights = std::vector<double>;
|
||||||
|
|
||||||
|
const size_t m = gfg.size();
|
||||||
|
Weights weights;
|
||||||
|
weights.reserve(m);
|
||||||
|
|
||||||
|
for (const GaussianFactor::shared_ptr &gf : gfg)
|
||||||
|
{
|
||||||
|
switch (skeletonWeight)
|
||||||
|
{
|
||||||
|
case SubgraphBuilderParameters::EQUAL:
|
||||||
|
weights.push_back(1.0);
|
||||||
|
break;
|
||||||
|
case SubgraphBuilderParameters::RHS_2NORM:
|
||||||
|
{
|
||||||
|
if (JacobianFactor::shared_ptr jf =
|
||||||
|
std::dynamic_pointer_cast<JacobianFactor>(gf))
|
||||||
|
{
|
||||||
|
weights.push_back(jf->getb().norm());
|
||||||
|
}
|
||||||
|
else if (HessianFactor::shared_ptr hf =
|
||||||
|
std::dynamic_pointer_cast<HessianFactor>(gf))
|
||||||
|
{
|
||||||
|
weights.push_back(hf->linearTerm().norm());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case SubgraphBuilderParameters::LHS_FNORM:
|
||||||
|
{
|
||||||
|
if (JacobianFactor::shared_ptr jf =
|
||||||
|
std::dynamic_pointer_cast<JacobianFactor>(gf))
|
||||||
|
{
|
||||||
|
weights.push_back(std::sqrt(jf->getA().squaredNorm()));
|
||||||
|
}
|
||||||
|
else if (HessianFactor::shared_ptr hf =
|
||||||
|
std::dynamic_pointer_cast<HessianFactor>(gf))
|
||||||
|
{
|
||||||
|
weights.push_back(std::sqrt(hf->information().squaredNorm()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SubgraphBuilderParameters::RANDOM:
|
||||||
|
weights.push_back(std::rand() % 100 + 1.0);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"utils::assign_weights: undefined weight scheme ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/****************************************************************/
|
||||||
|
std::vector<size_t> utils::kruskal(const GaussianFactorGraph &gfg,
|
||||||
|
const FastMap<Key, size_t> &ordering,
|
||||||
|
const std::vector<double> &weights)
|
||||||
|
{
|
||||||
|
const VariableIndex variableIndex(gfg);
|
||||||
|
const size_t n = variableIndex.size();
|
||||||
|
const vector<size_t> sortedIndices = sort_idx(weights);
|
||||||
|
|
||||||
|
/* initialize buffer */
|
||||||
|
vector<size_t> treeIndices;
|
||||||
|
treeIndices.reserve(n - 1);
|
||||||
|
|
||||||
|
// container for acsendingly sorted edges
|
||||||
|
DSFVector dsf(n);
|
||||||
|
|
||||||
|
size_t count = 0;
|
||||||
|
for (const size_t index : sortedIndices)
|
||||||
|
{
|
||||||
|
const GaussianFactor &gf = *gfg[index];
|
||||||
|
const auto keys = gf.keys();
|
||||||
|
if (keys.size() != 2)
|
||||||
|
continue;
|
||||||
|
const size_t u = ordering.find(keys[0])->second,
|
||||||
|
v = ordering.find(keys[1])->second;
|
||||||
|
if (dsf.find(u) != dsf.find(v))
|
||||||
|
{
|
||||||
|
dsf.merge(u, v);
|
||||||
|
treeIndices.push_back(index);
|
||||||
|
if (++count == n - 1)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return treeIndices;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************/
|
/****************************************************************/
|
||||||
vector<size_t> SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg,
|
vector<size_t> SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg,
|
||||||
const FastMap<Key, size_t> &ordering,
|
const FastMap<Key, size_t> &ordering,
|
||||||
|
|
@ -329,31 +424,7 @@ vector<size_t> SubgraphBuilder::bfs(const GaussianFactorGraph &gfg) const {
|
||||||
vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
|
vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
|
||||||
const FastMap<Key, size_t> &ordering,
|
const FastMap<Key, size_t> &ordering,
|
||||||
const vector<double> &weights) const {
|
const vector<double> &weights) const {
|
||||||
const VariableIndex variableIndex(gfg);
|
return utils::kruskal(gfg, ordering, weights);
|
||||||
const size_t n = variableIndex.size();
|
|
||||||
const vector<size_t> sortedIndices = sort_idx(weights);
|
|
||||||
|
|
||||||
/* initialize buffer */
|
|
||||||
vector<size_t> treeIndices;
|
|
||||||
treeIndices.reserve(n - 1);
|
|
||||||
|
|
||||||
// container for acsendingly sorted edges
|
|
||||||
DSFVector dsf(n);
|
|
||||||
|
|
||||||
size_t count = 0;
|
|
||||||
for (const size_t index : sortedIndices) {
|
|
||||||
const GaussianFactor &gf = *gfg[index];
|
|
||||||
const auto keys = gf.keys();
|
|
||||||
if (keys.size() != 2) continue;
|
|
||||||
const size_t u = ordering.find(keys[0])->second,
|
|
||||||
v = ordering.find(keys[1])->second;
|
|
||||||
if (dsf.find(u) != dsf.find(v)) {
|
|
||||||
dsf.merge(u, v);
|
|
||||||
treeIndices.push_back(index);
|
|
||||||
if (++count == n - 1) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return treeIndices;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************/
|
/****************************************************************/
|
||||||
|
|
@ -406,45 +477,7 @@ Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const {
|
||||||
/****************************************************************/
|
/****************************************************************/
|
||||||
SubgraphBuilder::Weights SubgraphBuilder::weights(
|
SubgraphBuilder::Weights SubgraphBuilder::weights(
|
||||||
const GaussianFactorGraph &gfg) const {
|
const GaussianFactorGraph &gfg) const {
|
||||||
const size_t m = gfg.size();
|
return utils::assignWeights(gfg, parameters_.skeletonWeight);
|
||||||
Weights weight;
|
|
||||||
weight.reserve(m);
|
|
||||||
|
|
||||||
for (const GaussianFactor::shared_ptr &gf : gfg) {
|
|
||||||
switch (parameters_.skeletonWeight) {
|
|
||||||
case SubgraphBuilderParameters::EQUAL:
|
|
||||||
weight.push_back(1.0);
|
|
||||||
break;
|
|
||||||
case SubgraphBuilderParameters::RHS_2NORM: {
|
|
||||||
if (JacobianFactor::shared_ptr jf =
|
|
||||||
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
|
|
||||||
weight.push_back(jf->getb().norm());
|
|
||||||
} else if (HessianFactor::shared_ptr hf =
|
|
||||||
std::dynamic_pointer_cast<HessianFactor>(gf)) {
|
|
||||||
weight.push_back(hf->linearTerm().norm());
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case SubgraphBuilderParameters::LHS_FNORM: {
|
|
||||||
if (JacobianFactor::shared_ptr jf =
|
|
||||||
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
|
|
||||||
weight.push_back(std::sqrt(jf->getA().squaredNorm()));
|
|
||||||
} else if (HessianFactor::shared_ptr hf =
|
|
||||||
std::dynamic_pointer_cast<HessianFactor>(gf)) {
|
|
||||||
weight.push_back(std::sqrt(hf->information().squaredNorm()));
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
|
|
||||||
case SubgraphBuilderParameters::RANDOM:
|
|
||||||
weight.push_back(std::rand() % 100 + 1.0);
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"SubgraphBuilder::weights: undefined weight scheme ");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return weight;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,16 @@ struct GTSAM_EXPORT SubgraphBuilderParameters {
|
||||||
static std::string augmentationWeightTranslator(AugmentationWeight w);
|
static std::string augmentationWeightTranslator(AugmentationWeight w);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace utils
|
||||||
|
{
|
||||||
|
|
||||||
|
std::vector<double> assignWeights(const GaussianFactorGraph &gfg,
|
||||||
|
const SubgraphBuilderParameters::SkeletonWeight &skeleton_weight);
|
||||||
|
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
|
||||||
|
const FastMap<Key, size_t> &ordering,
|
||||||
|
const std::vector<double> &weights);
|
||||||
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
class GTSAM_EXPORT SubgraphBuilder {
|
class GTSAM_EXPORT SubgraphBuilder {
|
||||||
public:
|
public:
|
||||||
|
|
@ -161,6 +171,8 @@ class GTSAM_EXPORT SubgraphBuilder {
|
||||||
virtual ~SubgraphBuilder() {}
|
virtual ~SubgraphBuilder() {}
|
||||||
virtual Subgraph operator()(const GaussianFactorGraph &jfg) const;
|
virtual Subgraph operator()(const GaussianFactorGraph &jfg) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<size_t> buildTree(const GaussianFactorGraph &gfg,
|
std::vector<size_t> buildTree(const GaussianFactorGraph &gfg,
|
||||||
const FastMap<Key, size_t> &ordering,
|
const FastMap<Key, size_t> &ordering,
|
||||||
|
|
@ -168,12 +180,13 @@ class GTSAM_EXPORT SubgraphBuilder {
|
||||||
std::vector<size_t> unary(const GaussianFactorGraph &gfg) const;
|
std::vector<size_t> unary(const GaussianFactorGraph &gfg) const;
|
||||||
std::vector<size_t> natural_chain(const GaussianFactorGraph &gfg) const;
|
std::vector<size_t> natural_chain(const GaussianFactorGraph &gfg) const;
|
||||||
std::vector<size_t> bfs(const GaussianFactorGraph &gfg) const;
|
std::vector<size_t> bfs(const GaussianFactorGraph &gfg) const;
|
||||||
|
std::vector<size_t> sample(const std::vector<double> &weights,
|
||||||
|
const size_t t) const;
|
||||||
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
|
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
|
||||||
const FastMap<Key, size_t> &ordering,
|
const FastMap<Key, size_t> &ordering,
|
||||||
const std::vector<double> &weights) const;
|
const std::vector<double> &weights) const;
|
||||||
std::vector<size_t> sample(const std::vector<double> &weights,
|
|
||||||
const size_t t) const;
|
|
||||||
Weights weights(const GaussianFactorGraph &gfg) const ;
|
Weights weights(const GaussianFactorGraph &gfg) const ;
|
||||||
|
|
||||||
SubgraphBuilderParameters parameters_;
|
SubgraphBuilderParameters parameters_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,43 @@ TEST( SubgraphSolver, constructor3 )
|
||||||
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
|
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(SubgraphBuilder, utilsKruskal)
|
||||||
|
{
|
||||||
|
|
||||||
|
const auto [g, _] = example::planarGraph(N); // A*x-b
|
||||||
|
|
||||||
|
const FastMap<Key, size_t> forward_ordering = Ordering::Natural(g).invert();
|
||||||
|
const auto weights = utils::assignWeights(g, gtsam::SubgraphBuilderParameters::SkeletonWeight::EQUAL);
|
||||||
|
|
||||||
|
const auto mstEdgeIndices = utils::kruskal(g, forward_ordering, weights);
|
||||||
|
|
||||||
|
// auto PrintMst = [](const auto &graph, const auto &mst_edge_indices)
|
||||||
|
// {
|
||||||
|
// std::cout << "MST Edge indices are: \n";
|
||||||
|
// for (const auto &edge : mst_edge_indices)
|
||||||
|
// {
|
||||||
|
// std::cout << edge << " : ";
|
||||||
|
// for (const auto &key : graph[edge]->keys())
|
||||||
|
// {
|
||||||
|
// std::cout << gtsam::DefaultKeyFormatter(gtsam::Symbol(key)) << ", ";
|
||||||
|
// }
|
||||||
|
// std::cout << "\n";
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
// PrintMst(g, mstEdgeIndices);
|
||||||
|
|
||||||
|
EXPECT(mstEdgeIndices[0] == 1);
|
||||||
|
EXPECT(mstEdgeIndices[1] == 2);
|
||||||
|
EXPECT(mstEdgeIndices[2] == 3);
|
||||||
|
EXPECT(mstEdgeIndices[3] == 4);
|
||||||
|
EXPECT(mstEdgeIndices[4] == 5);
|
||||||
|
EXPECT(mstEdgeIndices[5] == 6);
|
||||||
|
EXPECT(mstEdgeIndices[6] == 7);
|
||||||
|
EXPECT(mstEdgeIndices[7] == 8);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue