Use two new methods
parent
803dae75f3
commit
8746b15a4a
|
@ -308,6 +308,14 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
||||||
|
// Helper function to apply discrete values to the tree
|
||||||
|
auto choose(auto tree, const DiscreteValues &discreteValues) {
|
||||||
|
for (const auto &[key, value] : discreteValues) {
|
||||||
|
tree = tree.choose(key, value);
|
||||||
|
}
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a HybridConditional by choosing branches based on the given discrete
|
* Return a HybridConditional by choosing branches based on the given discrete
|
||||||
* values. If all discrete parents are specified, return a HybridConditional
|
* values. If all discrete parents are specified, return a HybridConditional
|
||||||
|
@ -316,31 +324,27 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
HybridConditional::shared_ptr choose(
|
HybridConditional::shared_ptr choose(
|
||||||
const HybridGaussianConditional::shared_ptr &self,
|
const HybridGaussianConditional::shared_ptr &self,
|
||||||
const DiscreteValues &discreteValues) {
|
const DiscreteValues &discreteValues) {
|
||||||
const auto &discreteParents = self->discreteKeys();
|
auto parentValues = discreteValues.filter(self->discreteKeys());
|
||||||
DiscreteValues deadParentValues;
|
auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys());
|
||||||
DiscreteKeys liveParents;
|
|
||||||
for (const auto &key : discreteParents) {
|
// Case 1: Fully determined, return corresponding Gaussian conditional
|
||||||
auto it = discreteValues.find(key.first);
|
if (parentValues.size() == self->discreteKeys().size()) {
|
||||||
if (it != discreteValues.end())
|
return std::make_shared<HybridConditional>(self->choose(parentValues));
|
||||||
deadParentValues[key.first] = it->second;
|
|
||||||
else
|
|
||||||
liveParents.emplace_back(key);
|
|
||||||
}
|
}
|
||||||
// If so then we just get the corresponding Gaussian conditional:
|
|
||||||
if (deadParentValues.size() == discreteParents.size()) {
|
// Case 2: Some live parents remain, build a new tree
|
||||||
// print on how many discreteParents we are choosing:
|
if (!unspecifiedParentKeys.empty()) {
|
||||||
return std::make_shared<HybridConditional>(self->choose(deadParentValues));
|
|
||||||
} else if (liveParents.size() > 0) {
|
|
||||||
auto newTree = self->factors();
|
auto newTree = self->factors();
|
||||||
for (auto &&[key, value] : discreteValues) {
|
for (const auto &[key, value] : parentValues) {
|
||||||
newTree = newTree.choose(key, value);
|
newTree = newTree.choose(key, value);
|
||||||
}
|
}
|
||||||
return std::make_shared<HybridConditional>(
|
return std::make_shared<HybridConditional>(
|
||||||
std::make_shared<HybridGaussianConditional>(liveParents, newTree));
|
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
|
||||||
} else {
|
newTree));
|
||||||
// Add as-is
|
|
||||||
return std::make_shared<HybridConditional>(self);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Case 3: No changes needed, return original
|
||||||
|
return std::make_shared<HybridConditional>(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -356,6 +360,11 @@ TEST(HybridGaussianConditional, PrunePlus) {
|
||||||
EXPECT(oneParent->isHybrid());
|
EXPECT(oneParent->isHybrid());
|
||||||
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
|
const HybridConditional::shared_ptr oneParent2 =
|
||||||
|
choose(hgc, {{M(7), 0}, {M(1), 0}});
|
||||||
|
EXPECT(oneParent2->isHybrid());
|
||||||
|
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr gaussian =
|
const HybridConditional::shared_ptr gaussian =
|
||||||
choose(hgc, {{M(1), 0}, {M(2), 1}});
|
choose(hgc, {{M(1), 0}, {M(2), 1}});
|
||||||
EXPECT(gaussian->asGaussian());
|
EXPECT(gaussian->asGaussian());
|
||||||
|
|
Loading…
Reference in New Issue