fix hybrid tests

release/4.3a0
Varun Agrawal 2024-12-31 10:34:03 -05:00
parent 5019153e12
commit 623bd63ec8
3 changed files with 8 additions and 6 deletions

View File

@ -454,7 +454,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
}
size_t maxNrLeaves = 3;
auto prunedDecisionTree = joint.prune(maxNrLeaves);
auto prunedDecisionTree = *joint.prune(maxNrLeaves);
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,

View File

@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>(
keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
}
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned = hgc.prune(
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor);
const auto pruned = hgc.prune(
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());

View File

@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
}
/****************************************************************************