Merge pull request #2136 from borglab/wrap/rng

release/4.3a0
Varun Agrawal 2025-05-16 09:56:30 -04:00 committed by GitHub
commit b95a352ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 164 additions and 67 deletions

View File

@ -201,7 +201,7 @@ jobs:
if: runner.os == 'Linux' if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@master uses: pierotofy/set-swap-space@master
with: with:
swap-size-gb: 12 swap-size-gb: 8
- name: Build & Test - name: Build & Test
run: | run: |

View File

@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
DiscreteValues result; DiscreteValues result;
return sample(result); return sample(result, rng);
} }
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
std::mt19937_64* rng) const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end()); for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) { it != std::make_reverse_iterator(begin()); ++it) {
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// Sample the conditional only if value for j not already in result // Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey(); const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) { if (result.count(j) == 0) {
conditional->sampleInPlace(&result); conditional->sampleInPlace(&result, rng);
} }
} }
return result; return result;

View File

@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* *
* @return a sampled value for all variables. * @return a sampled value for all variables.
*/ */
DiscreteValues sample() const; DiscreteValues sample(std::mt19937_64* rng = nullptr) const;
/** /**
* @brief do ancestral sampling, given certain variables. * @brief do ancestral sampling, given certain variables.
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* *
* @return given values extended with sampled value for all other variables. * @return given values extended with sampled value for all other variables.
*/ */
DiscreteValues sample(DiscreteValues given) const; DiscreteValues sample(DiscreteValues given,
std::mt19937_64* rng = nullptr) const;
/** /**
* @brief Prune the Bayes net * @brief Prune the Bayes net

View File

@ -268,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
} }
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values,
std::mt19937_64* rng) const {
// throw if more than one frontal: // throw if more than one frontal:
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
throw std::invalid_argument( throw std::invalid_argument(
@ -281,14 +282,13 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
throw std::invalid_argument( throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j"); "DiscreteConditional::sampleInPlace: values already contains j");
} }
size_t sampled = sample(*values); // Sample variable given parents size_t sampled = sample(*values, rng); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution (*values)[j] = sampled; // store result in partial solution
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
static mt19937 rng(2); // random number generator std::mt19937_64* rng) const {
// Get the correct conditional distribution // Get the correct conditional distribution
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
@ -309,28 +309,33 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return value; // shortcut exit return value; // shortcut exit
} }
} }
// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
std::discrete_distribution<size_t> distribution(p.begin(), p.end()); std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng); return distribution(*rng);
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value,
std::mt19937_64* rng) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent " "Single value sample() can only be invoked on single-parent "
"conditional"); "conditional");
DiscreteValues values; DiscreteValues values;
values.emplace(keys_.back(), parent_value); values.emplace(keys_.back(), parent_value);
return sample(values); return sample(values, rng);
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(
"sample() can only be invoked on no-parent prior"); "sample() can only be invoked on no-parent prior");
DiscreteValues values; DiscreteValues values;
return sample(values); return sample(values, rng);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -23,9 +23,13 @@
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <memory> #include <memory>
#include <random> // for std::mt19937_64
#include <string> #include <string>
#include <vector> #include <vector>
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/** /**
@ -195,17 +199,29 @@ class GTSAM_EXPORT DiscreteConditional
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
/** /**
* sample * Sample from conditional, given missing variables
* Example:
* std::mt19937_64 rng(42);
* DiscreteValues given = ...;
* size_t sample = dc.sample(given, &rng);
*
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @param rng Pseudo-Random Number Generator.
* @return sample from conditional * @return sample from conditional
*/ */
virtual size_t sample(const DiscreteValues& parentsValues) const; virtual size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng = nullptr) const;
/// Single parent version. /// Single parent version.
size_t sample(size_t parent_value) const; size_t sample(size_t parent_value, std::mt19937_64* rng = nullptr) const;
/// Zero parent version. /**
size_t sample() const; * Sample from conditional, zero parent version
* Example:
* std::mt19937_64 rng(42);
* auto sample = dc.sample(&rng);
*/
size_t sample(std::mt19937_64* rng = nullptr) const;
/** /**
* @brief Return assignment for single frontal variable that maximizes value. * @brief Return assignment for single frontal variable that maximizes value.
@ -227,8 +243,9 @@ class GTSAM_EXPORT DiscreteConditional
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// sample in place, stores result in partial solution /// Sample in place with optional PRNG, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues,
std::mt19937_64* rng = nullptr) const;
/// Return all assignments for frontal variables. /// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const; std::vector<DiscreteValues> frontalAssignments() const;

View File

@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { size_t TableDistribution::sample(const DiscreteValues& parentsValues,
static mt19937 rng(2); // random number generator std::mt19937_64* rng) const {
DiscreteKeys parentsKeys; DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) { for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)}); parentsKeys.push_back({key, table_.cardinality(key)});
@ -172,8 +171,12 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
return value; // shortcut exit return value; // shortcut exit
} }
} }
// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
std::discrete_distribution<size_t> distribution(p.begin(), p.end()); std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng); return distribution(*rng);
} }
} // namespace gtsam } // namespace gtsam

View File

@ -125,7 +125,6 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override; DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// Multiply by scalar s /// Multiply by scalar s
DiscreteFactor::shared_ptr operator*(double s) const override; DiscreteFactor::shared_ptr operator*(double s) const override;
@ -143,9 +142,11 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @param rng Pseudo random number generator
* @return sample from conditional * @return sample from conditional
*/ */
virtual size_t sample(const DiscreteValues& parentsValues) const override; virtual size_t sample(const DiscreteValues& parentsValues,
std::mt19937_64* rng = nullptr) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface

View File

@ -138,10 +138,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood( gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const; const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues,
size_t sample(size_t value) const; std::mt19937_64 @rng = nullptr) const;
size_t sample() const; size_t sample(size_t value,
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; std::mt19937_64 @rng = nullptr) const;
size_t sample(std::mt19937_64 @rng = nullptr) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues,
std::mt19937_64 @rng = nullptr) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
// Markdown and HTML // Markdown and HTML
@ -233,8 +236,11 @@ class DiscreteBayesNet {
double evaluate(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(std::mt19937_64
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; @rng = nullptr) const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given,
std::mt19937_64
@rng = nullptr) const;
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,

View File

@ -99,9 +99,9 @@ TEST(DiscreteBayesNet, Asia) {
// now sample from it // now sample from it
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1}, DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
{XRay.first, 1}, {Tuberculosis.first, 0}, {XRay.first, 0}, {Tuberculosis.first, 0},
{Smoking.first, 1}, {Either.first, 1}, {Smoking.first, 1}, {Either.first, 0},
{LungCancer.first, 1}, {Bronchitis.first, 0}}; {LungCancer.first, 0}, {Bronchitis.first, 1}};
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
auto actualSample = chordal2->sample(); auto actualSample = chordal2->sample();
EXPECT(assert_equal(expectedSample, actualSample)); EXPECT(assert_equal(expectedSample, actualSample));

View File

@ -25,9 +25,6 @@
#include <memory> #include <memory>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
@ -191,7 +188,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
} }
} }
// Sample a discrete assignment. // Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete()); const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
// Select the continuous Bayes net corresponding to the assignment. // Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment); GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net. // Sample from the Gaussian Bayes net.

View File

@ -158,6 +158,8 @@ class HybridBayesNet {
gtsam::HybridValues optimize() const; gtsam::HybridValues optimize() const;
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const; gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng) const;
gtsam::HybridValues sample(std::mt19937_64@ rng) const;
gtsam::HybridValues sample(const gtsam::HybridValues& given) const; gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
gtsam::HybridValues sample() const; gtsam::HybridValues sample() const;

View File

@ -90,7 +90,7 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
// sample // sample
std::mt19937_64 rng(42); std::mt19937_64 rng(42);
EXPECT(assert_equal(zero, bayesNet.sample(&rng))); EXPECT(assert_equal(one, bayesNet.sample(&rng)));
EXPECT(assert_equal(one, bayesNet.sample(one, &rng))); EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng))); EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
@ -616,16 +616,16 @@ TEST(HybridBayesNet, Sampling) {
double discrete_sum = double discrete_sum =
std::accumulate(discrete_samples.begin(), discrete_samples.end(), std::accumulate(discrete_samples.begin(), discrete_samples.end(),
decltype(discrete_samples)::value_type(0)); decltype(discrete_samples)::value_type(0));
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9); EXPECT_DOUBLES_EQUAL(0.519, discrete_sum / num_samples, 1e-9);
VectorValues expected; VectorValues expected;
// regression for specific RNG seed // regression for specific RNG seed
#if __APPLE__ || _WIN32 #if __APPLE__ || _WIN32
expected.insert({X(0), Vector1(-0.0131207162712)}); expected.insert({X(0), Vector1(0.0252479903896)});
expected.insert({X(1), Vector1(-0.499026377568)}); expected.insert({X(1), Vector1(-0.513637101911)});
#elif __linux__ #elif __linux__
expected.insert({X(0), Vector1(-0.00799425182219)}); expected.insert({X(0), Vector1(0.0165089744897)});
expected.insert({X(1), Vector1(-0.526463854268)}); expected.insert({X(1), Vector1(-0.454323399979)});
#endif #endif
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples))); EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));

View File

@ -34,7 +34,7 @@
#include <string> #include <string>
#include <cmath> #include <cmath>
// In Wrappers we have no access to this so have a default ready // In wrappers we can access std::mt19937_64 via gtsam.MT19937
static std::mt19937_64 kRandomNumberGenerator(42); static std::mt19937_64 kRandomNumberGenerator(42);
using namespace std; using namespace std;

View File

@ -215,7 +215,7 @@ namespace gtsam {
* Sample from conditional, zero parent version * Sample from conditional, zero parent version
* Example: * Example:
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng); * auto sample = gc.sample(&rng);
*/ */
VectorValues sample(std::mt19937_64* rng) const; VectorValues sample(std::mt19937_64* rng) const;
@ -224,7 +224,7 @@ namespace gtsam {
* Example: * Example:
* std::mt19937_64 rng(42); * std::mt19937_64 rng(42);
* VectorValues given = ...; * VectorValues given = ...;
* auto sample = gbn.sample(given, &rng); * auto sample = gc.sample(given, &rng);
*/ */
VectorValues sample(const VectorValues& parentsValues, VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const; std::mt19937_64* rng) const;

View File

@ -559,9 +559,12 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
gtsam::JacobianFactor* likelihood( gtsam::JacobianFactor* likelihood(
const gtsam::VectorValues& frontalValues) const; const gtsam::VectorValues& frontalValues) const;
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const; gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
gtsam::VectorValues sample(std::mt19937_64@ rng) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng) const;
gtsam::VectorValues sample() const; gtsam::VectorValues sample() const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
// Advanced Interface // Advanced Interface
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const; const gtsam::VectorValues& rhs) const;

View File

@ -10,3 +10,15 @@
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation. * and saves one copy operation.
*/ */
/*
* Custom Pybind11 module for the Mersenne-Twister PRNG object
* `std::mt19937_64`.
* This can be invoked with `gtsam.MT19937()` and passed
* wherever a rng pointer is expected.
*/
#include <random>
py::class_<std::mt19937_64>(m_, "MT19937")
.def(py::init<>())
.def(py::init<std::mt19937_64::result_type>())
.def("__call__", &std::mt19937_64::operator());

View File

