Make sure we still subtract min()
parent
1fe09f5e09
commit
6592f8a8b4
|
@ -29,6 +29,8 @@
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
#include "gtsam/linear/GaussianConditional.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -38,14 +40,13 @@ namespace gtsam {
|
||||||
* This struct contains the following fields:
|
* This struct contains the following fields:
|
||||||
* - nrFrontals: Optional size_t for number of frontal variables
|
* - nrFrontals: Optional size_t for number of frontal variables
|
||||||
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
||||||
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
|
|
||||||
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
||||||
* constructor
|
* constructor
|
||||||
*/
|
*/
|
||||||
struct HybridGaussianConditional::Helper {
|
struct HybridGaussianConditional::Helper {
|
||||||
std::optional<size_t> nrFrontals;
|
|
||||||
FactorValuePairs pairs;
|
FactorValuePairs pairs;
|
||||||
double minNegLogConstant;
|
std::optional<size_t> nrFrontals = {};
|
||||||
|
double minNegLogConstant = std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
using GC = GaussianConditional;
|
using GC = GaussianConditional;
|
||||||
using P = std::vector<std::pair<Vector, double>>;
|
using P = std::vector<std::pair<Vector, double>>;
|
||||||
|
@ -54,8 +55,6 @@ struct HybridGaussianConditional::Helper {
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
||||||
nrFrontals = 1;
|
nrFrontals = 1;
|
||||||
minNegLogConstant = std::numeric_limits<double>::infinity();
|
|
||||||
|
|
||||||
std::vector<GaussianFactorValuePair> fvs;
|
std::vector<GaussianFactorValuePair> fvs;
|
||||||
std::vector<GC::shared_ptr> gcs;
|
std::vector<GC::shared_ptr> gcs;
|
||||||
fvs.reserve(p.size());
|
fvs.reserve(p.size());
|
||||||
|
@ -73,8 +72,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from tree of GaussianConditionals.
|
/// Construct from tree of GaussianConditionals.
|
||||||
explicit Helper(const Conditionals &conditionals)
|
explicit Helper(const Conditionals &conditionals) {
|
||||||
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
|
||||||
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
@ -89,6 +87,25 @@ struct HybridGaussianConditional::Helper {
|
||||||
"Provided conditionals do not contain any frontal variables.");
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct from tree of factor/scalar pairs.
|
||||||
|
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
||||||
|
auto func = [this](const GaussianFactorValuePair &pair) {
|
||||||
|
if (!pair.first) return;
|
||||||
|
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||||
|
if (!gc)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional called with non-conditional.");
|
||||||
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
||||||
|
};
|
||||||
|
pairs.visit(func);
|
||||||
|
if (!nrFrontals.has_value()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional: need at least one frontal variable. "
|
||||||
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -138,6 +155,10 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const HybridGaussianConditional::Conditionals &conditionals)
|
const HybridGaussianConditional::Conditionals &conditionals)
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
|
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
||||||
|
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const HybridGaussianConditional::Conditionals
|
const HybridGaussianConditional::Conditionals
|
||||||
HybridGaussianConditional::conditionals() const {
|
HybridGaussianConditional::conditionals() const {
|
||||||
|
@ -300,8 +321,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
|
||||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
return std::shared_ptr<HybridGaussianConditional>(
|
return std::shared_ptr<HybridGaussianConditional>(
|
||||||
new HybridGaussianConditional(discreteKeys(), nrFrontals_,
|
new HybridGaussianConditional(discreteKeys(), prunedConditionals));
|
||||||
prunedConditionals, negLogConstant_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -141,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Conditionals &conditionals);
|
const Conditionals &conditionals);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from multiple discrete keys M and a tree of
|
||||||
|
* factor/scalar pairs, where the scalar is assumed to be the
|
||||||
|
* the negative log constant for each assignment m, up to a constant.
|
||||||
|
*
|
||||||
|
* @note Will throw if factors are not actually conditionals.
|
||||||
|
*
|
||||||
|
* @param discreteParents the discrete parents. Will be placed last.
|
||||||
|
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
||||||
|
*/
|
||||||
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
|
const FactorValuePairs &pairs);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -230,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Helper &helper);
|
const Helper &helper);
|
||||||
|
|
||||||
/// Private constructor used when constants have already been calculated.
|
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
|
|
||||||
const FactorValuePairs &factors,
|
|
||||||
double negLogConstant)
|
|
||||||
: BaseFactor(discreteKeys, factors),
|
|
||||||
BaseConditional(nrFrontals),
|
|
||||||
negLogConstant_(negLogConstant) {}
|
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue