Added a sampler class for sampling from noise model distributions with a user-specified seed

release/4.3a0
Alex Cunningham 2010-11-08 03:38:27 +00:00
parent fa81eb4b5e
commit 55d98c7e69
4 changed files with 137 additions and 2 deletions

View File

@ -12,8 +12,8 @@ check_PROGRAMS =
# Noise Model
headers += SharedGaussian.h SharedDiagonal.h
sources += NoiseModel.cpp Errors.cpp
check_PROGRAMS += tests/testNoiseModel tests/testErrors
sources += NoiseModel.cpp Errors.cpp Sampler.cpp
check_PROGRAMS += tests/testNoiseModel tests/testErrors tests/testSampler
# Vector Configurations
headers += VectorValues.h

45
gtsam/linear/Sampler.cpp Normal file
View File

@ -0,0 +1,45 @@
/**
* @file Sampler.cpp
* @author Alex Cunningham
*/
#include <boost/random/normal_distribution.hpp>
#include <boost/random/variate_generator.hpp>
#include <gtsam/linear/Sampler.h>
namespace gtsam {
/* ************************************************************************* */
Sampler::Sampler(const SharedDiagonal& model, int32_t seed)
: sigmas_(model->sigmas()), generator_(static_cast<unsigned>(seed))
{
}
/* ************************************************************************* */
Sampler::Sampler(const Vector& sigmas, int32_t seed)
: sigmas_(sigmas), generator_(static_cast<unsigned>(seed))
{
}
/* ************************************************************************* */
Vector Sampler::sample() {
size_t d = dim();
Vector result(d);
for (size_t i = 0; i < d; i++) {
double sigma = sigmas_(i);
// handle constrained case separately
if (sigma == 0.0) {
result(i) = 0.0;
} else {
typedef boost::normal_distribution<double> Normal;
Normal dist(0.0, sigma);
boost::variate_generator<boost::minstd_rand&, Normal> norm(generator_, dist);
result(i) = norm();
}
}
return result;
}
} // \namespace gtsam

59
gtsam/linear/Sampler.h Normal file
View File

@ -0,0 +1,59 @@
/**
* @file Sampler.h
* @brief sampling that can be parameterized using a NoiseModel to generate samples from
* the given distribution
* @author Alex Cunningham
*/
#pragma once
#include <gtsam/linear/SharedDiagonal.h>
namespace gtsam {
/**
* Sampling structure that keeps internal random number generators for
* diagonal distributions specified by NoiseModel
*
* This is primarily to allow for variable seeds, and does roughly the same
* thing as sample() in NoiseModel.
*/
class Sampler {
protected:
/** sigmas from the noise model */
Vector sigmas_;
/** generator */
boost::minstd_rand generator_;
public:
typedef boost::shared_ptr<Sampler> shared_ptr;
/**
* Create a sampler for the distribution specified by a diagonal NoiseModel
* with a manually specified seed
*
* NOTE: do not use zero as a seed, it will break the generator
*/
Sampler(const SharedDiagonal& model, int32_t seed = 42u);
/**
* Create a sampler for a distribution specified by a vector of sigmas directly
*
* NOTE: do not use zero as a seed, it will break the generator
*/
Sampler(const Vector& sigmas, int32_t seed = 42u);
/** access functions */
size_t dim() const { return sigmas_.size(); }
Vector sigmas() const { return sigmas_; }
/**
* sample from distribution
* NOTE: not const due to need to update the underlying generator
*/
Vector sample();
};
}

View File

@ -0,0 +1,31 @@
/**
* @file testSampler
* @author Alex Cunningham
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/linear/Sampler.h>
using namespace gtsam;
const double tol = 1e-5;
/* ************************************************************************* */
TEST(testSampler, basic) {
Vector sigmas = Vector_(3, 1.0, 0.1, 0.0);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(sigmas);
char seed = 'A';
Sampler sampler1(model, seed), sampler2(model, 1), sampler3(model, 1);
EXPECT(assert_equal(sigmas, sampler1.sigmas()));
EXPECT(assert_equal(sigmas, sampler2.sigmas()));
EXPECT_LONGS_EQUAL(3, sampler1.dim());
EXPECT_LONGS_EQUAL(3, sampler2.dim());
Vector actual1 = sampler1.sample();
EXPECT_DOUBLES_EQUAL(0.0, actual1(2), tol);
EXPECT(assert_equal(sampler2.sample(), sampler3.sample(), tol));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */