prototype choose
parent
98cdf1193f
commit
d17215c69b
|
@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
|||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
namespace two_mode_measurement {
|
||||
// Create a two key conditional:
|
||||
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
const std::vector<GaussianConditional::shared_ptr> gcs = {
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
|
||||
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
const auto hgc =
|
||||
std::make_shared<HybridGaussianConditional>(modes, conditionals);
|
||||
} // namespace two_mode_measurement
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||
// DecisionTreeFactor with 3 keys:
|
||||
TEST(HybridGaussianConditional, Prune) {
|
||||
// Create a two key conditional:
|
||||
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||
std::vector<GaussianConditional::shared_ptr> gcs;
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
gcs.push_back(
|
||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
||||
}
|
||||
auto empty = std::make_shared<GaussianConditional>();
|
||||
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||
HybridGaussianConditional hgc(modes, conditionals);
|
||||
using two_mode_measurement::hgc;
|
||||
|
||||
DiscreteKeys keys = modes;
|
||||
DiscreteKeys keys = two_mode_measurement::modes;
|
||||
keys.push_back({M(3), 2});
|
||||
{
|
||||
for (size_t i = 0; i < 8; i++) {
|
||||
|
@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
|
@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
|
||||
// Check that the minimum negLogConstant is set correctly
|
||||
EXPECT_DOUBLES_EQUAL(
|
||||
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
||||
hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
||||
pruned->negLogConstant(), 1e-9);
|
||||
}
|
||||
{
|
||||
|
@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
|
||||
// Check that the minimum negLogConstant is correct
|
||||
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
/* *************************************************************************
|
||||
/* ************************************************************************* */
|
||||
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
||||
/**
|
||||
* Return a HybridConditional by choosing branches based on the given discrete
|
||||
* values. If all discrete parents are specified, return a HybridConditional
|
||||
* which is just a GaussianConditional.
|
||||
*/
|
||||
HybridConditional::shared_ptr choose(
|
||||
const HybridGaussianConditional::shared_ptr &self,
|
||||
const DiscreteValues &discreteValues) {
|
||||
const auto &discreteParents = self->discreteKeys();
|
||||
DiscreteValues deadParentValues;
|
||||
DiscreteKeys liveParents;
|
||||
for (const auto &key : discreteParents) {
|
||||
auto it = discreteValues.find(key.first);
|
||||
if (it != discreteValues.end())
|
||||
deadParentValues[key.first] = it->second;
|
||||
else
|
||||
liveParents.emplace_back(key);
|
||||
}
|
||||
// If so then we just get the corresponding Gaussian conditional:
|
||||
if (deadParentValues.size() == discreteParents.size()) {
|
||||
// print on how many discreteParents we are choosing:
|
||||
return std::make_shared<HybridConditional>(self->choose(deadParentValues));
|
||||
} else if (liveParents.size() > 0) {
|
||||
auto newTree = self->factors();
|
||||
for (auto &&[key, value] : discreteValues) {
|
||||
newTree = newTree.choose(key, value);
|
||||
}
|
||||
return std::make_shared<HybridConditional>(
|
||||
std::make_shared<HybridGaussianConditional>(liveParents, newTree));
|
||||
} else {
|
||||
// Add as-is
|
||||
return std::make_shared<HybridConditional>(self);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the pruning and dead-mode removal.
|
||||
TEST(HybridGaussianConditional, PrunePlus) {
|
||||
using two_mode_measurement::hgc; // two discrete parents
|
||||
|
||||
const HybridConditional::shared_ptr same = choose(hgc, {});
|
||||
EXPECT(same->isHybrid());
|
||||
EXPECT(same->asHybrid()->nrComponents() == 4);
|
||||
|
||||
const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}});
|
||||
EXPECT(oneParent->isHybrid());
|
||||
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||
|
||||
const HybridConditional::shared_ptr gaussian =
|
||||
choose(hgc, {{M(1), 0}, {M(2), 1}});
|
||||
EXPECT(gaussian->asGaussian());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
|
|
Loading…
Reference in New Issue