prototype choose

release/4.3a0
Frank Dellaert 2025-01-29 21:36:51 -05:00
parent 98cdf1193f
commit d17215c69b
1 changed files with 78 additions and 17 deletions

View File

@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
} }
/* ************************************************************************* */
namespace two_mode_measurement {
// Create a two key conditional:
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
const std::vector<GaussianConditional::shared_ptr> gcs = {
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
const auto hgc =
std::make_shared<HybridGaussianConditional>(modes, conditionals);
} // namespace two_mode_measurement
/* ************************************************************************* */ /* ************************************************************************* */
// Test pruning a HybridGaussianConditional with two discrete keys, based on a // Test pruning a HybridGaussianConditional with two discrete keys, based on a
// DecisionTreeFactor with 3 keys: // DecisionTreeFactor with 3 keys:
TEST(HybridGaussianConditional, Prune) { TEST(HybridGaussianConditional, Prune) {
// Create a two key conditional: using two_mode_measurement::hgc;
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
std::vector<GaussianConditional::shared_ptr> gcs;
for (size_t i = 0; i < 4; i++) {
gcs.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
}
auto empty = std::make_shared<GaussianConditional>();
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
HybridGaussianConditional hgc(modes, conditionals);
DiscreteKeys keys = modes; DiscreteKeys keys = two_mode_measurement::modes;
keys.push_back({M(3), 2}); keys.push_back({M(3), 2});
{ {
for (size_t i = 0; i < 8; i++) { for (size_t i = 0; i < 8; i++) {
@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional // Prune the HybridGaussianConditional
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional // Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
} }
@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals // Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
// Check that the minimum negLogConstant is set correctly // Check that the minimum negLogConstant is set correctly
EXPECT_DOUBLES_EQUAL( EXPECT_DOUBLES_EQUAL(
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
pruned->negLogConstant(), 1e-9); pruned->negLogConstant(), 1e-9);
} }
{ {
@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals // Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
// Check that the minimum negLogConstant is correct // Check that the minimum negLogConstant is correct
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
} }
} }
/* ************************************************************************* /* ************************************************************************* */
#include <gtsam/hybrid/HybridConditional.h>
/**
* Return a HybridConditional by choosing branches based on the given discrete
* values. If all discrete parents are specified, return a HybridConditional
* which is just a GaussianConditional.
*/ */
HybridConditional::shared_ptr choose(
const HybridGaussianConditional::shared_ptr &self,
const DiscreteValues &discreteValues) {
const auto &discreteParents = self->discreteKeys();
DiscreteValues deadParentValues;
DiscreteKeys liveParents;
for (const auto &key : discreteParents) {
auto it = discreteValues.find(key.first);
if (it != discreteValues.end())
deadParentValues[key.first] = it->second;
else
liveParents.emplace_back(key);
}
// If so then we just get the corresponding Gaussian conditional:
if (deadParentValues.size() == discreteParents.size()) {
// print on how many discreteParents we are choosing:
return std::make_shared<HybridConditional>(self->choose(deadParentValues));
} else if (liveParents.size() > 0) {
auto newTree = self->factors();
for (auto &&[key, value] : discreteValues) {
newTree = newTree.choose(key, value);
}
return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(liveParents, newTree));
} else {
// Add as-is
return std::make_shared<HybridConditional>(self);
}
}
/* ************************************************************************* */
// Test the pruning and dead-mode removal.
TEST(HybridGaussianConditional, PrunePlus) {
using two_mode_measurement::hgc; // two discrete parents
const HybridConditional::shared_ptr same = choose(hgc, {});
EXPECT(same->isHybrid());
EXPECT(same->asHybrid()->nrComponents() == 4);
const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}});
EXPECT(oneParent->isHybrid());
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr gaussian =
choose(hgc, {{M(1), 0}, {M(2), 1}});
EXPECT(gaussian->asGaussian());
}
/* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);