diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ff2752bcb..b4bf61220 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const { - // The joint discrete probability. - DiscreteConditional discreteProbs; - - for (auto &&conditional : *this) { - if (conditional->isDiscrete()) { - discreteProbs = discreteProbs * (*conditional->asDiscrete()); - } - } - return std::make_shared(discreteProbs); -} - /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. @@ -139,52 +126,52 @@ std::function &, double)> prunerFunc( } /* ************************************************************************* */ -void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDiscreteProbs) { - // TODO(Varun) Should prune the joint conditional, maybe during elimination? - // Loop with index since we need it later. +DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( + size_t maxNrLeaves) { + // Get the joint distribution of only the discrete keys + gttic_(HybridBayesNet_PruneDiscreteConditionals); + // The joint discrete probability. + DiscreteConditional discreteProbs; + + std::vector discrete_factor_idxs; + // Record frontal keys so we can maintain ordering + Ordering discrete_frontals; + for (size_t i = 0; i < this->size(); i++) { - HybridConditional::shared_ptr conditional = this->at(i); + auto conditional = this->at(i); if (conditional->isDiscrete()) { - auto discrete = conditional->asDiscrete(); + discreteProbs = discreteProbs * (*conditional->asDiscrete()); - // Convert pointer from conditional to factor - auto discreteFactor = - std::dynamic_pointer_cast(discrete); - // Apply prunerFunc to the underlying conditional - DecisionTreeFactor::ADT prunedDiscreteFactor = - discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional)); - - gttic_(HybridBayesNet_MakeConditional); - // Create the new (hybrid) conditional - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); - auto prunedDiscrete = std::make_shared( - frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor); - conditional = std::make_shared(prunedDiscrete); - gttoc_(HybridBayesNet_MakeConditional); - - // Add it back to the BayesNet - this->at(i) = conditional; + Ordering conditional_keys(conditional->frontals()); + discrete_frontals += conditional_keys; + discrete_factor_idxs.push_back(i); } } + const DecisionTreeFactor prunedDiscreteProbs = + discreteProbs.prune(maxNrLeaves); + gttoc_(HybridBayesNet_PruneDiscreteConditionals); + + // Eliminate joint probability back into conditionals + gttic_(HybridBayesNet_UpdateDiscreteConditionals); + DiscreteFactorGraph dfg{prunedDiscreteProbs}; + DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); + + // Assign pruned discrete conditionals back at the correct indices. + for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { + size_t idx = discrete_factor_idxs.at(i); + this->at(idx) = std::make_shared(dbn->at(i)); + } + gttoc_(HybridBayesNet_UpdateDiscreteConditionals); + + return prunedDiscreteProbs; } /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { - // Get the joint distribution of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); - DiscreteConditional::shared_ptr discreteConditionals = - this->discreteConditionals(); - const DecisionTreeFactor prunedDiscreteProbs = - discreteConditionals->prune(maxNrLeaves); - gttoc_(HybridBayesNet_PruneDiscreteConditionals); + DecisionTreeFactor prunedDiscreteProbs = + this->pruneDiscreteConditionals(maxNrLeaves); - gttic_(HybridBayesNet_UpdateDiscreteConditionals); - this->updateDiscreteConditionals(prunedDiscreteProbs); - gttoc_(HybridBayesNet_UpdateDiscreteConditionals); - - /* To Prune, we visitWith every leaf in the GaussianMixture. + /* To prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. * diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 19e88d754..e71cfe9b4 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /** - * @brief Get all the discrete conditionals as a decision tree factor. - * - * @return DiscreteConditional::shared_ptr - */ - DiscreteConditional::shared_ptr discreteConditionals() const; - /** * @brief Sample from an incomplete BayesNet, given missing variables. * @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { private: /** - * @brief Update the discrete conditionals with the pruned versions. + * @brief Prune all the discrete conditionals. * - * @param prunedDiscreteProbs + * @param maxNrLeaves */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); + DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f25675a55..1dfcbd6b7 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) { auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); // Regression test on pruned logProbability tree - std::vector pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; + std::vector pruned_leaves = {0.0, 32.713418, 0.0, 31.735823}; AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); @@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); + // Regression double density = exp(logProbability); - EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(density, + 1.6078460548731697 * actualTree(discrete_values), 1e-6); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); @@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(7, posterior->size()); size_t maxNrLeaves = 3; - auto discreteConditionals = posterior->discreteConditionals(); + DiscreteConditional discreteConditionals; + for (auto&& conditional : *posterior) { + if (conditional->isDiscrete()) { + discreteConditionals = + discreteConditionals * (*conditional->asDiscrete()); + } + } const DecisionTreeFactor::shared_ptr prunedDecisionTree = std::make_shared( - discreteConditionals->prune(maxNrLeaves)); + discreteConditionals.prune(maxNrLeaves)); EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, prunedDecisionTree->nrLeaves()); - auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); + // regression + DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; + DecisionTreeFactor::ADT potentials( + dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); // Prune! posterior->prune(maxNrLeaves); - // Functor to verify values against the original_discrete_conditionals + // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, double probability) -> double { // typecast so we can use this to get probability value @@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { if (prunedDecisionTree->operator()(choices) == 0) { EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); } else { - EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, + EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability, 1e-9); } return 0.0;