Get rid of double storage
parent
05af66296d
commit
b56595c6f8
|
@ -45,7 +45,6 @@ namespace gtsam {
|
||||||
struct HybridGaussianConditional::Helper {
|
struct HybridGaussianConditional::Helper {
|
||||||
std::optional<size_t> nrFrontals;
|
std::optional<size_t> nrFrontals;
|
||||||
FactorValuePairs pairs;
|
FactorValuePairs pairs;
|
||||||
Conditionals conditionals;
|
|
||||||
double minNegLogConstant;
|
double minNegLogConstant;
|
||||||
|
|
||||||
using GC = GaussianConditional;
|
using GC = GaussianConditional;
|
||||||
|
@ -70,14 +69,12 @@ struct HybridGaussianConditional::Helper {
|
||||||
gcs.push_back(gaussianConditional);
|
gcs.push_back(gaussianConditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
conditionals = Conditionals({mode}, gcs);
|
|
||||||
pairs = FactorValuePairs({mode}, fvs);
|
pairs = FactorValuePairs({mode}, fvs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from tree of GaussianConditionals.
|
/// Construct from tree of GaussianConditionals.
|
||||||
explicit Helper(const Conditionals &conditionals)
|
explicit Helper(const Conditionals &conditionals)
|
||||||
: conditionals(conditionals),
|
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
|
||||||
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
@ -106,7 +103,6 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
pair.second - helper.minNegLogConstant};
|
pair.second - helper.minNegLogConstant};
|
||||||
})),
|
})),
|
||||||
BaseConditional(*helper.nrFrontals),
|
BaseConditional(*helper.nrFrontals),
|
||||||
conditionals_(helper.conditionals),
|
|
||||||
negLogConstant_(helper.minNegLogConstant) {}
|
negLogConstant_(helper.minNegLogConstant) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
|
@ -143,24 +139,26 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const HybridGaussianConditional::Conditionals &
|
const HybridGaussianConditional::Conditionals
|
||||||
HybridGaussianConditional::conditionals() const {
|
HybridGaussianConditional::conditionals() const {
|
||||||
return conditionals_;
|
return Conditionals(factors(), [](const auto& pair) {
|
||||||
|
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
size_t HybridGaussianConditional::nrComponents() const {
|
size_t HybridGaussianConditional::nrComponents() const {
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
|
factors().visit([&total](const auto& node) {
|
||||||
if (node) total += 1;
|
if (node.first) total += 1;
|
||||||
});
|
});
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues& discreteValues) const {
|
||||||
auto &ptr = conditionals_(discreteValues);
|
auto& [ptr, _] = factors()(discreteValues);
|
||||||
if (!ptr) return nullptr;
|
if (!ptr) return nullptr;
|
||||||
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
||||||
if (conditional)
|
if (conditional)
|
||||||
|
@ -176,18 +174,15 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
// This will return false if either conditionals_ is empty or e->conditionals_
|
// Factors existence and scalar values are checked in BaseFactor::equals.
|
||||||
// is empty, but not if both are empty or both are not empty:
|
// Here we check additionally that the factors *are* conditionals and are equal.
|
||||||
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||||
|
const GaussianFactorValuePair& pair2) {
|
||||||
// Check the base and the factors:
|
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
|
||||||
return BaseFactor::equals(*e, tol) &&
|
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
|
||||||
conditionals_.equals(e->conditionals_,
|
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
|
||||||
[tol](const GaussianConditional::shared_ptr &f1,
|
};
|
||||||
const GaussianConditional::shared_ptr &f2) {
|
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
|
||||||
return (!f1 && !f2) ||
|
|
||||||
(f1 && f2 && f1->equals(*f2, tol));
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -202,7 +197,7 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
conditionals().print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
|
@ -254,7 +249,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals_,
|
conditionals(),
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianFactorValuePair {
|
-> GaussianFactorValuePair {
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
|
@ -294,22 +289,30 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals().apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
pruned_conditionals);
|
pruned_conditionals);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::logProbability(
|
double HybridGaussianConditional::logProbability(
|
||||||
const HybridValues &values) const {
|
const HybridValues& values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
|
else
|
||||||
|
throw std::logic_error(
|
||||||
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
double HybridGaussianConditional::evaluate(const HybridValues& values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
||||||
return conditional->evaluate(values.continuous());
|
return conditional->evaluate(values.continuous());
|
||||||
|
else
|
||||||
|
throw std::logic_error(
|
||||||
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
|
||||||
|
|
||||||
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
||||||
///< Take advantage of the neg-log space so everything is a minimization
|
///< Take advantage of the neg-log space so everything is a minimization
|
||||||
double negLogConstant_;
|
double negLogConstant_;
|
||||||
|
@ -192,8 +190,8 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
std::shared_ptr<HybridGaussianFactor> likelihood(
|
std::shared_ptr<HybridGaussianFactor> likelihood(
|
||||||
const VectorValues &given) const;
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||||
const Conditionals &conditionals() const;
|
const Conditionals conditionals() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||||
|
@ -241,7 +239,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue