diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8bce45c51..a64136384 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { +/** + * @brief Helper function to get the pruner functional. + * + * @param decisionTree The probability decision tree of only discrete keys. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ +std::function &, double)> prunerFunc( + const DecisionTreeFactor &decisionTree, + const HybridConditional &conditional) { + // Get the discrete keys as sets for the decision tree + // and the gaussian mixture. + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); + + auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( + const Assignment &choices, + double probability) -> double { + // typecast so we can use this to get probability value + DiscreteValues values(choices); + // Case where the gaussian mixture has the same + // discrete keys as the decision tree. + if (conditionalKeySet == decisionTreeKeySet) { + if (decisionTree(values) == 0) { + return 0.0; + } else { + return probability; + } + } else { + std::vector set_diff; + std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), + conditionalKeySet.begin(), conditionalKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment.begin(), assignment.end()); + + // If any one of the sub-branches are non-zero, + // we need this probability. + if (decisionTree(augmented_values) > 0.0) { + return probability; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return 0.0; + } + }; + return pruner; +} + +/* ************************************************************************* */ +void HybridBayesNet::updateDiscreteConditionals( + const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { + KeyVector prunedTreeKeys = prunedDecisionTree->keys(); + + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + if (conditional->isDiscrete()) { + // std::cout << demangle(typeid(conditional).name()) << std::endl; + auto discrete = conditional->asDiscreteConditional(); + KeyVector frontals(discrete->frontals().begin(), + discrete->frontals().end()); + + // Apply prunerFunc to the underlying AlgebraicDecisionTree + auto discreteTree = + boost::dynamic_pointer_cast(discrete); + DecisionTreeFactor::ADT prunedDiscreteTree = + discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); + + // Create the new (hybrid) conditional + auto prunedDiscrete = boost::make_shared( + frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); + conditional = boost::make_shared(prunedDiscrete); + + // Add it back to the BayesNet + this->at(i) = conditional; + } + } +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Get the decision tree of only the discrete keys auto discreteConditionals = this->discreteConditionals(); const DecisionTreeFactor::shared_ptr decisionTree = boost::make_shared( discreteConditionals->prune(maxNrLeaves)); + this->updateDiscreteConditionals(decisionTree); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index b8234d70a..87e6c5db6 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -111,7 +111,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - protected: /** * @brief Get all the discrete conditionals as a decision tree factor. * @@ -121,11 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { public: /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. - HybridBayesNet prune(size_t maxNrLeaves) const; + HybridBayesNet prune(size_t maxNrLeaves); /// @} private: + /** + * @brief Update the discrete conditionals with the pruned versions. + * + * @param prunedDecisionTree + */ + void updateDiscreteConditionals( + const DecisionTreeFactor::shared_ptr &prunedDecisionTree); + /** Serialization function */ friend class boost::serialization::access; template diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 5885fdcdc..fc353f9c1 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -201,6 +201,55 @@ TEST(HybridBayesNet, Prune) { EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); } +/* ****************************************************************************/ +// Test bayes net updateDiscreteConditionals +TEST(HybridBayesNet, UpdateDiscreteConditionals) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + size_t maxNrLeaves = 3; + auto discreteConditionals = hybridBayesNet->discreteConditionals(); + const DecisionTreeFactor::shared_ptr prunedDecisionTree = + boost::make_shared( + discreteConditionals->prune(maxNrLeaves)); + + EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, + prunedDecisionTree->nrLeaves()); + + auto original_discrete_conditionals = + *(hybridBayesNet->at(4)->asDiscreteConditional()); + + // Prune! + hybridBayesNet->prune(maxNrLeaves); + + // Functor to verify values against the original_discrete_conditionals + auto checker = [&](const Assignment& assignment, + double probability) -> double { + // typecast so we can use this to get probability value + DiscreteValues choices(assignment); + if (prunedDecisionTree->operator()(choices) == 0) { + EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); + } else { + EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, + 1e-9); + } + return 0.0; + }; + + // Get the pruned discrete conditionals as an AlgebraicDecisionTree + auto pruned_discrete_conditionals = + hybridBayesNet->at(4)->asDiscreteConditional(); + auto discrete_conditional_tree = + boost::dynamic_pointer_cast( + pruned_discrete_conditionals); + + // The checker functor verifies the values for us. + discrete_conditional_tree->apply(checker); +} + /* ****************************************************************************/ // Test HybridBayesNet serialization. TEST(HybridBayesNet, Serialization) {