new updateDiscreteConditionals method for after we prune

release/4.3a0
Varun Agrawal 2022-10-20 16:35:14 -04:00
parent 82f328b808
commit 949958dc6e
3 changed files with 146 additions and 3 deletions

View File

@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { /**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
return 0.0;
} else {
return probability;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this probability.
if (decisionTree(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return 0.0;
}
};
return pruner;
}
/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
// Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
// Create the new (hybrid) conditional
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
// Add it back to the BayesNet
this->at(i) = conditional;
}
}
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals(); auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr decisionTree = const DecisionTreeFactor::shared_ptr decisionTree =
boost::make_shared<DecisionTreeFactor>( boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves)); discreteConditionals->prune(maxNrLeaves));
this->updateDiscreteConditionals(decisionTree);
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.

View File

@ -111,7 +111,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
VectorValues optimize(const DiscreteValues &assignment) const; VectorValues optimize(const DiscreteValues &assignment) const;
protected:
/** /**
* @brief Get all the discrete conditionals as a decision tree factor. * @brief Get all the discrete conditionals as a decision tree factor.
* *
@ -121,11 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
public: public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const; HybridBayesNet prune(size_t maxNrLeaves);
/// @} /// @}
private: private:
/**
* @brief Update the discrete conditionals with the pruned versions.
*
* @param prunedDecisionTree
*/
void updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;
template <class ARCHIVE> template <class ARCHIVE>

View File

@ -201,6 +201,55 @@ TEST(HybridBayesNet, Prune) {
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
} }
/* ****************************************************************************/
// Test bayes net updateDiscreteConditionals
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
Switching s(4);
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
size_t maxNrLeaves = 3;
auto discreteConditionals = hybridBayesNet->discreteConditionals();
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals =
*(hybridBayesNet->at(4)->asDiscreteConditional());
// Prune!
hybridBayesNet->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues choices(assignment);
if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
1e-9);
}
return 0.0;
};
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals =
hybridBayesNet->at(4)->asDiscreteConditional();
auto discrete_conditional_tree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals);
// The checker functor verifies the values for us.
discrete_conditional_tree->apply(checker);
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test HybridBayesNet serialization. // Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) { TEST(HybridBayesNet, Serialization) {