prune joint discrete probability which is faster
parent
f7071298c3
commit
3fe2682d93
|
@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
// The joint discrete probability.
|
||||
DiscreteConditional discreteProbs;
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
return std::make_shared<DiscreteConditional>(discreteProbs);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
|
@ -139,52 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
||||
// TODO(Varun) Should prune the joint conditional, maybe during elimination?
|
||||
// Loop with index since we need it later.
|
||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||
size_t maxNrLeaves) {
|
||||
// Get the joint distribution of only the discrete keys
|
||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
// The joint discrete probability.
|
||||
DiscreteConditional discreteProbs;
|
||||
|
||||
std::vector<size_t> discrete_factor_idxs;
|
||||
// Record frontal keys so we can maintain ordering
|
||||
Ordering discrete_frontals;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
auto conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
auto discrete = conditional->asDiscrete();
|
||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||
|
||||
// Convert pointer from conditional to factor
|
||||
auto discreteFactor =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(discrete);
|
||||
// Apply prunerFunc to the underlying conditional
|
||||
DecisionTreeFactor::ADT prunedDiscreteFactor =
|
||||
discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
||||
|
||||
gttic_(HybridBayesNet_MakeConditional);
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor);
|
||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
||||
gttoc_(HybridBayesNet_MakeConditional);
|
||||
|
||||
// Add it back to the BayesNet
|
||||
this->at(i) = conditional;
|
||||
Ordering conditional_keys(conditional->frontals());
|
||||
discrete_frontals += conditional_keys;
|
||||
discrete_factor_idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteProbs.prune(maxNrLeaves);
|
||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
|
||||
// Eliminate joint probability back into conditionals
|
||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||
|
||||
// Assign pruned discrete conditionals back at the correct indices.
|
||||
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||
size_t idx = discrete_factor_idxs.at(i);
|
||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||
}
|
||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
|
||||
return prunedDiscreteProbs;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
// Get the joint distribution of only the discrete keys
|
||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
DiscreteConditional::shared_ptr discreteConditionals =
|
||||
this->discreteConditionals();
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteConditionals->prune(maxNrLeaves);
|
||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
DecisionTreeFactor prunedDiscreteProbs =
|
||||
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||
|
||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
/* To prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
*
|
||||
|
|
|
@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr discreteConditionals() const;
|
||||
|
||||
/**
|
||||
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||
*
|
||||
|
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
|
||||
private:
|
||||
/**
|
||||
* @brief Update the discrete conditionals with the pruned versions.
|
||||
* @brief Prune all the discrete conditionals.
|
||||
*
|
||||
* @param prunedDiscreteProbs
|
||||
* @param maxNrLeaves
|
||||
*/
|
||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
||||
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
|
|
|
@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
|
|||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
|
||||
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
|
||||
|
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
|
|||
logProbability +=
|
||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||
|
||||
// Regression
|
||||
double density = exp(logProbability);
|
||||
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(density,
|
||||
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||
1e-9);
|
||||
|
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto discreteConditionals = posterior->discreteConditionals();
|
||||
DiscreteConditional discreteConditionals;
|
||||
for (auto&& conditional : *posterior) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discreteConditionals =
|
||||
discreteConditionals * (*conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||
std::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
discreteConditionals.prune(maxNrLeaves));
|
||||
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
|
||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
||||
// regression
|
||||
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
DecisionTreeFactor::ADT potentials(
|
||||
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||
|
||||
// Prune!
|
||||
posterior->prune(maxNrLeaves);
|
||||
|
||||
// Functor to verify values against the original_discrete_conditionals
|
||||
// Functor to verify values against the expected_discrete_conditionals
|
||||
auto checker = [&](const Assignment<Key>& assignment,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
|
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||
} else {
|
||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
||||
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||
1e-9);
|
||||
}
|
||||
return 0.0;
|
||||
|
|
Loading…
Reference in New Issue