Document methods, refactor pruning a tiny bit.
parent
a3b177c604
commit
408c14b837
|
|
@ -141,8 +141,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||
const DecisionTreeFactor &prunedDecisionTree) {
|
||||
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
|
||||
|
||||
// Loop with index since we need it later.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
|
|
@ -154,7 +154,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
auto discreteTree =
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
|
||||
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
|
|
@ -173,9 +173,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr decisionTree =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
|
||||
|
||||
this->updateDiscreteConditionals(decisionTree);
|
||||
|
||||
|
|
@ -194,7 +192,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
if (auto gm = conditional->asMixture()) {
|
||||
// Make a copy of the Gaussian mixture and prune it!
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
|
||||
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
|
||||
prunedGaussianMixture->prune(decisionTree); // imperative :-(
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
||||
|
|
|
|||
|
|
@ -51,33 +51,51 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// @{
|
||||
|
||||
/// GTSAM-style printing
|
||||
void print(
|
||||
const std::string &s = "",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
void print(const std::string &s = "", const KeyFormatter &formatter =
|
||||
DefaultKeyFormatter) const override;
|
||||
|
||||
/// GTSAM-style equals
|
||||
bool equals(const This& fg, double tol = 1e-9) const;
|
||||
|
||||
bool equals(const This &fg, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Add HybridConditional to Bayes Net
|
||||
using Base::emplace_shared;
|
||||
/**
|
||||
* @brief Add a hybrid conditional using a shared_ptr.
|
||||
*
|
||||
* This is the "native" push back, as this class stores hybrid conditionals.
|
||||
*/
|
||||
void push_back(boost::shared_ptr<HybridConditional> conditional) {
|
||||
factors_.push_back(conditional);
|
||||
}
|
||||
|
||||
/// Add a conditional directly using a pointer.
|
||||
/**
|
||||
* Preferred: add a conditional directly using a pointer.
|
||||
*
|
||||
* Examples:
|
||||
* hbn.emplace_back(new GaussianMixture(...)));
|
||||
* hbn.emplace_back(new GaussianConditional(...)));
|
||||
* hbn.emplace_back(new DiscreteConditional(...)));
|
||||
*/
|
||||
template <class Conditional>
|
||||
void emplace_back(Conditional *conditional) {
|
||||
factors_.push_back(boost::make_shared<HybridConditional>(
|
||||
boost::shared_ptr<Conditional>(conditional)));
|
||||
}
|
||||
|
||||
/// Add a conditional directly using a shared_ptr.
|
||||
void push_back(boost::shared_ptr<HybridConditional> conditional) {
|
||||
factors_.push_back(conditional);
|
||||
}
|
||||
|
||||
/// Add a conditional directly using implicit conversion.
|
||||
/**
|
||||
* Add a conditional using a shared_ptr, using implicit conversion to
|
||||
* a HybridConditional.
|
||||
*
|
||||
* This is useful when you create a conditional shared pointer as you need it
|
||||
* somewhere else.
|
||||
*
|
||||
* Example:
|
||||
* auto shared_ptr_to_a_conditional =
|
||||
* boost::make_shared<GaussianMixture>(...);
|
||||
* hbn.push_back(shared_ptr_to_a_conditional);
|
||||
*/
|
||||
void push_back(HybridConditional &&conditional) {
|
||||
factors_.push_back(
|
||||
boost::make_shared<HybridConditional>(std::move(conditional)));
|
||||
|
|
@ -214,8 +232,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*
|
||||
* @param prunedDecisionTree
|
||||
*/
|
||||
void updateDiscreteConditionals(
|
||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
|
||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
|||
|
|
@ -93,8 +93,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
|
|||
|
||||
// Create hybrid Bayes net.
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.push_back(GaussianConditional::sharedMeanAndStddev(
|
||||
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0));
|
||||
bayesNet.push_back(continuousConditional);
|
||||
bayesNet.emplace_back(
|
||||
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
|
||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
||||
|
|
|
|||
Loading…
Reference in New Issue