Make sure we still subtract min()
parent
1fe09f5e09
commit
6592f8a8b4
|
@ -29,6 +29,8 @@
|
|||
#include <gtsam/linear/JacobianFactor.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include "gtsam/linear/GaussianConditional.h"
|
||||
|
||||
namespace gtsam {
|
||||
/* *******************************************************************************/
|
||||
|
@ -38,14 +40,13 @@ namespace gtsam {
|
|||
* This struct contains the following fields:
|
||||
* - nrFrontals: Optional size_t for number of frontal variables
|
||||
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
||||
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
|
||||
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
||||
* constructor
|
||||
*/
|
||||
struct HybridGaussianConditional::Helper {
|
||||
std::optional<size_t> nrFrontals;
|
||||
FactorValuePairs pairs;
|
||||
double minNegLogConstant;
|
||||
std::optional<size_t> nrFrontals = {};
|
||||
double minNegLogConstant = std::numeric_limits<double>::infinity();
|
||||
|
||||
using GC = GaussianConditional;
|
||||
using P = std::vector<std::pair<Vector, double>>;
|
||||
|
@ -54,8 +55,6 @@ struct HybridGaussianConditional::Helper {
|
|||
template <typename... Args>
|
||||
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
||||
nrFrontals = 1;
|
||||
minNegLogConstant = std::numeric_limits<double>::infinity();
|
||||
|
||||
std::vector<GaussianFactorValuePair> fvs;
|
||||
std::vector<GC::shared_ptr> gcs;
|
||||
fvs.reserve(p.size());
|
||||
|
@ -73,8 +72,7 @@ struct HybridGaussianConditional::Helper {
|
|||
}
|
||||
|
||||
/// Construct from tree of GaussianConditionals.
|
||||
explicit Helper(const Conditionals &conditionals)
|
||||
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||
explicit Helper(const Conditionals &conditionals) {
|
||||
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||
|
@ -89,6 +87,25 @@ struct HybridGaussianConditional::Helper {
|
|||
"Provided conditionals do not contain any frontal variables.");
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct from tree of factor/scalar pairs.
|
||||
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
||||
auto func = [this](const GaussianFactorValuePair &pair) {
|
||||
if (!pair.first) return;
|
||||
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||
if (!gc)
|
||||
throw std::runtime_error(
|
||||
"HybridGaussianConditional called with non-conditional.");
|
||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
||||
};
|
||||
pairs.visit(func);
|
||||
if (!nrFrontals.has_value()) {
|
||||
throw std::runtime_error(
|
||||
"HybridGaussianConditional: need at least one frontal variable. "
|
||||
"Provided conditionals do not contain any frontal variables.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -138,6 +155,10 @@ HybridGaussianConditional::HybridGaussianConditional(
|
|||
const HybridGaussianConditional::Conditionals &conditionals)
|
||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
||||
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const HybridGaussianConditional::Conditionals
|
||||
HybridGaussianConditional::conditionals() const {
|
||||
|
@ -300,8 +321,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
|||
|
||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||
return std::shared_ptr<HybridGaussianConditional>(
|
||||
new HybridGaussianConditional(discreteKeys(), nrFrontals_,
|
||||
prunedConditionals, negLogConstant_));
|
||||
new HybridGaussianConditional(discreteKeys(), prunedConditionals));
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -141,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||
const Conditionals &conditionals);
|
||||
|
||||
/**
|
||||
* @brief Construct from multiple discrete keys M and a tree of
|
||||
* factor/scalar pairs, where the scalar is assumed to be the
|
||||
* the negative log constant for each assignment m, up to a constant.
|
||||
*
|
||||
* @note Will throw if factors are not actually conditionals.
|
||||
*
|
||||
* @param discreteParents the discrete parents. Will be placed last.
|
||||
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
||||
*/
|
||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||
const FactorValuePairs &pairs);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
@ -230,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||
const Helper &helper);
|
||||
|
||||
/// Private constructor used when constants have already been calculated.
|
||||
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
|
||||
const FactorValuePairs &factors,
|
||||
double negLogConstant)
|
||||
: BaseFactor(discreteKeys, factors),
|
||||
BaseConditional(nrFrontals),
|
||||
negLogConstant_(negLogConstant) {}
|
||||
|
||||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
|
|
Loading…
Reference in New Issue