Get rid of double storage
parent
05af66296d
commit
b56595c6f8
|
@ -45,7 +45,6 @@ namespace gtsam {
|
|||
struct HybridGaussianConditional::Helper {
|
||||
std::optional<size_t> nrFrontals;
|
||||
FactorValuePairs pairs;
|
||||
Conditionals conditionals;
|
||||
double minNegLogConstant;
|
||||
|
||||
using GC = GaussianConditional;
|
||||
|
@ -70,14 +69,12 @@ struct HybridGaussianConditional::Helper {
|
|||
gcs.push_back(gaussianConditional);
|
||||
}
|
||||
|
||||
conditionals = Conditionals({mode}, gcs);
|
||||
pairs = FactorValuePairs({mode}, fvs);
|
||||
}
|
||||
|
||||
/// Construct from tree of GaussianConditionals.
|
||||
explicit Helper(const Conditionals &conditionals)
|
||||
: conditionals(conditionals),
|
||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||
|
@ -106,7 +103,6 @@ HybridGaussianConditional::HybridGaussianConditional(
|
|||
pair.second - helper.minNegLogConstant};
|
||||
})),
|
||||
BaseConditional(*helper.nrFrontals),
|
||||
conditionals_(helper.conditionals),
|
||||
negLogConstant_(helper.minNegLogConstant) {}
|
||||
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
|
@ -143,24 +139,26 @@ HybridGaussianConditional::HybridGaussianConditional(
|
|||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const HybridGaussianConditional::Conditionals &
|
||||
const HybridGaussianConditional::Conditionals
|
||||
HybridGaussianConditional::conditionals() const {
|
||||
return conditionals_;
|
||||
return Conditionals(factors(), [](const auto& pair) {
|
||||
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
size_t HybridGaussianConditional::nrComponents() const {
|
||||
size_t total = 0;
|
||||
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
|
||||
if (node) total += 1;
|
||||
factors().visit([&total](const auto& node) {
|
||||
if (node.first) total += 1;
|
||||
});
|
||||
return total;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||
const DiscreteValues &discreteValues) const {
|
||||
auto &ptr = conditionals_(discreteValues);
|
||||
const DiscreteValues& discreteValues) const {
|
||||
auto& [ptr, _] = factors()(discreteValues);
|
||||
if (!ptr) return nullptr;
|
||||
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
||||
if (conditional)
|
||||
|
@ -176,18 +174,15 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
|||
const This *e = dynamic_cast<const This *>(&lf);
|
||||
if (e == nullptr) return false;
|
||||
|
||||
// This will return false if either conditionals_ is empty or e->conditionals_
|
||||
// is empty, but not if both are empty or both are not empty:
|
||||
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
||||
|
||||
// Check the base and the factors:
|
||||
return BaseFactor::equals(*e, tol) &&
|
||||
conditionals_.equals(e->conditionals_,
|
||||
[tol](const GaussianConditional::shared_ptr &f1,
|
||||
const GaussianConditional::shared_ptr &f2) {
|
||||
return (!f1 && !f2) ||
|
||||
(f1 && f2 && f1->equals(*f2, tol));
|
||||
});
|
||||
// Factors existence and scalar values are checked in BaseFactor::equals.
|
||||
// Here we check additionally that the factors *are* conditionals and are equal.
|
||||
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||
const GaussianFactorValuePair& pair2) {
|
||||
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
|
||||
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
|
||||
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
|
||||
};
|
||||
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
@ -202,7 +197,7 @@ void HybridGaussianConditional::print(const std::string &s,
|
|||
std::cout << std::endl
|
||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||
<< std::endl;
|
||||
conditionals_.print(
|
||||
conditionals().print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
|
@ -254,7 +249,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
|||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||
conditionals_,
|
||||
conditionals(),
|
||||
[&](const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianFactorValuePair {
|
||||
const auto likelihood_m = conditional->likelihood(given);
|
||||
|
@ -294,22 +289,30 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
|||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||
};
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
auto pruned_conditionals = conditionals().apply(pruner);
|
||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||
pruned_conditionals);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double HybridGaussianConditional::logProbability(
|
||||
const HybridValues &values) const {
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->logProbability(values.continuous());
|
||||
const HybridValues& values) const {
|
||||
auto [factor, _] = factors()(values.discrete());
|
||||
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
||||
return conditional->logProbability(values.continuous());
|
||||
else
|
||||
throw std::logic_error(
|
||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->evaluate(values.continuous());
|
||||
double HybridGaussianConditional::evaluate(const HybridValues& values) const {
|
||||
auto [factor, _] = factors()(values.discrete());
|
||||
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
||||
return conditional->evaluate(values.continuous());
|
||||
else
|
||||
throw std::logic_error(
|
||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
private:
|
||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||
|
||||
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
||||
///< Take advantage of the neg-log space so everything is a minimization
|
||||
double negLogConstant_;
|
||||
|
@ -192,8 +190,8 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
std::shared_ptr<HybridGaussianFactor> likelihood(
|
||||
const VectorValues &given) const;
|
||||
|
||||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals() const;
|
||||
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||
const Conditionals conditionals() const;
|
||||
|
||||
/**
|
||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||
|
@ -241,7 +239,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue