prototype choose
parent
98cdf1193f
commit
d17215c69b
|
@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
||||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
namespace two_mode_measurement {
|
||||||
|
// Create a two key conditional:
|
||||||
|
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||||
|
const std::vector<GaussianConditional::shared_ptr> gcs = {
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
|
||||||
|
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||||
|
const auto hgc =
|
||||||
|
std::make_shared<HybridGaussianConditional>(modes, conditionals);
|
||||||
|
} // namespace two_mode_measurement
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||||
// DecisionTreeFactor with 3 keys:
|
// DecisionTreeFactor with 3 keys:
|
||||||
TEST(HybridGaussianConditional, Prune) {
|
TEST(HybridGaussianConditional, Prune) {
|
||||||
// Create a two key conditional:
|
using two_mode_measurement::hgc;
|
||||||
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
|
||||||
std::vector<GaussianConditional::shared_ptr> gcs;
|
|
||||||
for (size_t i = 0; i < 4; i++) {
|
|
||||||
gcs.push_back(
|
|
||||||
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
|
||||||
}
|
|
||||||
auto empty = std::make_shared<GaussianConditional>();
|
|
||||||
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
|
||||||
HybridGaussianConditional hgc(modes, conditionals);
|
|
||||||
|
|
||||||
DiscreteKeys keys = modes;
|
DiscreteKeys keys = two_mode_measurement::modes;
|
||||||
keys.push_back({M(3), 2});
|
keys.push_back({M(3), 2});
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < 8; i++) {
|
for (size_t i = 0; i < 8; i++) {
|
||||||
|
@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
// Prune the HybridGaussianConditional
|
// Prune the HybridGaussianConditional
|
||||||
const auto pruned =
|
const auto pruned =
|
||||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||||
}
|
}
|
||||||
|
@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned =
|
const auto pruned =
|
||||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||||
|
|
||||||
// Check that the minimum negLogConstant is set correctly
|
// Check that the minimum negLogConstant is set correctly
|
||||||
EXPECT_DOUBLES_EQUAL(
|
EXPECT_DOUBLES_EQUAL(
|
||||||
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
|
||||||
pruned->negLogConstant(), 1e-9);
|
pruned->negLogConstant(), 1e-9);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
|
@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned =
|
const auto pruned =
|
||||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||||
|
|
||||||
// Check that the minimum negLogConstant is correct
|
// Check that the minimum negLogConstant is correct
|
||||||
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9);
|
EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *************************************************************************
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a HybridConditional by choosing branches based on the given discrete
|
||||||
|
* values. If all discrete parents are specified, return a HybridConditional
|
||||||
|
* which is just a GaussianConditional.
|
||||||
*/
|
*/
|
||||||
|
HybridConditional::shared_ptr choose(
|
||||||
|
const HybridGaussianConditional::shared_ptr &self,
|
||||||
|
const DiscreteValues &discreteValues) {
|
||||||
|
const auto &discreteParents = self->discreteKeys();
|
||||||
|
DiscreteValues deadParentValues;
|
||||||
|
DiscreteKeys liveParents;
|
||||||
|
for (const auto &key : discreteParents) {
|
||||||
|
auto it = discreteValues.find(key.first);
|
||||||
|
if (it != discreteValues.end())
|
||||||
|
deadParentValues[key.first] = it->second;
|
||||||
|
else
|
||||||
|
liveParents.emplace_back(key);
|
||||||
|
}
|
||||||
|
// If so then we just get the corresponding Gaussian conditional:
|
||||||
|
if (deadParentValues.size() == discreteParents.size()) {
|
||||||
|
// print on how many discreteParents we are choosing:
|
||||||
|
return std::make_shared<HybridConditional>(self->choose(deadParentValues));
|
||||||
|
} else if (liveParents.size() > 0) {
|
||||||
|
auto newTree = self->factors();
|
||||||
|
for (auto &&[key, value] : discreteValues) {
|
||||||
|
newTree = newTree.choose(key, value);
|
||||||
|
}
|
||||||
|
return std::make_shared<HybridConditional>(
|
||||||
|
std::make_shared<HybridGaussianConditional>(liveParents, newTree));
|
||||||
|
} else {
|
||||||
|
// Add as-is
|
||||||
|
return std::make_shared<HybridConditional>(self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test the pruning and dead-mode removal.
|
||||||
|
TEST(HybridGaussianConditional, PrunePlus) {
|
||||||
|
using two_mode_measurement::hgc; // two discrete parents
|
||||||
|
|
||||||
|
const HybridConditional::shared_ptr same = choose(hgc, {});
|
||||||
|
EXPECT(same->isHybrid());
|
||||||
|
EXPECT(same->asHybrid()->nrComponents() == 4);
|
||||||
|
|
||||||
|
const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}});
|
||||||
|
EXPECT(oneParent->isHybrid());
|
||||||
|
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
|
const HybridConditional::shared_ptr gaussian =
|
||||||
|
choose(hgc, {{M(1), 0}, {M(2), 1}});
|
||||||
|
EXPECT(gaussian->asGaussian());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
|
|
Loading…
Reference in New Issue