Merge pull request #2013 from borglab/fix/pruning

Fix pruning
release/4.3a0
Frank Dellaert 2025-01-30 10:54:46 -05:00 committed by GitHub
commit 3cf15901c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 325 additions and 156 deletions

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
namespace gtsam { namespace gtsam {
@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
return result; return result;
} }
/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > *marginalThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
// Set the fixed values if requested.
if (fixedValues) {
*fixedValues = deadModesValues;
}
}
// Return the resulting DiscreteBayesNet.
DiscreteBayesNet result;
if (pruned.keys().size() > 0) result.push_back(pruned);
return result;
}
/* *********************************************************************** */ /* *********************************************************************** */
std::string DiscreteBayesNet::markdown( std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,

View File

@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*/ */
DiscreteValues sample(DiscreteValues given) const; DiscreteValues sample(DiscreteValues given) const;
/**
* @brief Prune the Bayes net
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param marginalThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals.
*/
DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -68,28 +68,82 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x);
// insert in base class; // insert in base class;
std::pair<iterator, bool> insert( const value_type& value ){ std::pair<iterator, bool> insert(const value_type& value) {
return Base::insert(value); return Base::insert(value);
} }
/** /**
* Insert key-assignment pair. * @brief Insert key-assignment pair.
* Throws an invalid_argument exception if *
* any keys to be inserted are already used. */ * @param assignment The key-assignment pair to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const std::pair<Key, size_t>& assignment); DiscreteValues& insert(const std::pair<Key, size_t>& assignment);
/** Insert all values from \c values. Throws an invalid_argument exception if /**
* any keys to be inserted are already used. */ * @brief Insert all values from another DiscreteValues object.
*
* @param values The DiscreteValues object containing values to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const DiscreteValues& values); DiscreteValues& insert(const DiscreteValues& values);
/** For all key/value pairs in \c values, replace values with corresponding /**
* keys in this object with those in \c values. Throws std::out_of_range if * @brief Update values with corresponding keys from another DiscreteValues
* any keys in \c values are not present in this object. */ * object.
*
* @param values The DiscreteValues object containing values to update.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::out_of_range if any keys in values are not present in this
* object.
*/
DiscreteValues& update(const DiscreteValues& values); DiscreteValues& update(const DiscreteValues& values);
/** /**
* @brief Return a vector of DiscreteValues, one for each possible * @brief Check if the DiscreteValues contains the given key.
* combination of values. *
* @param key The key to check for.
* @return True if the key is present, false otherwise.
*/
bool contains(Key key) const { return this->find(key) != this->end(); }
/**
* @brief Filter values by keys.
*
* @param keys The keys to filter by.
* @return DiscreteValues The filtered DiscreteValues object.
*/
DiscreteValues filter(const DiscreteKeys& keys) const {
DiscreteValues result;
for (const auto& [key, _] : keys) {
if (auto it = this->find(key); it != this->end())
result[key] = it->second;
}
return result;
}
/**
* @brief Return the keys that are not present in the DiscreteValues object.
*
* @param keys The keys to check for.
* @return DiscreteKeys Keys not present in the DiscreteValues object.
*/
DiscreteKeys missingKeys(const DiscreteKeys& keys) const {
DiscreteKeys result;
for (const auto& [key, cardinality] : keys) {
if (!this->contains(key)) result.emplace_back(key, cardinality);
}
return result;
}
/**
* @brief Return a vector of DiscreteValues, one for each possible combination
* of values.
*
* @param keys The keys to generate the Cartesian product for.
* @return std::vector<DiscreteValues> The vector of DiscreteValues.
*/ */
static std::vector<DiscreteValues> CartesianProduct( static std::vector<DiscreteValues> CartesianProduct(
const DiscreteKeys& keys) { const DiscreteKeys& keys) {
@ -135,14 +189,16 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
} }
/// Free version of markdown. /// Free version of markdown.
std::string GTSAM_EXPORT markdown(const DiscreteValues& values, std::string GTSAM_EXPORT
const KeyFormatter& keyFormatter = DefaultKeyFormatter, markdown(const DiscreteValues& values,
const DiscreteValues::Names& names = {}); const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
/// Free version of html. /// Free version of html.
std::string GTSAM_EXPORT html(const DiscreteValues& values, std::string GTSAM_EXPORT
const KeyFormatter& keyFormatter = DefaultKeyFormatter, html(const DiscreteValues& values,
const DiscreteValues::Names& names = {}); const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
// traits // traits
template <> template <>

View File

@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) {
DiscreteValues(kExample).update({{12, 2}}))); DiscreteValues(kExample).update({{12, 2}})));
} }
/* ************************************************************************* */
// Test DiscreteValues::filter
TEST(DiscreteValues, Filter) {
DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}};
DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values
EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys)));
}
/* ************************************************************************* */
// Test DiscreteValues::missingKeys
TEST(DiscreteValues, MissingKeys) {
DiscreteValues values = {{12, 1}, {5, 0}};
DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing
EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys)));
}
/* ************************************************************************* */ /* ************************************************************************* */
// Check markdown representation with a value formatter. // Check markdown representation with a value formatter.
TEST(DiscreteValues, markdownWithValueFormatter) { TEST(DiscreteValues, markdownWithValueFormatter) {

View File

@ -19,7 +19,6 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/TableDistribution.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -43,115 +42,60 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune( HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const { size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned. // Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal(); const DiscreteBayesNet marginal = discreteMarginal();
// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint; DiscreteConditional pruned;
for (auto &&conditional : marginal) { for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
joint = joint * (*conditional);
// Set the fixed values if requested.
if (marginalThreshold && fixedValues) {
*fixedValues = fixed;
} }
// Initialize the resulting HybridBayesNet.
HybridBayesNet result; HybridBayesNet result;
// Prune the joint. NOTE: imperative and, again, possibly quite expensive. // Go through all the Gaussian conditionals, restrict them according to
DiscreteConditional pruned = joint; // fixed values, and then prune further.
pruned.prune(maxNrLeaves); for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;
DiscreteValues deadModesValues; // No-op if not a HybridGaussianConditional.
// If we have a dead mode threshold and discrete variables left after pruning, if (marginalThreshold) conditional = conditional->restrict(fixed);
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1; // Now decide on type what to do:
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}
if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}
// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);
/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
} else {
result.emplace_shared<DiscreteConditional>(pruned);
}
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the HybridGaussianConditional for just nullptrs.
*/
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned discrete joint.
for (auto &&conditional : *this) {
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional! // Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned); auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) {
if (deadModeThreshold.has_value()) { throw std::runtime_error(
KeyVector deadKeys, conditionalDiscreteKeys; "A HybridGaussianConditional had all its conditionals pruned");
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);
}
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
conditionalDiscreteKeys.push_back(dkey.first);
}
// The discrete keys in the conditional are the same as the keys in the
// dead modes, then we just get the corresponding Gaussian conditional.
if (deadKeys == conditionalDiscreteKeys) {
result.push_back(
prunedHybridGaussianConditional->choose(deadModesValues));
} else {
// Add as-is
result.push_back(prunedHybridGaussianConditional);
}
} else {
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
} }
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional // Add the non-HybridGaussianConditional conditional
result.push_back(gc); result.push_back(gc);
} } else
// We ignore DiscreteConditional as they are already pruned and added. throw std::runtime_error(
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
} }
// Add the pruned discrete conditionals to the result.
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
result.push_back(discrete);
return result; return result;
} }

