add pruning support for HybridNonlinearFactor

release/4.3a0
Varun Agrawal 2024-10-29 01:55:04 -04:00
parent 7c672bb91b
commit 3c50f9387c
2 changed files with 38 additions and 0 deletions

View File

@ -16,6 +16,7 @@
* @date Sep 12, 2024
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
@ -202,4 +203,34 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
linearized_factors);
}
/* *******************************************************************************/
HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
const DecisionTreeFactor& discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
auto pruner =
[&](const Assignment<Key>& choices,
const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair {
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
};
FactorValuePairs prunedFactors = factors().apply(pruner);
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
}
} // namespace gtsam

View File

@ -166,6 +166,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
/// @}
/// Getter for NonlinearFactor decision tree
const FactorValuePairs& factors() const { return factors_; }
/// Linearize specific nonlinear factors based on the assignment in
/// discreteValues.
GaussianFactor::shared_ptr linearize(
@ -176,6 +179,10 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
std::shared_ptr<HybridGaussianFactor> linearize(
const Values& continuousValues) const;
/// Prune this factor based on the discrete probabilities.
HybridNonlinearFactor::shared_ptr prune(
const DecisionTreeFactor& discreteProbs) const;
private:
/// Helper struct to assist private constructor below.
struct ConstructorHelper;