Merge pull request #2136 from borglab/wrap/rng
commit
b95a352ccb
|
@ -201,7 +201,7 @@ jobs:
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
uses: pierotofy/set-swap-space@master
|
uses: pierotofy/set-swap-space@master
|
||||||
with:
|
with:
|
||||||
swap-size-gb: 12
|
swap-size-gb: 8
|
||||||
|
|
||||||
- name: Build & Test
|
- name: Build & Test
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteValues DiscreteBayesNet::sample() const {
|
DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
|
||||||
DiscreteValues result;
|
DiscreteValues result;
|
||||||
return sample(result);
|
return sample(result, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
// sample each node in turn in topological sort order (parents first)
|
// sample each node in turn in topological sort order (parents first)
|
||||||
for (auto it = std::make_reverse_iterator(end());
|
for (auto it = std::make_reverse_iterator(end());
|
||||||
it != std::make_reverse_iterator(begin()); ++it) {
|
it != std::make_reverse_iterator(begin()); ++it) {
|
||||||
|
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
// Sample the conditional only if value for j not already in result
|
// Sample the conditional only if value for j not already in result
|
||||||
const Key j = conditional->firstFrontalKey();
|
const Key j = conditional->firstFrontalKey();
|
||||||
if (result.count(j) == 0) {
|
if (result.count(j) == 0) {
|
||||||
conditional->sampleInPlace(&result);
|
conditional->sampleInPlace(&result, rng);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
*
|
*
|
||||||
* @return a sampled value for all variables.
|
* @return a sampled value for all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteValues sample() const;
|
DiscreteValues sample(std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief do ancestral sampling, given certain variables.
|
* @brief do ancestral sampling, given certain variables.
|
||||||
|
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
*
|
*
|
||||||
* @return given values extended with sampled value for all other variables.
|
* @return given values extended with sampled value for all other variables.
|
||||||
*/
|
*/
|
||||||
DiscreteValues sample(DiscreteValues given) const;
|
DiscreteValues sample(DiscreteValues given,
|
||||||
|
std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the Bayes net
|
* @brief Prune the Bayes net
|
||||||
|
|
|
@ -268,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
// throw if more than one frontal:
|
// throw if more than one frontal:
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -281,14 +282,13 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"DiscreteConditional::sampleInPlace: values already contains j");
|
"DiscreteConditional::sampleInPlace: values already contains j");
|
||||||
}
|
}
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
size_t sampled = sample(*values, rng); // Sample variable given parents
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
|
||||||
static mt19937 rng(2); // random number generator
|
std::mt19937_64* rng) const {
|
||||||
|
|
||||||
// Get the correct conditional distribution
|
// Get the correct conditional distribution
|
||||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
@ -309,28 +309,33 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
return value; // shortcut exit
|
return value; // shortcut exit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if rng is nullptr, then assign default
|
||||||
|
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
|
||||||
|
|
||||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
return distribution(rng);
|
return distribution(*rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
size_t DiscreteConditional::sample(size_t parent_value,
|
||||||
|
std::mt19937_64* rng) const {
|
||||||
if (nrParents() != 1)
|
if (nrParents() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Single value sample() can only be invoked on single-parent "
|
"Single value sample() can only be invoked on single-parent "
|
||||||
"conditional");
|
"conditional");
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
values.emplace(keys_.back(), parent_value);
|
values.emplace(keys_.back(), parent_value);
|
||||||
return sample(values);
|
return sample(values, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample() const {
|
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"sample() can only be invoked on no-parent prior");
|
"sample() can only be invoked on no-parent prior");
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
return sample(values);
|
return sample(values, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -23,9 +23,13 @@
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <random> // for std::mt19937_64
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
||||||
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -195,17 +199,29 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* Sample from conditional, given missing variables
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* DiscreteValues given = ...;
|
||||||
|
* size_t sample = dc.sample(given, &rng);
|
||||||
|
*
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
* @param rng Pseudo-Random Number Generator.
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
virtual size_t sample(const DiscreteValues& parentsValues,
|
||||||
|
std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/// Single parent version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value, std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/// Zero parent version.
|
/**
|
||||||
size_t sample() const;
|
* Sample from conditional, zero parent version
|
||||||
|
* Example:
|
||||||
|
* std::mt19937_64 rng(42);
|
||||||
|
* auto sample = dc.sample(&rng);
|
||||||
|
*/
|
||||||
|
size_t sample(std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return assignment for single frontal variable that maximizes value.
|
* @brief Return assignment for single frontal variable that maximizes value.
|
||||||
|
@ -227,8 +243,9 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// Sample in place with optional PRNG, stores result in partial solution
|
||||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues,
|
||||||
|
std::mt19937_64* rng = nullptr) const;
|
||||||
|
|
||||||
/// Return all assignments for frontal variables.
|
/// Return all assignments for frontal variables.
|
||||||
std::vector<DiscreteValues> frontalAssignments() const;
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
|
|
@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
size_t TableDistribution::sample(const DiscreteValues& parentsValues,
|
||||||
static mt19937 rng(2); // random number generator
|
std::mt19937_64* rng) const {
|
||||||
|
|
||||||
DiscreteKeys parentsKeys;
|
DiscreteKeys parentsKeys;
|
||||||
for (auto&& [key, _] : parentsValues) {
|
for (auto&& [key, _] : parentsValues) {
|
||||||
parentsKeys.push_back({key, table_.cardinality(key)});
|
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||||
|
@ -172,8 +171,12 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||||
return value; // shortcut exit
|
return value; // shortcut exit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if rng is nullptr, then assign default
|
||||||
|
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
|
||||||
|
|
||||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
return distribution(rng);
|
return distribution(*rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -125,7 +125,6 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||||
|
|
||||||
|
|
||||||
/// Multiply by scalar s
|
/// Multiply by scalar s
|
||||||
DiscreteFactor::shared_ptr operator*(double s) const override;
|
DiscreteFactor::shared_ptr operator*(double s) const override;
|
||||||
|
|
||||||
|
@ -143,9 +142,11 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
* @param rng Pseudo random number generator
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
virtual size_t sample(const DiscreteValues& parentsValues,
|
||||||
|
std::mt19937_64* rng = nullptr) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
|
|
|
@ -138,10 +138,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
const gtsam::DiscreteValues& frontalValues) const;
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues,
|
||||||
size_t sample(size_t value) const;
|
std::mt19937_64 @rng = nullptr) const;
|
||||||
size_t sample() const;
|
size_t sample(size_t value,
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
std::mt19937_64 @rng = nullptr) const;
|
||||||
|
size_t sample(std::mt19937_64 @rng = nullptr) const;
|
||||||
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues,
|
||||||
|
std::mt19937_64 @rng = nullptr) const;
|
||||||
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
// Markdown and HTML
|
// Markdown and HTML
|
||||||
|
@ -233,8 +236,11 @@ class DiscreteBayesNet {
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample(std::mt19937_64
|
||||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
@rng = nullptr) const;
|
||||||
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given,
|
||||||
|
std::mt19937_64
|
||||||
|
@rng = nullptr) const;
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
|
|
@ -99,9 +99,9 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
|
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
|
||||||
{XRay.first, 1}, {Tuberculosis.first, 0},
|
{XRay.first, 0}, {Tuberculosis.first, 0},
|
||||||
{Smoking.first, 1}, {Either.first, 1},
|
{Smoking.first, 1}, {Either.first, 0},
|
||||||
{LungCancer.first, 1}, {Bronchitis.first, 0}};
|
{LungCancer.first, 0}, {Bronchitis.first, 1}};
|
||||||
SETDEBUG("DiscreteConditional::sample", false);
|
SETDEBUG("DiscreteConditional::sample", false);
|
||||||
auto actualSample = chordal2->sample();
|
auto actualSample = chordal2->sample();
|
||||||
EXPECT(assert_equal(expectedSample, actualSample));
|
EXPECT(assert_equal(expectedSample, actualSample));
|
||||||
|
|
|
@ -25,9 +25,6 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
// In Wrappers we have no access to this so have a default ready
|
|
||||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -191,7 +188,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sample a discrete assignment.
|
// Sample a discrete assignment.
|
||||||
const DiscreteValues assignment = dbn.sample(given.discrete());
|
const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
|
||||||
// Select the continuous Bayes net corresponding to the assignment.
|
// Select the continuous Bayes net corresponding to the assignment.
|
||||||
GaussianBayesNet gbn = choose(assignment);
|
GaussianBayesNet gbn = choose(assignment);
|
||||||
// Sample from the Gaussian Bayes net.
|
// Sample from the Gaussian Bayes net.
|
||||||
|
|
|
@ -158,6 +158,8 @@ class HybridBayesNet {
|
||||||
gtsam::HybridValues optimize() const;
|
gtsam::HybridValues optimize() const;
|
||||||
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
|
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
|
||||||
|
|
||||||
|
gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng) const;
|
||||||
|
gtsam::HybridValues sample(std::mt19937_64@ rng) const;
|
||||||
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
|
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
|
||||||
gtsam::HybridValues sample() const;
|
gtsam::HybridValues sample() const;
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
|
|
||||||
// sample
|
// sample
|
||||||
std::mt19937_64 rng(42);
|
std::mt19937_64 rng(42);
|
||||||
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
|
EXPECT(assert_equal(one, bayesNet.sample(&rng)));
|
||||||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||||
|
|
||||||
|
@ -616,16 +616,16 @@ TEST(HybridBayesNet, Sampling) {
|
||||||
double discrete_sum =
|
double discrete_sum =
|
||||||
std::accumulate(discrete_samples.begin(), discrete_samples.end(),
|
std::accumulate(discrete_samples.begin(), discrete_samples.end(),
|
||||||
decltype(discrete_samples)::value_type(0));
|
decltype(discrete_samples)::value_type(0));
|
||||||
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.519, discrete_sum / num_samples, 1e-9);
|
||||||
|
|
||||||
VectorValues expected;
|
VectorValues expected;
|
||||||
// regression for specific RNG seed
|
// regression for specific RNG seed
|
||||||
#if __APPLE__ || _WIN32
|
#if __APPLE__ || _WIN32
|
||||||
expected.insert({X(0), Vector1(-0.0131207162712)});
|
expected.insert({X(0), Vector1(0.0252479903896)});
|
||||||
expected.insert({X(1), Vector1(-0.499026377568)});
|
expected.insert({X(1), Vector1(-0.513637101911)});
|
||||||
#elif __linux__
|
#elif __linux__
|
||||||
expected.insert({X(0), Vector1(-0.00799425182219)});
|
expected.insert({X(0), Vector1(0.0165089744897)});
|
||||||
expected.insert({X(1), Vector1(-0.526463854268)});
|
expected.insert({X(1), Vector1(-0.454323399979)});
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));
|
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
// In Wrappers we have no access to this so have a default ready
|
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
||||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
|
@ -215,7 +215,7 @@ namespace gtsam {
|
||||||
* Sample from conditional, zero parent version
|
* Sample from conditional, zero parent version
|
||||||
* Example:
|
* Example:
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* auto sample = gbn.sample(&rng);
|
* auto sample = gc.sample(&rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(std::mt19937_64* rng) const;
|
VectorValues sample(std::mt19937_64* rng) const;
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ namespace gtsam {
|
||||||
* Example:
|
* Example:
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* VectorValues given = ...;
|
* VectorValues given = ...;
|
||||||
* auto sample = gbn.sample(given, &rng);
|
* auto sample = gc.sample(given, &rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(const VectorValues& parentsValues,
|
VectorValues sample(const VectorValues& parentsValues,
|
||||||
std::mt19937_64* rng) const;
|
std::mt19937_64* rng) const;
|
||||||
|
|
|
@ -559,9 +559,12 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
gtsam::JacobianFactor* likelihood(
|
gtsam::JacobianFactor* likelihood(
|
||||||
const gtsam::VectorValues& frontalValues) const;
|
const gtsam::VectorValues& frontalValues) const;
|
||||||
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;
|
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;
|
||||||
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
|
|
||||||
|
gtsam::VectorValues sample(std::mt19937_64@ rng) const;
|
||||||
|
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng) const;
|
||||||
gtsam::VectorValues sample() const;
|
gtsam::VectorValues sample() const;
|
||||||
|
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
|
||||||
|
|
||||||
// Advanced Interface
|
// Advanced Interface
|
||||||
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
||||||
const gtsam::VectorValues& rhs) const;
|
const gtsam::VectorValues& rhs) const;
|
||||||
|
|
|
@ -10,3 +10,15 @@
|
||||||
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
||||||
* and saves one copy operation.
|
* and saves one copy operation.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Custom Pybind11 module for the Mersenne-Twister PRNG object
|
||||||
|
* `std::mt19937_64`.
|
||||||
|
* This can be invoked with `gtsam.MT19937()` and passed
|
||||||
|
* wherever a rng pointer is expected.
|
||||||
|
*/
|
||||||
|
#include <random>
|
||||||
|
py::class_<std::mt19937_64>(m_, "MT19937")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<std::mt19937_64::result_type>())
|
||||||
|
.def("__call__", &std::mt19937_64::operator());
|
||||||
|
|
|
@ -33,6 +33,41 @@ XRay = (2, 2)
|
||||||
Dyspnea = (1, 2)
|
Dyspnea = (1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscreteConditional(GtsamTestCase):
|
||||||
|
"""Tests for Discrete Conditional"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.key = (0, 2)
|
||||||
|
self.parent = (1, 2)
|
||||||
|
self.parents = DiscreteKeys()
|
||||||
|
self.parents.push_back(self.parent)
|
||||||
|
|
||||||
|
def test_sample(self):
|
||||||
|
"""Tests to check sampling in DiscreteConditionals"""
|
||||||
|
rng = gtsam.MT19937(11)
|
||||||
|
niters = 1000
|
||||||
|
|
||||||
|
# Sample with only 1 variable
|
||||||
|
conditional = DiscreteConditional(self.key, "7/3")
|
||||||
|
# Sample multiple times and average to get mean
|
||||||
|
p = 0
|
||||||
|
for _ in range(niters):
|
||||||
|
p += conditional.sample(rng)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(p / niters, 0.3, 1)
|
||||||
|
|
||||||
|
# Sample with variable and parent
|
||||||
|
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
|
||||||
|
# Sample multiple times and average to get mean
|
||||||
|
p = 0
|
||||||
|
parentValues = gtsam.DiscreteValues()
|
||||||
|
parentValues[self.parent[0]] = 1
|
||||||
|
for _ in range(niters):
|
||||||
|
p += conditional.sample(parentValues, rng)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(p / niters, 0.8, 1)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
|
||||||
|
@ -85,10 +120,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# solve
|
# solve
|
||||||
actualMPE = fg.optimize()
|
actualMPE = fg.optimize()
|
||||||
expectedMPE = DiscreteValues()
|
expectedMPE = DiscreteValues()
|
||||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
for key in [
|
||||||
|
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
|
||||||
|
Bronchitis
|
||||||
|
]:
|
||||||
expectedMPE[key[0]] = 0
|
expectedMPE[key[0]] = 0
|
||||||
self.assertEqual(list(actualMPE.items()),
|
self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
|
||||||
list(expectedMPE.items()))
|
|
||||||
|
|
||||||
# Check value for MPE is the same
|
# Check value for MPE is the same
|
||||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||||
|
@ -104,8 +141,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
expectedMPE2[key[0]] = 0
|
expectedMPE2[key[0]] = 0
|
||||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||||
expectedMPE2[key[0]] = 1
|
expectedMPE2[key[0]] = 1
|
||||||
self.assertEqual(list(actualMPE2.items()),
|
self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
|
||||||
list(expectedMPE2.items()))
|
|
||||||
|
|
||||||
# now sample from it
|
# now sample from it
|
||||||
chordal2 = fg.eliminateSequential(ordering)
|
chordal2 = fg.eliminateSequential(ordering)
|
||||||
|
@ -135,8 +171,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# self.assertEqual(len(values), 5)
|
# self.assertEqual(len(values), 5)
|
||||||
|
|
||||||
for i in [0, 1, 2]:
|
for i in [0, 1, 2]:
|
||||||
self.assertAlmostEqual(fragment.at(i).logProbability(values),
|
self.assertAlmostEqual(
|
||||||
math.log(fragment.at(i).evaluate(values)))
|
fragment.at(i).logProbability(values),
|
||||||
|
math.log(fragment.at(i).evaluate(values)))
|
||||||
self.assertAlmostEqual(fragment.logProbability(values),
|
self.assertAlmostEqual(fragment.logProbability(values),
|
||||||
math.log(fragment.evaluate(values)))
|
math.log(fragment.evaluate(values)))
|
||||||
actual = fragment.sample(given)
|
actual = fragment.sample(given)
|
||||||
|
|
|
@ -23,11 +23,12 @@ _x_ = 11
|
||||||
_y_ = 22
|
_y_ = 22
|
||||||
_z_ = 33
|
_z_ = 33
|
||||||
|
|
||||||
|
I_1x1 = np.eye(1, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
def smallBayesNet():
|
def smallBayesNet():
|
||||||
"""Create a small Bayes Net for testing"""
|
"""Create a small Bayes Net for testing"""
|
||||||
bayesNet = GaussianBayesNet()
|
bayesNet = GaussianBayesNet()
|
||||||
I_1x1 = np.eye(1, dtype=float)
|
|
||||||
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
|
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
|
||||||
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
|
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
|
||||||
return bayesNet
|
return bayesNet
|
||||||
|
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
values.insert(_x_, np.array([9.0]))
|
values.insert(_x_, np.array([9.0]))
|
||||||
values.insert(_y_, np.array([5.0]))
|
values.insert(_y_, np.array([5.0]))
|
||||||
for i in [0, 1]:
|
for i in [0, 1]:
|
||||||
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
self.assertAlmostEqual(
|
||||||
np.log(bayesNet.at(i).evaluate(values)))
|
bayesNet.at(i).logProbability(values),
|
||||||
|
np.log(bayesNet.at(i).evaluate(values)))
|
||||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||||
np.log(bayesNet.evaluate(values)))
|
np.log(bayesNet.evaluate(values)))
|
||||||
|
|
||||||
|
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
mean = bayesNet.optimize()
|
mean = bayesNet.optimize()
|
||||||
self.gtsamAssertEquals(sample, mean, tol=3.0)
|
self.gtsamAssertEquals(sample, mean, tol=3.0)
|
||||||
|
|
||||||
|
# Sample with rng
|
||||||
|
rng = gtsam.MT19937(42)
|
||||||
|
conditional = GaussianConditional(_x_, [9.0], I_1x1)
|
||||||
|
# Sample multiple times and average to get mean
|
||||||
|
val = 0
|
||||||
|
niters = 10000
|
||||||
|
for _ in range(niters):
|
||||||
|
val += conditional.sample(rng).at(_x_).item()
|
||||||
|
self.assertAlmostEqual(val / niters, 9.0, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -287,8 +287,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
print(f"P(mode=1; Z) = {marginals[1]}")
|
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||||
|
|
||||||
# Check that the estimate is close to the true value.
|
# Check that the estimate is close to the true value.
|
||||||
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
self.assertAlmostEqual(marginals[0], 0.219, delta=0.01)
|
||||||
self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
|
self.assertAlmostEqual(marginals[1], 0.781, delta=0.01)
|
||||||
|
|
||||||
# Convert to factor graph using measurements.
|
# Convert to factor graph using measurements.
|
||||||
fg = bayesNet.toFactorGraph(measurements)
|
fg = bayesNet.toFactorGraph(measurements)
|
||||||
|
|
Loading…
Reference in New Issue