diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f417cf6fa..6b70f444b 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -66,4 +66,6 @@ namespace gtsam { DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); }; + /// typedef for wrapper: + using DiscreteCluster = DiscreteJunctionTree::Cluster; } diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b6c6ee2cd..6df443300 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -62,6 +62,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + size_t cardinality(gtsam::Key j) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; @@ -247,9 +249,9 @@ class DiscreteBayesTree { class DiscreteLookupTable : gtsam::DiscreteConditional{ DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, const gtsam::DecisionTreeFactor::ADT& potentials); - void print( - const std::string& s = "Discrete Lookup Table: ", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + void print(string s = "Discrete Lookup Table: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const; }; @@ -333,4 +335,41 @@ class DiscreteFactorGraph { std::map> names) const; }; +#include + +class DiscreteEliminationTree { + DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::VariableIndex& structure, + const gtsam::Ordering& order); + + DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::Ordering& order); + + void print( + string name = "EliminationTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteEliminationTree& other, + double tol = 1e-9) const; +}; + +#include + +class DiscreteCluster { + gtsam::Ordering orderedFrontalKeys; + gtsam::DiscreteFactorGraph factors; + const gtsam::DiscreteCluster& operator[](size_t i) const; + size_t nrChildren() const; + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +class DiscreteJunctionTree { + DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree); + void print( + string name = "JunctionTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + size_t nrRoots() const; + const gtsam::DiscreteCluster& operator[](size_t i) const; +}; + } // namespace gtsam diff --git a/gtsam/inference/ClusterTree.h b/gtsam/inference/ClusterTree.h index 26c853a7b..3711d6429 100644 --- a/gtsam/inference/ClusterTree.h +++ b/gtsam/inference/ClusterTree.h @@ -49,7 +49,7 @@ class ClusterTree { virtual ~Cluster() {} const Cluster& operator[](size_t i) const { - return *(children[i]); + return *(children.at(i)); } /// Construct from factors associated with a single key @@ -161,7 +161,7 @@ class ClusterTree { } const Cluster& operator[](size_t i) const { - return *(roots_[i]); + return *(roots_.at(i)); } /// @} diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i index 1f4f88e2b..17ea117c3 100644 --- a/gtsam/inference/inference.i +++ b/gtsam/inference/inference.i @@ -147,7 +147,7 @@ class Ordering { // Standard interface size_t size() const; - size_t at(size_t key) const; + size_t at(size_t i) const; void push_back(size_t key); // enabling serialization functionality diff --git a/gtsam/symbolic/SymbolicJunctionTree.h b/gtsam/symbolic/SymbolicJunctionTree.h index f1168f962..f0bd09cf6 100644 --- a/gtsam/symbolic/SymbolicJunctionTree.h +++ b/gtsam/symbolic/SymbolicJunctionTree.h @@ -65,4 +65,6 @@ namespace gtsam { SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree); }; + /// typedef for wrapper: + using SymbolicCluster = SymbolicJunctionTree::Cluster; } diff --git a/gtsam/symbolic/symbolic.i b/gtsam/symbolic/symbolic.i index c05f35895..4da59dfa9 100644 --- a/gtsam/symbolic/symbolic.i +++ b/gtsam/symbolic/symbolic.i @@ -137,6 +137,43 @@ class SymbolicBayesNet { const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; +#include + +class SymbolicEliminationTree { + SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph, + const gtsam::VariableIndex& structure, + const gtsam::Ordering& order); + + SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph, + const gtsam::Ordering& order); + + void print( + string name = "EliminationTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::SymbolicEliminationTree& other, + double tol = 1e-9) const; +}; + +#include + +class SymbolicCluster { + gtsam::Ordering orderedFrontalKeys; + gtsam::SymbolicFactorGraph factors; + const gtsam::SymbolicCluster& operator[](size_t i) const; + size_t nrChildren() const; + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +class SymbolicJunctionTree { + SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree); + void print( + string name = "JunctionTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + size_t nrRoots() const; + const gtsam::SymbolicCluster& operator[](size_t i) const; +}; + #include class SymbolicBayesTreeClique {