diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 9a3cde96d..8f3dd85a6 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -130,14 +130,9 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const { - return ADT::operator()(values); - } - - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 19af5bd13..5bc3c463f 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** + * @brief Calculate probability for given values. + * Calls specialized evaluation under the hood. + * + * Note: Uses Assignment as it is the base class of DiscreteValues. + * + * @param values Discrete assignment. + * @return double + */ + double evaluate(const Assignment& values) const { + return operator()(values); + } + /// Find value for given assignment of values to variables - virtual double operator()(const DiscreteValues&) const = 0; + virtual double operator()(const Assignment& values) const = 0; /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..32cba84ed 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const DiscreteValues& values) const { +double TableFactor::operator()(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { @@ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } + // NOTE(Varun): This constructor is really expensive!! DecisionTreeFactor f(dkeys, table); return f; } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 7015353e1..53495d078 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); - } - - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// Evaluate probability distribution, is just look up in TableFactor. + double operator()(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 2bd9e6dfd..a450605b3 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const DiscreteValues& values) const { +double AllDiff::operator()(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index d7a63eae0..6fe43568e 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 18b335092..a55865c77 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override { + double operator()(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index bbbc87667..752228c18 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const DiscreteValues& values) const { +double Domain::operator()(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 7f7b717c2..13249d733 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6b78f38f5..9762aec0f 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const DiscreteValues& values) const { +double SingleValue::operator()(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3f7f22d6a..93fe38aaa 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override;