Initial attempt at removing extra arguments

release/4.3a0
Frank Dellaert 2024-09-26 10:41:04 -07:00
parent 8d5ff15c25
commit 1f11b5472f
3 changed files with 49 additions and 47 deletions

View File

@ -28,56 +28,65 @@
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
const HybridGaussianConditional::Conditionals &conditionals) {
auto func = [](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (conditional) {
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = conditional->negLogConstant();
}
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
};
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
}
HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents, GetFactorValuePairs(conditionals)),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate negLogConstant_ as the minimum of the negative-log normalizers of
// the conditionals, by visiting the decision tree:
: BaseConditional(0) { // Initialize with zero; we'll set it properly later
// Check if conditionals are empty
if (conditionals.empty()) {
throw std::invalid_argument("Conditionals cannot be empty");
}
KeyVector frontals, parents;
negLogConstant_ = std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->negLogConstant_ =
std::min(this->negLogConstant_, conditional->negLogConstant());
}
});
auto func =
[&](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (c) {
KeyVector cf(c->frontals().begin(), c->frontals().end());
KeyVector cp(c->parents().begin(), c->parents().end());
if (frontals.empty()) {
// Get frontal/parent keys from first conditional.
frontals = cf;
parents = cp;
} else if (cf != frontals || cp != parents) {
throw std::invalid_argument(
"All conditionals must have the same frontals and parents");
}
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = c->negLogConstant();
this->negLogConstant_ = std::min(this->negLogConstant_, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
BaseFactor::factors_ = HybridGaussianFactor::Factors(conditionals, func);
// Initialize base classes
KeyVector continuousKeys = frontals;
continuousKeys.insert(continuousKeys.end(), parents.begin(), parents.end());
BaseFactor::keys_ = continuousKeys;
BaseFactor::discreteKeys_ = discreteParents;
BaseConditional::nrFrontals_ = frontals.size();
// Assign conditionals
conditionals_ = conditionals; // TODO(frank): a duplicate of factors_ !!!
}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}
/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
return conditionals_;
}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(continuousFrontals, continuousParents,
DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {

View File

@ -95,17 +95,13 @@ class GTSAM_EXPORT HybridGaussianConditional
/**
* @brief Construct a new HybridGaussianConditional object.
*
* @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents.
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals);
/**
@ -113,14 +109,11 @@ class GTSAM_EXPORT HybridGaussianConditional
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent.
*/
HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

View File

@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// typedef for Decision Tree of Gaussian factors.
using Factors = DecisionTree<Key, sharedFactor>;
private:
protected:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;