Merge pull request #1926 from borglab/common-ops

release/4.3a0
Varun Agrawal 2024-12-10 10:18:29 -05:00 committed by GitHub
commit 77ba91bf52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 15 deletions

View File

@ -83,7 +83,7 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const {
// apply operand // apply operand
ADT result = ADT::apply(op); ADT result = ADT::apply(op);
// Make a new factor // Make a new factor
@ -91,7 +91,7 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const {
// apply operand // apply operand
ADT result = ADT::apply(op); ADT result = ADT::apply(op);
// Make a new factor // Make a new factor
@ -100,7 +100,7 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { Binary op) const {
map<Key, size_t> cs; // new cardinalities map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
for (Key j : keys()) cs[j] = cardinality(j); for (Key j : keys()) cs[j] = cardinality(j);
@ -118,8 +118,8 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
size_t nrFrontals, ADT::Binary op) const { Binary op) const {
if (nrFrontals > size()) { if (nrFrontals > size()) {
throw invalid_argument( throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal " "DecisionTreeFactor::combine: invalid number of frontal "
@ -146,7 +146,7 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const { const Ordering& frontalKeys, Binary op) const {
if (frontalKeys.size() > size()) { if (frontalKeys.size() > size()) {
throw invalid_argument( throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal " "DecisionTreeFactor::combine: invalid number of frontal "

View File

@ -51,6 +51,11 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr; typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// Needed since we have definitions in both DiscreteFactor and DecisionTree
using Base::Binary;
using Base::Unary;
using Base::UnaryAssignment;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -185,21 +190,21 @@ namespace gtsam {
* Apply unary operator (*this) "op" f * Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree * @param op a unary operator that operates on AlgebraicDecisionTree
*/ */
DecisionTreeFactor apply(ADT::Unary op) const; DecisionTreeFactor apply(Unary op) const;
/** /**
* Apply unary operator (*this) "op" f * Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes * @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value. * both the assignment and the value.
*/ */
DecisionTreeFactor apply(ADT::UnaryAssignment op) const; DecisionTreeFactor apply(UnaryAssignment op) const;
/** /**
* Apply binary operator (*this) "op" f * Apply binary operator (*this) "op" f
* @param f the second argument for op * @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionTree * @param op a binary operator that operates on AlgebraicDecisionTree
*/ */
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const;
/** /**
* Combine frontal variables using binary operator "op" * Combine frontal variables using binary operator "op"
@ -207,7 +212,7 @@ namespace gtsam {
* @param op a binary operator that operates on AlgebraicDecisionTree * @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor * @return shared pointer to newly created DecisionTreeFactor
*/ */
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; shared_ptr combine(size_t nrFrontals, Binary op) const;
/** /**
* Combine frontal variables in an Ordering using binary operator "op" * Combine frontal variables in an Ordering using binary operator "op"
@ -215,7 +220,7 @@ namespace gtsam {
* @param op a binary operator that operates on AlgebraicDecisionTree * @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor * @return shared pointer to newly created DecisionTreeFactor
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const; shared_ptr combine(const Ordering& keys, Binary op) const;
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;

View File

@ -46,6 +46,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
using Values = DiscreteValues; ///< backwards compatibility using Values = DiscreteValues; ///< backwards compatibility
using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>;
protected: protected:
/// Map of Keys and their cardinalities. /// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_; std::map<Key, size_t> cardinalities_;

View File

@ -94,10 +94,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef std::shared_ptr<TableFactor> shared_ptr; typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt; typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList; typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>;
public: public:
/// @name Standard Constructors /// @name Standard Constructors