View File

@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
* *
* @param maxNrLeaves Continuous values at which to compute the error. * @param maxNrLeaves Continuous values at which to compute the error.
* @param deadModeThreshold The threshold to check the mode marginals against. * @param marginalThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is * @param fixedValues The fixed values resulting from dead mode removal.
* considered "dead" for hybrid elimination. *
* The mode can then be removed since it only has a single possible * @note If marginal greater than this threshold, the mode gets assigned that
* assignment. * value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.
* @return A pruned HybridBayesNet * @return A pruned HybridBayesNet
*/ */
HybridBayesNet prune( HybridBayesNet prune(size_t maxNrLeaves,
size_t maxNrLeaves, const std::optional<double> &marginalThreshold = {},
const std::optional<double> &deadModeThreshold = {}) const; DiscreteValues *fixedValues = nullptr) const;
/** /**
* @brief Error method using HybridValues which returns specific error for * @brief Error method using HybridValues which returns specific error for

View File

@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values)); return std::exp(logProbability(values));
} }
/* ************************************************************************ */
HybridConditional::shared_ptr HybridConditional::restrict(
const DiscreteValues &discreteValues) const {
if (auto gc = asGaussian()) {
return std::make_shared<HybridConditional>(gc);
} else if (auto dc = asDiscrete()) {
return std::make_shared<HybridConditional>(dc);
};
auto hgc = asHybrid();
if (!hgc)
throw std::runtime_error(
"HybridConditional::restrict: conditional type not handled");
// Case 1: Fully determined, return corresponding Gaussian conditional
auto parentValues = discreteValues.filter(discreteKeys_);
if (parentValues.size() == discreteKeys_.size()) {
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
}
// Case 2: Some live parents remain, build a new tree
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
if (!unspecifiedParentKeys.empty()) {
auto newTree = hgc->factors();
for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value);
}
return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
newTree));
}
// Case 3: No changes needed, return original
return std::make_shared<HybridConditional>(hgc);
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional
return true; return true;
} }
/**
* Return a HybridConditional by choosing branches based on the given discrete
* values. If all discrete parents are specified, return a HybridConditional
* which is just a GaussianConditional. If this conditional is *not* a hybrid
* conditional, just return that.
*/
shared_ptr restrict(const DiscreteValues& discreteValues) const;
/// @} /// @}
private: private:

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) { if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment. // all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
} }
// Add the partial bayes net to the posterior bayes net. // Add the partial bayes net to the posterior bayes net.

