Use two new methods

release/4.3a0
Frank Dellaert 2025-01-29 22:08:12 -05:00
parent 803dae75f3
commit 8746b15a4a
1 changed files with 28 additions and 19 deletions

View File

@ -308,6 +308,14 @@ TEST(HybridGaussianConditional, Prune) {
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
// Helper function to apply discrete values to the tree
auto choose(auto tree, const DiscreteValues &discreteValues) {
for (const auto &[key, value] : discreteValues) {
tree = tree.choose(key, value);
}
return tree;
}
/** /**
* Return a HybridConditional by choosing branches based on the given discrete * Return a HybridConditional by choosing branches based on the given discrete
* values. If all discrete parents are specified, return a HybridConditional * values. If all discrete parents are specified, return a HybridConditional
@ -316,31 +324,27 @@ TEST(HybridGaussianConditional, Prune) {
HybridConditional::shared_ptr choose( HybridConditional::shared_ptr choose(
const HybridGaussianConditional::shared_ptr &self, const HybridGaussianConditional::shared_ptr &self,
const DiscreteValues &discreteValues) { const DiscreteValues &discreteValues) {
const auto &discreteParents = self->discreteKeys(); auto parentValues = discreteValues.filter(self->discreteKeys());
DiscreteValues deadParentValues; auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys());
DiscreteKeys liveParents;
for (const auto &key : discreteParents) { // Case 1: Fully determined, return corresponding Gaussian conditional
auto it = discreteValues.find(key.first); if (parentValues.size() == self->discreteKeys().size()) {
if (it != discreteValues.end()) return std::make_shared<HybridConditional>(self->choose(parentValues));
deadParentValues[key.first] = it->second;
else
liveParents.emplace_back(key);
} }
// If so then we just get the corresponding Gaussian conditional:
if (deadParentValues.size() == discreteParents.size()) { // Case 2: Some live parents remain, build a new tree
// print on how many discreteParents we are choosing: if (!unspecifiedParentKeys.empty()) {
return std::make_shared<HybridConditional>(self->choose(deadParentValues));
} else if (liveParents.size() > 0) {
auto newTree = self->factors(); auto newTree = self->factors();
for (auto &&[key, value] : discreteValues) { for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value); newTree = newTree.choose(key, value);
} }
return std::make_shared<HybridConditional>( return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(liveParents, newTree)); std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
} else { newTree));
// Add as-is
return std::make_shared<HybridConditional>(self);
} }
// Case 3: No changes needed, return original
return std::make_shared<HybridConditional>(self);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -356,6 +360,11 @@ TEST(HybridGaussianConditional, PrunePlus) {
EXPECT(oneParent->isHybrid()); EXPECT(oneParent->isHybrid());
EXPECT(oneParent->asHybrid()->nrComponents() == 2); EXPECT(oneParent->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr oneParent2 =
choose(hgc, {{M(7), 0}, {M(1), 0}});
EXPECT(oneParent2->isHybrid());
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr gaussian = const HybridConditional::shared_ptr gaussian =
choose(hgc, {{M(1), 0}, {M(2), 1}}); choose(hgc, {{M(1), 0}, {M(2), 1}});
EXPECT(gaussian->asGaussian()); EXPECT(gaussian->asGaussian());