remove `factors_` from Bayes net implementation

release/4.3a0
Varun Agrawal 2022-12-24 07:35:18 +05:30
parent 2483d7c421
commit fe394cc074
1 changed files with 33 additions and 38 deletions

View File

@ -29,8 +29,8 @@ namespace gtsam {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
// The canonical decision tree factor which will get the discrete conditionals
// added to it.
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor dtFactor;
for (size_t i = 0; i < this->size(); i++) {
@ -176,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
return at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
return at(i)->asGaussian();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
return at(i)->asDiscreteConditional();
}
/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
for (const sharedConditional &conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment));
} else if (factors_.at(idx)->isContinuous()) {
} else if (conditional->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));
gbn.push_back((conditional->asGaussian()));
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
} else if (conditional->isDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
}
@ -216,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose(
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
for (auto &conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
@ -236,12 +236,13 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng) const {
HybridValues HybridBayesNet::sample(VectorValues given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (size_t idx = 0; idx < size(); idx++) {
if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we add to the discrete bayes net.
dbn.push_back(this->atDiscrete(idx));
for (const sharedConditional &conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete bayes net.
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
@ -279,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues,
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
// Iterate over each conditional.
for (const sharedConditional &conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}
} else if (factors_.at(idx)->isContinuous()) {
} else if (conditional->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
double error = conditional->asGaussian()->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
} else if (conditional->isDiscrete()) {
// Conditional is discrete-only, we skip.
continue;
}
}