Wrapped elimination and junction trees

(cherry picked from commit 93d9ab6a2e)
release/4.3a0
Frank Dellaert 2023-06-05 02:11:13 +01:00
parent b7e2650a02
commit 07e8d24cbf
6 changed files with 86 additions and 6 deletions

View File

@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
} }

View File

@ -62,6 +62,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const; size_t cardinality(gtsam::Key j) const;
@ -247,9 +249,9 @@ class DiscreteBayesTree {
class DiscreteLookupTable : gtsam::DiscreteConditional{ class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials); const gtsam::DecisionTreeFactor::ADT& potentials);
void print( void print(string s = "Discrete Lookup Table: ",
const std::string& s = "Discrete Lookup Table: ", const gtsam::KeyFormatter& keyFormatter =
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
}; };
@ -333,4 +335,41 @@ class DiscreteFactorGraph {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscreteEliminationTree.h>
class DiscreteEliminationTree {
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/discrete/DiscreteJunctionTree.h>
class DiscreteCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::DiscreteFactorGraph factors;
const gtsam::DiscreteCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class DiscreteJunctionTree {
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -49,7 +49,7 @@ class ClusterTree {
virtual ~Cluster() {} virtual ~Cluster() {}
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(children[i]); return *(children.at(i));
} }
/// Construct from factors associated with a single key /// Construct from factors associated with a single key
@ -161,7 +161,7 @@ class ClusterTree {
} }
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(roots_[i]); return *(roots_.at(i));
} }
/// @} /// @}

View File

@ -148,7 +148,7 @@ class Ordering {
// Standard interface // Standard interface
size_t size() const; size_t size() const;
size_t at(size_t key) const; size_t at(size_t i) const;
void push_back(size_t key); void push_back(size_t key);
// enabling serialization functionality // enabling serialization functionality

View File

@ -65,4 +65,6 @@ namespace gtsam {
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree); SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using SymbolicCluster = SymbolicJunctionTree::Cluster;
} }

View File

@ -137,6 +137,43 @@ class SymbolicBayesNet {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const; const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicEliminationTree.h>
class SymbolicEliminationTree {
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/symbolic/SymbolicJunctionTree.h>
class SymbolicCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::SymbolicFactorGraph factors;
const gtsam::SymbolicCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class SymbolicJunctionTree {
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::SymbolicCluster& operator[](size_t i) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h> #include <gtsam/symbolic/SymbolicBayesTree.h>
class SymbolicBayesTreeClique { class SymbolicBayesTreeClique {