Add test for prune

release/4.3a0
Frank Dellaert 2024-10-01 11:31:16 -07:00
parent a709a2d750
commit 5b713032c1
1 changed files with 57 additions and 1 deletions

View File

@ -25,8 +25,12 @@
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include <memory>
#include <vector>
#include "gtsam/discrete/DecisionTree.h"
#include "gtsam/discrete/DiscreteKey.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
@ -250,8 +254,60 @@ TEST(HybridGaussianConditional, Likelihood2) {
}
/* ************************************************************************* */
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
// DecisionTreeFactor with 3 keys:
TEST(HybridGaussianConditional, Prune) {
// Create a two key conditional:
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
std::vector<GaussianConditional::shared_ptr> gcs;
for (size_t i = 0; i < 4; i++) {
gcs.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
}
auto empty = std::make_shared<GaussianConditional>();
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
HybridGaussianConditional hgc(modes, conditionals);
DiscreteKeys keys = modes;
keys.push_back({M(3), 2});
{
for (size_t i = 0; i < 8; i++) {
std::vector<double> potentials{0, 0, 0, 0, 0, 0, 0, 0};
potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional
const auto pruned = hgc.prune(decisionTreeFactor);
// Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
}
}
{
const std::vector<double> potentials{0, 0, 0.5, 0, //
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
}
{
const std::vector<double> potentials{0.2, 0, 0.3, 0, //
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
}
}
/* *************************************************************************
*/
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
/* *************************************************************************
*/