new updateDiscreteConditionals method for after we prune
parent
82f328b808
commit
949958dc6e
|
@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &decisionTree,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == decisionTreeKeySet) {
|
||||
if (decisionTree(values) == 0) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return probability;
|
||||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment.begin(), assignment.end());
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (decisionTree(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return 0.0;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
||||
auto discrete = conditional->asDiscreteConditional();
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
|
||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||
auto discreteTree =
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||
|
||||
// Create the new (hybrid) conditional
|
||||
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||
|
||||
// Add it back to the BayesNet
|
||||
this->at(i) = conditional;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr decisionTree =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
this->updateDiscreteConditionals(decisionTree);
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
|
|
|
@ -111,7 +111,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
|
@ -121,11 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
|
||||
public:
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Update the discrete conditionals with the pruned versions.
|
||||
*
|
||||
* @param prunedDecisionTree
|
||||
*/
|
||||
void updateDiscreteConditionals(
|
||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
|
|
|
@ -201,6 +201,55 @@ TEST(HybridBayesNet, Prune) {
|
|||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net updateDiscreteConditionals
|
||||
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto discreteConditionals = hybridBayesNet->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
|
||||
auto original_discrete_conditionals =
|
||||
*(hybridBayesNet->at(4)->asDiscreteConditional());
|
||||
|
||||
// Prune!
|
||||
hybridBayesNet->prune(maxNrLeaves);
|
||||
|
||||
// Functor to verify values against the original_discrete_conditionals
|
||||
auto checker = [&](const Assignment<Key>& assignment,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues choices(assignment);
|
||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||
} else {
|
||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
||||
1e-9);
|
||||
}
|
||||
return 0.0;
|
||||
};
|
||||
|
||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||
auto pruned_discrete_conditionals =
|
||||
hybridBayesNet->at(4)->asDiscreteConditional();
|
||||
auto discrete_conditional_tree =
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||
pruned_discrete_conditionals);
|
||||
|
||||
// The checker functor verifies the values for us.
|
||||
discrete_conditional_tree->apply(checker);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
|
|
Loading…
Reference in New Issue