clean up the prunerFunc
parent
c15cfb6068
commit
2225ecf442
|
@ -131,7 +131,7 @@ void GaussianMixture::print(const std::string &s,
|
|||
|
||||
/* ************************************************************************* */
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
|
@ -142,18 +142,19 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
|||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @param decisionTreeKeySet Set of DiscreteKeys in decisionTree.
|
||||
* Pre-computed for efficiency.
|
||||
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
PrunerFunc(const DecisionTreeFactor &decisionTree,
|
||||
const std::set<DiscreteKey> &decisionTreeKeySet,
|
||||
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
|
@ -202,7 +203,7 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
|||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet);
|
||||
auto pruner = prunerFunc(decisionTree);
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
|
|
|
@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
*
|
||||
* @param decisionTree The pruned discrete probability decision tree.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
prunerFunc(const DecisionTreeFactor &decisionTree);
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
|
Loading…
Reference in New Issue