refactoring variables for clarity

release/4.3a0
Frank Dellaert 2022-12-28 08:19:59 -05:00
parent 4d3bbf6ca4
commit b772d677ec
1 changed files with 34 additions and 34 deletions

View File

@ -283,11 +283,10 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
return probPrimeTree; return probPrimeTree;
} }
/****************************************************************************/ /*********************************************************************************
/**
* Test for correctness of different branches of the P'(Continuous | Discrete). * Test for correctness of different branches of the P'(Continuous | Discrete).
* The values should match those of P'(Continuous) for each discrete mode. * The values should match those of P'(Continuous) for each discrete mode.
*/ ********************************************************************************/
TEST(HybridEstimation, Probability) { TEST(HybridEstimation, Probability) {
constexpr size_t K = 4; constexpr size_t K = 4;
std::vector<double> measurements = {0, 1, 2, 2}; std::vector<double> measurements = {0, 1, 2, 2};
@ -444,20 +443,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
* Do hybrid elimination and do regression test on discrete conditional. * Do hybrid elimination and do regression test on discrete conditional.
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, eliminateSequentialRegression) { TEST(HybridEstimation, eliminateSequentialRegression) {
// 1. Create the factor graph from the nonlinear factor graph. // Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // Create expected discrete conditional on m0.
const Ordering ordering = fg->getHybridOrdering(); DiscreteKey m(M(0), 2);
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); DiscreteConditional expected(m % "0.51341712/1"); // regression
// GTSAM_PRINT(*bn);
// TODO(dellaert): dc should be discrete conditional on m0, but it is an // Eliminate into BN using one ordering
// unnormalized factor? Ordering ordering1;
// DiscreteKey m(M(0), 2); ordering1 += X(0), X(1), M(0);
// DiscreteConditional expected(m % "0.51341712/1"); HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// auto dc = bn->back()->asDiscrete();
// EXPECT(assert_equal(expected, *dc, 1e-9)); // Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering
Ordering ordering2;
ordering2 += X(0), X(1), M(0);
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc2, 1e-9));
} }
/********************************************************************************* /*********************************************************************************
@ -472,7 +481,7 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, CorrectnessViaSampling) { TEST(HybridEstimation, CorrectnessViaSampling) {
// 1. Create the factor graph from the nonlinear factor graph. // 1. Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); const auto fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // 2. Eliminate into BN
const Ordering ordering = fg->getHybridOrdering(); const Ordering ordering = fg->getHybridOrdering();
@ -481,37 +490,28 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
// Set up sampling // Set up sampling
std::mt19937_64 rng(11); std::mt19937_64 rng(11);
// 3. Do sampling // Compute the log-ratio between the Bayes net and the factor graph.
int num_samples = 10; auto compute_ratio = [&](const HybridValues& sample) -> double {
return bn->error(sample) - fg->error(sample);
// Functor to compute the ratio between the
// Bayes net and the factor graph.
auto compute_ratio =
[](const HybridBayesNet::shared_ptr& bayesNet,
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio);
return ratio;
}; };
// The error evaluated by the factor graph and the Bayes net should differ by // The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant. // the normalizing term computed via the Bayes net determinant.
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
double ratio = compute_ratio(bn, fg, sample); double expected_ratio = compute_ratio(sample);
// regression // regression
EXPECT_DOUBLES_EQUAL(1.9477340410546764, ratio, 1e-9); // EXPECT_DOUBLES_EQUAL(1.9477340410546764, ratio, 1e-9);
// 4. Check that all samples == constant // 3. Do sampling
constexpr int num_samples = 10;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net // Sample from the bayes net
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
// 4. Check that the ratio is constant.
// TODO(Varun) The ratio changes based on the mode // TODO(Varun) The ratio changes based on the mode
// EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9); // std::cout << compute_ratio(sample) << std::endl;
// EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-9);
} }
} }