new updateDiscreteConditionals method for after we prune
parent
82f328b808
commit
949958dc6e
|
@ -42,13 +42,100 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
/**
|
||||||
|
* @brief Helper function to get the pruner functional.
|
||||||
|
*
|
||||||
|
* @param decisionTree The probability decision tree of only discrete keys.
|
||||||
|
* @return std::function<GaussianConditional::shared_ptr(
|
||||||
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
|
*/
|
||||||
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
const DecisionTreeFactor &decisionTree,
|
||||||
|
const HybridConditional &conditional) {
|
||||||
|
// Get the discrete keys as sets for the decision tree
|
||||||
|
// and the gaussian mixture.
|
||||||
|
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||||
|
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
||||||
|
|
||||||
|
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||||
|
const Assignment<Key> &choices,
|
||||||
|
double probability) -> double {
|
||||||
|
// typecast so we can use this to get probability value
|
||||||
|
DiscreteValues values(choices);
|
||||||
|
// Case where the gaussian mixture has the same
|
||||||
|
// discrete keys as the decision tree.
|
||||||
|
if (conditionalKeySet == decisionTreeKeySet) {
|
||||||
|
if (decisionTree(values) == 0) {
|
||||||
|
return 0.0;
|
||||||
|
} else {
|
||||||
|
return probability;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::vector<DiscreteKey> set_diff;
|
||||||
|
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||||
|
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||||
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
|
const std::vector<DiscreteValues> assignments =
|
||||||
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
|
DiscreteValues augmented_values(values);
|
||||||
|
augmented_values.insert(assignment.begin(), assignment.end());
|
||||||
|
|
||||||
|
// If any one of the sub-branches are non-zero,
|
||||||
|
// we need this probability.
|
||||||
|
if (decisionTree(augmented_values) > 0.0) {
|
||||||
|
return probability;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we are here, it means that all the sub-branches are 0,
|
||||||
|
// so we prune.
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return pruner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
|
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||||
|
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
HybridConditional::shared_ptr conditional = this->at(i);
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
||||||
|
auto discrete = conditional->asDiscreteConditional();
|
||||||
|
KeyVector frontals(discrete->frontals().begin(),
|
||||||
|
discrete->frontals().end());
|
||||||
|
|
||||||
|
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||||
|
auto discreteTree =
|
||||||
|
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||||
|
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||||
|
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||||
|
|
||||||
|
// Create the new (hybrid) conditional
|
||||||
|
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||||
|
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||||
|
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||||
|
|
||||||
|
// Add it back to the BayesNet
|
||||||
|
this->at(i) = conditional;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
// Get the decision tree of only the discrete keys
|
||||||
auto discreteConditionals = this->discreteConditionals();
|
auto discreteConditionals = this->discreteConditionals();
|
||||||
const DecisionTreeFactor::shared_ptr decisionTree =
|
const DecisionTreeFactor::shared_ptr decisionTree =
|
||||||
boost::make_shared<DecisionTreeFactor>(
|
boost::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals->prune(maxNrLeaves));
|
||||||
|
|
||||||
|
this->updateDiscreteConditionals(decisionTree);
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
|
|
|
@ -111,7 +111,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
protected:
|
|
||||||
/**
|
/**
|
||||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||||
*
|
*
|
||||||
|
@ -121,11 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
HybridBayesNet prune(size_t maxNrLeaves);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/**
|
||||||
|
* @brief Update the discrete conditionals with the pruned versions.
|
||||||
|
*
|
||||||
|
* @param prunedDecisionTree
|
||||||
|
*/
|
||||||
|
void updateDiscreteConditionals(
|
||||||
|
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
|
||||||
|
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
template <class ARCHIVE>
|
template <class ARCHIVE>
|
||||||
|
|
|
@ -201,6 +201,55 @@ TEST(HybridBayesNet, Prune) {
|
||||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test bayes net updateDiscreteConditionals
|
||||||
|
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
|
size_t maxNrLeaves = 3;
|
||||||
|
auto discreteConditionals = hybridBayesNet->discreteConditionals();
|
||||||
|
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||||
|
boost::make_shared<DecisionTreeFactor>(
|
||||||
|
discreteConditionals->prune(maxNrLeaves));
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
|
prunedDecisionTree->nrLeaves());
|
||||||
|
|
||||||
|
auto original_discrete_conditionals =
|
||||||
|
*(hybridBayesNet->at(4)->asDiscreteConditional());
|
||||||
|
|
||||||
|
// Prune!
|
||||||
|
hybridBayesNet->prune(maxNrLeaves);
|
||||||
|
|
||||||
|
// Functor to verify values against the original_discrete_conditionals
|
||||||
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
|
double probability) -> double {
|
||||||
|
// typecast so we can use this to get probability value
|
||||||
|
DiscreteValues choices(assignment);
|
||||||
|
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||||
|
} else {
|
||||||
|
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
||||||
|
1e-9);
|
||||||
|
}
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||||
|
auto pruned_discrete_conditionals =
|
||||||
|
hybridBayesNet->at(4)->asDiscreteConditional();
|
||||||
|
auto discrete_conditional_tree =
|
||||||
|
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||||
|
pruned_discrete_conditionals);
|
||||||
|
|
||||||
|
// The checker functor verifies the values for us.
|
||||||
|
discrete_conditional_tree->apply(checker);
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test HybridBayesNet serialization.
|
// Test HybridBayesNet serialization.
|
||||||
TEST(HybridBayesNet, Serialization) {
|
TEST(HybridBayesNet, Serialization) {
|
||||||
|
|
Loading…
Reference in New Issue