277 lines
10 KiB
C++
277 lines
10 KiB
C++
/* ----------------------------------------------------------------------------
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
* See LICENSE for the license information
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file HybridBayesNet.cpp
|
|
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
|
* @author Fan Jiang
|
|
* @author Varun Agrawal
|
|
* @author Shangjie Xue
|
|
* @date January 2022
|
|
*/
|
|
|
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
|
#include <gtsam/hybrid/HybridValues.h>
|
|
|
|
namespace gtsam {
|
|
|
|
/* ************************************************************************* */
|
|
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|
AlgebraicDecisionTree<Key> decisionTree;
|
|
|
|
// The canonical decision tree factor which will get the discrete conditionals
|
|
// added to it.
|
|
DecisionTreeFactor dtFactor;
|
|
|
|
for (size_t i = 0; i < this->size(); i++) {
|
|
HybridConditional::shared_ptr conditional = this->at(i);
|
|
if (conditional->isDiscrete()) {
|
|
// Convert to a DecisionTreeFactor and add it to the main factor.
|
|
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
|
dtFactor = dtFactor * f;
|
|
}
|
|
}
|
|
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
/**
|
|
* @brief Helper function to get the pruner functional.
|
|
*
|
|
* @param decisionTree The probability decision tree of only discrete keys.
|
|
* @return std::function<GaussianConditional::shared_ptr(
|
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
*/
|
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|
const DecisionTreeFactor &decisionTree,
|
|
const HybridConditional &conditional) {
|
|
// Get the discrete keys as sets for the decision tree
|
|
// and the gaussian mixture.
|
|
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
|
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
|
|
|
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
|
const Assignment<Key> &choices,
|
|
double probability) -> double {
|
|
// typecast so we can use this to get probability value
|
|
DiscreteValues values(choices);
|
|
// Case where the gaussian mixture has the same
|
|
// discrete keys as the decision tree.
|
|
if (conditionalKeySet == decisionTreeKeySet) {
|
|
if (decisionTree(values) == 0) {
|
|
return 0.0;
|
|
} else {
|
|
return probability;
|
|
}
|
|
} else {
|
|
std::vector<DiscreteKey> set_diff;
|
|
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
|
conditionalKeySet.begin(), conditionalKeySet.end(),
|
|
std::back_inserter(set_diff));
|
|
|
|
const std::vector<DiscreteValues> assignments =
|
|
DiscreteValues::CartesianProduct(set_diff);
|
|
for (const DiscreteValues &assignment : assignments) {
|
|
DiscreteValues augmented_values(values);
|
|
augmented_values.insert(assignment.begin(), assignment.end());
|
|
|
|
// If any one of the sub-branches are non-zero,
|
|
// we need this probability.
|
|
if (decisionTree(augmented_values) > 0.0) {
|
|
return probability;
|
|
}
|
|
}
|
|
// If we are here, it means that all the sub-branches are 0,
|
|
// so we prune.
|
|
return 0.0;
|
|
}
|
|
};
|
|
return pruner;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
void HybridBayesNet::updateDiscreteConditionals(
|
|
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
|
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
|
|
|
for (size_t i = 0; i < this->size(); i++) {
|
|
HybridConditional::shared_ptr conditional = this->at(i);
|
|
if (conditional->isDiscrete()) {
|
|
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
|
auto discrete = conditional->asDiscreteConditional();
|
|
KeyVector frontals(discrete->frontals().begin(),
|
|
discrete->frontals().end());
|
|
|
|
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
|
auto discreteTree =
|
|
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
|
DecisionTreeFactor::ADT prunedDiscreteTree =
|
|
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
|
|
|
// Create the new (hybrid) conditional
|
|
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
|
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
|
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
|
|
|
// Add it back to the BayesNet
|
|
this->at(i) = conditional;
|
|
}
|
|
}
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|
// Get the decision tree of only the discrete keys
|
|
auto discreteConditionals = this->discreteConditionals();
|
|
const DecisionTreeFactor::shared_ptr decisionTree =
|
|
boost::make_shared<DecisionTreeFactor>(
|
|
discreteConditionals->prune(maxNrLeaves));
|
|
|
|
this->updateDiscreteConditionals(decisionTree);
|
|
|
|
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
|
* For each leaf, using the assignment we can check the discrete decision tree
|
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
|
*
|
|
* We can later check the GaussianMixture for just nullptrs.
|
|
*/
|
|
|
|
HybridBayesNet prunedBayesNetFragment;
|
|
|
|
// Go through all the conditionals in the
|
|
// Bayes Net and prune them as per decisionTree.
|
|
for (size_t i = 0; i < this->size(); i++) {
|
|
HybridConditional::shared_ptr conditional = this->at(i);
|
|
|
|
if (conditional->isHybrid()) {
|
|
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
|
|
|
// Make a copy of the gaussian mixture and prune it!
|
|
auto prunedGaussianMixture =
|
|
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
|
prunedGaussianMixture->prune(*decisionTree);
|
|
|
|
// Type-erase and add to the pruned Bayes Net fragment.
|
|
prunedBayesNetFragment.push_back(
|
|
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
|
|
|
} else {
|
|
// Add the non-GaussianMixture conditional
|
|
prunedBayesNetFragment.push_back(conditional);
|
|
}
|
|
}
|
|
|
|
return prunedBayesNetFragment;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
|
return factors_.at(i)->asMixture();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
|
return factors_.at(i)->asGaussian();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
|
return factors_.at(i)->asDiscreteConditional();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
GaussianBayesNet HybridBayesNet::choose(
|
|
const DiscreteValues &assignment) const {
|
|
GaussianBayesNet gbn;
|
|
for (size_t idx = 0; idx < size(); idx++) {
|
|
if (factors_.at(idx)->isHybrid()) {
|
|
// If factor is hybrid, select based on assignment.
|
|
GaussianMixture gm = *this->atMixture(idx);
|
|
gbn.push_back(gm(assignment));
|
|
|
|
} else if (factors_.at(idx)->isContinuous()) {
|
|
// If continuous only, add gaussian conditional.
|
|
gbn.push_back((this->atGaussian(idx)));
|
|
|
|
} else if (factors_.at(idx)->isDiscrete()) {
|
|
// If factor at `idx` is discrete-only, we simply continue.
|
|
continue;
|
|
}
|
|
}
|
|
|
|
return gbn;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
HybridValues HybridBayesNet::optimize() const {
|
|
// Solve for the MPE
|
|
DiscreteBayesNet discrete_bn;
|
|
for (auto &conditional : factors_) {
|
|
if (conditional->isDiscrete()) {
|
|
discrete_bn.push_back(conditional->asDiscreteConditional());
|
|
}
|
|
}
|
|
|
|
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
|
|
|
// Given the MPE, compute the optimal continuous values.
|
|
GaussianBayesNet gbn = this->choose(mpe);
|
|
return HybridValues(mpe, gbn.optimize());
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
|
GaussianBayesNet gbn = this->choose(assignment);
|
|
return gbn.optimize();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
double HybridBayesNet::error(const VectorValues &continuousValues,
|
|
const DiscreteValues &discreteValues) const {
|
|
GaussianBayesNet gbn = this->choose(discreteValues);
|
|
return gbn.error(continuousValues);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|
const VectorValues &continuousValues) const {
|
|
AlgebraicDecisionTree<Key> error_tree;
|
|
|
|
for (size_t idx = 0; idx < size(); idx++) {
|
|
AlgebraicDecisionTree<Key> conditional_error;
|
|
if (factors_.at(idx)->isHybrid()) {
|
|
// If factor is hybrid, select based on assignment.
|
|
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
|
conditional_error = gm->error(continuousValues);
|
|
|
|
if (idx == 0) {
|
|
error_tree = conditional_error;
|
|
} else {
|
|
error_tree = error_tree + conditional_error;
|
|
}
|
|
|
|
} else if (factors_.at(idx)->isContinuous()) {
|
|
// If continuous only, get the (double) error
|
|
// and add it to the error_tree
|
|
double error = this->atGaussian(idx)->error(continuousValues);
|
|
error_tree = error_tree.apply(
|
|
[error](double leaf_value) { return leaf_value + error; });
|
|
|
|
} else if (factors_.at(idx)->isDiscrete()) {
|
|
// If factor at `idx` is discrete-only, we skip.
|
|
continue;
|
|
}
|
|
}
|
|
|
|
return error_tree;
|
|
}
|
|
|
|
} // namespace gtsam
|