remove `factors_` from Bayes net implementation
parent
2483d7c421
commit
fe394cc074
|
|
@ -29,8 +29,8 @@ namespace gtsam {
|
|||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
||||
// The canonical decision tree factor which will get the discrete conditionals
|
||||
// added to it.
|
||||
// The canonical decision tree factor which will get
|
||||
// the discrete conditionals added to it.
|
||||
DecisionTreeFactor dtFactor;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
|
|
@ -176,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||
return factors_.at(i)->asMixture();
|
||||
return at(i)->asMixture();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
return factors_.at(i)->asGaussian();
|
||||
return at(i)->asGaussian();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||
return factors_.at(i)->asDiscreteConditional();
|
||||
return at(i)->asDiscreteConditional();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianBayesNet HybridBayesNet::choose(
|
||||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *this->atMixture(idx);
|
||||
for (const sharedConditional &conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *conditional->asMixture();
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
} else if (conditional->isContinuous()) {
|
||||
// If continuous only, add gaussian conditional.
|
||||
gbn.push_back((this->atGaussian(idx)));
|
||||
gbn.push_back((conditional->asGaussian()));
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we simply continue.
|
||||
} else if (conditional->isDiscrete()) {
|
||||
// If conditional is discrete-only, we simply continue.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
@ -216,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Solve for the MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &conditional : factors_) {
|
||||
for (auto &conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||
}
|
||||
|
|
@ -236,12 +236,13 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng) const {
|
||||
HybridValues HybridBayesNet::sample(VectorValues given,
|
||||
std::mt19937_64 *rng) const {
|
||||
DiscreteBayesNet dbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we add to the discrete bayes net.
|
||||
dbn.push_back(this->atDiscrete(idx));
|
||||
for (const sharedConditional &conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
// If conditional is discrete-only, we add to the discrete bayes net.
|
||||
dbn.push_back(conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
// Sample a discrete assignment.
|
||||
|
|
@ -279,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues,
|
|||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree;
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> conditional_error;
|
||||
// Iterate over each conditional.
|
||||
for (const sharedConditional &conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment and compute error.
|
||||
GaussianMixture::shared_ptr gm = conditional->asMixture();
|
||||
AlgebraicDecisionTree<Key> conditional_error =
|
||||
gm->error(continuousValues);
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment and compute error.
|
||||
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
||||
conditional_error = gm->error(continuousValues);
|
||||
error_tree = error_tree + conditional_error;
|
||||
|
||||
// Assign for the first index, add error for subsequent ones.
|
||||
if (idx == 0) {
|
||||
error_tree = conditional_error;
|
||||
} else {
|
||||
error_tree = error_tree + conditional_error;
|
||||
}
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
} else if (conditional->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = this->atGaussian(idx)->error(continuousValues);
|
||||
double error = conditional->asGaussian()->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
} else if (conditional->isDiscrete()) {
|
||||
// Conditional is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue