Merge pull request #2013 from borglab/fix/pruning

Fix pruning
release/4.3a0
Frank Dellaert 2025-01-30 10:54:46 -05:00 committed by GitHub
commit 3cf15901c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 325 additions and 156 deletions

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/FactorGraph-inst.h>
namespace gtsam {
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
return result;
}
/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > *marginalThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
// Set the fixed values if requested.
if (fixedValues) {
*fixedValues = deadModesValues;
}
}
// Return the resulting DiscreteBayesNet.
DiscreteBayesNet result;
if (pruned.keys().size() > 0) result.push_back(pruned);
return result;
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,

View File

@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*/
DiscreteValues sample(DiscreteValues given) const;
/**
* @brief Prune the Bayes net
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param marginalThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals.
*/
DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const;
///@}
/// @name Wrapper support
/// @{

View File

@ -73,23 +73,77 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
}
/**
* Insert key-assignment pair.
* Throws an invalid_argument exception if
* any keys to be inserted are already used. */
* @brief Insert key-assignment pair.
*
* @param assignment The key-assignment pair to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const std::pair<Key, size_t>& assignment);
/** Insert all values from \c values. Throws an invalid_argument exception if
* any keys to be inserted are already used. */
/**
* @brief Insert all values from another DiscreteValues object.
*
* @param values The DiscreteValues object containing values to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const DiscreteValues& values);
/** For all key/value pairs in \c values, replace values with corresponding
* keys in this object with those in \c values. Throws std::out_of_range if
* any keys in \c values are not present in this object. */
/**
* @brief Update values with corresponding keys from another DiscreteValues
* object.
*
* @param values The DiscreteValues object containing values to update.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::out_of_range if any keys in values are not present in this
* object.
*/
DiscreteValues& update(const DiscreteValues& values);
/**
* @brief Return a vector of DiscreteValues, one for each possible
* combination of values.
* @brief Check if the DiscreteValues contains the given key.
*
* @param key The key to check for.
* @return True if the key is present, false otherwise.
*/
bool contains(Key key) const { return this->find(key) != this->end(); }
/**
* @brief Filter values by keys.
*
* @param keys The keys to filter by.
* @return DiscreteValues The filtered DiscreteValues object.
*/
DiscreteValues filter(const DiscreteKeys& keys) const {
DiscreteValues result;
for (const auto& [key, _] : keys) {
if (auto it = this->find(key); it != this->end())
result[key] = it->second;
}
return result;
}
/**
* @brief Return the keys that are not present in the DiscreteValues object.
*
* @param keys The keys to check for.
* @return DiscreteKeys Keys not present in the DiscreteValues object.
*/
DiscreteKeys missingKeys(const DiscreteKeys& keys) const {
DiscreteKeys result;
for (const auto& [key, cardinality] : keys) {
if (!this->contains(key)) result.emplace_back(key, cardinality);
}
return result;
}
/**
* @brief Return a vector of DiscreteValues, one for each possible combination
* of values.
*
* @param keys The keys to generate the Cartesian product for.
* @return std::vector<DiscreteValues> The vector of DiscreteValues.
*/
static std::vector<DiscreteValues> CartesianProduct(
const DiscreteKeys& keys) {
@ -135,12 +189,14 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
}
/// Free version of markdown.
std::string GTSAM_EXPORT markdown(const DiscreteValues& values,
std::string GTSAM_EXPORT
markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
/// Free version of html.
std::string GTSAM_EXPORT html(const DiscreteValues& values,
std::string GTSAM_EXPORT
html(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});

View File

@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) {
DiscreteValues(kExample).update({{12, 2}})));
}
/* ************************************************************************* */
// Test DiscreteValues::filter
TEST(DiscreteValues, Filter) {
DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}};
DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values
EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys)));
}
/* ************************************************************************* */
// Test DiscreteValues::missingKeys
TEST(DiscreteValues, MissingKeys) {
DiscreteValues values = {{12, 1}, {5, 0}};
DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing
EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys)));
}
/* ************************************************************************* */
// Check markdown representation with a value formatter.
TEST(DiscreteValues, markdownWithValueFormatter) {

View File

@ -19,7 +19,6 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
@ -43,114 +42,59 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
}
/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();
// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
joint = joint * (*conditional);
DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
// Set the fixed values if requested.
if (marginalThreshold && fixedValues) {
*fixedValues = fixed;
}
// Initialize the resulting HybridBayesNet.
HybridBayesNet result;
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
// Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further.
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
// No-op if not a HybridGaussianConditional.
if (marginalThreshold) conditional = conditional->restrict(fixed);
int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the HybridGaussianConditional for just nullptrs.
*/
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned discrete joint.
for (auto &&conditional : *this) {
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);
if (!prunedHybridGaussianConditional) {
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
conditionalDiscreteKeys.push_back(dkey.first);
}
// The discrete keys in the conditional are the same as the keys in the
// dead modes, then we just get the corresponding Gaussian conditional.
if (deadKeys == conditionalDiscreteKeys) {
result.push_back(
prunedHybridGaussianConditional->choose(deadModesValues));
} else {
// Add as-is
result.push_back(prunedHybridGaussianConditional);
}
} else {
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
}
} else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional
result.push_back(gc);
} else
throw std::runtime_error(
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
}
// We ignore DiscreteConditional as they are already pruned and added.
}
// Add the pruned discrete conditionals to the result.
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
result.push_back(discrete);
return result;
}

View File

@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
*
* @param maxNrLeaves Continuous values at which to compute the error.
* @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @param marginalThreshold The threshold to check the mode marginals against.
* @param fixedValues The fixed values resulting from dead mode removal.
*
* @note If marginal greater than this threshold, the mode gets assigned that
* value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(
size_t maxNrLeaves,
const std::optional<double> &deadModeThreshold = {}) const;
HybridBayesNet prune(size_t maxNrLeaves,
const std::optional<double> &marginalThreshold = {},
DiscreteValues *fixedValues = nullptr) const;
/**
* @brief Error method using HybridValues which returns specific error for

View File

@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values));
}
/* ************************************************************************ */
HybridConditional::shared_ptr HybridConditional::restrict(
const DiscreteValues &discreteValues) const {
if (auto gc = asGaussian()) {
return std::make_shared<HybridConditional>(gc);
} else if (auto dc = asDiscrete()) {
return std::make_shared<HybridConditional>(dc);
};
auto hgc = asHybrid();
if (!hgc)
throw std::runtime_error(
"HybridConditional::restrict: conditional type not handled");
// Case 1: Fully determined, return corresponding Gaussian conditional
auto parentValues = discreteValues.filter(discreteKeys_);
if (parentValues.size() == discreteKeys_.size()) {
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
}
// Case 2: Some live parents remain, build a new tree
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
if (!unspecifiedParentKeys.empty()) {
auto newTree = hgc->factors();
for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value);
}
return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
newTree));
}
// Case 3: No changes needed, return original
return std::make_shared<HybridConditional>(hgc);
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional
return true;
}
/**
* Return a HybridConditional by choosing branches based on the given discrete
* values. If all discrete parents are specified, return a HybridConditional
* which is just a GaussianConditional. If this conditional is *not* a hybrid
* conditional, just return that.
*/
shared_ptr restrict(const DiscreteValues& discreteValues) const;
/// @}
private:

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
}
// Add the partial bayes net to the posterior bayes net.

View File

@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_;
std::optional<double> marginalThreshold_;
DiscreteValues fixedValues_;
public:
/**
* @brief Constructor
*
* @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a
* @param marginalThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value.
*/
HybridSmoother(const std::optional<double> deadModeThreshold = {})
: deadModeThreshold_(deadModeThreshold) {}
HybridSmoother(const std::optional<double> marginalThreshold = {})
: marginalThreshold_(marginalThreshold) {}
/**
* Given new factors, perform an incremental update.

View File

@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) {
// prune
auto pruned = bayesNet.prune(1);
CHECK(pruned.at(1)->asHybrid());
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents());
CHECK(pruned.at(0)->asHybrid());
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
EXPECT(!pruned.equals(bayesNet));
// error
@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
const double pruneDeadVariables = 0.99;
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
// First conditional is still the same: P( x0 | x1 m0)
EXPECT(prunedBayesNet.at(0)->isHybrid());
// Check that hybrid conditional that only depend on M1
// is now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(1)->isContinuous());
// Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0)
EXPECT(prunedBayesNet.at(0)->isHybrid());
// Check that discrete joint only has M0 and not (M0, M1)
// since M0 is removed
KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys();
EXPECT(KeyVector{M(0)} == actual_keys);
// Check that hybrid conditionals that only depend on M1
// are now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(0)->isDiscrete());
EXPECT(prunedBayesNet.at(1)->isDiscrete());
EXPECT(prunedBayesNet.at(2)->isHybrid());
// Only P(X2 | X1, M1) depends on M1,
// so it gets convert to a Gaussian P(X2 | X1)
EXPECT(prunedBayesNet.at(3)->isContinuous());
EXPECT(prunedBayesNet.at(4)->isHybrid());
auto joint = prunedBayesNet.at(3)->asDiscrete();
EXPECT(joint);
EXPECT(joint->keys() == KeyVector{M(0)});
}
/* ****************************************************************************/
@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) {
const HybridValues hybridValues{delta.continuous(), discrete_values};
double pruned_logProbability = 0;
pruned_logProbability +=
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues);
prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
pruned_logProbability +=
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues);
prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues);
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
};
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
CHECK(pruned.at(0)->asDiscrete());
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete();
CHECK(pruned.at(4)->asDiscrete());
auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
auto discrete_conditional_tree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals);

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
@ -238,22 +239,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
}
/* ************************************************************************* */
namespace two_mode_measurement {
// Create a two key conditional:
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
const std::vector<GaussianConditional::shared_ptr> gcs = {
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
const auto hgc =
std::make_shared<HybridGaussianConditional>(modes, conditionals);
} // namespace two_mode_measurement
/* ************************************************************************* */
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
// DecisionTreeFactor with 3 keys:
TEST(HybridGaussianConditional, Prune) {
// Create a two key conditional:
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
std::vector<GaussianConditional::shared_ptr> gcs;
for (size_t i = 0; i < 4; i++) {
gcs.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
}
auto empty = std::make_shared<GaussianConditional>();
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
HybridGaussianConditional hgc(modes, conditionals);
using two_mode_measurement::hgc;
DiscreteKeys keys = modes;
DiscreteKeys keys = two_mode_measurement::modes;
keys.push_back({M(3), 2});
{
for (size_t i = 0; i < 8; i++) {
@ -262,7 +268,7 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
}
@ -273,14 +279,14 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
// Check that the minimum negLogConstant is set correctly
EXPECT_DOUBLES_EQUAL(
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
pruned->negLogConstant(), 1e-9);
}
{
@ -289,18 +295,48 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
// Check that the minimum negLogConstant is correct
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9);
EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
}
}
/* *************************************************************************
* This test verifies the behavior of the restrict method in different
* scenarios:
* - When no restrictions are applied.
* - When one parent is restricted.
* - When two parents are restricted.
* - When the restriction results in a Gaussian conditional.
*/
TEST(HybridGaussianConditional, Restrict) {
// Create a HybridConditional with two discrete parents P(z0|m0,m1)
const auto hc =
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
const HybridConditional::shared_ptr same = hc->restrict({});
EXPECT(same->isHybrid());
EXPECT(same->asHybrid()->nrComponents() == 4);
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
EXPECT(oneParent->isHybrid());
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr oneParent2 =
hc->restrict({{M(7), 0}, {M(1), 0}});
EXPECT(oneParent2->isHybrid());
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr gaussian =
hc->restrict({{M(1), 0}, {M(2), 1}});
EXPECT(gaussian->asGaussian());
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);

View File

@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) {
}
EXPECT_LONGS_EQUAL(11,
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues());
smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
// Get the continuous delta update as well as
// the optimal discrete assignment.
@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) {
}
EXPECT_LONGS_EQUAL(14,
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues());
smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
// Get the continuous delta update as well as
// the optimal discrete assignment.