Document methods, refactor pruning a tiny bit.

release/4.3a0
Frank Dellaert 2023-01-05 19:52:23 -08:00
parent a3b177c604
commit 408c14b837
3 changed files with 39 additions and 25 deletions

View File

@ -141,8 +141,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
const DecisionTreeFactor &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
@ -154,7 +154,7 @@ void HybridBayesNet::updateDiscreteConditionals(
auto discreteTree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
@ -173,9 +173,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr decisionTree =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
this->updateDiscreteConditionals(decisionTree);
@ -194,7 +192,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
prunedGaussianMixture->prune(decisionTree); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture);

View File

@ -51,33 +51,51 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @{
/// GTSAM-style printing
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
void print(const std::string &s = "", const KeyFormatter &formatter =
DefaultKeyFormatter) const override;
/// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;
bool equals(const This &fg, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/// Add HybridConditional to Bayes Net
using Base::emplace_shared;
/**
* @brief Add a hybrid conditional using a shared_ptr.
*
* This is the "native" push back, as this class stores hybrid conditionals.
*/
void push_back(boost::shared_ptr<HybridConditional> conditional) {
factors_.push_back(conditional);
}
/// Add a conditional directly using a pointer.
/**
* Preferred: add a conditional directly using a pointer.
*
* Examples:
* hbn.emplace_back(new GaussianMixture(...)));
* hbn.emplace_back(new GaussianConditional(...)));
* hbn.emplace_back(new DiscreteConditional(...)));
*/
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(boost::make_shared<HybridConditional>(
boost::shared_ptr<Conditional>(conditional)));
}
/// Add a conditional directly using a shared_ptr.
void push_back(boost::shared_ptr<HybridConditional> conditional) {
factors_.push_back(conditional);
}
/// Add a conditional directly using implicit conversion.
/**
* Add a conditional using a shared_ptr, using implicit conversion to
* a HybridConditional.
*
* This is useful when you create a conditional shared pointer as you need it
* somewhere else.
*
* Example:
* auto shared_ptr_to_a_conditional =
* boost::make_shared<GaussianMixture>(...);
* hbn.push_back(shared_ptr_to_a_conditional);
*/
void push_back(HybridConditional &&conditional) {
factors_.push_back(
boost::make_shared<HybridConditional>(std::move(conditional)));
@ -214,8 +232,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*
* @param prunedDecisionTree
*/
void updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
/** Serialization function */
friend class boost::serialization::access;

View File

@ -93,8 +93,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0));
bayesNet.push_back(continuousConditional);
bayesNet.emplace_back(
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));