refactoring variables for clarity
parent
4d3bbf6ca4
commit
b772d677ec
|
@ -283,11 +283,10 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
|
||||||
return probPrimeTree;
|
return probPrimeTree;
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/*********************************************************************************
|
||||||
/**
|
|
||||||
* Test for correctness of different branches of the P'(Continuous | Discrete).
|
* Test for correctness of different branches of the P'(Continuous | Discrete).
|
||||||
* The values should match those of P'(Continuous) for each discrete mode.
|
* The values should match those of P'(Continuous) for each discrete mode.
|
||||||
*/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, Probability) {
|
TEST(HybridEstimation, Probability) {
|
||||||
constexpr size_t K = 4;
|
constexpr size_t K = 4;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2};
|
std::vector<double> measurements = {0, 1, 2, 2};
|
||||||
|
@ -444,20 +443,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
||||||
* Do hybrid elimination and do regression test on discrete conditional.
|
* Do hybrid elimination and do regression test on discrete conditional.
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, eliminateSequentialRegression) {
|
TEST(HybridEstimation, eliminateSequentialRegression) {
|
||||||
// 1. Create the factor graph from the nonlinear factor graph.
|
// Create the factor graph from the nonlinear factor graph.
|
||||||
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
||||||
|
|
||||||
// 2. Eliminate into BN
|
// Create expected discrete conditional on m0.
|
||||||
const Ordering ordering = fg->getHybridOrdering();
|
DiscreteKey m(M(0), 2);
|
||||||
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
DiscreteConditional expected(m % "0.51341712/1"); // regression
|
||||||
// GTSAM_PRINT(*bn);
|
|
||||||
|
|
||||||
// TODO(dellaert): dc should be discrete conditional on m0, but it is an
|
// Eliminate into BN using one ordering
|
||||||
// unnormalized factor?
|
Ordering ordering1;
|
||||||
// DiscreteKey m(M(0), 2);
|
ordering1 += X(0), X(1), M(0);
|
||||||
// DiscreteConditional expected(m % "0.51341712/1");
|
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||||
// auto dc = bn->back()->asDiscrete();
|
|
||||||
// EXPECT(assert_equal(expected, *dc, 1e-9));
|
// Check that the discrete conditional matches the expected.
|
||||||
|
auto dc1 = bn1->back()->asDiscrete();
|
||||||
|
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||||
|
|
||||||
|
// Eliminate into BN using a different ordering
|
||||||
|
Ordering ordering2;
|
||||||
|
ordering2 += X(0), X(1), M(0);
|
||||||
|
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||||
|
|
||||||
|
// Check that the discrete conditional matches the expected.
|
||||||
|
auto dc2 = bn2->back()->asDiscrete();
|
||||||
|
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************
|
/*********************************************************************************
|
||||||
|
@ -472,7 +481,7 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, CorrectnessViaSampling) {
|
TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
// 1. Create the factor graph from the nonlinear factor graph.
|
// 1. Create the factor graph from the nonlinear factor graph.
|
||||||
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
const auto fg = createHybridGaussianFactorGraph();
|
||||||
|
|
||||||
// 2. Eliminate into BN
|
// 2. Eliminate into BN
|
||||||
const Ordering ordering = fg->getHybridOrdering();
|
const Ordering ordering = fg->getHybridOrdering();
|
||||||
|
@ -481,37 +490,28 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
// Set up sampling
|
// Set up sampling
|
||||||
std::mt19937_64 rng(11);
|
std::mt19937_64 rng(11);
|
||||||
|
|
||||||
// 3. Do sampling
|
// Compute the log-ratio between the Bayes net and the factor graph.
|
||||||
int num_samples = 10;
|
auto compute_ratio = [&](const HybridValues& sample) -> double {
|
||||||
|
return bn->error(sample) - fg->error(sample);
|
||||||
// Functor to compute the ratio between the
|
|
||||||
// Bayes net and the factor graph.
|
|
||||||
auto compute_ratio =
|
|
||||||
[](const HybridBayesNet::shared_ptr& bayesNet,
|
|
||||||
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
|
|
||||||
const HybridValues& sample) -> double {
|
|
||||||
const DiscreteValues assignment = sample.discrete();
|
|
||||||
// Compute in log form for numerical stability
|
|
||||||
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
|
|
||||||
factorGraph->error({sample.continuous(), assignment});
|
|
||||||
double ratio = exp(-log_ratio);
|
|
||||||
return ratio;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// The error evaluated by the factor graph and the Bayes net should differ by
|
// The error evaluated by the factor graph and the Bayes net should differ by
|
||||||
// the normalizing term computed via the Bayes net determinant.
|
// the normalizing term computed via the Bayes net determinant.
|
||||||
const HybridValues sample = bn->sample(&rng);
|
const HybridValues sample = bn->sample(&rng);
|
||||||
double ratio = compute_ratio(bn, fg, sample);
|
double expected_ratio = compute_ratio(sample);
|
||||||
// regression
|
// regression
|
||||||
EXPECT_DOUBLES_EQUAL(1.9477340410546764, ratio, 1e-9);
|
// EXPECT_DOUBLES_EQUAL(1.9477340410546764, ratio, 1e-9);
|
||||||
|
|
||||||
// 4. Check that all samples == constant
|
// 3. Do sampling
|
||||||
|
constexpr int num_samples = 10;
|
||||||
for (size_t i = 0; i < num_samples; i++) {
|
for (size_t i = 0; i < num_samples; i++) {
|
||||||
// Sample from the bayes net
|
// Sample from the bayes net
|
||||||
const HybridValues sample = bn->sample(&rng);
|
const HybridValues sample = bn->sample(&rng);
|
||||||
|
|
||||||
|
// 4. Check that the ratio is constant.
|
||||||
// TODO(Varun) The ratio changes based on the mode
|
// TODO(Varun) The ratio changes based on the mode
|
||||||
// EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9);
|
// std::cout << compute_ratio(sample) << std::endl;
|
||||||
|
// EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-9);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue