Merge pull request #2136 from borglab/wrap/rng
commit
b95a352ccb
|
@ -201,7 +201,7 @@ jobs:
|
|||
if: runner.os == 'Linux'
|
||||
uses: pierotofy/set-swap-space@master
|
||||
with:
|
||||
swap-size-gb: 12
|
||||
swap-size-gb: 8
|
||||
|
||||
- name: Build & Test
|
||||
run: |
|
||||
|
|
|
@ -50,12 +50,13 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
DiscreteValues DiscreteBayesNet::sample(std::mt19937_64* rng) const {
|
||||
DiscreteValues result;
|
||||
return sample(result);
|
||||
return sample(result, rng);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result,
|
||||
std::mt19937_64* rng) const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
for (auto it = std::make_reverse_iterator(end());
|
||||
it != std::make_reverse_iterator(begin()); ++it) {
|
||||
|
@ -63,7 +64,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
// Sample the conditional only if value for j not already in result
|
||||
const Key j = conditional->firstFrontalKey();
|
||||
if (result.count(j) == 0) {
|
||||
conditional->sampleInPlace(&result);
|
||||
conditional->sampleInPlace(&result, rng);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -112,7 +112,7 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
|||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues sample() const;
|
||||
DiscreteValues sample(std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/**
|
||||
* @brief do ancestral sampling, given certain variables.
|
||||
|
@ -122,7 +122,8 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
|||
*
|
||||
* @return given values extended with sampled value for all other variables.
|
||||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
DiscreteValues sample(DiscreteValues given,
|
||||
std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the Bayes net
|
||||
|
|
|
@ -268,7 +268,8 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values,
|
||||
std::mt19937_64* rng) const {
|
||||
// throw if more than one frontal:
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
|
@ -281,14 +282,13 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|||
throw std::invalid_argument(
|
||||
"DiscreteConditional::sampleInPlace: values already contains j");
|
||||
}
|
||||
size_t sampled = sample(*values); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
size_t sampled = sample(*values, rng); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
|
||||
std::mt19937_64* rng) const {
|
||||
// Get the correct conditional distribution
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
|
@ -309,28 +309,33 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return value; // shortcut exit
|
||||
}
|
||||
}
|
||||
|
||||
// Check if rng is nullptr, then assign default
|
||||
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
|
||||
|
||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||
return distribution(rng);
|
||||
return distribution(*rng);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
size_t DiscreteConditional::sample(size_t parent_value,
|
||||
std::mt19937_64* rng) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value sample() can only be invoked on single-parent "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_.back(), parent_value);
|
||||
return sample(values);
|
||||
return sample(values, rng);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample() const {
|
||||
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
"sample() can only be invoked on no-parent prior");
|
||||
DiscreteValues values;
|
||||
return sample(values);
|
||||
return sample(values, rng);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -23,9 +23,13 @@
|
|||
#include <gtsam/inference/Conditional-inst.h>
|
||||
|
||||
#include <memory>
|
||||
#include <random> // for std::mt19937_64
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
|
@ -195,17 +199,29 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* Sample from conditional, given missing variables
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* DiscreteValues given = ...;
|
||||
* size_t sample = dc.sample(given, &rng);
|
||||
*
|
||||
* @param parentsValues Known values of the parents
|
||||
* @param rng Pseudo-Random Number Generator.
|
||||
* @return sample from conditional
|
||||
*/
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
||||
virtual size_t sample(const DiscreteValues& parentsValues,
|
||||
std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/// Single parent version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
size_t sample(size_t parent_value, std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/// Zero parent version.
|
||||
size_t sample() const;
|
||||
/**
|
||||
* Sample from conditional, zero parent version
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* auto sample = dc.sample(&rng);
|
||||
*/
|
||||
size_t sample(std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/**
|
||||
* @brief Return assignment for single frontal variable that maximizes value.
|
||||
|
@ -227,8 +243,9 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/// sample in place, stores result in partial solution
|
||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||
/// Sample in place with optional PRNG, stores result in partial solution
|
||||
void sampleInPlace(DiscreteValues* parentsValues,
|
||||
std::mt19937_64* rng = nullptr) const;
|
||||
|
||||
/// Return all assignments for frontal variables.
|
||||
std::vector<DiscreteValues> frontalAssignments() const;
|
||||
|
|
|
@ -144,9 +144,8 @@ void TableDistribution::prune(size_t maxNrAssignments) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues,
|
||||
std::mt19937_64* rng) const {
|
||||
DiscreteKeys parentsKeys;
|
||||
for (auto&& [key, _] : parentsValues) {
|
||||
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||
|
@ -172,8 +171,12 @@ size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
|||
return value; // shortcut exit
|
||||
}
|
||||
}
|
||||
|
||||
// Check if rng is nullptr, then assign default
|
||||
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
|
||||
|
||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||
return distribution(rng);
|
||||
return distribution(*rng);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -125,7 +125,6 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
|||
/// Create new factor by maximizing over all values with the same separator.
|
||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||
|
||||
|
||||
/// Multiply by scalar s
|
||||
DiscreteFactor::shared_ptr operator*(double s) const override;
|
||||
|
||||
|
@ -143,9 +142,11 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
|||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @param rng Pseudo random number generator
|
||||
* @return sample from conditional
|
||||
*/
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
||||
virtual size_t sample(const DiscreteValues& parentsValues,
|
||||
std::mt19937_64* rng = nullptr) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
|
|
|
@ -138,10 +138,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
size_t sample() const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues,
|
||||
std::mt19937_64 @rng = nullptr) const;
|
||||
size_t sample(size_t value,
|
||||
std::mt19937_64 @rng = nullptr) const;
|
||||
size_t sample(std::mt19937_64 @rng = nullptr) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues,
|
||||
std::mt19937_64 @rng = nullptr) const;
|
||||
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
||||
|
||||
// Markdown and HTML
|
||||
|
@ -233,8 +236,11 @@ class DiscreteBayesNet {
|
|||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
gtsam::DiscreteValues sample(std::mt19937_64
|
||||
@rng = nullptr) const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given,
|
||||
std::mt19937_64
|
||||
@rng = nullptr) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
|
|
|
@ -99,9 +99,9 @@ TEST(DiscreteBayesNet, Asia) {
|
|||
|
||||
// now sample from it
|
||||
DiscreteValues expectedSample{{Asia.first, 1}, {Dyspnea.first, 1},
|
||||
{XRay.first, 1}, {Tuberculosis.first, 0},
|
||||
{Smoking.first, 1}, {Either.first, 1},
|
||||
{LungCancer.first, 1}, {Bronchitis.first, 0}};
|
||||
{XRay.first, 0}, {Tuberculosis.first, 0},
|
||||
{Smoking.first, 1}, {Either.first, 0},
|
||||
{LungCancer.first, 0}, {Bronchitis.first, 1}};
|
||||
SETDEBUG("DiscreteConditional::sample", false);
|
||||
auto actualSample = chordal2->sample();
|
||||
EXPECT(assert_equal(expectedSample, actualSample));
|
||||
|
|
|
@ -25,9 +25,6 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
// In Wrappers we have no access to this so have a default ready
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -191,7 +188,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
|
|||
}
|
||||
}
|
||||
// Sample a discrete assignment.
|
||||
const DiscreteValues assignment = dbn.sample(given.discrete());
|
||||
const DiscreteValues assignment = dbn.sample(given.discrete(), rng);
|
||||
// Select the continuous Bayes net corresponding to the assignment.
|
||||
GaussianBayesNet gbn = choose(assignment);
|
||||
// Sample from the Gaussian Bayes net.
|
||||
|
|
|
@ -158,6 +158,8 @@ class HybridBayesNet {
|
|||
gtsam::HybridValues optimize() const;
|
||||
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
|
||||
|
||||
gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng) const;
|
||||
gtsam::HybridValues sample(std::mt19937_64@ rng) const;
|
||||
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
|
||||
gtsam::HybridValues sample() const;
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
|||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
|
||||
EXPECT(assert_equal(one, bayesNet.sample(&rng)));
|
||||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||
|
||||
|
@ -616,16 +616,16 @@ TEST(HybridBayesNet, Sampling) {
|
|||
double discrete_sum =
|
||||
std::accumulate(discrete_samples.begin(), discrete_samples.end(),
|
||||
decltype(discrete_samples)::value_type(0));
|
||||
EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.519, discrete_sum / num_samples, 1e-9);
|
||||
|
||||
VectorValues expected;
|
||||
// regression for specific RNG seed
|
||||
#if __APPLE__ || _WIN32
|
||||
expected.insert({X(0), Vector1(-0.0131207162712)});
|
||||
expected.insert({X(1), Vector1(-0.499026377568)});
|
||||
expected.insert({X(0), Vector1(0.0252479903896)});
|
||||
expected.insert({X(1), Vector1(-0.513637101911)});
|
||||
#elif __linux__
|
||||
expected.insert({X(0), Vector1(-0.00799425182219)});
|
||||
expected.insert({X(1), Vector1(-0.526463854268)});
|
||||
expected.insert({X(0), Vector1(0.0165089744897)});
|
||||
expected.insert({X(1), Vector1(-0.454323399979)});
|
||||
#endif
|
||||
|
||||
EXPECT(assert_equal(expected, average_continuous.scale(1.0 / num_samples)));
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
#include <string>
|
||||
#include <cmath>
|
||||
|
||||
// In Wrappers we have no access to this so have a default ready
|
||||
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||
|
||||
using namespace std;
|
||||
|
|
|
@ -215,7 +215,7 @@ namespace gtsam {
|
|||
* Sample from conditional, zero parent version
|
||||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* auto sample = gbn.sample(&rng);
|
||||
* auto sample = gc.sample(&rng);
|
||||
*/
|
||||
VectorValues sample(std::mt19937_64* rng) const;
|
||||
|
||||
|
@ -224,7 +224,7 @@ namespace gtsam {
|
|||
* Example:
|
||||
* std::mt19937_64 rng(42);
|
||||
* VectorValues given = ...;
|
||||
* auto sample = gbn.sample(given, &rng);
|
||||
* auto sample = gc.sample(given, &rng);
|
||||
*/
|
||||
VectorValues sample(const VectorValues& parentsValues,
|
||||
std::mt19937_64* rng) const;
|
||||
|
|
|
@ -559,9 +559,12 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
|||
gtsam::JacobianFactor* likelihood(
|
||||
const gtsam::VectorValues& frontalValues) const;
|
||||
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;
|
||||
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
|
||||
|
||||
gtsam::VectorValues sample(std::mt19937_64@ rng) const;
|
||||
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng) const;
|
||||
gtsam::VectorValues sample() const;
|
||||
|
||||
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
|
||||
|
||||
// Advanced Interface
|
||||
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
||||
const gtsam::VectorValues& rhs) const;
|
||||
|
|
|
@ -10,3 +10,15 @@
|
|||
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
||||
* and saves one copy operation.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Custom Pybind11 module for the Mersenne-Twister PRNG object
|
||||
* `std::mt19937_64`.
|
||||
* This can be invoked with `gtsam.MT19937()` and passed
|
||||
* wherever a rng pointer is expected.
|
||||
*/
|
||||
#include <random>
|
||||
py::class_<std::mt19937_64>(m_, "MT19937")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::mt19937_64::result_type>())
|
||||
.def("__call__", &std::mt19937_64::operator());
|
||||
|
|
|
@ -33,6 +33,41 @@ XRay = (2, 2)
|
|||
Dyspnea = (1, 2)
|
||||
|
||||
|
||||
class TestDiscreteConditional(GtsamTestCase):
|
||||
"""Tests for Discrete Conditional"""
|
||||
|
||||
def setUp(self):
|
||||
self.key = (0, 2)
|
||||
self.parent = (1, 2)
|
||||
self.parents = DiscreteKeys()
|
||||
self.parents.push_back(self.parent)
|
||||
|
||||
def test_sample(self):
|
||||
"""Tests to check sampling in DiscreteConditionals"""
|
||||
rng = gtsam.MT19937(11)
|
||||
niters = 1000
|
||||
|
||||
# Sample with only 1 variable
|
||||
conditional = DiscreteConditional(self.key, "7/3")
|
||||
# Sample multiple times and average to get mean
|
||||
p = 0
|
||||
for _ in range(niters):
|
||||
p += conditional.sample(rng)
|
||||
|
||||
self.assertAlmostEqual(p / niters, 0.3, 1)
|
||||
|
||||
# Sample with variable and parent
|
||||
conditional = DiscreteConditional(self.key, self.parents, "7/3 2/8")
|
||||
# Sample multiple times and average to get mean
|
||||
p = 0
|
||||
parentValues = gtsam.DiscreteValues()
|
||||
parentValues[self.parent[0]] = 1
|
||||
for _ in range(niters):
|
||||
p += conditional.sample(parentValues, rng)
|
||||
|
||||
self.assertAlmostEqual(p / niters, 0.8, 1)
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
||||
|
@ -85,10 +120,12 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# solve
|
||||
actualMPE = fg.optimize()
|
||||
expectedMPE = DiscreteValues()
|
||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||
for key in [
|
||||
Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer,
|
||||
Bronchitis
|
||||
]:
|
||||
expectedMPE[key[0]] = 0
|
||||
self.assertEqual(list(actualMPE.items()),
|
||||
list(expectedMPE.items()))
|
||||
self.assertEqual(list(actualMPE.items()), list(expectedMPE.items()))
|
||||
|
||||
# Check value for MPE is the same
|
||||
self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
|
||||
|
@ -104,8 +141,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
expectedMPE2[key[0]] = 0
|
||||
for key in [Asia, Dyspnea, Smoking, Bronchitis]:
|
||||
expectedMPE2[key[0]] = 1
|
||||
self.assertEqual(list(actualMPE2.items()),
|
||||
list(expectedMPE2.items()))
|
||||
self.assertEqual(list(actualMPE2.items()), list(expectedMPE2.items()))
|
||||
|
||||
# now sample from it
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
|
@ -135,8 +171,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
# self.assertEqual(len(values), 5)
|
||||
|
||||
for i in [0, 1, 2]:
|
||||
self.assertAlmostEqual(fragment.at(i).logProbability(values),
|
||||
math.log(fragment.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(
|
||||
fragment.at(i).logProbability(values),
|
||||
math.log(fragment.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(fragment.logProbability(values),
|
||||
math.log(fragment.evaluate(values)))
|
||||
actual = fragment.sample(given)
|
||||
|
|
|
@ -23,11 +23,12 @@ _x_ = 11
|
|||
_y_ = 22
|
||||
_z_ = 33
|
||||
|
||||
I_1x1 = np.eye(1, dtype=float)
|
||||
|
||||
|
||||
def smallBayesNet():
|
||||
"""Create a small Bayes Net for testing"""
|
||||
bayesNet = GaussianBayesNet()
|
||||
I_1x1 = np.eye(1, dtype=float)
|
||||
bayesNet.push_back(GaussianConditional(_x_, [9.0], I_1x1, _y_, I_1x1))
|
||||
bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1))
|
||||
return bayesNet
|
||||
|
@ -51,8 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
|
|||
values.insert(_x_, np.array([9.0]))
|
||||
values.insert(_y_, np.array([5.0]))
|
||||
for i in [0, 1]:
|
||||
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
||||
np.log(bayesNet.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(
|
||||
bayesNet.at(i).logProbability(values),
|
||||
np.log(bayesNet.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||
np.log(bayesNet.evaluate(values)))
|
||||
|
||||
|
@ -66,6 +68,16 @@ class TestGaussianBayesNet(GtsamTestCase):
|
|||
mean = bayesNet.optimize()
|
||||
self.gtsamAssertEquals(sample, mean, tol=3.0)
|
||||
|
||||
# Sample with rng
|
||||
rng = gtsam.MT19937(42)
|
||||
conditional = GaussianConditional(_x_, [9.0], I_1x1)
|
||||
# Sample multiple times and average to get mean
|
||||
val = 0
|
||||
niters = 10000
|
||||
for _ in range(niters):
|
||||
val += conditional.sample(rng).at(_x_).item()
|
||||
self.assertAlmostEqual(val / niters, 9.0, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -287,8 +287,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[0], 0.219, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[1], 0.781, delta=0.01)
|
||||
|
||||
# Convert to factor graph using measurements.
|
||||
fg = bayesNet.toFactorGraph(measurements)
|
||||
|
|
Loading…
Reference in New Issue