Get rid of double storage

release/4.3a0
Frank Dellaert 2024-10-08 18:10:27 +09:00
parent 05af66296d
commit b56595c6f8
2 changed files with 37 additions and 37 deletions

View File

@ -45,7 +45,6 @@ namespace gtsam {
struct HybridGaussianConditional::Helper { struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals; std::optional<size_t> nrFrontals;
FactorValuePairs pairs; FactorValuePairs pairs;
Conditionals conditionals;
double minNegLogConstant; double minNegLogConstant;
using GC = GaussianConditional; using GC = GaussianConditional;
@ -70,14 +69,12 @@ struct HybridGaussianConditional::Helper {
gcs.push_back(gaussianConditional); gcs.push_back(gaussianConditional);
} }
conditionals = Conditionals({mode}, gcs);
pairs = FactorValuePairs({mode}, fvs); pairs = FactorValuePairs({mode}, fvs);
} }
/// Construct from tree of GaussianConditionals. /// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals) explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals), : minNegLogConstant(std::numeric_limits<double>::infinity()) {
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()}; if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals(); if (!nrFrontals) nrFrontals = gc->nrFrontals();
@ -106,7 +103,6 @@ HybridGaussianConditional::HybridGaussianConditional(
pair.second - helper.minNegLogConstant}; pair.second - helper.minNegLogConstant};
})), })),
BaseConditional(*helper.nrFrontals), BaseConditional(*helper.nrFrontals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {} negLogConstant_(helper.minNegLogConstant) {}
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
@ -143,16 +139,18 @@ HybridGaussianConditional::HybridGaussianConditional(
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {} : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
const HybridGaussianConditional::Conditionals & const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const { HybridGaussianConditional::conditionals() const {
return conditionals_; return Conditionals(factors(), [](const auto& pair) {
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const { size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0; size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { factors().visit([&total](const auto& node) {
if (node) total += 1; if (node.first) total += 1;
}); });
return total; return total;
} }
@ -160,7 +158,7 @@ size_t HybridGaussianConditional::nrComponents() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose( GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues& discreteValues) const { const DiscreteValues& discreteValues) const {
auto &ptr = conditionals_(discreteValues); auto& [ptr, _] = factors()(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr); auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional) if (conditional)
@ -176,18 +174,15 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false; if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_ // Factors existence and scalar values are checked in BaseFactor::equals.
// is empty, but not if both are empty or both are not empty: // Here we check additionally that the factors *are* conditionals and are equal.
if (conditionals_.empty() ^ e->conditionals_.empty()) return false; auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
// Check the base and the factors: auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
return BaseFactor::equals(*e, tol) && c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
conditionals_.equals(e->conditionals_, return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
[tol](const GaussianConditional::shared_ptr &f1, };
const GaussianConditional::shared_ptr &f2) { return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -202,7 +197,7 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << std::endl std::cout << std::endl
<< " logNormalizationConstant: " << -negLogConstant() << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl
<< std::endl; << std::endl;
conditionals_.print( conditionals().print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string { [&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd; RedirectCout rd;
@ -254,7 +249,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods( const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals_, conditionals(),
[&](const GaussianConditional::shared_ptr &conditional) [&](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair { -> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given); const auto likelihood_m = conditional->likelihood(given);
@ -294,7 +289,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
return (max->evaluate(choices) == 0.0) ? nullptr : conditional; return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
}; };
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::make_shared<HybridGaussianConditional>(discreteKeys(),
pruned_conditionals); pruned_conditionals);
} }
@ -302,14 +297,22 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::logProbability( double HybridGaussianConditional::logProbability(
const HybridValues& values) const { const HybridValues& values) const {
auto conditional = conditionals_(values.discrete()); auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->logProbability(values.continuous()); return conditional->logProbability(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
} }
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues& values) const { double HybridGaussianConditional::evaluate(const HybridValues& values) const {
auto conditional = conditionals_(values.discrete()); auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->evaluate(values.continuous()); return conditional->evaluate(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
} }
} // namespace gtsam } // namespace gtsam

View File

@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private: private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))). ///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization ///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_; double negLogConstant_;
@ -192,8 +190,8 @@ class GTSAM_EXPORT HybridGaussianConditional
std::shared_ptr<HybridGaussianFactor> likelihood( std::shared_ptr<HybridGaussianFactor> likelihood(
const VectorValues &given) const; const VectorValues &given) const;
/// Getter for the underlying Conditionals DecisionTree /// Get Conditionals DecisionTree (dynamic cast from factors)
const Conditionals &conditionals() const; const Conditionals conditionals() const;
/** /**
* @brief Compute the logProbability of this hybrid Gaussian conditional. * @brief Compute the logProbability of this hybrid Gaussian conditional.
@ -241,7 +239,6 @@ class GTSAM_EXPORT HybridGaussianConditional
void serialize(Archive &ar, const unsigned int /*version*/) { void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
} }
#endif #endif
}; };