Merge branch 'discrete-improv' into discrete-improv-2
commit
dd8de1f300
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -56,7 +58,8 @@ DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
|
|
||||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
// sample each node in turn in topological sort order (parents first)
|
// sample each node in turn in topological sort order (parents first)
|
||||||
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) {
|
for (auto it = std::make_reverse_iterator(end());
|
||||||
|
it != std::make_reverse_iterator(begin()); ++it) {
|
||||||
(*it)->sampleInPlace(&result);
|
(*it)->sampleInPlace(&result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -241,13 +241,13 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
// Initialize
|
// Initialize
|
||||||
size_t maxValue = 0;
|
size_t maxValue = 0;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues values = parentsValues;
|
||||||
|
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
assert(nrParents() == 0);
|
|
||||||
DiscreteValues frontals;
|
|
||||||
Key j = firstFrontalKey();
|
Key j = firstFrontalKey();
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
values[j] = value;
|
||||||
double pValueS = (*this)(frontals);
|
double pValueS = (*this)(values);
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
|
|
|
@ -216,7 +216,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
* @param parentsValues Known assignments for the parents.
|
* @param parentsValues Known assignments for the parents.
|
||||||
* @return maximizing assignment for the frontal variable.
|
* @return maximizing assignment for the frontal variable.
|
||||||
*/
|
*/
|
||||||
size_t argmax() const;
|
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
|
|
@ -14,6 +14,9 @@ class DiscreteKeys {
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
gtsam::DiscreteKey at(size_t n) const;
|
gtsam::DiscreteKey at(size_t n) const;
|
||||||
void push_back(const gtsam::DiscreteKey& point_pair);
|
void push_back(const gtsam::DiscreteKey& point_pair);
|
||||||
|
void print(const std::string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||||
|
@ -104,6 +107,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
DiscreteConditional(const gtsam::DiscreteKey& key,
|
||||||
|
const gtsam::DiscreteKeys& parents,
|
||||||
|
const std::vector<double>& table);
|
||||||
|
|
||||||
// Standard interface
|
// Standard interface
|
||||||
double logNormalizationConstant() const;
|
double logNormalizationConstant() const;
|
||||||
|
@ -131,6 +137,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
size_t argmax(const gtsam::DiscreteValues& parents) const;
|
||||||
|
|
||||||
// Markdown and HTML
|
// Markdown and HTML
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
@ -159,7 +166,6 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(size_t value) const;
|
double operator()(size_t value) const;
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
size_t argmax() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
|
|
@ -16,14 +16,13 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
#include <gtsam/base/debug.h>
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/Vector.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -43,8 +42,7 @@ TEST(DiscreteBayesNet, bayesNet) {
|
||||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4");
|
auto prior = std::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||||
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
|
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), (ADT)*prior));
|
||||||
(ADT)*prior));
|
|
||||||
bayesNet.push_back(prior);
|
bayesNet.push_back(prior);
|
||||||
|
|
||||||
auto conditional =
|
auto conditional =
|
||||||
|
|
|
@ -289,6 +289,35 @@ TEST(DiscreteConditional, choose) {
|
||||||
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check argmax on P(C|D) and P(D), plus tie-breaking for P(B)
|
||||||
|
TEST(DiscreteConditional, Argmax) {
|
||||||
|
DiscreteKey B(2, 2), C(2, 2), D(4, 2);
|
||||||
|
DiscreteConditional B_prior(D, "1/1");
|
||||||
|
DiscreteConditional D_prior(D, "1/3");
|
||||||
|
DiscreteConditional C_given_D((C | D) = "1/4 1/1");
|
||||||
|
|
||||||
|
// Case 1: Tie breaking
|
||||||
|
size_t actual1 = B_prior.argmax();
|
||||||
|
// In the case of ties, the first value is chosen.
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual1);
|
||||||
|
// Case 2: No parents
|
||||||
|
size_t actual2 = D_prior.argmax();
|
||||||
|
// Selects 1 since it has 0.75 probability
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2);
|
||||||
|
|
||||||
|
// Case 3: Given parent values
|
||||||
|
DiscreteValues given;
|
||||||
|
given[D.first] = 1;
|
||||||
|
size_t actual3 = C_given_D.argmax(given);
|
||||||
|
// Should be 0 since D=1 gives 0.5/0.5
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual3);
|
||||||
|
|
||||||
|
given[D.first] = 0;
|
||||||
|
size_t actual4 = C_given_D.argmax(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual4);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
Loading…
Reference in New Issue