From 9f88a360dfd54f1e00361344f24d9d2ef085b50f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 18:29:13 -0500 Subject: [PATCH] make evaluate use the Assignment base class --- gtsam/discrete/DecisionTreeFactor.h | 4 ++-- gtsam/discrete/DiscreteFactor.h | 3 +++ gtsam/discrete/TableFactor.cpp | 2 +- gtsam/discrete/TableFactor.h | 9 ++++++--- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 642187ff1..bf0e23b50 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -129,9 +129,9 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const { + double evaluate(const Assignment& values) const override { return ADT::operator()(values); } diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 29984d795..23abf725e 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { virtual DiscreteFactor::shared_ptr operator/( const DiscreteFactor::shared_ptr& f) const = 0; + /// Calculate probability for given values + virtual double evaluate(const Assignment& values) const = 0; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a8adf0918..727c96ce4 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -135,7 +135,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const DiscreteValues& values) const { +double TableFactor::operator()(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 12266b2a5..ea222ca5c 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -169,14 +169,17 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { + double evaluate(const Assignment& values) const override { return operator()(values); } /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + double operator()(const Assignment& values) const; + double operator()(const DiscreteValues& values) const override { + return operator()(Assignment(values)); + } /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override;