factor out class methods 'weights' and 'kruskal' into free-functions
parent
1f574c95ea
commit
981d8e9391
|
|
@ -238,6 +238,101 @@ std::string SubgraphBuilderParameters::augmentationWeightTranslator(
|
|||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
/****************************************************************/
|
||||
std::vector<double> utils::assignWeights(const GaussianFactorGraph &gfg, const SubgraphBuilderParameters::SkeletonWeight &skeletonWeight)
|
||||
{
|
||||
using Weights = std::vector<double>;
|
||||
|
||||
const size_t m = gfg.size();
|
||||
Weights weights;
|
||||
weights.reserve(m);
|
||||
|
||||
for (const GaussianFactor::shared_ptr &gf : gfg)
|
||||
{
|
||||
switch (skeletonWeight)
|
||||
{
|
||||
case SubgraphBuilderParameters::EQUAL:
|
||||
weights.push_back(1.0);
|
||||
break;
|
||||
case SubgraphBuilderParameters::RHS_2NORM:
|
||||
{
|
||||
if (JacobianFactor::shared_ptr jf =
|
||||
std::dynamic_pointer_cast<JacobianFactor>(gf))
|
||||
{
|
||||
weights.push_back(jf->getb().norm());
|
||||
}
|
||||
else if (HessianFactor::shared_ptr hf =
|
||||
std::dynamic_pointer_cast<HessianFactor>(gf))
|
||||
{
|
||||
weights.push_back(hf->linearTerm().norm());
|
||||
}
|
||||
}
|
||||
break;
|
||||
case SubgraphBuilderParameters::LHS_FNORM:
|
||||
{
|
||||
if (JacobianFactor::shared_ptr jf =
|
||||
std::dynamic_pointer_cast<JacobianFactor>(gf))
|
||||
{
|
||||
weights.push_back(std::sqrt(jf->getA().squaredNorm()));
|
||||
}
|
||||
else if (HessianFactor::shared_ptr hf =
|
||||
std::dynamic_pointer_cast<HessianFactor>(gf))
|
||||
{
|
||||
weights.push_back(std::sqrt(hf->information().squaredNorm()));
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case SubgraphBuilderParameters::RANDOM:
|
||||
weights.push_back(std::rand() % 100 + 1.0);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"utils::assign_weights: undefined weight scheme ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
|
||||
/****************************************************************/
|
||||
std::vector<size_t> utils::kruskal(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
const std::vector<double> &weights)
|
||||
{
|
||||
const VariableIndex variableIndex(gfg);
|
||||
const size_t n = variableIndex.size();
|
||||
const vector<size_t> sortedIndices = sort_idx(weights);
|
||||
|
||||
/* initialize buffer */
|
||||
vector<size_t> treeIndices;
|
||||
treeIndices.reserve(n - 1);
|
||||
|
||||
// container for acsendingly sorted edges
|
||||
DSFVector dsf(n);
|
||||
|
||||
size_t count = 0;
|
||||
for (const size_t index : sortedIndices)
|
||||
{
|
||||
const GaussianFactor &gf = *gfg[index];
|
||||
const auto keys = gf.keys();
|
||||
if (keys.size() != 2)
|
||||
continue;
|
||||
const size_t u = ordering.find(keys[0])->second,
|
||||
v = ordering.find(keys[1])->second;
|
||||
if (dsf.find(u) != dsf.find(v))
|
||||
{
|
||||
dsf.merge(u, v);
|
||||
treeIndices.push_back(index);
|
||||
if (++count == n - 1)
|
||||
break;
|
||||
}
|
||||
}
|
||||
return treeIndices;
|
||||
}
|
||||
|
||||
/****************************************************************/
|
||||
vector<size_t> SubgraphBuilder::buildTree(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
|
|
@ -329,31 +424,7 @@ vector<size_t> SubgraphBuilder::bfs(const GaussianFactorGraph &gfg) const {
|
|||
vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
const vector<double> &weights) const {
|
||||
const VariableIndex variableIndex(gfg);
|
||||
const size_t n = variableIndex.size();
|
||||
const vector<size_t> sortedIndices = sort_idx(weights);
|
||||
|
||||
/* initialize buffer */
|
||||
vector<size_t> treeIndices;
|
||||
treeIndices.reserve(n - 1);
|
||||
|
||||
// container for acsendingly sorted edges
|
||||
DSFVector dsf(n);
|
||||
|
||||
size_t count = 0;
|
||||
for (const size_t index : sortedIndices) {
|
||||
const GaussianFactor &gf = *gfg[index];
|
||||
const auto keys = gf.keys();
|
||||
if (keys.size() != 2) continue;
|
||||
const size_t u = ordering.find(keys[0])->second,
|
||||
v = ordering.find(keys[1])->second;
|
||||
if (dsf.find(u) != dsf.find(v)) {
|
||||
dsf.merge(u, v);
|
||||
treeIndices.push_back(index);
|
||||
if (++count == n - 1) break;
|
||||
}
|
||||
}
|
||||
return treeIndices;
|
||||
return utils::kruskal(gfg, ordering, weights);
|
||||
}
|
||||
|
||||
/****************************************************************/
|
||||
|
|
@ -406,45 +477,7 @@ Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const {
|
|||
/****************************************************************/
|
||||
SubgraphBuilder::Weights SubgraphBuilder::weights(
|
||||
const GaussianFactorGraph &gfg) const {
|
||||
const size_t m = gfg.size();
|
||||
Weights weight;
|
||||
weight.reserve(m);
|
||||
|
||||
for (const GaussianFactor::shared_ptr &gf : gfg) {
|
||||
switch (parameters_.skeletonWeight) {
|
||||
case SubgraphBuilderParameters::EQUAL:
|
||||
weight.push_back(1.0);
|
||||
break;
|
||||
case SubgraphBuilderParameters::RHS_2NORM: {
|
||||
if (JacobianFactor::shared_ptr jf =
|
||||
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
|
||||
weight.push_back(jf->getb().norm());
|
||||
} else if (HessianFactor::shared_ptr hf =
|
||||
std::dynamic_pointer_cast<HessianFactor>(gf)) {
|
||||
weight.push_back(hf->linearTerm().norm());
|
||||
}
|
||||
} break;
|
||||
case SubgraphBuilderParameters::LHS_FNORM: {
|
||||
if (JacobianFactor::shared_ptr jf =
|
||||
std::dynamic_pointer_cast<JacobianFactor>(gf)) {
|
||||
weight.push_back(std::sqrt(jf->getA().squaredNorm()));
|
||||
} else if (HessianFactor::shared_ptr hf =
|
||||
std::dynamic_pointer_cast<HessianFactor>(gf)) {
|
||||
weight.push_back(std::sqrt(hf->information().squaredNorm()));
|
||||
}
|
||||
} break;
|
||||
|
||||
case SubgraphBuilderParameters::RANDOM:
|
||||
weight.push_back(std::rand() % 100 + 1.0);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"SubgraphBuilder::weights: undefined weight scheme ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return weight;
|
||||
return utils::assignWeights(gfg, parameters_.skeletonWeight);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -149,6 +149,16 @@ struct GTSAM_EXPORT SubgraphBuilderParameters {
|
|||
static std::string augmentationWeightTranslator(AugmentationWeight w);
|
||||
};
|
||||
|
||||
namespace utils
|
||||
{
|
||||
|
||||
std::vector<double> assignWeights(const GaussianFactorGraph &gfg,
|
||||
const SubgraphBuilderParameters::SkeletonWeight &skeleton_weight);
|
||||
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
const std::vector<double> &weights);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
class GTSAM_EXPORT SubgraphBuilder {
|
||||
public:
|
||||
|
|
@ -161,6 +171,8 @@ class GTSAM_EXPORT SubgraphBuilder {
|
|||
virtual ~SubgraphBuilder() {}
|
||||
virtual Subgraph operator()(const GaussianFactorGraph &jfg) const;
|
||||
|
||||
public:
|
||||
|
||||
private:
|
||||
std::vector<size_t> buildTree(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
|
|
@ -168,12 +180,13 @@ class GTSAM_EXPORT SubgraphBuilder {
|
|||
std::vector<size_t> unary(const GaussianFactorGraph &gfg) const;
|
||||
std::vector<size_t> natural_chain(const GaussianFactorGraph &gfg) const;
|
||||
std::vector<size_t> bfs(const GaussianFactorGraph &gfg) const;
|
||||
std::vector<size_t> sample(const std::vector<double> &weights,
|
||||
const size_t t) const;
|
||||
std::vector<size_t> kruskal(const GaussianFactorGraph &gfg,
|
||||
const FastMap<Key, size_t> &ordering,
|
||||
const std::vector<double> &weights) const;
|
||||
std::vector<size_t> sample(const std::vector<double> &weights,
|
||||
const size_t t) const;
|
||||
Weights weights(const GaussianFactorGraph &gfg) const;
|
||||
Weights weights(const GaussianFactorGraph &gfg) const ;
|
||||
|
||||
SubgraphBuilderParameters parameters_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,43 @@ TEST( SubgraphSolver, constructor3 )
|
|||
DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(SubgraphBuilder, utilsKruskal)
|
||||
{
|
||||
|
||||
const auto [g, _] = example::planarGraph(N); // A*x-b
|
||||
|
||||
const FastMap<Key, size_t> forward_ordering = Ordering::Natural(g).invert();
|
||||
const auto weights = utils::assignWeights(g, gtsam::SubgraphBuilderParameters::SkeletonWeight::EQUAL);
|
||||
|
||||
const auto mstEdgeIndices = utils::kruskal(g, forward_ordering, weights);
|
||||
|
||||
// auto PrintMst = [](const auto &graph, const auto &mst_edge_indices)
|
||||
// {
|
||||
// std::cout << "MST Edge indices are: \n";
|
||||
// for (const auto &edge : mst_edge_indices)
|
||||
// {
|
||||
// std::cout << edge << " : ";
|
||||
// for (const auto &key : graph[edge]->keys())
|
||||
// {
|
||||
// std::cout << gtsam::DefaultKeyFormatter(gtsam::Symbol(key)) << ", ";
|
||||
// }
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// };
|
||||
|
||||
// PrintMst(g, mstEdgeIndices);
|
||||
|
||||
EXPECT(mstEdgeIndices[0] == 1);
|
||||
EXPECT(mstEdgeIndices[1] == 2);
|
||||
EXPECT(mstEdgeIndices[2] == 3);
|
||||
EXPECT(mstEdgeIndices[3] == 4);
|
||||
EXPECT(mstEdgeIndices[4] == 5);
|
||||
EXPECT(mstEdgeIndices[5] == 6);
|
||||
EXPECT(mstEdgeIndices[6] == 7);
|
||||
EXPECT(mstEdgeIndices[7] == 8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
Loading…
Reference in New Issue