gtsam/gtsam/hybrid/HybridBayesNet.cpp

293 lines
10 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridBayesNet.cpp
* @brief A Bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <memory>
// In Wrappers we have no access to this so have a default ready
static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();
// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
// Set the fixed values if requested.
if (marginalThreshold && fixedValues) {
*fixedValues = fixed;
}
HybridBayesNet result;
result.reserve(size());
// Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further.
for (std::shared_ptr<HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional.
if (marginalThreshold) {
conditional = std::static_pointer_cast<HybridConditional>(
conditional->restrict(fixed));
}
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);
if (!prunedHybridGaussianConditional) {
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
} else if (conditional->isContinuous()) {
// Add the non-Hybrid GaussianConditional conditional
result.push_back(conditional);
} else
throw std::runtime_error(
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
}
#if GTSAM_HYBRID_TIMING
gttoc_(HybridPruning);
#endif
// Add the pruned discrete conditionals to the result.
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
result.push_back(discrete);
return result;
}
/* ************************************************************************* */
DiscreteBayesNet HybridBayesNet::discreteMarginal() const {
DiscreteBayesNet result;
for (auto &&conditional : *this) {
if (auto dc = conditional->asDiscrete()) {
result.push_back(dc);
}
}
return result;
}
/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment.
gbn.push_back(gm->choose(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
}
return gbn;
}
/* ************************************************************************* */
DiscreteValues HybridBayesNet::mpe() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
dtc->toDecisionTreeFactor()));
} else {
discrete_fg.push_back(conditional->asDiscrete());
}
}
}
return discrete_fg.optimize();
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
DiscreteValues mpe = this->mpe();
// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);
}
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = choose(assignment);
// Check if there exists a nullptr in the GaussianBayesNet
// If yes, return an empty VectorValues
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
return VectorValues();
}
return gbn.optimize();
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
const DiscreteValues assignment = dbn.sample(given.discrete());
// Select the continuous Bayes net corresponding to the assignment.
GaussianBayesNet gbn = choose(assignment);
// Sample from the Gaussian Bayes net.
VectorValues sample = gbn.sample(given.continuous(), rng);
return {sample, assignment};
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
HybridValues given;
return sample(given, rng);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}
/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
result = result + conditional->errorTree(continuousValues);
}
return result;
}
/* ************************************************************************* */
double HybridBayesNet::negLogConstant(
const std::optional<DiscreteValues> &discrete) const {
double negLogNormConst = 0.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (discrete.has_value()) {
if (auto gm = conditional->asHybrid()) {
negLogNormConst += gm->choose(*discrete)->negLogConstant();
} else if (auto gc = conditional->asGaussian()) {
negLogNormConst += gc->negLogConstant();
} else if (auto dc = conditional->asDiscrete()) {
negLogNormConst += dc->choose(*discrete)->negLogConstant();
} else {
throw std::runtime_error(
"Unknown conditional type when computing negLogConstant");
}
} else {
negLogNormConst += conditional->negLogConstant();
}
}
return negLogNormConst;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> p =
errors.apply([](double error) { return exp(-error); });
return p / p.sum();
}
/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
return exp(logProbability(values));
}
/* ************************************************************************* */
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
const VectorValues &measurements) const {
HybridGaussianFactorGraph fg;
// For all nodes in the Bayes net, if its frontal variable is in measurements,
// replace it by a likelihood factor:
for (auto &&conditional : *this) {
if (conditional->frontalsIn(measurements)) {
if (auto gc = conditional->asGaussian()) {
fg.push_back(gc->likelihood(measurements));
} else if (auto gm = conditional->asHybrid()) {
fg.push_back(gm->likelihood(measurements));
} else {
throw std::runtime_error("Unknown conditional type");
}
} else {
fg.push_back(conditional);
}
}
return fg;
}
} // namespace gtsam