prune joint discrete probability which is faster

release/4.3a0
Varun Agrawal 2023-07-14 17:52:22 -04:00
parent f7071298c3
commit 3fe2682d93
3 changed files with 58 additions and 66 deletions

View File

@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
// The joint discrete probability.
DiscreteConditional discreteProbs;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());
}
}
return std::make_shared<DiscreteConditional>(discreteProbs);
}
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
@ -139,52 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) { size_t maxNrLeaves) {
// TODO(Varun) Should prune the joint conditional, maybe during elimination? // Get the joint distribution of only the discrete keys
// Loop with index since we need it later. gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;
std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); auto conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); discreteProbs = discreteProbs * (*conditional->asDiscrete());
// Convert pointer from conditional to factor Ordering conditional_keys(conditional->frontals());
auto discreteFactor = discrete_frontals += conditional_keys;
std::dynamic_pointer_cast<DecisionTreeFactor>(discrete); discrete_factor_idxs.push_back(i);
// Apply prunerFunc to the underlying conditional
DecisionTreeFactor::ADT prunedDiscreteFactor =
discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet
this->at(i) = conditional;
} }
} }
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
return prunedDiscreteProbs;
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys DecisionTreeFactor prunedDiscreteProbs =
gttic_(HybridBayesNet_PruneDiscreteConditionals); this->pruneDiscreteConditionals(maxNrLeaves);
DiscreteConditional::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
gttic_(HybridBayesNet_UpdateDiscreteConditionals); /* To prune, we visitWith every leaf in the GaussianMixture.
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
* *

View File

@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
VectorValues optimize(const DiscreteValues &assignment) const; VectorValues optimize(const DiscreteValues &assignment) const;
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr discreteConditionals() const;
/** /**
* @brief Sample from an incomplete BayesNet, given missing variables. * @brief Sample from an incomplete BayesNet, given missing variables.
* *
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
private: private:
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Prune all the discrete conditionals.
* *
* @param prunedDiscreteProbs * @param maxNrLeaves
*/ */
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
// Regression test on pruned logProbability tree // Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves); AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
logProbability += logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues); posterior->at(4)->asDiscrete()->logProbability(hybridValues);
// Regression
double density = exp(logProbability); double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density,
1.6078460548731697 * actualTree(discrete_values), 1e-6);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9); 1e-9);
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(7, posterior->size()); EXPECT_LONGS_EQUAL(7, posterior->size());
size_t maxNrLeaves = 3; size_t maxNrLeaves = 3;
auto discreteConditionals = posterior->discreteConditionals(); DiscreteConditional discreteConditionals;
for (auto&& conditional : *posterior) {
if (conditional->isDiscrete()) {
discreteConditionals =
discreteConditionals * (*conditional->asDiscrete());
}
}
const DecisionTreeFactor::shared_ptr prunedDecisionTree = const DecisionTreeFactor::shared_ptr prunedDecisionTree =
std::make_shared<DecisionTreeFactor>( std::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves)); discreteConditionals.prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves()); prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); // regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
// Prune! // Prune!
posterior->prune(maxNrLeaves); posterior->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals // Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment, auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double { double probability) -> double {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
if (prunedDecisionTree->operator()(choices) == 0) { if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else { } else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
1e-9); 1e-9);
} }
return 0.0; return 0.0;