Document methods, refactor pruning a tiny bit.
parent
a3b177c604
commit
408c14b837
|
|
@ -141,8 +141,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
const DecisionTreeFactor &prunedDecisionTree) {
|
||||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
|
||||||
|
|
||||||
// Loop with index since we need it later.
|
// Loop with index since we need it later.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
|
@ -154,7 +154,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
auto discreteTree =
|
auto discreteTree =
|
||||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
|
||||||
|
|
||||||
// Create the new (hybrid) conditional
|
// Create the new (hybrid) conditional
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
KeyVector frontals(discrete->frontals().begin(),
|
||||||
|
|
@ -173,9 +173,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
// Get the decision tree of only the discrete keys
|
||||||
auto discreteConditionals = this->discreteConditionals();
|
auto discreteConditionals = this->discreteConditionals();
|
||||||
const DecisionTreeFactor::shared_ptr decisionTree =
|
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
|
||||||
boost::make_shared<DecisionTreeFactor>(
|
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
|
||||||
|
|
||||||
this->updateDiscreteConditionals(decisionTree);
|
this->updateDiscreteConditionals(decisionTree);
|
||||||
|
|
||||||
|
|
@ -194,7 +192,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// Make a copy of the Gaussian mixture and prune it!
|
// Make a copy of the Gaussian mixture and prune it!
|
||||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
|
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
|
||||||
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
|
prunedGaussianMixture->prune(decisionTree); // imperative :-(
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
||||||
|
|
|
||||||
|
|
@ -51,33 +51,51 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// GTSAM-style printing
|
/// GTSAM-style printing
|
||||||
void print(
|
void print(const std::string &s = "", const KeyFormatter &formatter =
|
||||||
const std::string &s = "",
|
DefaultKeyFormatter) const override;
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// GTSAM-style equals
|
/// GTSAM-style equals
|
||||||
bool equals(const This& fg, double tol = 1e-9) const;
|
bool equals(const This &fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Add HybridConditional to Bayes Net
|
/**
|
||||||
using Base::emplace_shared;
|
* @brief Add a hybrid conditional using a shared_ptr.
|
||||||
|
*
|
||||||
|
* This is the "native" push back, as this class stores hybrid conditionals.
|
||||||
|
*/
|
||||||
|
void push_back(boost::shared_ptr<HybridConditional> conditional) {
|
||||||
|
factors_.push_back(conditional);
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a conditional directly using a pointer.
|
/**
|
||||||
|
* Preferred: add a conditional directly using a pointer.
|
||||||
|
*
|
||||||
|
* Examples:
|
||||||
|
* hbn.emplace_back(new GaussianMixture(...)));
|
||||||
|
* hbn.emplace_back(new GaussianConditional(...)));
|
||||||
|
* hbn.emplace_back(new DiscreteConditional(...)));
|
||||||
|
*/
|
||||||
template <class Conditional>
|
template <class Conditional>
|
||||||
void emplace_back(Conditional *conditional) {
|
void emplace_back(Conditional *conditional) {
|
||||||
factors_.push_back(boost::make_shared<HybridConditional>(
|
factors_.push_back(boost::make_shared<HybridConditional>(
|
||||||
boost::shared_ptr<Conditional>(conditional)));
|
boost::shared_ptr<Conditional>(conditional)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a conditional directly using a shared_ptr.
|
/**
|
||||||
void push_back(boost::shared_ptr<HybridConditional> conditional) {
|
* Add a conditional using a shared_ptr, using implicit conversion to
|
||||||
factors_.push_back(conditional);
|
* a HybridConditional.
|
||||||
}
|
*
|
||||||
|
* This is useful when you create a conditional shared pointer as you need it
|
||||||
/// Add a conditional directly using implicit conversion.
|
* somewhere else.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* auto shared_ptr_to_a_conditional =
|
||||||
|
* boost::make_shared<GaussianMixture>(...);
|
||||||
|
* hbn.push_back(shared_ptr_to_a_conditional);
|
||||||
|
*/
|
||||||
void push_back(HybridConditional &&conditional) {
|
void push_back(HybridConditional &&conditional) {
|
||||||
factors_.push_back(
|
factors_.push_back(
|
||||||
boost::make_shared<HybridConditional>(std::move(conditional)));
|
boost::make_shared<HybridConditional>(std::move(conditional)));
|
||||||
|
|
@ -214,8 +232,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*
|
*
|
||||||
* @param prunedDecisionTree
|
* @param prunedDecisionTree
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(
|
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree);
|
|
||||||
|
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
||||||
|
|
@ -93,8 +93,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
|
||||||
|
|
||||||
// Create hybrid Bayes net.
|
// Create hybrid Bayes net.
|
||||||
HybridBayesNet bayesNet;
|
HybridBayesNet bayesNet;
|
||||||
bayesNet.push_back(GaussianConditional::sharedMeanAndStddev(
|
bayesNet.push_back(continuousConditional);
|
||||||
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0));
|
|
||||||
bayesNet.emplace_back(
|
bayesNet.emplace_back(
|
||||||
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
|
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
|
||||||
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue