Removed FactorAndConstant, no longer needed

release/4.3a0
Frank Dellaert 2023-01-12 22:34:34 -08:00
parent 1dcc6ddde9
commit 03ad393e12
7 changed files with 99 additions and 169 deletions

View File

@ -203,8 +203,8 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{ return GaussianMixtureFactor::sharedFactor{
conditional->likelihood(given), 0.0}; conditional->likelihood(given)};
}); });
return boost::make_shared<GaussianMixtureFactor>( return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods); continuousParentKeys, discreteParentKeys, likelihoods);

View File

@ -31,11 +31,8 @@ namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Mixture &factors) const Factors &factors)
: Base(continuousKeys, discreteKeys), : Base(continuousKeys, discreteKeys), factors_(factors) {}
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, factors_.equals(e->factors_,
const FactorAndConstant &f2) { [tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1.factor->equals(*(f2.factor), tol) && return f1->equals(*f2, tol);
std::abs(f1.constant - f2.constant) < tol; });
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const FactorAndConstant &gf_z) -> std::string { [&](const sharedFactor &gf) -> std::string {
auto gf = gf_z.factor;
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf && !gf->empty()) { if (gf && !gf->empty()) {
@ -81,14 +76,9 @@ void GaussianMixtureFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactor::shared_ptr GaussianMixtureFactor::factor( GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
return factors_(assignment).factor; return factors_(assignment);
}
/* *******************************************************************************/
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
return factors_(assignment).constant;
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -107,10 +97,10 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const FactorAndConstant &factor_z) { auto wrap = [](const sharedFactor &gf) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.factor); result.push_back(gf);
return GraphAndConstant(result, factor_z.constant); return GraphAndConstant(result, 0.0);
}; };
return {factors_, wrap}; return {factors_, wrap};
} }
@ -119,8 +109,8 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error( AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return factor_z.error(continuousValues); return gf->error(continuousValues);
}; };
DecisionTree<Key, double> errorTree(factors_, errorFunc); DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree; return errorTree;
@ -128,8 +118,8 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixtureFactor::error(const HybridValues &values) const { double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete()); const sharedFactor gf = factors_(values.discrete());
return factor_z.error(values.continuous()); return gf->error(values.continuous());
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -39,7 +39,7 @@ class VectorValues;
* serves to "select" a mixture component corresponding to a GaussianFactor type * serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement. * of measurement.
* *
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set * Represents the underlying Gaussian mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution. * of discrete variables indexes to the continuous gaussian distribution.
* *
* @ingroup hybrid * @ingroup hybrid
@ -52,38 +52,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using sharedFactor = boost::shared_ptr<GaussianFactor>; using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// Gaussian factor and log of normalizing constant.
struct FactorAndConstant {
sharedFactor factor;
double constant;
// Return error with constant correction.
double error(const VectorValues &values) const {
// Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood,
// hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>; using Factors = DecisionTree<Key, sharedFactor>;
using Mixture = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
@ -105,7 +75,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
GaussianMixtureFactor() = default; GaussianMixtureFactor() = default;
/** /**
* @brief Construct a new Gaussian Mixture Factor object. * @brief Construct a new Gaussian mixture factor.
* *
* @param continuousKeys A vector of keys representing continuous variables. * @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and * @param discreteKeys A vector of keys representing discrete variables and
@ -115,12 +85,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Mixture &factors); const Factors &factors);
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
/** /**
* @brief Construct a new GaussianMixtureFactor object using a vector of * @brief Construct a new GaussianMixtureFactor object using a vector of
@ -134,7 +99,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<sharedFactor> &factors) const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys, : GaussianMixtureFactor(continuousKeys, discreteKeys,
Mixture(discreteKeys, factors)) {} Factors(discreteKeys, factors)) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -151,10 +116,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
/// @{ /// @{
/// Get factor at a given discrete assignment. /// Get factor at a given discrete assignment.
sharedFactor factor(const DiscreteValues &assignment) const; sharedFactor operator()(const DiscreteValues &assignment) const;
/// Get constant at a given discrete assignment.
double constant(const DiscreteValues &assignment) const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while

View File

@ -213,20 +213,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Collect all the factors to create a set of Gaussian factor graphs in a // Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. // decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree sum = factors.assembleGraphTree(); GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
// Convert factor graphs with a nullptr to an empty factor graph. // Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which // This is done after assembly since it is non-trivial to keep track of which
// FG has a nullptr as we're looping over the factors. // FG has a nullptr as we're looping over the factors.
sum = removeEmpty(sum); factorGraphTree = removeEmpty(factorGraphTree);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::sharedFactor>;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair { auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
if (graph_z.graph.empty()) { if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, nullptr};
} }
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
@ -240,27 +240,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Get the log of the log normalization constant inverse and // Get the log of the log normalization constant inverse and
// add it to the previous constant. // add it to the previous constant.
const double logZ = // const double logZ =
graph_z.constant - conditional->logNormalizationConstant(); // graph_z.constant - conditional->logNormalizationConstant();
// Get the log of the log normalization constant inverse.
// double logZ = -conditional->logNormalizationConstant();
// // IF this is the last continuous variable to eliminated, we need to
// // calculate the error here: the value of all factors at the mean, see
// // ml_map_rao.pdf.
// if (continuousSeparator.empty()) {
// const auto posterior_mean = conditional->solve(VectorValues());
// logZ += graph_z.graph.error(posterior_mean);
// }
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
return {conditional, {newFactor, logZ}}; return {conditional, newFactor};
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc); DecisionTree<Key, EliminationPair> eliminationResults(factorGraphTree,
eliminateFunc);
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
tictoc_print_(); tictoc_print_();
@ -279,26 +271,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create a // If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice. // DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) { if (continuousSeparator.empty()) {
auto factorProb = auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { // This is the probability q(μ) at the MLE point.
// This is the probability q(μ) at the MLE point. // conditionalAndFactor.second is a factor without keys, just containing the residual.
// factor_z.factor is a factor without keys, static const VectorValues kEmpty;
// just containing the residual. // return exp(-conditionalAndFactor.first->logNormalizationConstant());
return exp(-factor_z.error(VectorValues())); // return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
}; return exp( - conditionalAndFactor.second->error(kEmpty));
// return 1.0;
};
const DecisionTree<Key, double> fdt(newFactors, factorProb); const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
// // Normalize the values of decision tree to be valid probabilities
// double sum = 0.0;
// auto visitor = [&](double y) { sum += y; };
// fdt.visit(visitor);
// // Check if sum is 0, and update accordingly.
// if (sum == 0) {
// sum = 1.0;
// }
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
const auto discreteFactor = const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
@ -375,6 +358,11 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// PREPROCESS: Identify the nature of the current elimination // PREPROCESS: Identify the nature of the current elimination
// TODO(dellaert): just check the factors:
// 1. if all factors are discrete, then we can do discrete elimination:
// 2. if all factors are continuous, then we can do continuous elimination:
// 3. if not, we do hybrid elimination:
// First, identify the separator keys, i.e. all keys that are not frontal. // First, identify the separator keys, i.e. all keys that are not frontal.
KeySet separatorKeys; KeySet separatorKeys;
for (auto &&factor : factors) { for (auto &&factor : factors) {

View File

@ -197,8 +197,7 @@ TEST(GaussianMixture, Likelihood) {
const GaussianMixtureFactor::Factors factors( const GaussianMixtureFactor::Factors factors(
gm.conditionals(), gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) { [measurements](const GaussianConditional::shared_ptr& conditional) {
return GaussianMixtureFactor::FactorAndConstant{ return conditional->likelihood(measurements);
conditional->likelihood(measurements), 0.0};
}); });
const GaussianMixtureFactor expected({X(0)}, {mode}, factors); const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(expected, *factor)); EXPECT(assert_equal(expected, *factor));

View File

@ -34,6 +34,7 @@ using namespace gtsam;
using noiseModel::Isotropic; using noiseModel::Isotropic;
using symbol_shorthand::M; using symbol_shorthand::M;
using symbol_shorthand::X; using symbol_shorthand::X;
using symbol_shorthand::Z;
static const Key asiaKey = 0; static const Key asiaKey = 0;
static const DiscreteKey Asia(asiaKey, 2); static const DiscreteKey Asia(asiaKey, 2);
@ -73,8 +74,12 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net. // Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) { TEST(HybridBayesNet, Tiny) {
auto bayesNet = tiny::createHybridBayesNet(); auto bn = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bayesNet.size()); EXPECT_LONGS_EQUAL(3, bn.size());
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(4, fg.size());
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -57,6 +57,9 @@ using gtsam::symbol_shorthand::X;
using gtsam::symbol_shorthand::Y; using gtsam::symbol_shorthand::Y;
using gtsam::symbol_shorthand::Z; using gtsam::symbol_shorthand::Z;
// Set up sampling
std::mt19937_64 kRng(42);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(HybridGaussianFactorGraph, Creation) { TEST(HybridGaussianFactorGraph, Creation) {
HybridConditional conditional; HybridConditional conditional;
@ -638,24 +641,47 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expected{ GaussianFactorGraphTree expected{
M(0), M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}), {GaussianFactorGraph(std::vector<GF>{(*mixture)(d0), prior}), 0.0},
mixture->constant(d0)}, {GaussianFactorGraph(std::vector<GF>{(*mixture)(d1), prior}), 0.0}};
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
mixture->constant(d1)}};
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5)); EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5)); EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
} }
/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
return bn.evaluate(*sample) / fg.probPrime(*sample);
// return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
HybridValues sample = bn.sample(&kRng);
double expected_ratio = compute_ratio(&sample);
// Test ratios for a number of independent samples:
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&kRng);
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
}
return true;
}
/* ****************************************************************************/ /* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result. // Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) { TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z; using symbol_shorthand::Z;
const int num_measurements = 1; const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph( const VectorValues measurements{{Z(0), Vector1(5.0)}};
num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(4, fg.size()); EXPECT_LONGS_EQUAL(4, fg.size());
EXPECT(ratioTest(bn, measurements, fg));
// Create expected Bayes Net: // Create expected Bayes Net:
HybridBayesNet expectedBayesNet; HybridBayesNet expectedBayesNet;
@ -675,6 +701,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -683,9 +711,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0. // Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z; using symbol_shorthand::Z;
const int num_measurements = 2; const int num_measurements = 2;
auto fg = tiny::createHybridGaussianFactorGraph( const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
num_measurements, auto bn = tiny::createHybridBayesNet(num_measurements);
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}); auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(6, fg.size()); EXPECT_LONGS_EQUAL(6, fg.size());
// Create expected Bayes Net: // Create expected Bayes Net:
@ -707,6 +735,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -723,32 +753,12 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
auto fg = bn.toFactorGraph(measurements); auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(7, fg.size()); EXPECT_LONGS_EQUAL(7, fg.size());
EXPECT(ratioTest(bn, measurements, fg));
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
// Compute the log-ratio between the Bayes net and the factor graph. EXPECT(ratioTest(bn, measurements, *posterior));
auto compute_ratio = [&](HybridValues *sample) -> double {
// update sample with given measurements:
sample->update(measurements);
return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
// Set up sampling
std::mt19937_64 rng(42);
// The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant.
HybridValues sample = bn.sample(&rng);
double expected_ratio = compute_ratio(&sample);
// regression
EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6);
// Test ratios for a number of independent samples:
constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&rng);
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
}
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -818,31 +828,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Test resulting posterior Bayes net has correct size: // Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size()); EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): below is copy/pasta from above, refactor EXPECT(ratioTest(bn, measurements, *posterior));
// Compute the log-ratio between the Bayes net and the factor graph.
auto compute_ratio = [&](HybridValues *sample) -> double {
// update sample with given measurements:
sample->update(measurements);
return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
// Set up sampling
std::mt19937_64 rng(42);
// The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant.
HybridValues sample = bn.sample(&rng);
double expected_ratio = compute_ratio(&sample);
// regression
EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6);
// Test ratios for a number of independent samples:
constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&rng);
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */