add type info
parent
8e85b68863
commit
f39f678c14
|
@ -183,10 +183,12 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
|||
|
||||
// Check the base and the factors:
|
||||
return BaseFactor::equals(*e, tol) &&
|
||||
conditionals_.equals(
|
||||
e->conditionals_, [tol](const auto &f1, const auto &f2) {
|
||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||
});
|
||||
conditionals_.equals(e->conditionals_,
|
||||
[tol](const GaussianConditional::shared_ptr &f1,
|
||||
const GaussianConditional::shared_ptr &f2) {
|
||||
return (!f1 && !f2) ||
|
||||
(f1 && f2 && f1->equals(*f2, tol));
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
|||
// Build the FactorValuePairs DecisionTree
|
||||
pairs = FactorValuePairs(
|
||||
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||
[](const auto& f) {
|
||||
[](const sharedFactor& f) {
|
||||
return std::pair{f,
|
||||
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
||||
});
|
||||
|
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
|||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||
: discreteKeys({discreteKey}) {
|
||||
// Extract continuous keys from the first non-null factor
|
||||
for (const auto& pair : factorPairs) {
|
||||
for (const GaussianFactorValuePair& pair : factorPairs) {
|
||||
if (pair.first && continuousKeys.empty()) {
|
||||
continuousKeys = pair.first->keys();
|
||||
break;
|
||||
|
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
|||
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||
|
||||
// Check the base and the factors:
|
||||
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
|
||||
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||
const GaussianFactorValuePair& pair2) {
|
||||
auto f1 = pair1.first, f2 = pair2.first;
|
||||
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
||||
|
@ -140,7 +141,7 @@ void HybridGaussianFactor::print(const std::string& s,
|
|||
} else {
|
||||
factors_.print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const auto& pair) -> std::string {
|
||||
[&](const GaussianFactorValuePair& pair) -> std::string {
|
||||
RedirectCout rd;
|
||||
std::cout << ":\n";
|
||||
if (pair.first) {
|
||||
|
@ -168,7 +169,8 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
|||
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
||||
// scalar.
|
||||
return {{factors_,
|
||||
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
|
||||
[](const GaussianFactorValuePair& pair)
|
||||
-> std::pair<GaussianFactorGraph, double> {
|
||||
return {GaussianFactorGraph{pair.first}, pair.second};
|
||||
}}};
|
||||
}
|
||||
|
@ -177,7 +179,7 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
|||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||
const VectorValues& continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [&continuousValues](const auto& pair) {
|
||||
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
||||
return pair.first ? pair.first->error(continuousValues) + pair.second
|
||||
: std::numeric_limits<double>::infinity();
|
||||
};
|
||||
|
@ -188,7 +190,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
|||
/* *******************************************************************************/
|
||||
double HybridGaussianFactor::error(const HybridValues& values) const {
|
||||
// Directly index to get the component, no need to build the whole tree.
|
||||
const auto pair = factors_(values.discrete());
|
||||
const GaussianFactorValuePair pair = factors_(values.discrete());
|
||||
return pair.first ? pair.first->error(values.continuous()) + pair.second
|
||||
: std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
|
|
@ -57,7 +57,8 @@ using std::dynamic_pointer_cast;
|
|||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||
using Result =
|
||||
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
||||
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
|
||||
using ResultValuePair = std::pair<Result, double>;
|
||||
using ResultTree = DecisionTree<Key, ResultValuePair>;
|
||||
|
||||
static const VectorValues kEmpty;
|
||||
|
||||
|
@ -305,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
const ResultTree &eliminationResults,
|
||||
const DiscreteKeys &discreteSeparator) {
|
||||
// Correct for the normalization constant used up by the conditional
|
||||
auto correct = [&](const auto &pair) -> GaussianFactorValuePair {
|
||||
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
|
||||
const auto &[conditional, factor] = pair.first;
|
||||
const double scalar = pair.second;
|
||||
if (conditional && factor) {
|
||||
|
@ -384,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
|
||||
// Create the HybridGaussianConditional from the conditionals
|
||||
HybridGaussianConditional::Conditionals conditionals(
|
||||
eliminationResults, [](const auto &pair) { return pair.first.first; });
|
||||
eliminationResults,
|
||||
[](const ResultValuePair &pair) { return pair.first.first; });
|
||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
||||
discreteSeparator, conditionals);
|
||||
|
||||
|
|
Loading…
Reference in New Issue