From f8d9da1b8edeebdba2acdfc2100da85d6b8ed2e0 Mon Sep 17 00:00:00 2001 From: Joel Truher Date: Mon, 9 Dec 2024 06:57:35 -0800 Subject: [PATCH 1/7] actually fix CHECK_EQUAL --- CppUnitLite/Test.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CppUnitLite/Test.h b/CppUnitLite/Test.h index 9ee0d10eb..c3fa83234 100644 --- a/CppUnitLite/Test.h +++ b/CppUnitLite/Test.h @@ -129,7 +129,7 @@ protected: result_.addFailure(Failure(name_, __FILE__, __LINE__, #expected, #actual)); } #define CHECK_EQUAL(expected,actual)\ -{ if (!((expected) == (actual))) { result_.addFailure(Failure(name_, __FILE__, __LINE__, std::to_string(expected), std::to_string(actual))); } } +{ if (!((expected) == (actual))) { result_.addFailure(Failure(name_, __FILE__, __LINE__, std::to_string(expected), std::to_string(actual))); return; } } #define LONGS_EQUAL(expected,actual)\ { long actualTemp = actual; \ From 88b36da602c9697ca86db5d42921fa6d058cc1af Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 16:55:15 -0500 Subject: [PATCH 2/7] make evaluate a common base method --- gtsam/discrete/DecisionTreeFactor.h | 9 ++------- gtsam/discrete/DiscreteFactor.h | 15 ++++++++++++++- gtsam/discrete/TableFactor.cpp | 3 ++- gtsam/discrete/TableFactor.h | 10 ++-------- gtsam_unstable/discrete/AllDiff.cpp | 2 +- gtsam_unstable/discrete/AllDiff.h | 2 +- gtsam_unstable/discrete/BinaryAllDiff.h | 2 +- gtsam_unstable/discrete/Domain.cpp | 2 +- gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.cpp | 2 +- gtsam_unstable/discrete/SingleValue.h | 2 +- 11 files changed, 27 insertions(+), 24 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 9a3cde96d..8f3dd85a6 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -130,14 +130,9 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const { - return ADT::operator()(values); - } - - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 19af5bd13..5bc3c463f 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** + * @brief Calculate probability for given values. + * Calls specialized evaluation under the hood. + * + * Note: Uses Assignment as it is the base class of DiscreteValues. + * + * @param values Discrete assignment. + * @return double + */ + double evaluate(const Assignment& values) const { + return operator()(values); + } + /// Find value for given assignment of values to variables - virtual double operator()(const DiscreteValues&) const = 0; + virtual double operator()(const Assignment& values) const = 0; /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..32cba84ed 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const DiscreteValues& values) const { +double TableFactor::operator()(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { @@ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } + // NOTE(Varun): This constructor is really expensive!! DecisionTreeFactor f(dkeys, table); return f; } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 7015353e1..53495d078 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); - } - - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// Evaluate probability distribution, is just look up in TableFactor. + double operator()(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 2bd9e6dfd..a450605b3 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const DiscreteValues& values) const { +double AllDiff::operator()(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index d7a63eae0..6fe43568e 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 18b335092..a55865c77 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index bbbc87667..752228c18 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const DiscreteValues& values) const { +double Domain::operator()(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 7f7b717c2..13249d733 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6b78f38f5..9762aec0f 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const DiscreteValues& values) const { +double SingleValue::operator()(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3f7f22d6a..93fe38aaa 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From e0fedda712808b557ef229b2ec4cda50f25fb743 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 17:22:51 -0500 Subject: [PATCH 3/7] make the Unary and Binary ops common --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++------ gtsam/discrete/DecisionTreeFactor.h | 15 ++++++++++----- gtsam/discrete/DiscreteFactor.h | 5 +++++ gtsam/discrete/TableFactor.h | 4 ---- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac5..776d4bd90 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -83,7 +83,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { + DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -91,7 +91,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -100,7 +100,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { + Binary op) const { map cs; // new cardinalities // make unique key-cardinality map for (Key j : keys()) cs[j] = cardinality(j); @@ -118,8 +118,8 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - size_t nrFrontals, ADT::Binary op) const { + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, + Binary op) const { if (nrFrontals > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " @@ -146,7 +146,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - const Ordering& frontalKeys, ADT::Binary op) const { + const Ordering& frontalKeys, Binary op) const { if (frontalKeys.size() > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8f3dd85a6..9172180b1 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -51,6 +51,11 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; + // Needed since we have definitions in both DiscreteFactor and DecisionTree + using Base::Binary; + using Base::Unary; + using Base::UnaryAssignment; + /// @name Standard Constructors /// @{ @@ -182,21 +187,21 @@ namespace gtsam { * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(ADT::Unary op) const; + DecisionTreeFactor apply(Unary op) const; /** * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree. Takes * both the assignment and the value. */ - DecisionTreeFactor apply(ADT::UnaryAssignment op) const; + DecisionTreeFactor apply(UnaryAssignment op) const; /** * Apply binary operator (*this) "op" f * @param f the second argument for op * @param op a binary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; + DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const; /** * Combine frontal variables using binary operator "op" @@ -204,7 +209,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + shared_ptr combine(size_t nrFrontals, Binary op) const; /** * Combine frontal variables in an Ordering using binary operator "op" @@ -212,7 +217,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(const Ordering& keys, ADT::Binary op) const; + shared_ptr combine(const Ordering& keys, Binary op) const; /// Enumerate all values into a map from values to double. std::vector> enumerate() const; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 5bc3c463f..7175f7608 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -46,6 +46,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { using Values = DiscreteValues; ///< backwards compatibility + using Unary = std::function; + using UnaryAssignment = + std::function&, const double&)>; + using Binary = std::function; + protected: /// Map of Keys and their cardinalities. std::map cardinalities_; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 53495d078..e7ed2294f 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -94,10 +94,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef std::shared_ptr shared_ptr; typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; - using Unary = std::function; - using UnaryAssignment = - std::function&, const double&)>; - using Binary = std::function; public: /// @name Standard Constructors From 1152470800d2f7845343b20ce592e32837564b6a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 17:37:11 -0500 Subject: [PATCH 4/7] fix wrapper --- gtsam/discrete/discrete.i | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 55c8f9e22..d82ba9794 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -41,7 +41,7 @@ virtual class DiscreteFactor : gtsam::Factor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; - double operator()(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::Assignment& values) const; }; #include @@ -69,7 +69,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { size_t cardinality(gtsam::Key j) const; - double operator()(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::Assignment& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; @@ -248,7 +248,6 @@ class DiscreteBayesTree { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - double operator()(const gtsam::DiscreteValues& values) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; From ab943b539ec3f7c186aa40be2579407fa320278d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 18:21:26 -0500 Subject: [PATCH 5/7] Revert "fix wrapper" This reverts commit 1152470800d2f7845343b20ce592e32837564b6a. --- gtsam/discrete/discrete.i | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index d82ba9794..55c8f9e22 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -41,7 +41,7 @@ virtual class DiscreteFactor : gtsam::Factor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; - double operator()(const gtsam::Assignment& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -69,7 +69,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { size_t cardinality(gtsam::Key j) const; - double operator()(const gtsam::Assignment& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; @@ -248,6 +248,7 @@ class DiscreteBayesTree { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; From a98ac0fdb232e8bf50b48df7008565010bfd0a4f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 21:09:00 -0500 Subject: [PATCH 6/7] make evaluate the overriden method --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- gtsam/discrete/DecisionTreeFactor.h | 5 ++++- gtsam/discrete/DiscreteConditional.h | 4 ++-- gtsam/discrete/DiscreteFactor.h | 8 ++++---- gtsam/discrete/TableFactor.cpp | 2 +- gtsam/discrete/TableFactor.h | 2 +- gtsam_unstable/discrete/AllDiff.cpp | 2 +- gtsam_unstable/discrete/AllDiff.h | 2 +- gtsam_unstable/discrete/BinaryAllDiff.h | 2 +- gtsam_unstable/discrete/Domain.cpp | 2 +- gtsam_unstable/discrete/Domain.h | 2 +- gtsam_unstable/discrete/SingleValue.cpp | 2 +- gtsam_unstable/discrete/SingleValue.h | 2 +- 13 files changed, 20 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac5..a57915e45 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -195,7 +195,7 @@ namespace gtsam { // Construct unordered_map with values std::vector> result; for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); + result.emplace_back(assignment, evaluate(assignment)); } return result; } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8f3dd85a6..741715e43 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -132,10 +132,13 @@ namespace gtsam { /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double operator()(const Assignment& values) const override { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } + /// Disambiguate to use DiscreteFactor version. Mainly for wrapper + using DiscreteFactor::operator(); + /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index f59e29285..858623301 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional } /// Evaluate, just look up in AlgebraicDecisionTree - double evaluate(const DiscreteValues& values) const { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using DiscreteFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 5bc3c463f..2ba670004 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -101,12 +101,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { * @param values Discrete assignment. * @return double */ - double evaluate(const Assignment& values) const { - return operator()(values); - } + virtual double evaluate(const Assignment& values) const = 0; /// Find value for given assignment of values to variables - virtual double operator()(const Assignment& values) const = 0; + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 32cba84ed..ea51a996c 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const Assignment& values) const { +double TableFactor::evaluate(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 53495d078..1aecc1669 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -156,7 +156,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @{ /// Evaluate probability distribution, is just look up in TableFactor. - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index a450605b3..585ca8103 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const Assignment& values) const { +double AllDiff::evaluate(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 6fe43568e..1180abad4 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index a55865c77..e96bfdfde 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const Assignment& values) const override { + double evaluate(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 752228c18..74f621dc7 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const Assignment& values) const { +double Domain::evaluate(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 13249d733..23a566d24 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 9762aec0f..220bc9c06 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const Assignment& values) const { +double SingleValue::evaluate(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 93fe38aaa..3df1209b8 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; From 8145086e5b57315bc94491def0bcdc67aa48148d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 21:09:37 -0500 Subject: [PATCH 7/7] formatting --- gtsam/discrete/discrete.i | 4 ++-- python/gtsam/tests/test_DecisionTreeFactor.py | 3 ++- python/gtsam/tests/test_DiscreteBayesTree.py | 4 ++-- python/gtsam/tests/test_DiscreteConditional.py | 3 ++- python/gtsam/tests/test_DiscreteFactorGraph.py | 5 ++++- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 55c8f9e22..b2e2524f8 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -61,14 +61,14 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(const std::vector& keys, string table); DecisionTreeFactor(const gtsam::DiscreteConditional& c); - + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; size_t cardinality(gtsam::Key j) const; - + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12308bb3c..a78d9c94a 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,9 +13,10 @@ Author: Frank Dellaert import unittest +from gtsam.utils.test_case import GtsamTestCase + from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues, Ordering) -from gtsam.utils.test_case import GtsamTestCase class TestDecisionTreeFactor(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index 2a9b6ea09..e08491fab 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -19,8 +19,8 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, - DiscreteValues, Ordering) + DiscreteConditional, DiscreteFactorGraph, DiscreteValues, + Ordering) class TestDiscreteBayesNet(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 241a5f0be..6c9eb9aec 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -13,9 +13,10 @@ Author: Varun Agrawal import unittest -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys + # Some DiscreteKeys for binary variables: A = 0, 2 B = 1, 2 diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index d725ceac8..3053087b4 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -14,9 +14,12 @@ Author: Frank Dellaert import unittest import numpy as np -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam.utils.test_case import GtsamTestCase +from gtsam import (DecisionTreeFactor, DiscreteConditional, + DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, + Symbol) + OrderingType = Ordering.OrderingType