View File

@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
HybridGaussianFactorGraph remainingFactorGraph_; HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode. /// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_; std::optional<double> marginalThreshold_;
DiscreteValues fixedValues_;
public: public:
/** /**
* @brief Constructor * @brief Constructor
* *
* @param removeDeadModes Flag indicating whether to remove dead modes. * @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a * @param marginalThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value. * value and is considered "dead". 0.99 is a good starting value.
*/ */
HybridSmoother(const std::optional<double> deadModeThreshold = {}) HybridSmoother(const std::optional<double> marginalThreshold = {})
: deadModeThreshold_(deadModeThreshold) {} : marginalThreshold_(marginalThreshold) {}
/** /**
* Given new factors, perform an incremental update. * Given new factors, perform an incremental update.

View File

@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) {
// prune // prune
auto pruned = bayesNet.prune(1); auto pruned = bayesNet.prune(1);
CHECK(pruned.at(1)->asHybrid()); CHECK(pruned.at(0)->asHybrid());
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
EXPECT(!pruned.equals(bayesNet)); EXPECT(!pruned.equals(bayesNet));
// error // error
@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) {
const double pruneDeadVariables = 0.99; const double pruneDeadVariables = 0.99;
auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); auto prunedBayesNet = posterior->prune(2, pruneDeadVariables);
// First conditional is still the same: P( x0 | x1 m0)
EXPECT(prunedBayesNet.at(0)->isHybrid());
// Check that hybrid conditional that only depend on M1
// is now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(1)->isContinuous());
// Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0)
EXPECT(prunedBayesNet.at(0)->isHybrid());
// Check that discrete joint only has M0 and not (M0, M1) // Check that discrete joint only has M0 and not (M0, M1)
// since M0 is removed // since M0 is removed
KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); auto joint = prunedBayesNet.at(3)->asDiscrete();
EXPECT(KeyVector{M(0)} == actual_keys); EXPECT(joint);
EXPECT(joint->keys() == KeyVector{M(0)});
// Check that hybrid conditionals that only depend on M1
// are now Gaussian and not Hybrid
EXPECT(prunedBayesNet.at(0)->isDiscrete());
EXPECT(prunedBayesNet.at(1)->isDiscrete());
EXPECT(prunedBayesNet.at(2)->isHybrid());
// Only P(X2 | X1, M1) depends on M1,
// so it gets convert to a Gaussian P(X2 | X1)
EXPECT(prunedBayesNet.at(3)->isContinuous());
EXPECT(prunedBayesNet.at(4)->isHybrid());
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) {
const HybridValues hybridValues{delta.continuous(), discrete_values}; const HybridValues hybridValues{delta.continuous(), discrete_values};
double pruned_logProbability = 0; double pruned_logProbability = 0;
pruned_logProbability += pruned_logProbability +=
prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues);
pruned_logProbability += pruned_logProbability +=
prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues);
pruned_logProbability += pruned_logProbability +=
prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues);
pruned_logProbability += pruned_logProbability +=
prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues);
double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values);
@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
}; };
// Get the pruned discrete conditionals as an AlgebraicDecisionTree // Get the pruned discrete conditionals as an AlgebraicDecisionTree
CHECK(pruned.at(0)->asDiscrete()); CHECK(pruned.at(4)->asDiscrete());
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete();
auto discrete_conditional_tree = auto discrete_conditional_tree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>( std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals); pruned_discrete_conditionals);

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -238,22 +239,27 @@ TEST(HybridGaussianConditional, Likelihood2) {
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
} }
/* ************************************************************************* */
namespace two_mode_measurement {
// Create a two key conditional:
const DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
const std::vector<GaussianConditional::shared_ptr> gcs = {
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3),
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)};
const HybridGaussianConditional::Conditionals conditionals(modes, gcs);
const auto hgc =
std::make_shared<HybridGaussianConditional>(modes, conditionals);
} // namespace two_mode_measurement
/* ************************************************************************* */ /* ************************************************************************* */
// Test pruning a HybridGaussianConditional with two discrete keys, based on a // Test pruning a HybridGaussianConditional with two discrete keys, based on a
// DecisionTreeFactor with 3 keys: // DecisionTreeFactor with 3 keys:
TEST(HybridGaussianConditional, Prune) { TEST(HybridGaussianConditional, Prune) {
// Create a two key conditional: using two_mode_measurement::hgc;
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
std::vector<GaussianConditional::shared_ptr> gcs;
for (size_t i = 0; i < 4; i++) {
gcs.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
}
auto empty = std::make_shared<GaussianConditional>();
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
HybridGaussianConditional hgc(modes, conditionals);
DiscreteKeys keys = modes; DiscreteKeys keys = two_mode_measurement::modes;
keys.push_back({M(3), 2}); keys.push_back({M(3), 2});
{ {
for (size_t i = 0; i < 8; i++) { for (size_t i = 0; i < 8; i++) {
@ -262,7 +268,7 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional // Prune the HybridGaussianConditional
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional // Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
} }
@ -273,14 +279,14 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals // Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
// Check that the minimum negLogConstant is set correctly // Check that the minimum negLogConstant is set correctly
EXPECT_DOUBLES_EQUAL( EXPECT_DOUBLES_EQUAL(
hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(),
pruned->negLogConstant(), 1e-9); pruned->negLogConstant(), 1e-9);
} }
{ {
@ -289,18 +295,48 @@ TEST(HybridGaussianConditional, Prune) {
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals // Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
// Check that the minimum negLogConstant is correct // Check that the minimum negLogConstant is correct
EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9);
} }
} }
/* ************************************************************************* /* *************************************************************************
* This test verifies the behavior of the restrict method in different
* scenarios:
* - When no restrictions are applied.
* - When one parent is restricted.
* - When two parents are restricted.
* - When the restriction results in a Gaussian conditional.
*/ */
TEST(HybridGaussianConditional, Restrict) {
// Create a HybridConditional with two discrete parents P(z0|m0,m1)
const auto hc =
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
const HybridConditional::shared_ptr same = hc->restrict({});
EXPECT(same->isHybrid());
EXPECT(same->asHybrid()->nrComponents() == 4);
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
EXPECT(oneParent->isHybrid());
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr oneParent2 =
hc->restrict({{M(7), 0}, {M(1), 0}});
EXPECT(oneParent2->isHybrid());
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr gaussian =
hc->restrict({{M(1), 0}, {M(2), 1}});
EXPECT(gaussian->asGaussian());
}
/* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);

View File

@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) {
} }
EXPECT_LONGS_EQUAL(11, EXPECT_LONGS_EQUAL(11,
smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
// Get the continuous delta update as well as // Get the continuous delta update as well as
// the optimal discrete assignment. // the optimal discrete assignment.
@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) {
} }
EXPECT_LONGS_EQUAL(14, EXPECT_LONGS_EQUAL(14,
smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
// Get the continuous delta update as well as // Get the continuous delta update as well as
// the optimal discrete assignment. // the optimal discrete assignment.