Get rid of double storage

release/4.3a0
Frank Dellaert 2024-10-08 18:10:27 +09:00
parent 05af66296d
commit b56595c6f8
2 changed files with 37 additions and 37 deletions

View File

@ -45,7 +45,6 @@ namespace gtsam {
struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
FactorValuePairs pairs;
Conditionals conditionals;
double minNegLogConstant;
using GC = GaussianConditional;
@ -70,14 +69,12 @@ struct HybridGaussianConditional::Helper {
gcs.push_back(gaussianConditional);
}
conditionals = Conditionals({mode}, gcs);
pairs = FactorValuePairs({mode}, fvs);
}
/// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals();
@ -106,7 +103,6 @@ HybridGaussianConditional::HybridGaussianConditional(
pair.second - helper.minNegLogConstant};
})),
BaseConditional(*helper.nrFrontals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}
HybridGaussianConditional::HybridGaussianConditional(
@ -143,24 +139,26 @@ HybridGaussianConditional::HybridGaussianConditional(
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const {
return conditionals_;
return Conditionals(factors(), [](const auto& pair) {
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
});
}
/* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
if (node) total += 1;
factors().visit([&total](const auto& node) {
if (node.first) total += 1;
});
return total;
}
/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
const DiscreteValues& discreteValues) const {
auto& [ptr, _] = factors()(discreteValues);
if (!ptr) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
@ -176,18 +174,15 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
// Factors existence and scalar values are checked in BaseFactor::equals.
// Here we check additionally that the factors *are* conditionals and are equal.
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
};
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
}
/* *******************************************************************************/
@ -202,7 +197,7 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << std::endl
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
<< std::endl;
conditionals_.print(
conditionals().print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
@ -254,7 +249,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals_,
conditionals(),
[&](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
@ -294,22 +289,30 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
};
auto pruned_conditionals = conditionals_.apply(pruner);
auto pruned_conditionals = conditionals().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
pruned_conditionals);
}
/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->logProbability(values.continuous());
const HybridValues& values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->logProbability(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
}
/* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous());
double HybridGaussianConditional::evaluate(const HybridValues& values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->evaluate(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
}
} // namespace gtsam

View File

@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_;
@ -192,8 +190,8 @@ class GTSAM_EXPORT HybridGaussianConditional
std::shared_ptr<HybridGaussianFactor> likelihood(
const VectorValues &given) const;
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const;
/// Get Conditionals DecisionTree (dynamic cast from factors)
const Conditionals conditionals() const;
/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
@ -241,7 +239,6 @@ class GTSAM_EXPORT HybridGaussianConditional
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
#endif
};