remove `factors_` from Bayes net implementation

release/4.3a0
Varun Agrawal 2022-12-24 07:35:18 +05:30
parent 2483d7c421
commit fe394cc074
1 changed files with 33 additions and 38 deletions

View File

@ -29,8 +29,8 @@ namespace gtsam {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> decisionTree;
// The canonical decision tree factor which will get the discrete conditionals // The canonical decision tree factor which will get
// added to it. // the discrete conditionals added to it.
DecisionTreeFactor dtFactor; DecisionTreeFactor dtFactor;
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
@ -176,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture(); return at(i)->asMixture();
} }
/* ************************************************************************* */ /* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian(); return at(i)->asGaussian();
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional(); return at(i)->asDiscreteConditional();
} }
/* ************************************************************************* */ /* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose( GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
GaussianBayesNet gbn; GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) { for (const sharedConditional &conditional : *this) {
if (factors_.at(idx)->isHybrid()) { if (conditional->isHybrid()) {
// If factor is hybrid, select based on assignment. // If conditional is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx); GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment)); gbn.push_back(gm(assignment));
} else if (factors_.at(idx)->isContinuous()) { } else if (conditional->isContinuous()) {
// If continuous only, add gaussian conditional. // If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx))); gbn.push_back((conditional->asGaussian()));
} else if (factors_.at(idx)->isDiscrete()) { } else if (conditional->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue. // If conditional is discrete-only, we simply continue.
continue; continue;
} }
} }
@ -216,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose(
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Solve for the MPE
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) { for (auto &conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional()); discrete_bn.push_back(conditional->asDiscreteConditional());
} }
@ -236,12 +236,13 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng) const { HybridValues HybridBayesNet::sample(VectorValues given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn; DiscreteBayesNet dbn;
for (size_t idx = 0; idx < size(); idx++) { for (const sharedConditional &conditional : *this) {
if (factors_.at(idx)->isDiscrete()) { if (conditional->isDiscrete()) {
// If factor at `idx` is discrete-only, we add to the discrete bayes net. // If conditional is discrete-only, we add to the discrete bayes net.
dbn.push_back(this->atDiscrete(idx)); dbn.push_back(conditional->asDiscrete());
} }
} }
// Sample a discrete assignment. // Sample a discrete assignment.
@ -279,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues,
/* ************************************************************************* */ /* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error( AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree; AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor. // Iterate over each conditional.
for (size_t idx = 0; idx < size(); idx++) { for (const sharedConditional &conditional : *this) {
AlgebraicDecisionTree<Key> conditional_error; if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error; error_tree = error_tree + conditional_error;
}
} else if (factors_.at(idx)->isContinuous()) { } else if (conditional->isContinuous()) {
// If continuous only, get the (double) error // If continuous only, get the (double) error
// and add it to the error_tree // and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues); double error = conditional->asGaussian()->error(continuousValues);
// Add the computed error to every leaf of the error tree. // Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) { } else if (conditional->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip. // Conditional is discrete-only, we skip.
continue; continue;
} }
} }