Wrapped elimination and junction trees

release/4.3a0
Frank Dellaert 2023-06-05 02:11:13 +01:00
parent fa7bde7529
commit 93d9ab6a2e
6 changed files with 86 additions and 6 deletions

View File

@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}

View File

@ -62,6 +62,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
@ -247,9 +249,9 @@ class DiscreteBayesTree {
class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials);
void print(
const std::string& s = "Discrete Lookup Table: ",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
void print(string s = "Discrete Lookup Table: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
};
@ -333,4 +335,41 @@ class DiscreteFactorGraph {
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteEliminationTree.h>
class DiscreteEliminationTree {
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/discrete/DiscreteJunctionTree.h>
class DiscreteCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::DiscreteFactorGraph factors;
const gtsam::DiscreteCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class DiscreteJunctionTree {
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
} // namespace gtsam

View File

@ -49,7 +49,7 @@ class ClusterTree {
virtual ~Cluster() {}
const Cluster& operator[](size_t i) const {
return *(children[i]);
return *(children.at(i));
}
/// Construct from factors associated with a single key
@ -161,7 +161,7 @@ class ClusterTree {
}
const Cluster& operator[](size_t i) const {
return *(roots_[i]);
return *(roots_.at(i));
}
/// @}

View File

@ -147,7 +147,7 @@ class Ordering {
// Standard interface
size_t size() const;
size_t at(size_t key) const;
size_t at(size_t i) const;
void push_back(size_t key);
// enabling serialization functionality

View File

@ -65,4 +65,6 @@ namespace gtsam {
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
};
/// typedef for wrapper:
using SymbolicCluster = SymbolicJunctionTree::Cluster;
}

View File

@ -137,6 +137,43 @@ class SymbolicBayesNet {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/symbolic/SymbolicEliminationTree.h>
class SymbolicEliminationTree {
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/symbolic/SymbolicJunctionTree.h>
class SymbolicCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::SymbolicFactorGraph factors;
const gtsam::SymbolicCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class SymbolicJunctionTree {
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::SymbolicCluster& operator[](size_t i) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h>
class SymbolicBayesTreeClique {