prune joint discrete probability which is faster

release/4.3a0
Varun Agrawal 2023-07-14 17:52:22 -04:00
parent f7071298c3
commit 3fe2682d93
3 changed files with 58 additions and 66 deletions

View File

@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
// The joint discrete probability.
DiscreteConditional discreteProbs;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());
}
}
return std::make_shared<DiscreteConditional>(discreteProbs);
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
@ -139,52 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}
/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) {
// TODO(Varun) Should prune the joint conditional, maybe during elimination?
// Loop with index since we need it later.
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;
std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();
discreteProbs = discreteProbs * (*conditional->asDiscrete());
// Convert pointer from conditional to factor
auto discreteFactor =
std::dynamic_pointer_cast<DecisionTreeFactor>(discrete);
// Apply prunerFunc to the underlying conditional
DecisionTreeFactor::ADT prunedDiscreteFactor =
discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet
this->at(i) = conditional;
Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
return prunedDiscreteProbs;
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DiscreteConditional::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture.
/* To prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*

View File

@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr discreteConditionals() const;
/**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
private:
/**
* @brief Update the discrete conditionals with the pruned versions.
* @brief Prune all the discrete conditionals.
*
* @param prunedDiscreteProbs
* @param maxNrLeaves
*/
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */

View File

@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
// Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
// Regression
double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(density,
1.6078460548731697 * actualTree(discrete_values), 1e-6);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9);
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(7, posterior->size());
size_t maxNrLeaves = 3;
auto discreteConditionals = posterior->discreteConditionals();
DiscreteConditional discreteConditionals;
for (auto&& conditional : *posterior) {
if (conditional->isDiscrete()) {
discreteConditionals =
discreteConditionals * (*conditional->asDiscrete());
}
}
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
std::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
discreteConditionals.prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
// regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
// Prune!
posterior->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals
// Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double {
// typecast so we can use this to get probability value
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
1e-9);
}
return 0.0;