diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 776d4bd90..f68d2ae00 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -195,7 +195,7 @@ namespace gtsam { // Construct unordered_map with values std::vector> result; for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); + result.emplace_back(assignment, evaluate(assignment)); } return result; } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 9172180b1..e1eb2a453 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -137,10 +137,13 @@ namespace gtsam { /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double operator()(const Assignment& values) const override { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } + /// Disambiguate to use DiscreteFactor version. Mainly for wrapper + using DiscreteFactor::operator(); + /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index f59e29285..858623301 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional } /// Evaluate, just look up in AlgebraicDecisionTree - double evaluate(const DiscreteValues& values) const { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using DiscreteFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 7175f7608..fa1179c39 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -106,12 +106,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { * @param values Discrete assignment. * @return double */ - double evaluate(const Assignment& values) const { - return operator()(values); - } + virtual double evaluate(const Assignment& values) const = 0; /// Find value for given assignment of values to variables - virtual double operator()(const Assignment& values) const = 0; + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 32cba84ed..ea51a996c 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const Assignment& values) const { +double TableFactor::evaluate(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index e7ed2294f..07564892f 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -152,7 +152,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @{ /// Evaluate probability distribution, is just look up in TableFactor. - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index d82ba9794..b2e2524f8 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -41,7 +41,7 @@ virtual class DiscreteFactor : gtsam::Factor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; - double operator()(const gtsam::Assignment& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; #include @@ -61,15 +61,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(const std::vector& keys, string table); DecisionTreeFactor(const gtsam::DiscreteConditional& c); - + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; size_t cardinality(gtsam::Key j) const; - - double operator()(const gtsam::Assignment& values) const; + + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; @@ -248,6 +248,7 @@ class DiscreteBayesTree { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index a450605b3..585ca8103 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const Assignment& values) const { +double AllDiff::evaluate(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 6fe43568e..1180abad4 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index a55865c77..e96bfdfde 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const Assignment& values) const override { + double evaluate(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 752228c18..74f621dc7 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const Assignment& values) const { +double Domain::evaluate(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 13249d733..23a566d24 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 9762aec0f..220bc9c06 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const Assignment& values) const { +double SingleValue::evaluate(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 93fe38aaa..3df1209b8 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const Assignment& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12308bb3c..a78d9c94a 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,9 +13,10 @@ Author: Frank Dellaert import unittest +from gtsam.utils.test_case import GtsamTestCase + from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues, Ordering) -from gtsam.utils.test_case import GtsamTestCase class TestDecisionTreeFactor(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index 2a9b6ea09..e08491fab 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -19,8 +19,8 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, - DiscreteValues, Ordering) + DiscreteConditional, DiscreteFactorGraph, DiscreteValues, + Ordering) class TestDiscreteBayesNet(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 241a5f0be..6c9eb9aec 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -13,9 +13,10 @@ Author: Varun Agrawal import unittest -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys + # Some DiscreteKeys for binary variables: A = 0, 2 B = 1, 2 diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index d725ceac8..3053087b4 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -14,9 +14,12 @@ Author: Frank Dellaert import unittest import numpy as np -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam.utils.test_case import GtsamTestCase +from gtsam import (DecisionTreeFactor, DiscreteConditional, + DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, + Symbol) + OrderingType = Ordering.OrderingType