@ -33,6 +33,41 @@ XRay = (2, 2)
Dyspnea = (1, 2) Dyspnea = (1, 2)
class TestDiscreteConditional(GtsamTestCase):
"""Tests for Discrete Conditional"""
def setUp(self):
self.key = (0, 2)
self.parent = (1, 2)
self.parents = DiscreteKeys()
self.parents.push_back(self.parent)
def test_sample(self):
"""Tests to check sampling in DiscreteConditionals"""
rng = gtsam.MT19937(11)
niters = 1000
# Sample with only 1 variable
conditional = DiscreteConditional(self.key, "7/3")
# Sample multiple times and average to get mean
p = 0
for _ in range(niters):
p += conditional.sample(rng)
self.assertAlmostEqual(p / niters, 0.3, 1)
# Sample with variable and parent
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
# Sample multiple times and average to get mean
p = 0
parentValues = gtsam.DiscreteValues()
parentValues[self.parent[0]] = 1
for _ in range(niters):
p += conditional.sample(parentValues, rng)
self.assertAlmostEqual(p / niters, 0.8, 1)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -85,10 +120,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
# solve # solve
actualMPE = fg.optimize() actualMPE = fg.optimize()
expectedMPE = DiscreteValues() expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: for key in [
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
Bronchitis
]:
expectedMPE[key[0]] = 0 expectedMPE[key[0]] = 0
self.assertEqual(list(actualMPE.items()), self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
list(expectedMPE.items()))
# Check value for MPE is the same # Check value for MPE is the same
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
@ -104,8 +141,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
expectedMPE2[key[0]] = 0 expectedMPE2[key[0]] = 0
for key in [Asia, Dyspnea, Smoking, Bronchitis]: for key in [Asia, Dyspnea, Smoking, Bronchitis]:
expectedMPE2[key[0]] = 1 expectedMPE2[key[0]] = 1
self.assertEqual(list(actualMPE2.items()), self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
list(expectedMPE2.items()))
# now sample from it # now sample from it
chordal2 = fg.eliminateSequential(ordering) chordal2 = fg.eliminateSequential(ordering)
@ -135,8 +171,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
# self.assertEqual(len(values), 5) # self.assertEqual(len(values), 5)
for i in [0, 1, 2]: for i in [0, 1, 2]:
self.assertAlmostEqual(fragment.at(i).logProbability(values), self.assertAlmostEqual(
math.log(fragment.at(i).evaluate(values))) fragment.at(i).logProbability(values),
math.log(fragment.at(i).evaluate(values)))
self.assertAlmostEqual(fragment.logProbability(values), self.assertAlmostEqual(fragment.logProbability(values),
math.log(fragment.evaluate(values))) math.log(fragment.evaluate(values)))
actual = fragment.sample(given) actual = fragment.sample(given)

View File

@ -23,11 +23,12 @@ _x_ = 11
_y_ = 22 _y_ = 22
_z_ = 33 _z_ = 33
I_1x1 = np.eye(1, dtype=float)
def smallBayesNet(): def smallBayesNet():
"""Create a small Bayes Net for testing""" """Create a small Bayes Net for testing"""
bayesNet = GaussianBayesNet() bayesNet = GaussianBayesNet()
I_1x1 = np.eye(1, dtype=float)
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1)) bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1)) bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
return bayesNet return bayesNet
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
values.insert(_x_, np.array([9.0])) values.insert(_x_, np.array([9.0]))
values.insert(_y_, np.array([5.0])) values.insert(_y_, np.array([5.0]))
for i in [0, 1]: for i in [0, 1]:
self.assertAlmostEqual(bayesNet.at(i).logProbability(values), self.assertAlmostEqual(
np.log(bayesNet.at(i).evaluate(values))) bayesNet.at(i).logProbability(values),
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(bayesNet.logProbability(values), self.assertAlmostEqual(bayesNet.logProbability(values),
np.log(bayesNet.evaluate(values))) np.log(bayesNet.evaluate(values)))
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
mean = bayesNet.optimize() mean = bayesNet.optimize()
self.gtsamAssertEquals(sample, mean, tol=3.0) self.gtsamAssertEquals(sample, mean, tol=3.0)
# Sample with rng
rng = gtsam.MT19937(42)
conditional = GaussianConditional(_x_, [9.0], I_1x1)
# Sample multiple times and average to get mean
val = 0
niters = 10000
for _ in range(niters):
val += conditional.sample(rng).at(_x_).item()
self.assertAlmostEqual(val / niters, 9.0, 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -287,8 +287,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
print(f"P(mode=1; Z) = {marginals[1]}") print(f"P(mode=1; Z) = {marginals[1]}")
# Check that the estimate is close to the true value. # Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01) self.assertAlmostEqual(marginals[0], 0.219, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.77, delta=0.01) self.assertAlmostEqual(marginals[1], 0.781, delta=0.01)
# Convert to factor graph using measurements. # Convert to factor graph using measurements.
fg = bayesNet.toFactorGraph(measurements) fg = bayesNet.toFactorGraph(measurements)