349 lines
14 KiB
C++
349 lines
14 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file HybridGaussianConditional.cpp
|
|
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
|
* @author Fan Jiang
|
|
* @author Varun Agrawal
|
|
* @author Frank Dellaert
|
|
* @date Mar 12, 2022
|
|
*/
|
|
|
|
#include <gtsam/base/utilities.h>
|
|
#include <gtsam/discrete/DiscreteValues.h>
|
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
|
#include <gtsam/hybrid/HybridValues.h>
|
|
#include <gtsam/inference/Conditional-inst.h>
|
|
#include <gtsam/linear/GaussianBayesNet.h>
|
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
|
#include <gtsam/linear/JacobianFactor.h>
|
|
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include "gtsam/linear/GaussianConditional.h"
|
|
|
|
namespace gtsam {
|
|
/* *******************************************************************************/
|
|
/**
|
|
* @brief Helper struct for constructing HybridGaussianConditional objects
|
|
*
|
|
* This struct contains the following fields:
|
|
* - nrFrontals: Optional size_t for number of frontal variables
|
|
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
|
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
|
* constructor
|
|
*/
|
|
struct HybridGaussianConditional::Helper {
|
|
FactorValuePairs pairs;
|
|
std::optional<size_t> nrFrontals = {};
|
|
double minNegLogConstant = std::numeric_limits<double>::infinity();
|
|
|
|
using GC = GaussianConditional;
|
|
using P = std::vector<std::pair<Vector, double>>;
|
|
|
|
/// Construct from a vector of mean and sigma pairs, plus extra args.
|
|
template <typename... Args>
|
|
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
|
nrFrontals = 1;
|
|
std::vector<GaussianFactorValuePair> fvs;
|
|
std::vector<GC::shared_ptr> gcs;
|
|
fvs.reserve(p.size());
|
|
gcs.reserve(p.size());
|
|
for (auto &&[mean, sigma] : p) {
|
|
auto gaussianConditional =
|
|
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
|
|
double value = gaussianConditional->negLogConstant();
|
|
minNegLogConstant = std::min(minNegLogConstant, value);
|
|
fvs.emplace_back(gaussianConditional, value);
|
|
gcs.push_back(gaussianConditional);
|
|
}
|
|
|
|
pairs = FactorValuePairs({mode}, fvs);
|
|
}
|
|
|
|
/// Construct from tree of GaussianConditionals.
|
|
explicit Helper(const Conditionals &conditionals) {
|
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
|
double value = gc->negLogConstant();
|
|
minNegLogConstant = std::min(minNegLogConstant, value);
|
|
return {gc, value};
|
|
};
|
|
pairs = FactorValuePairs(conditionals, func);
|
|
if (!nrFrontals.has_value()) {
|
|
throw std::runtime_error(
|
|
"HybridGaussianConditional: need at least one frontal variable. "
|
|
"Provided conditionals do not contain any frontal variables.");
|
|
}
|
|
}
|
|
|
|
/// Construct from tree of factor/scalar pairs.
|
|
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
|
auto func = [this](const GaussianFactorValuePair &pair) {
|
|
if (!pair.first) return;
|
|
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
|
if (!gc)
|
|
throw std::runtime_error(
|
|
"HybridGaussianConditional called with non-conditional.");
|
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
|
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
|
};
|
|
pairs.visit(func);
|
|
if (!nrFrontals.has_value()) {
|
|
throw std::runtime_error(
|
|
"HybridGaussianConditional: need at least one frontal variable. "
|
|
"Provided conditionals do not contain any frontal variables.");
|
|
}
|
|
}
|
|
};
|
|
|
|
/* *******************************************************************************/
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKeys &discreteParents, const Helper &helper)
|
|
: BaseFactor(discreteParents,
|
|
FactorValuePairs(helper.pairs,
|
|
[&](const GaussianFactorValuePair &
|
|
pair) { // subtract minNegLogConstant
|
|
return GaussianFactorValuePair{
|
|
pair.first,
|
|
pair.second - helper.minNegLogConstant};
|
|
})),
|
|
BaseConditional(*helper.nrFrontals),
|
|
negLogConstant_(helper.minNegLogConstant) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKey &discreteParent,
|
|
const std::vector<GaussianConditional::shared_ptr> &conditionals)
|
|
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
|
Conditionals({discreteParent}, conditionals)) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKey &discreteParent, Key key, //
|
|
const std::vector<std::pair<Vector, double>> ¶meters)
|
|
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
|
Helper(discreteParent, parameters, key)) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKey &discreteParent, Key key, //
|
|
const Matrix &A, Key parent,
|
|
const std::vector<std::pair<Vector, double>> ¶meters)
|
|
: HybridGaussianConditional(
|
|
DiscreteKeys{discreteParent},
|
|
Helper(discreteParent, parameters, key, A, parent)) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKey &discreteParent, Key key, //
|
|
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
|
|
const std::vector<std::pair<Vector, double>> ¶meters)
|
|
: HybridGaussianConditional(
|
|
DiscreteKeys{discreteParent},
|
|
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKeys &discreteParents,
|
|
const HybridGaussianConditional::Conditionals &conditionals)
|
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
|
|
|
HybridGaussianConditional::HybridGaussianConditional(
|
|
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
|
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
|
|
|
/* *******************************************************************************/
|
|
const HybridGaussianConditional::Conditionals
|
|
HybridGaussianConditional::conditionals() const {
|
|
return Conditionals(factors(), [](const auto& pair) {
|
|
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
|
});
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
size_t HybridGaussianConditional::nrComponents() const {
|
|
size_t total = 0;
|
|
factors().visit([&total](const auto& node) {
|
|
if (node.first) total += 1;
|
|
});
|
|
return total;
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
|
const DiscreteValues& discreteValues) const {
|
|
auto& [ptr, _] = factors()(discreteValues);
|
|
if (!ptr) return nullptr;
|
|
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
|
if (conditional)
|
|
return conditional;
|
|
else
|
|
throw std::logic_error(
|
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
|
double tol) const {
|
|
const This *e = dynamic_cast<const This *>(&lf);
|
|
if (e == nullptr) return false;
|
|
|
|
// Factors existence and scalar values are checked in BaseFactor::equals.
|
|
// Here we check additionally that the factors *are* conditionals and are equal.
|
|
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
|
const GaussianFactorValuePair& pair2) {
|
|
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
|
|
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
|
|
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
|
|
};
|
|
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
void HybridGaussianConditional::print(const std::string &s,
|
|
const KeyFormatter &formatter) const {
|
|
std::cout << (s.empty() ? "" : s + "\n");
|
|
BaseConditional::print("", formatter);
|
|
std::cout << " Discrete Keys = ";
|
|
for (auto &dk : discreteKeys()) {
|
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
|
}
|
|
std::cout << std::endl
|
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
|
<< std::endl;
|
|
factors().print(
|
|
"", [&](Key k) { return formatter(k); },
|
|
[&](const GaussianFactorValuePair &pair) -> std::string {
|
|
RedirectCout rd;
|
|
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
|
gf->print("", formatter);
|
|
return rd.str();
|
|
} else {
|
|
return "nullptr";
|
|
}
|
|
});
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
KeyVector HybridGaussianConditional::continuousParents() const {
|
|
// Get all parent keys:
|
|
const auto range = parents();
|
|
KeyVector continuousParentKeys(range.begin(), range.end());
|
|
// Loop over all discrete keys:
|
|
for (const auto &discreteKey : discreteKeys()) {
|
|
const Key key = discreteKey.first;
|
|
// remove that key from continuousParentKeys:
|
|
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
|
continuousParentKeys.end(), key),
|
|
continuousParentKeys.end());
|
|
}
|
|
return continuousParentKeys;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
bool HybridGaussianConditional::allFrontalsGiven(
|
|
const VectorValues &given) const {
|
|
for (auto &&kv : given) {
|
|
if (given.find(kv.first) == given.end()) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
|
const VectorValues &given) const {
|
|
if (!allFrontalsGiven(given)) {
|
|
throw std::runtime_error(
|
|
"HybridGaussianConditional::likelihood: given values are missing some "
|
|
"frontals.");
|
|
}
|
|
|
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
|
const KeyVector continuousParentKeys = continuousParents();
|
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
|
factors(),
|
|
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
|
if (auto conditional =
|
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
|
const auto likelihood_m = conditional->likelihood(given);
|
|
// scalar is already correct.
|
|
assert(pair.second ==
|
|
conditional->negLogConstant() - negLogConstant_);
|
|
return {likelihood_m, pair.second};
|
|
} else {
|
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
|
}
|
|
});
|
|
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
|
likelihoods);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
|
std::set<DiscreteKey> s(discreteKeys.begin(), discreteKeys.end());
|
|
return s;
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
|
const DecisionTreeFactor &discreteProbs) const {
|
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
|
discreteProbs.keys().end());
|
|
std::vector<Key> diff;
|
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
|
std::back_inserter(diff));
|
|
|
|
// Find maximum probability value for every combination of our keys.
|
|
Ordering keys(diff);
|
|
auto max = discreteProbs.max(keys);
|
|
|
|
// Check the max value for every combination of our keys.
|
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
|
auto pruner =
|
|
[&](const Assignment<Key> &choices,
|
|
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
|
if (max->evaluate(choices) == 0.0)
|
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
|
else
|
|
return pair;
|
|
};
|
|
|
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
|
return std::shared_ptr<HybridGaussianConditional>(
|
|
new HybridGaussianConditional(discreteKeys(), prunedConditionals));
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
double HybridGaussianConditional::logProbability(
|
|
const HybridValues& values) const {
|
|
auto [factor, _] = factors()(values.discrete());
|
|
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
|
return conditional->logProbability(values.continuous());
|
|
else
|
|
throw std::logic_error(
|
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
}
|
|
|
|
/* *******************************************************************************/
|
|
double HybridGaussianConditional::evaluate(const HybridValues& values) const {
|
|
auto [factor, _] = factors()(values.discrete());
|
|
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
|
return conditional->evaluate(values.continuous());
|
|
else
|
|
throw std::logic_error(
|
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
}
|
|
|
|
} // namespace gtsam
|