add pruning support for HybridNonlinearFactor
parent
7c672bb91b
commit
3c50f9387c
|
@ -16,6 +16,7 @@
|
||||||
* @date Sep 12, 2024
|
* @date Sep 12, 2024
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||||
#include <gtsam/linear/NoiseModel.h>
|
#include <gtsam/linear/NoiseModel.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
@ -202,4 +203,34 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
|
||||||
linearized_factors);
|
linearized_factors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
|
||||||
|
const DecisionTreeFactor& discreteProbs) const {
|
||||||
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||||
|
discreteProbs.keys().end());
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
|
// Find maximum probability value for every combination of our keys.
|
||||||
|
Ordering keys(diff);
|
||||||
|
auto max = discreteProbs.max(keys);
|
||||||
|
|
||||||
|
// Check the max value for every combination of our keys.
|
||||||
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
|
auto pruner =
|
||||||
|
[&](const Assignment<Key>& choices,
|
||||||
|
const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair {
|
||||||
|
if (max->evaluate(choices) == 0.0)
|
||||||
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
|
else
|
||||||
|
return pair;
|
||||||
|
};
|
||||||
|
|
||||||
|
FactorValuePairs prunedFactors = factors().apply(pruner);
|
||||||
|
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
|
@ -166,6 +166,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
/// Getter for NonlinearFactor decision tree
|
||||||
|
const FactorValuePairs& factors() const { return factors_; }
|
||||||
|
|
||||||
/// Linearize specific nonlinear factors based on the assignment in
|
/// Linearize specific nonlinear factors based on the assignment in
|
||||||
/// discreteValues.
|
/// discreteValues.
|
||||||
GaussianFactor::shared_ptr linearize(
|
GaussianFactor::shared_ptr linearize(
|
||||||
|
@ -176,6 +179,10 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
std::shared_ptr<HybridGaussianFactor> linearize(
|
std::shared_ptr<HybridGaussianFactor> linearize(
|
||||||
const Values& continuousValues) const;
|
const Values& continuousValues) const;
|
||||||
|
|
||||||
|
/// Prune this factor based on the discrete probabilities.
|
||||||
|
HybridNonlinearFactor::shared_ptr prune(
|
||||||
|
const DecisionTreeFactor& discreteProbs) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Helper struct to assist private constructor below.
|
/// Helper struct to assist private constructor below.
|
||||||
struct ConstructorHelper;
|
struct ConstructorHelper;
|
||||||
|
|
Loading…
Reference in New Issue