add pruning support for HybridNonlinearFactor
parent
7c672bb91b
commit
3c50f9387c
|
@ -16,6 +16,7 @@
|
|||
* @date Sep 12, 2024
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||
#include <gtsam/linear/NoiseModel.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
|
@ -202,4 +203,34 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
|
|||
linearized_factors);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
|
||||
const DecisionTreeFactor& discreteProbs) const {
|
||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||
discreteProbs.keys().end());
|
||||
std::vector<Key> diff;
|
||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||
std::back_inserter(diff));
|
||||
|
||||
// Find maximum probability value for every combination of our keys.
|
||||
Ordering keys(diff);
|
||||
auto max = discreteProbs.max(keys);
|
||||
|
||||
// Check the max value for every combination of our keys.
|
||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||
auto pruner =
|
||||
[&](const Assignment<Key>& choices,
|
||||
const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair {
|
||||
if (max->evaluate(choices) == 0.0)
|
||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
else
|
||||
return pair;
|
||||
};
|
||||
|
||||
FactorValuePairs prunedFactors = factors().apply(pruner);
|
||||
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -166,6 +166,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
|||
|
||||
/// @}
|
||||
|
||||
/// Getter for NonlinearFactor decision tree
|
||||
const FactorValuePairs& factors() const { return factors_; }
|
||||
|
||||
/// Linearize specific nonlinear factors based on the assignment in
|
||||
/// discreteValues.
|
||||
GaussianFactor::shared_ptr linearize(
|
||||
|
@ -176,6 +179,10 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
|||
std::shared_ptr<HybridGaussianFactor> linearize(
|
||||
const Values& continuousValues) const;
|
||||
|
||||
/// Prune this factor based on the discrete probabilities.
|
||||
HybridNonlinearFactor::shared_ptr prune(
|
||||
const DecisionTreeFactor& discreteProbs) const;
|
||||
|
||||
private:
|
||||
/// Helper struct to assist private constructor below.
|
||||
struct ConstructorHelper;
|
||||
|
|
Loading…
Reference in New Issue