restrict method
parent
8746b15a4a
commit
4d1a8e5057
|
@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const {
|
||||||
return std::exp(logProbability(values));
|
return std::exp(logProbability(values));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
HybridConditional::shared_ptr HybridConditional::restrict(
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
if (auto gc = asGaussian()) {
|
||||||
|
return std::make_shared<HybridConditional>(gc);
|
||||||
|
} else if (auto dc = asDiscrete()) {
|
||||||
|
return std::make_shared<HybridConditional>(dc);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto hgc = asHybrid();
|
||||||
|
if (!hgc)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridConditional::restrict: conditional type not handled");
|
||||||
|
|
||||||
|
// Case 1: Fully determined, return corresponding Gaussian conditional
|
||||||
|
auto parentValues = discreteValues.filter(discreteKeys_);
|
||||||
|
if (parentValues.size() == discreteKeys_.size()) {
|
||||||
|
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: Some live parents remain, build a new tree
|
||||||
|
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
|
||||||
|
if (!unspecifiedParentKeys.empty()) {
|
||||||
|
auto newTree = hgc->factors();
|
||||||
|
for (const auto &[key, value] : parentValues) {
|
||||||
|
newTree = newTree.choose(key, value);
|
||||||
|
}
|
||||||
|
return std::make_shared<HybridConditional>(
|
||||||
|
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
|
||||||
|
newTree));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: No changes needed, return original
|
||||||
|
return std::make_shared<HybridConditional>(hgc);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a HybridConditional by choosing branches based on the given discrete
|
||||||
|
* values. If all discrete parents are specified, return a HybridConditional
|
||||||
|
* which is just a GaussianConditional. If this conditional is *not* a hybrid
|
||||||
|
* conditional, just return that.
|
||||||
|
*/
|
||||||
|
shared_ptr restrict(const DiscreteValues& discreteValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -316,57 +316,34 @@ auto choose(auto tree, const DiscreteValues &discreteValues) {
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/* *************************************************************************
|
||||||
* Return a HybridConditional by choosing branches based on the given discrete
|
* This test verifies the behavior of the restrict method in different
|
||||||
* values. If all discrete parents are specified, return a HybridConditional
|
* scenarios:
|
||||||
* which is just a GaussianConditional.
|
* - When no restrictions are applied.
|
||||||
|
* - When one parent is restricted.
|
||||||
|
* - When two parents are restricted.
|
||||||
|
* - When the restriction results in a Gaussian conditional.
|
||||||
*/
|
*/
|
||||||
HybridConditional::shared_ptr choose(
|
TEST(HybridGaussianConditional, Restrict) {
|
||||||
const HybridGaussianConditional::shared_ptr &self,
|
// Create a HybridConditional with two discrete parents P(z0|m0,m1)
|
||||||
const DiscreteValues &discreteValues) {
|
const auto hc =
|
||||||
auto parentValues = discreteValues.filter(self->discreteKeys());
|
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
|
||||||
auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys());
|
|
||||||
|
|
||||||
// Case 1: Fully determined, return corresponding Gaussian conditional
|
const HybridConditional::shared_ptr same = hc->restrict({});
|
||||||
if (parentValues.size() == self->discreteKeys().size()) {
|
|
||||||
return std::make_shared<HybridConditional>(self->choose(parentValues));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: Some live parents remain, build a new tree
|
|
||||||
if (!unspecifiedParentKeys.empty()) {
|
|
||||||
auto newTree = self->factors();
|
|
||||||
for (const auto &[key, value] : parentValues) {
|
|
||||||
newTree = newTree.choose(key, value);
|
|
||||||
}
|
|
||||||
return std::make_shared<HybridConditional>(
|
|
||||||
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
|
|
||||||
newTree));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 3: No changes needed, return original
|
|
||||||
return std::make_shared<HybridConditional>(self);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// Test the pruning and dead-mode removal.
|
|
||||||
TEST(HybridGaussianConditional, PrunePlus) {
|
|
||||||
using two_mode_measurement::hgc; // two discrete parents
|
|
||||||
|
|
||||||
const HybridConditional::shared_ptr same = choose(hgc, {});
|
|
||||||
EXPECT(same->isHybrid());
|
EXPECT(same->isHybrid());
|
||||||
EXPECT(same->asHybrid()->nrComponents() == 4);
|
EXPECT(same->asHybrid()->nrComponents() == 4);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}});
|
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
|
||||||
EXPECT(oneParent->isHybrid());
|
EXPECT(oneParent->isHybrid());
|
||||||
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr oneParent2 =
|
const HybridConditional::shared_ptr oneParent2 =
|
||||||
choose(hgc, {{M(7), 0}, {M(1), 0}});
|
hc->restrict({{M(7), 0}, {M(1), 0}});
|
||||||
EXPECT(oneParent2->isHybrid());
|
EXPECT(oneParent2->isHybrid());
|
||||||
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr gaussian =
|
const HybridConditional::shared_ptr gaussian =
|
||||||
choose(hgc, {{M(1), 0}, {M(2), 1}});
|
hc->restrict({{M(1), 0}, {M(2), 1}});
|
||||||
EXPECT(gaussian->asGaussian());
|
EXPECT(gaussian->asGaussian());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue