Make sure we still subtract min()

release/4.3a0
Frank Dellaert 2024-10-17 02:05:56 +09:00
parent 1fe09f5e09
commit 6592f8a8b4
2 changed files with 42 additions and 17 deletions

View File

@ -29,6 +29,8 @@
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include <cstddef> #include <cstddef>
#include <memory>
#include "gtsam/linear/GaussianConditional.h"
namespace gtsam { namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
@ -38,14 +40,13 @@ namespace gtsam {
* This struct contains the following fields: * This struct contains the following fields:
* - nrFrontals: Optional size_t for number of frontal variables * - nrFrontals: Optional size_t for number of frontal variables
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant * - pairs: FactorValuePairs for storing conditionals with their negLogConstant
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in * - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
* constructor * constructor
*/ */
struct HybridGaussianConditional::Helper { struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
FactorValuePairs pairs; FactorValuePairs pairs;
double minNegLogConstant; std::optional<size_t> nrFrontals = {};
double minNegLogConstant = std::numeric_limits<double>::infinity();
using GC = GaussianConditional; using GC = GaussianConditional;
using P = std::vector<std::pair<Vector, double>>; using P = std::vector<std::pair<Vector, double>>;
@ -54,8 +55,6 @@ struct HybridGaussianConditional::Helper {
template <typename... Args> template <typename... Args>
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) { explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
nrFrontals = 1; nrFrontals = 1;
minNegLogConstant = std::numeric_limits<double>::infinity();
std::vector<GaussianFactorValuePair> fvs; std::vector<GaussianFactorValuePair> fvs;
std::vector<GC::shared_ptr> gcs; std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size()); fvs.reserve(p.size());
@ -73,8 +72,7 @@ struct HybridGaussianConditional::Helper {
} }
/// Construct from tree of GaussianConditionals. /// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals) explicit Helper(const Conditionals &conditionals) {
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()}; if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals(); if (!nrFrontals) nrFrontals = gc->nrFrontals();
@ -89,6 +87,25 @@ struct HybridGaussianConditional::Helper {
"Provided conditionals do not contain any frontal variables."); "Provided conditionals do not contain any frontal variables.");
} }
} }
/// Construct from tree of factor/scalar pairs.
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
auto func = [this](const GaussianFactorValuePair &pair) {
if (!pair.first) return;
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
if (!gc)
throw std::runtime_error(
"HybridGaussianConditional called with non-conditional.");
if (!nrFrontals) nrFrontals = gc->nrFrontals();
minNegLogConstant = std::min(minNegLogConstant, pair.second);
};
pairs.visit(func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable. "
"Provided conditionals do not contain any frontal variables.");
}
}
}; };
/* *******************************************************************************/ /* *******************************************************************************/
@ -138,6 +155,10 @@ HybridGaussianConditional::HybridGaussianConditional(
const HybridGaussianConditional::Conditionals &conditionals) const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {} : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
/* *******************************************************************************/ /* *******************************************************************************/
const HybridGaussianConditional::Conditionals const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const { HybridGaussianConditional::conditionals() const {
@ -300,8 +321,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
FactorValuePairs prunedConditionals = factors().apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::shared_ptr<HybridGaussianConditional>( return std::shared_ptr<HybridGaussianConditional>(
new HybridGaussianConditional(discreteKeys(), nrFrontals_, new HybridGaussianConditional(discreteKeys(), prunedConditionals));
prunedConditionals, negLogConstant_));
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -141,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
/**
* @brief Construct from multiple discrete keys M and a tree of
* factor/scalar pairs, where the scalar is assumed to be the
* the negative log constant for each assignment m, up to a constant.
*
* @note Will throw if factors are not actually conditionals.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const FactorValuePairs &pairs);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -230,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper); const Helper &helper);
/// Private constructor used when constants have already been calculated.
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
const FactorValuePairs &factors,
double negLogConstant)
: BaseFactor(discreteKeys, factors),
BaseConditional(nrFrontals),
negLogConstant_(negLogConstant) {}
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;