prune joint discrete probability which is faster
parent
f7071298c3
commit
3fe2682d93
|
@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|
||||||
// The joint discrete probability.
|
|
||||||
DiscreteConditional discreteProbs;
|
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_shared<DiscreteConditional>(discreteProbs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
|
@ -139,52 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
size_t maxNrLeaves) {
|
||||||
// TODO(Varun) Should prune the joint conditional, maybe during elimination?
|
// Get the joint distribution of only the discrete keys
|
||||||
// Loop with index since we need it later.
|
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
// The joint discrete probability.
|
||||||
|
DiscreteConditional discreteProbs;
|
||||||
|
|
||||||
|
std::vector<size_t> discrete_factor_idxs;
|
||||||
|
// Record frontal keys so we can maintain ordering
|
||||||
|
Ordering discrete_frontals;
|
||||||
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
auto conditional = this->at(i);
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
auto discrete = conditional->asDiscrete();
|
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||||
|
|
||||||
// Convert pointer from conditional to factor
|
Ordering conditional_keys(conditional->frontals());
|
||||||
auto discreteFactor =
|
discrete_frontals += conditional_keys;
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor>(discrete);
|
discrete_factor_idxs.push_back(i);
|
||||||
// Apply prunerFunc to the underlying conditional
|
|
||||||
DecisionTreeFactor::ADT prunedDiscreteFactor =
|
|
||||||
discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_MakeConditional);
|
|
||||||
// Create the new (hybrid) conditional
|
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
|
||||||
discrete->frontals().end());
|
|
||||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor);
|
|
||||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
|
||||||
gttoc_(HybridBayesNet_MakeConditional);
|
|
||||||
|
|
||||||
// Add it back to the BayesNet
|
|
||||||
this->at(i) = conditional;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const DecisionTreeFactor prunedDiscreteProbs =
|
||||||
|
discreteProbs.prune(maxNrLeaves);
|
||||||
|
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
|
||||||
|
// Eliminate joint probability back into conditionals
|
||||||
|
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||||
|
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||||
|
|
||||||
|
// Assign pruned discrete conditionals back at the correct indices.
|
||||||
|
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||||
|
size_t idx = discrete_factor_idxs.at(i);
|
||||||
|
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||||
|
}
|
||||||
|
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
|
||||||
|
return prunedDiscreteProbs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the joint distribution of only the discrete keys
|
DecisionTreeFactor prunedDiscreteProbs =
|
||||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||||
DiscreteConditional::shared_ptr discreteConditionals =
|
|
||||||
this->discreteConditionals();
|
|
||||||
const DecisionTreeFactor prunedDiscreteProbs =
|
|
||||||
discreteConditionals->prune(maxNrLeaves);
|
|
||||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
/* To prune, we visitWith every leaf in the GaussianMixture.
|
||||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
|
||||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
*
|
*
|
||||||
|
|
|
@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
|
||||||
*
|
|
||||||
* @return DiscreteConditional::shared_ptr
|
|
||||||
*/
|
|
||||||
DiscreteConditional::shared_ptr discreteConditionals() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sample from an incomplete BayesNet, given missing variables.
|
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||||
*
|
*
|
||||||
|
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Update the discrete conditionals with the pruned versions.
|
* @brief Prune all the discrete conditionals.
|
||||||
*
|
*
|
||||||
* @param prunedDiscreteProbs
|
* @param maxNrLeaves
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
// Regression test on pruned logProbability tree
|
||||||
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
|
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
|
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
||||||
|
// Regression
|
||||||
double density = exp(logProbability);
|
double density = exp(logProbability);
|
||||||
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density,
|
||||||
|
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||||
1e-9);
|
1e-9);
|
||||||
|
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
auto discreteConditionals = posterior->discreteConditionals();
|
DiscreteConditional discreteConditionals;
|
||||||
|
for (auto&& conditional : *posterior) {
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
discreteConditionals =
|
||||||
|
discreteConditionals * (*conditional->asDiscrete());
|
||||||
|
}
|
||||||
|
}
|
||||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||||
std::make_shared<DecisionTreeFactor>(
|
std::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals.prune(maxNrLeaves));
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
prunedDecisionTree->nrLeaves());
|
prunedDecisionTree->nrLeaves());
|
||||||
|
|
||||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
// regression
|
||||||
|
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
DecisionTreeFactor::ADT potentials(
|
||||||
|
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||||
|
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||||
|
|
||||||
// Prune!
|
// Prune!
|
||||||
posterior->prune(maxNrLeaves);
|
posterior->prune(maxNrLeaves);
|
||||||
|
|
||||||
// Functor to verify values against the original_discrete_conditionals
|
// Functor to verify values against the expected_discrete_conditionals
|
||||||
auto checker = [&](const Assignment<Key>& assignment,
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
|
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||||
1e-9);
|
1e-9);
|
||||||
}
|
}
|
||||||
return 0.0;
|
return 0.0;
|
||||||
|
|
Loading…
Reference in New Issue