Merge pull request #1781 from borglab/discrete-improv

release/4.3a0
Varun Agrawal 2024-07-22 09:47:22 -04:00 committed by GitHub
commit feab2a2d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 25 deletions

View File

@ -18,6 +18,8 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
namespace gtsam { namespace gtsam {
@ -56,7 +58,8 @@ DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
(*it)->sampleInPlace(&result); (*it)->sampleInPlace(&result);
} }
return result; return result;

View File

@ -235,16 +235,19 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::argmax() const { size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Initialize
size_t maxValue = 0; size_t maxValue = 0;
double maxP = 0; double maxP = 0;
DiscreteValues values = parentsValues;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey(); Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; values[j] = value;
double pValueS = (*this)(frontals); double pValueS = (*this)(values);
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
@ -459,7 +462,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
} }
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const{ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete()); return this->evaluate(x.discrete());
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <memory> #include <memory>
#include <string> #include <string>
@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional
public Conditional<DecisionTreeFactor, DiscreteConditional> { public Conditional<DecisionTreeFactor, DiscreteConditional> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteConditional This; ///< Typedef to this class typedef DiscreteConditional This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> typedef Conditional<BaseFactor, This>
@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
/// @{ /// @{
/// Log-probability is just -error(x). /// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const { double logProbability(const DiscreteValues& x) const { return -error(x); }
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
@ -214,10 +212,11 @@ class GTSAM_EXPORT DiscreteConditional
size_t sample() const; size_t sample() const;
/** /**
* @brief Return assignment that maximizes distribution. * @brief Return assignment for single frontal variable that maximizes value.
* @return Optimal assignment (1 frontal variable). * @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/ */
size_t argmax() const; size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
@ -244,7 +243,6 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods. /// @name HybridValues methods.
/// @{ /// @{

View File

@ -119,7 +119,8 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first). // Argmax each node in turn in topological sort order (parents first).
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
// dereference to get the sharedFactor to the lookup table // dereference to get the sharedFactor to the lookup table
(*it)->argmaxInPlace(&result); (*it)->argmaxInPlace(&result);
} }

View File

@ -14,6 +14,9 @@ class DiscreteKeys {
bool empty() const; bool empty() const;
gtsam::DiscreteKey at(size_t n) const; gtsam::DiscreteKey at(size_t n) const;
void push_back(const gtsam::DiscreteKey& point_pair); void push_back(const gtsam::DiscreteKey& point_pair);
void print(const std::string& s = "",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
}; };
// DiscreteValues is added in specializations/discrete.h as a std::map // DiscreteValues is added in specializations/discrete.h as a std::map
@ -104,6 +107,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint, DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal, const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys); const gtsam::Ordering& orderedKeys);
DiscreteConditional(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents,
const std::vector<double>& table);
// Standard interface // Standard interface
double logNormalizationConstant() const; double logNormalizationConstant() const;
@ -131,6 +137,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
size_t argmax(const gtsam::DiscreteValues& parents) const;
// Markdown and HTML // Markdown and HTML
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
@ -159,7 +166,6 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const; double operator()(size_t value) const;
std::vector<double> pmf() const; std::vector<double> pmf() const;
size_t argmax() const;
}; };
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>

View File

@ -16,14 +16,13 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h> #include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/base/debug.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
@ -43,8 +42,7 @@ TEST(DiscreteBayesNet, bayesNet) {
DiscreteKey Parent(0, 2), Child(1, 2); DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4"); auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), (ADT)*prior));
(ADT)*prior));
bayesNet.push_back(prior); bayesNet.push_back(prior);
auto conditional = auto conditional =

View File

@ -289,6 +289,35 @@ TEST(DiscreteConditional, choose) {
EXPECT(assert_equal(expected3, *actual3, 1e-9)); EXPECT(assert_equal(expected3, *actual3, 1e-9));
} }
/* ************************************************************************* */
// Check argmax on P(C|D) and P(D), plus tie-breaking for P(B)
TEST(DiscreteConditional, Argmax) {
DiscreteKey B(2, 2), C(2, 2), D(4, 2);
DiscreteConditional B_prior(D, "1/1");
DiscreteConditional D_prior(D, "1/3");
DiscreteConditional C_given_D((C | D) = "1/4 1/1");
// Case 1: Tie breaking
size_t actual1 = B_prior.argmax();
// In the case of ties, the first value is chosen.
EXPECT_LONGS_EQUAL(0, actual1);
// Case 2: No parents
size_t actual2 = D_prior.argmax();
// Selects 1 since it has 0.75 probability
EXPECT_LONGS_EQUAL(1, actual2);
// Case 3: Given parent values
DiscreteValues given;
given[D.first] = 1;
size_t actual3 = C_given_D.argmax(given);
// Should be 0 since D=1 gives 0.5/0.5
EXPECT_LONGS_EQUAL(0, actual3);
given[D.first] = 0;
size_t actual4 = C_given_D.argmax(given);
EXPECT_LONGS_EQUAL(1, actual4);
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation looks as expected, no parents. // Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) { TEST(DiscreteConditional, markdown_prior) {