remove `factors_` from Bayes net implementation
parent
2483d7c421
commit
fe394cc074
|
|
@ -29,8 +29,8 @@ namespace gtsam {
|
||||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
AlgebraicDecisionTree<Key> decisionTree;
|
AlgebraicDecisionTree<Key> decisionTree;
|
||||||
|
|
||||||
// The canonical decision tree factor which will get the discrete conditionals
|
// The canonical decision tree factor which will get
|
||||||
// added to it.
|
// the discrete conditionals added to it.
|
||||||
DecisionTreeFactor dtFactor;
|
DecisionTreeFactor dtFactor;
|
||||||
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
|
@ -176,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||||
return factors_.at(i)->asMixture();
|
return at(i)->asMixture();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||||
return factors_.at(i)->asGaussian();
|
return at(i)->asGaussian();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||||
return factors_.at(i)->asDiscreteConditional();
|
return at(i)->asDiscreteConditional();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianBayesNet HybridBayesNet::choose(
|
GaussianBayesNet HybridBayesNet::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (const sharedConditional &conditional : *this) {
|
||||||
if (factors_.at(idx)->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
// If factor is hybrid, select based on assignment.
|
// If conditional is hybrid, select based on assignment.
|
||||||
GaussianMixture gm = *this->atMixture(idx);
|
GaussianMixture gm = *conditional->asMixture();
|
||||||
gbn.push_back(gm(assignment));
|
gbn.push_back(gm(assignment));
|
||||||
|
|
||||||
} else if (factors_.at(idx)->isContinuous()) {
|
} else if (conditional->isContinuous()) {
|
||||||
// If continuous only, add gaussian conditional.
|
// If continuous only, add gaussian conditional.
|
||||||
gbn.push_back((this->atGaussian(idx)));
|
gbn.push_back((conditional->asGaussian()));
|
||||||
|
|
||||||
} else if (factors_.at(idx)->isDiscrete()) {
|
} else if (conditional->isDiscrete()) {
|
||||||
// If factor at `idx` is discrete-only, we simply continue.
|
// If conditional is discrete-only, we simply continue.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -216,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Solve for the MPE
|
// Solve for the MPE
|
||||||
DiscreteBayesNet discrete_bn;
|
DiscreteBayesNet discrete_bn;
|
||||||
for (auto &conditional : factors_) {
|
for (auto &conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||||
}
|
}
|
||||||
|
|
@ -236,12 +236,13 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::sample(VectorValues given, std::mt19937_64 *rng) const {
|
HybridValues HybridBayesNet::sample(VectorValues given,
|
||||||
|
std::mt19937_64 *rng) const {
|
||||||
DiscreteBayesNet dbn;
|
DiscreteBayesNet dbn;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (const sharedConditional &conditional : *this) {
|
||||||
if (factors_.at(idx)->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
// If factor at `idx` is discrete-only, we add to the discrete bayes net.
|
// If conditional is discrete-only, we add to the discrete bayes net.
|
||||||
dbn.push_back(this->atDiscrete(idx));
|
dbn.push_back(conditional->asDiscrete());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sample a discrete assignment.
|
// Sample a discrete assignment.
|
||||||
|
|
@ -279,34 +280,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree;
|
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||||
|
|
||||||
// Iterate over each factor.
|
// Iterate over each conditional.
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (const sharedConditional &conditional : *this) {
|
||||||
AlgebraicDecisionTree<Key> conditional_error;
|
if (conditional->isHybrid()) {
|
||||||
|
// If conditional is hybrid, select based on assignment and compute error.
|
||||||
|
GaussianMixture::shared_ptr gm = conditional->asMixture();
|
||||||
|
AlgebraicDecisionTree<Key> conditional_error =
|
||||||
|
gm->error(continuousValues);
|
||||||
|
|
||||||
if (factors_.at(idx)->isHybrid()) {
|
|
||||||
// If factor is hybrid, select based on assignment and compute error.
|
|
||||||
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
|
||||||
conditional_error = gm->error(continuousValues);
|
|
||||||
|
|
||||||
// Assign for the first index, add error for subsequent ones.
|
|
||||||
if (idx == 0) {
|
|
||||||
error_tree = conditional_error;
|
|
||||||
} else {
|
|
||||||
error_tree = error_tree + conditional_error;
|
error_tree = error_tree + conditional_error;
|
||||||
}
|
|
||||||
|
|
||||||
} else if (factors_.at(idx)->isContinuous()) {
|
} else if (conditional->isContinuous()) {
|
||||||
// If continuous only, get the (double) error
|
// If continuous only, get the (double) error
|
||||||
// and add it to the error_tree
|
// and add it to the error_tree
|
||||||
double error = this->atGaussian(idx)->error(continuousValues);
|
double error = conditional->asGaussian()->error(continuousValues);
|
||||||
// Add the computed error to every leaf of the error tree.
|
// Add the computed error to every leaf of the error tree.
|
||||||
error_tree = error_tree.apply(
|
error_tree = error_tree.apply(
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
|
|
||||||
} else if (factors_.at(idx)->isDiscrete()) {
|
} else if (conditional->isDiscrete()) {
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// Conditional is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue