Add test for prune
parent
a709a2d750
commit
5b713032c1
|
|
@ -25,8 +25,12 @@
|
|||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "gtsam/discrete/DecisionTree.h"
|
||||
#include "gtsam/discrete/DiscreteKey.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
|
@ -250,8 +254,60 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||
// DecisionTreeFactor with 3 keys:
|
||||
TEST(HybridGaussianConditional, Prune) {
|
||||
// Create a two key conditional:
|
||||
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
std::vector<GaussianConditional::shared_ptr> gcs;
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
gcs.push_back(
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
||||
}
|
||||
auto empty = std::make_shared<GaussianConditional>();
|
||||
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
HybridGaussianConditional hgc(modes, conditionals);
|
||||
|
||||
DiscreteKeys keys = modes;
|
||||
keys.push_back({M(3), 2});
|
||||
{
|
||||
for (size_t i = 0; i < 8; i++) {
|
||||
std::vector<double> potentials{0, 0, 0, 0, 0, 0, 0, 0};
|
||||
potentials[i] = 1;
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
}
|
||||
{
|
||||
const std::vector<double> potentials{0, 0, 0.5, 0, //
|
||||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
}
|
||||
{
|
||||
const std::vector<double> potentials{0.2, 0, 0.3, 0, //
|
||||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
}
|
||||
}
|
||||
|
||||
/* *************************************************************************
|
||||
*/
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
/* *************************************************************************
|
||||
*/
|
||||
|
|
|
|||
Loading…
Reference in New Issue