diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 74eb3ddb3..9a0520a34 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,6 +56,16 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, dtf.probabilities()) {} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTree& dtree) + : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) : TableFactor(c.discreteKeys(), c.probabilities()) {} diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index bd637bb7d..103e14bff 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -141,6 +141,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /// Constructor from DecisionTreeFactor + TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + + /// Constructor from DecisionTree/AlgebraicDecisionTree + TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); + /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -177,7 +183,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiple with DecisionTreeFactor + /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b);