diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 5e0d185e8..de940ec69 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -29,8 +29,8 @@ namespace gtsam { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { AlgebraicDecisionTree decisionTree; - // The canonical decision tree factor which will get the discrete conditionals - // added to it. + // The canonical decision tree factor which will get + // the discrete conditionals added to it. DecisionTreeFactor dtFactor; for (size_t i = 0; i < this->size(); i++) { @@ -176,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { /* ************************************************************************* */ GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { - return factors_.at(i)->asMixture(); + return at(i)->asMixture(); } /* ************************************************************************* */ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { - return factors_.at(i)->asGaussian(); + return at(i)->asGaussian(); } /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { - return factors_.at(i)->asDiscreteConditional(); + return at(i)->asDiscreteConditional(); } /* ************************************************************************* */ GaussianBayesNet HybridBayesNet::choose( const DiscreteValues &assignment) const { GaussianBayesNet gbn; - for (size_t idx = 0; idx < size(); idx++) { - if (factors_.at(idx)->isHybrid()) { - // If factor is hybrid, select based on assignment. - GaussianMixture gm = *this->atMixture(idx); + for (const sharedConditional &conditional : *this) { + if (conditional->isHybrid()) { + // If conditional is hybrid, select based on assignment. + GaussianMixture gm = *conditional->asMixture(); gbn.push_back(gm(assignment)); - } else if (factors_.at(idx)->isContinuous()) { + } else if (conditional->isContinuous()) { // If continuous only, add gaussian conditional. - gbn.push_back((this->atGaussian(idx))); + gbn.push_back((conditional->asGaussian())); - } else if (factors_.at(idx)->isDiscrete()) { - // If factor at `idx` is discrete-only, we simply continue. + } else if (conditional->isDiscrete()) { + // If conditional is discrete-only, we simply continue. continue; } } @@ -216,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose( HybridValues HybridBayesNet::optimize() const { // Solve for the MPE DiscreteBayesNet discrete_bn; - for (auto &conditional : factors_) { + for (auto &conditional : *this) { if (conditional->isDiscrete()) { discrete_bn.push_back(conditional->asDiscreteConditional()); } @@ -236,12 +236,13 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { } /* ************************************************************************* */ -HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng) const { +HybridValues HybridBayesNet::sample(VectorValues given, + std::mt19937_64 *rng) const { DiscreteBayesNet dbn; - for (size_t idx = 0; idx < size(); idx++) { - if (factors_.at(idx)->isDiscrete()) { - // If factor at `idx` is discrete-only, we add to the discrete bayes net. - dbn.push_back(this->atDiscrete(idx)); + for (const sharedConditional &conditional : *this) { + if (conditional->isDiscrete()) { + // If conditional is discrete-only, we add to the discrete bayes net. + dbn.push_back(conditional->asDiscrete()); } } // Sample a discrete assignment. @@ -279,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues, /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::error( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree; + AlgebraicDecisionTree error_tree(0.0); - // Iterate over each factor. - for (size_t idx = 0; idx < size(); idx++) { - AlgebraicDecisionTree conditional_error; + // Iterate over each conditional. + for (const sharedConditional &conditional : *this) { + if (conditional->isHybrid()) { + // If conditional is hybrid, select based on assignment and compute error. + GaussianMixture::shared_ptr gm = conditional->asMixture(); + AlgebraicDecisionTree conditional_error = + gm->error(continuousValues); - if (factors_.at(idx)->isHybrid()) { - // If factor is hybrid, select based on assignment and compute error. - GaussianMixture::shared_ptr gm = this->atMixture(idx); - conditional_error = gm->error(continuousValues); + error_tree = error_tree + conditional_error; - // Assign for the first index, add error for subsequent ones. - if (idx == 0) { - error_tree = conditional_error; - } else { - error_tree = error_tree + conditional_error; - } - - } else if (factors_.at(idx)->isContinuous()) { + } else if (conditional->isContinuous()) { // If continuous only, get the (double) error // and add it to the error_tree - double error = this->atGaussian(idx)->error(continuousValues); + double error = conditional->asGaussian()->error(continuousValues); // Add the computed error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (factors_.at(idx)->isDiscrete()) { - // If factor at `idx` is discrete-only, we skip. + } else if (conditional->isDiscrete()) { + // Conditional is discrete-only, we skip. continue; } }