From b7da0f483821a40feeb4ef1eef521726366009fa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 Jul 2024 18:40:45 -0400 Subject: [PATCH 01/16] add DiscreteConditional constructor using table --- gtsam/discrete/discrete.i | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index a1731f8e5..0deeb8033 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -104,6 +104,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, + const std::vector& table); // Standard interface double logNormalizationConstant() const; From 0a7db4129098004e2263561250ba00f6380f7de2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 Jul 2024 18:54:25 -0400 Subject: [PATCH 02/16] print method for DiscreteKeys --- gtsam/discrete/discrete.i | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 0deeb8033..5eacb3634 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -14,6 +14,9 @@ class DiscreteKeys { bool empty() const; gtsam::DiscreteKey at(size_t n) const; void push_back(const gtsam::DiscreteKey& point_pair); + void print(const std::string& s = "", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; // DiscreteValues is added in specializations/discrete.h as a std::map @@ -162,7 +165,6 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { gtsam::DefaultKeyFormatter) const; double operator()(size_t value) const; std::vector pmf() const; - size_t argmax() const; }; #include From 89f7f7f72198b385c18fa6de73b9fcf70d0fc46b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 Jul 2024 23:43:29 -0400 Subject: [PATCH 03/16] improve DiscreteConditional::argmax method to accept parent values --- gtsam/discrete/DiscreteConditional.cpp | 10 ++++---- gtsam/discrete/DiscreteConditional.h | 2 +- .../tests/testDiscreteConditional.cpp | 23 +++++++++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5abc094fb..a7f472f26 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -235,16 +235,16 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } /* ************************************************************************** */ -size_t DiscreteConditional::argmax() const { +size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { size_t maxValue = 0; double maxP = 0; + DiscreteValues values = parentsValues; + assert(nrFrontals() == 1); - assert(nrParents() == 0); - DiscreteValues frontals; Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { - frontals[j] = value; - double pValueS = (*this)(frontals); + values[j] = value; + double pValueS = (*this)(values); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 50fa6e161..8f38a83be 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -217,7 +217,7 @@ class GTSAM_EXPORT DiscreteConditional * @brief Return assignment that maximizes distribution. * @return Optimal assignment (1 frontal variable). */ - size_t argmax() const; + size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; /// @} /// @name Advanced Interface diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index f2c6d7b9f..a11c87975 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -289,6 +289,29 @@ TEST(DiscreteConditional, choose) { EXPECT(assert_equal(expected3, *actual3, 1e-9)); } +/* ************************************************************************* */ +// Check argmax on P(C|D) and P(D) +TEST(DiscreteConditional, Argmax) { + DiscreteKey C(2, 2), D(4, 2); + DiscreteConditional D_cond(D, "1/3"); + DiscreteConditional C_given_DE((C | D) = "1/4 1/1"); + + // Case 1: No parents + size_t actual1 = D_cond.argmax(); + EXPECT_LONGS_EQUAL(1, actual1); + + // Case 2: Given parent values + DiscreteValues given; + given[D.first] = 1; + size_t actual2 = C_given_DE.argmax(given); + // Should be 0 since D=1 gives 0.5/0.5 + EXPECT_LONGS_EQUAL(0, actual2); + + given[D.first] = 0; + size_t actual3 = C_given_DE.argmax(given); + EXPECT_LONGS_EQUAL(1, actual3); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) { From 1657262c8747a3913fa3ad6083b816893c10d536 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 Jul 2024 23:44:43 -0400 Subject: [PATCH 04/16] DiscreteBayesNet::mode method to get maximizing assignment --- gtsam/discrete/DiscreteBayesNet.cpp | 8 ++++++ gtsam/discrete/DiscreteBayesNet.h | 8 ++++++ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 26 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index f754250ed..bce14ad46 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -62,6 +62,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +DiscreteValues DiscreteBayesNet::mode() const { + DiscreteValues result; + for (auto it = begin(); it != end(); ++it) { + result[(*it)->firstFrontalKey()] = (*it)->argmax(result); + } + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index a5a4621aa..3bcdcfe84 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,14 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Compute the discrete assignment which gives the highest + * probability for the DiscreteBayesNet. + * + * @return DiscreteValues + */ + DiscreteValues mode() const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 95f407cae..7cd445c5b 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -122,6 +122,32 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(expectedSample, actualSample)); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Mode) { + DiscreteBayesNet asia; + + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version + + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + DiscreteValues actual = asia.mode(); + // NOTE: Examined the DBN and found the optimal assignment. + DiscreteValues expected{ + {Asia.first, 0}, {Smoking.first, 0}, {Tuberculosis.first, 0}, + {LungCancer.first, 0}, {Bronchitis.first, 0}, {Either.first, 0}, + {XRay.first, 0}, {Dyspnea.first, 0}, + }; + EXPECT(assert_equal(expected, actual)); +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); From d5be6d9bcea1e9a67a588bb0b1aab658889d2342 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 11 Jul 2024 00:19:17 -0400 Subject: [PATCH 05/16] wrap argmax and mode methods --- gtsam/discrete/discrete.i | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5eacb3634..9f55da28e 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -137,6 +137,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(size_t value) const; size_t sample() const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + size_t argmax(const gtsam::DiscreteValues& parents) const; // Markdown and HTML string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -192,6 +193,7 @@ class DiscreteBayesNet { gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; + gtsam::DiscreteValues mode() const; string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, From 96a24445a4d314285384250a8332590ddb3d00b6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 10:12:49 -0400 Subject: [PATCH 06/16] address review comments --- gtsam/discrete/tests/testDiscreteConditional.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index a11c87975..2abb67fca 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -293,22 +293,22 @@ TEST(DiscreteConditional, choose) { // Check argmax on P(C|D) and P(D) TEST(DiscreteConditional, Argmax) { DiscreteKey C(2, 2), D(4, 2); - DiscreteConditional D_cond(D, "1/3"); - DiscreteConditional C_given_DE((C | D) = "1/4 1/1"); + DiscreteConditional D_prior(D, "1/3"); + DiscreteConditional C_given_D((C | D) = "1/4 1/1"); // Case 1: No parents - size_t actual1 = D_cond.argmax(); + size_t actual1 = D_prior.argmax(); EXPECT_LONGS_EQUAL(1, actual1); // Case 2: Given parent values DiscreteValues given; given[D.first] = 1; - size_t actual2 = C_given_DE.argmax(given); + size_t actual2 = C_given_D.argmax(given); // Should be 0 since D=1 gives 0.5/0.5 EXPECT_LONGS_EQUAL(0, actual2); given[D.first] = 0; - size_t actual3 = C_given_DE.argmax(given); + size_t actual3 = C_given_D.argmax(given); EXPECT_LONGS_EQUAL(1, actual3); } From f6449c0ad8563cb576268dcb3b8692dd06ba59da Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 10:30:23 -0400 Subject: [PATCH 07/16] turns out we can merge DiscreteConditional and DiscreteLookupTable --- gtsam/discrete/DiscreteConditional.cpp | 32 +++++++++- gtsam/discrete/DiscreteConditional.h | 20 ++++--- gtsam/discrete/DiscreteLookupDAG.cpp | 83 +------------------------- gtsam/discrete/DiscreteLookupDAG.h | 38 +----------- 4 files changed, 49 insertions(+), 124 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index a7f472f26..ec17e22f6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -236,6 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Initialize size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { return maxValue; } +/* ************************************************************************** */ +void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); @@ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, } /* ************************************************************************* */ -double DiscreteConditional::evaluate(const HybridValues& x) const{ +double DiscreteConditional::evaluate(const HybridValues& x) const { return this->evaluate(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8f38a83be..eda838e91 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -18,9 +18,9 @@ #pragma once -#include #include #include +#include #include #include @@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional public Conditional { public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class + typedef DiscreteConditional This; ///< Typedef to this class typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional @@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional /// @{ /// Log-probability is just -error(x). - double logProbability(const DiscreteValues& x) const { - return -error(x); - } + double logProbability(const DiscreteValues& x) const { return -error(x); } /// print index signature only void printSignature( @@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional size_t sample() const; /** - * @brief Return assignment that maximizes distribution. - * @return Optimal assignment (1 frontal variable). + * @brief Return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; + /// @} /// @name Advanced Interface /// @{ @@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; - /// @} /// @name HybridValues methods. /// @{ diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index ab62055ed..11900b502 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -29,97 +29,20 @@ using std::vector; namespace gtsam { -/* ************************************************************************** */ -// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( -void DiscreteLookupTable::print(const std::string& s, - const KeyFormatter& formatter) const { - using std::cout; - using std::endl; - - cout << s << " g( "; - for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { - cout << formatter(*it) << " "; - } - if (nrParents()) { - cout << "; "; - for (const_iterator it = beginParents(); it != endParents(); ++it) { - cout << formatter(*it) << " "; - } - } - cout << "):\n"; - ADT::print("", formatter); - cout << endl; -} - -/* ************************************************************************** */ -void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { - ADT pFS = choose(*values, true); // P(F|S=parentsValues) - - // Initialize - DiscreteValues mpe; - double maxP = 0; - - // Get all Possible Configurations - const auto allPosbValues = frontalAssignments(); - - // Find the maximum - for (const auto& frontalVals : allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update maximum solution if better - if (pValueS > maxP) { - maxP = pValueS; - mpe = frontalVals; - } - } - - // set values (inPlace) to maximum - for (Key j : frontals()) { - (*values)[j] = mpe[j]; - } -} - -/* ************************************************************************** */ -size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) - - // Then, find the max over all remaining - // TODO(Duy): only works for one key now, seems horribly slow this way - size_t mpe = 0; - double maxP = 0; - DiscreteValues frontals; - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - for (size_t value = 0; value < cardinality(j); value++) { - frontals[j] = value; - double pValueS = pFS(frontals); // P(F=value|S=parentsValues) - // Update MPE solution if better - if (pValueS > maxP) { - maxP = pValueS; - mpe = value; - } - } - return mpe; -} - /* ************************************************************************** */ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( const DiscreteBayesNet& bayesNet) { DiscreteLookupDAG dag; for (auto&& conditional : bayesNet) { - if (auto lookupTable = - std::dynamic_pointer_cast(conditional)) { - dag.push_back(lookupTable); - } else { - throw std::runtime_error( - "DiscreteFactorGraph::maxProduct: Expected look up table."); - } + dag.push_back(conditional); } return dag; } DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { // Argmax each node in turn in topological sort order (parents first). - for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); + it != std::make_reverse_iterator(begin()); ++it) { // dereference to get the sharedFactor to the lookup table (*it)->argmaxInPlace(&result); } diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index f077a13d9..3b0a5770d 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -37,41 +37,9 @@ class DiscreteBayesNet; * Inherits from discrete conditional for convenience, but is not normalized. * Is used in the max-product algorithm. */ -class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { - public: - using This = DiscreteLookupTable; - using shared_ptr = std::shared_ptr; - using BaseConditional = Conditional; - - /** - * @brief Construct a new Discrete Lookup Table object - * - * @param nFrontals number of frontal variables - * @param keys a sorted list of gtsam::Keys - * @param potentials the algebraic decision tree with lookup values - */ - DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, - const ADT& potentials) - : DiscreteConditional(nFrontals, keys, potentials) {} - - /// GTSAM-style print - void print( - const std::string& s = "Discrete Lookup Table: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /** - * @brief return assignment for single frontal variable that maximizes value. - * @param parentsValues Known assignments for the parents. - * @return maximizing assignment for the frontal variable. - */ - size_t argmax(const DiscreteValues& parentsValues) const; - - /** - * @brief Calculate assignment for frontal variables that maximizes value. - * @param (in/out) parentsValues Known assignments for the parents. - */ - void argmaxInPlace(DiscreteValues* parentsValues) const; -}; +// Typedef for backwards compatibility +// TODO(Varun): Remove +using DiscreteLookupTable = DiscreteConditional; /** A DAG made from lookup tables, as defined above. */ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { From a43dad2e3439f1877944143114639d17c6a7712e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 10:31:27 -0400 Subject: [PATCH 08/16] use DiscreteLookupDAG for DiscreteBayesNet mode --- gtsam/discrete/DiscreteBayesNet.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index bce14ad46..bef0413c8 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace gtsam { @@ -56,18 +57,15 @@ DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // sample each node in turn in topological sort order (parents first) - for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); + it != std::make_reverse_iterator(begin()); ++it) { (*it)->sampleInPlace(&result); } return result; } DiscreteValues DiscreteBayesNet::mode() const { - DiscreteValues result; - for (auto it = begin(); it != end(); ++it) { - result[(*it)->firstFrontalKey()] = (*it)->argmax(result); - } - return result; + return DiscreteLookupDAG::FromBayesNet(*this).argmax(); } /* *********************************************************************** */ From 19ea2712c0281523a0b70a7e24789d43e2ae7683 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 10:31:50 -0400 Subject: [PATCH 09/16] setup discrete bayes net in mode test with proper ordering --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 7cd445c5b..64c823203 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -16,14 +16,13 @@ * @author Frank Dellaert */ +#include +#include +#include +#include #include #include #include -#include -#include -#include - -#include #include #include @@ -43,8 +42,7 @@ TEST(DiscreteBayesNet, bayesNet) { DiscreteKey Parent(0, 2), Child(1, 2); auto prior = std::make_shared(Parent % "6/4"); - CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), - (ADT)*prior)); + CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), (ADT)*prior)); bayesNet.push_back(prior); auto conditional = @@ -126,17 +124,18 @@ TEST(DiscreteBayesNet, Asia) { TEST(DiscreteBayesNet, Mode) { DiscreteBayesNet asia; - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version + // We need to order the Bayes net in bottom-up fashion + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); asia.add(Bronchitis | Smoking = "70/30 40/60"); - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version DiscreteValues actual = asia.mode(); // NOTE: Examined the DBN and found the optimal assignment. From ffa72e7fadfff48864c990027f9b60a9f23506e0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 15:53:14 -0400 Subject: [PATCH 10/16] remove DiscreteLookupTable from wrapper --- gtsam/discrete/discrete.i | 9 --------- 1 file changed, 9 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 9f55da28e..43d2559d9 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -263,15 +263,6 @@ class DiscreteBayesTree { #include -class DiscreteLookupTable : gtsam::DiscreteConditional{ - DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, - const gtsam::DecisionTreeFactor::ADT& potentials); - void print(string s = "Discrete Lookup Table: ", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - size_t argmax(const gtsam::DiscreteValues& parentsValues) const; -}; - class DiscreteLookupDAG { DiscreteLookupDAG(); void push_back(const gtsam::DiscreteLookupTable* table); From 4e66fff1537ebf7ab5f5c104ac01284a82d8235f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 14 Jul 2024 17:57:37 -0400 Subject: [PATCH 11/16] use MaxProduct to compute Discrete Bayes Net mode --- gtsam/discrete/DiscreteBayesNet.cpp | 3 ++- gtsam/discrete/DiscreteConditional.cpp | 5 +++-- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 22 +++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index bef0413c8..f00eca60c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -65,7 +66,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { } DiscreteValues DiscreteBayesNet::mode() const { - return DiscreteLookupDAG::FromBayesNet(*this).argmax(); + return DiscreteFactorGraph(*this).optimize(); } /* *********************************************************************** */ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index ec17e22f6..90b3cfa39 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -238,7 +238,8 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) - // Initialize + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -247,7 +248,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { values[j] = value; - double pValueS = (*this)(values); + double pValueS = pFS(values); // P(F=value|S=parentsValues) // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 64c823203..b87e1c67a 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -147,6 +147,28 @@ TEST(DiscreteBayesNet, Mode) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, ModeEdgeCase) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + // The expected MPE is A=1, B=1 + DiscreteValues expectedMPE = graph.optimize(); + + auto actualMPE = bayesNet.mode(); + + EXPECT(assert_equal(expectedMPE, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); From 4a04963197799e098e37cea414eb991342476279 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 12:26:49 -0400 Subject: [PATCH 12/16] kill DiscreteBayesNet::mode --- gtsam/discrete/DiscreteBayesNet.cpp | 4 -- gtsam/discrete/DiscreteBayesNet.h | 8 --- gtsam/discrete/discrete.i | 1 - gtsam/discrete/tests/testDiscreteBayesNet.cpp | 49 ------------------- 4 files changed, 62 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index f00eca60c..c1aa18828 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -65,10 +65,6 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } -DiscreteValues DiscreteBayesNet::mode() const { - return DiscreteFactorGraph(*this).optimize(); -} - /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 3bcdcfe84..a5a4621aa 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,14 +124,6 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; - /** - * @brief Compute the discrete assignment which gives the highest - * probability for the DiscreteBayesNet. - * - * @return DiscreteValues - */ - DiscreteValues mode() const; - ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 43d2559d9..0f34840bf 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -193,7 +193,6 @@ class DiscreteBayesNet { gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; - gtsam::DiscreteValues mode() const; string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index b87e1c67a..49a360cbb 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -120,55 +120,6 @@ TEST(DiscreteBayesNet, Asia) { EXPECT(assert_equal(expectedSample, actualSample)); } -/* ************************************************************************* */ -TEST(DiscreteBayesNet, Mode) { - DiscreteBayesNet asia; - - // We need to order the Bayes net in bottom-up fashion - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - DiscreteValues actual = asia.mode(); - // NOTE: Examined the DBN and found the optimal assignment. - DiscreteValues expected{ - {Asia.first, 0}, {Smoking.first, 0}, {Tuberculosis.first, 0}, - {LungCancer.first, 0}, {Bronchitis.first, 0}, {Either.first, 0}, - {XRay.first, 0}, {Dyspnea.first, 0}, - }; - EXPECT(assert_equal(expected, actual)); -} - -/* ************************************************************************* */ -TEST(DiscreteBayesNet, ModeEdgeCase) { - // Declare 2 keys - DiscreteKey A(0, 2), B(1, 2); - - // Create Bayes net such that marginal on A is bigger for 0 than 1, but the - // MPE does not have A=0. - DiscreteBayesNet bayesNet; - bayesNet.add(B | A = "1/1 1/2"); - bayesNet.add(A % "10/9"); - - // Which we verify using max-product: - DiscreteFactorGraph graph(bayesNet); - // The expected MPE is A=1, B=1 - DiscreteValues expectedMPE = graph.optimize(); - - auto actualMPE = bayesNet.mode(); - - EXPECT(assert_equal(expectedMPE, actualMPE)); - EXPECT_DOUBLES_EQUAL(0.315789, bayesNet(expectedMPE), 1e-5); // regression -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); From 83eff969c5d2e3c1cc34848bdb19d3aa25f9cf07 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 17:46:26 -0400 Subject: [PATCH 13/16] add tie-breaking test --- .../tests/testDiscreteConditional.cpp | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 2abb67fca..172dd0fa1 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -290,26 +290,32 @@ TEST(DiscreteConditional, choose) { } /* ************************************************************************* */ -// Check argmax on P(C|D) and P(D) +// Check argmax on P(C|D) and P(D), plus tie-breaking for P(B) TEST(DiscreteConditional, Argmax) { - DiscreteKey C(2, 2), D(4, 2); + DiscreteKey B(2, 2), C(2, 2), D(4, 2); + DiscreteConditional B_prior(D, "1/1"); DiscreteConditional D_prior(D, "1/3"); DiscreteConditional C_given_D((C | D) = "1/4 1/1"); - // Case 1: No parents - size_t actual1 = D_prior.argmax(); - EXPECT_LONGS_EQUAL(1, actual1); + // Case 1: Tie breaking + size_t actual1 = B_prior.argmax(); + // In the case of ties, the first value is chosen. + EXPECT_LONGS_EQUAL(0, actual1); + // Case 2: No parents + size_t actual2 = D_prior.argmax(); + // Selects 1 since it has 0.75 probability + EXPECT_LONGS_EQUAL(1, actual2); - // Case 2: Given parent values + // Case 3: Given parent values DiscreteValues given; given[D.first] = 1; - size_t actual2 = C_given_D.argmax(given); + size_t actual3 = C_given_D.argmax(given); // Should be 0 since D=1 gives 0.5/0.5 - EXPECT_LONGS_EQUAL(0, actual2); + EXPECT_LONGS_EQUAL(0, actual3); given[D.first] = 0; - size_t actual3 = C_given_D.argmax(given); - EXPECT_LONGS_EQUAL(1, actual3); + size_t actual4 = C_given_D.argmax(given); + EXPECT_LONGS_EQUAL(1, actual4); } /* ************************************************************************* */ From 016f6f28d1eba7f0780795557a064de80f16bd07 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 18:39:37 -0400 Subject: [PATCH 14/16] Revert "turns out we can merge DiscreteConditional and DiscreteLookupTable" This reverts commit f6449c0ad8563cb576268dcb3b8692dd06ba59da. --- gtsam/discrete/DiscreteConditional.cpp | 33 +--------- gtsam/discrete/DiscreteConditional.h | 20 +++---- gtsam/discrete/DiscreteLookupDAG.cpp | 83 +++++++++++++++++++++++++- gtsam/discrete/DiscreteLookupDAG.h | 38 +++++++++++- 4 files changed, 124 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 90b3cfa39..6df26d291 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -236,10 +236,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) - - // Then, find the max over all remaining - // TODO(Duy): only works for one key now, seems horribly slow this way size_t maxValue = 0; double maxP = 0; DiscreteValues values = parentsValues; @@ -258,33 +254,6 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { return maxValue; } -/* ************************************************************************** */ -void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const { - ADT pFS = choose(*values, true); // P(F|S=parentsValues) - - // Initialize - DiscreteValues mpe; - double maxP = 0; - - // Get all Possible Configurations - const auto allPosbValues = frontalAssignments(); - - // Find the maximum - for (const auto& frontalVals : allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update maximum solution if better - if (pValueS > maxP) { - maxP = pValueS; - mpe = frontalVals; - } - } - - // set values (inPlace) to maximum - for (Key j : frontals()) { - (*values)[j] = mpe[j]; - } -} - /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); @@ -490,7 +459,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, } /* ************************************************************************* */ -double DiscreteConditional::evaluate(const HybridValues& x) const { +double DiscreteConditional::evaluate(const HybridValues& x) const{ return this->evaluate(x.discrete()); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index eda838e91..8f38a83be 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -18,9 +18,9 @@ #pragma once +#include #include #include -#include #include #include @@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional public Conditional { public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class + typedef DiscreteConditional This; ///< Typedef to this class typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional @@ -159,7 +159,9 @@ class GTSAM_EXPORT DiscreteConditional /// @{ /// Log-probability is just -error(x). - double logProbability(const DiscreteValues& x) const { return -error(x); } + double logProbability(const DiscreteValues& x) const { + return -error(x); + } /// print index signature only void printSignature( @@ -212,18 +214,11 @@ class GTSAM_EXPORT DiscreteConditional size_t sample() const; /** - * @brief Return assignment for single frontal variable that maximizes value. - * @param parentsValues Known assignments for the parents. - * @return maximizing assignment for the frontal variable. + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; - /** - * @brief Calculate assignment for frontal variables that maximizes value. - * @param (in/out) parentsValues Known assignments for the parents. - */ - void argmaxInPlace(DiscreteValues* parentsValues) const; - /// @} /// @name Advanced Interface /// @{ @@ -249,6 +244,7 @@ class GTSAM_EXPORT DiscreteConditional std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Names& names = {}) const override; + /// @} /// @name HybridValues methods. /// @{ diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 11900b502..ab62055ed 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -29,20 +29,97 @@ using std::vector; namespace gtsam { +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + double maxP = 0; + DiscreteValues frontals; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + /* ************************************************************************** */ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( const DiscreteBayesNet& bayesNet) { DiscreteLookupDAG dag; for (auto&& conditional : bayesNet) { - dag.push_back(conditional); + if (auto lookupTable = + std::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } } return dag; } DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { // Argmax each node in turn in topological sort order (parents first). - for (auto it = std::make_reverse_iterator(end()); - it != std::make_reverse_iterator(begin()); ++it) { + for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { // dereference to get the sharedFactor to the lookup table (*it)->argmaxInPlace(&result); } diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 3b0a5770d..f077a13d9 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -37,9 +37,41 @@ class DiscreteBayesNet; * Inherits from discrete conditional for convenience, but is not normalized. * Is used in the max-product algorithm. */ -// Typedef for backwards compatibility -// TODO(Varun): Remove -using DiscreteLookupTable = DiscreteConditional; +class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = std::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; +}; /** A DAG made from lookup tables, as defined above. */ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { From e0444ac722d8cb0cc692b17e07d19fc028f713d4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 18:40:07 -0400 Subject: [PATCH 15/16] Revert "remove DiscreteLookupTable from wrapper" This reverts commit ffa72e7fadfff48864c990027f9b60a9f23506e0. --- gtsam/discrete/discrete.i | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 0f34840bf..0bdebd0e1 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -262,6 +262,15 @@ class DiscreteBayesTree { #include +class DiscreteLookupTable : gtsam::DiscreteConditional{ + DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, + const gtsam::DecisionTreeFactor::ADT& potentials); + void print(string s = "Discrete Lookup Table: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + size_t argmax(const gtsam::DiscreteValues& parentsValues) const; +}; + class DiscreteLookupDAG { DiscreteLookupDAG(); void push_back(const gtsam::DiscreteLookupTable* table); From 3d58ce56b2beab47561dc8e8981172cc182a8387 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 15 Jul 2024 18:45:15 -0400 Subject: [PATCH 16/16] small fix --- gtsam/discrete/DiscreteConditional.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 6df26d291..a7f472f26 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -244,7 +244,7 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { Key j = firstFrontalKey(); for (size_t value = 0; value < cardinality(j); value++) { values[j] = value; - double pValueS = pFS(values); // P(F=value|S=parentsValues) + double pValueS = (*this)(values); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS;