add type info
parent
8e85b68863
commit
f39f678c14
|
@ -183,9 +183,11 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return BaseFactor::equals(*e, tol) &&
|
return BaseFactor::equals(*e, tol) &&
|
||||||
conditionals_.equals(
|
conditionals_.equals(e->conditionals_,
|
||||||
e->conditionals_, [tol](const auto &f1, const auto &f2) {
|
[tol](const GaussianConditional::shared_ptr &f1,
|
||||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
const GaussianConditional::shared_ptr &f2) {
|
||||||
|
return (!f1 && !f2) ||
|
||||||
|
(f1 && f2 && f1->equals(*f2, tol));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
// Build the FactorValuePairs DecisionTree
|
// Build the FactorValuePairs DecisionTree
|
||||||
pairs = FactorValuePairs(
|
pairs = FactorValuePairs(
|
||||||
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||||
[](const auto& f) {
|
[](const sharedFactor& f) {
|
||||||
return std::pair{f,
|
return std::pair{f,
|
||||||
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
||||||
});
|
});
|
||||||
|
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
for (const auto& pair : factorPairs) {
|
for (const GaussianFactorValuePair& pair : factorPairs) {
|
||||||
if (pair.first && continuousKeys.empty()) {
|
if (pair.first && continuousKeys.empty()) {
|
||||||
continuousKeys = pair.first->keys();
|
continuousKeys = pair.first->keys();
|
||||||
break;
|
break;
|
||||||
|
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
||||||
if (factors_.empty() ^ e->factors_.empty()) return false;
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
|
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||||
|
const GaussianFactorValuePair& pair2) {
|
||||||
auto f1 = pair1.first, f2 = pair2.first;
|
auto f1 = pair1.first, f2 = pair2.first;
|
||||||
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||||
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
||||||
|
@ -140,7 +141,7 @@ void HybridGaussianFactor::print(const std::string& s,
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const auto& pair) -> std::string {
|
[&](const GaussianFactorValuePair& pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (pair.first) {
|
if (pair.first) {
|
||||||
|
@ -168,7 +169,8 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||||
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
||||||
// scalar.
|
// scalar.
|
||||||
return {{factors_,
|
return {{factors_,
|
||||||
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
|
[](const GaussianFactorValuePair& pair)
|
||||||
|
-> std::pair<GaussianFactorGraph, double> {
|
||||||
return {GaussianFactorGraph{pair.first}, pair.second};
|
return {GaussianFactorGraph{pair.first}, pair.second};
|
||||||
}}};
|
}}};
|
||||||
}
|
}
|
||||||
|
@ -177,7 +179,7 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
const VectorValues& continuousValues) const {
|
const VectorValues& continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [&continuousValues](const auto& pair) {
|
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
||||||
return pair.first ? pair.first->error(continuousValues) + pair.second
|
return pair.first ? pair.first->error(continuousValues) + pair.second
|
||||||
: std::numeric_limits<double>::infinity();
|
: std::numeric_limits<double>::infinity();
|
||||||
};
|
};
|
||||||
|
@ -188,7 +190,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianFactor::error(const HybridValues& values) const {
|
double HybridGaussianFactor::error(const HybridValues& values) const {
|
||||||
// Directly index to get the component, no need to build the whole tree.
|
// Directly index to get the component, no need to build the whole tree.
|
||||||
const auto pair = factors_(values.discrete());
|
const GaussianFactorValuePair pair = factors_(values.discrete());
|
||||||
return pair.first ? pair.first->error(values.continuous()) + pair.second
|
return pair.first ? pair.first->error(values.continuous()) + pair.second
|
||||||
: std::numeric_limits<double>::infinity();
|
: std::numeric_limits<double>::infinity();
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,8 @@ using std::dynamic_pointer_cast;
|
||||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||||
using Result =
|
using Result =
|
||||||
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
||||||
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
|
using ResultValuePair = std::pair<Result, double>;
|
||||||
|
using ResultTree = DecisionTree<Key, ResultValuePair>;
|
||||||
|
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
|
|
||||||
|
@ -305,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
const ResultTree &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const auto &pair) -> GaussianFactorValuePair {
|
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
|
||||||
const auto &[conditional, factor] = pair.first;
|
const auto &[conditional, factor] = pair.first;
|
||||||
const double scalar = pair.second;
|
const double scalar = pair.second;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
|
@ -384,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
|
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional from the conditionals
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::Conditionals conditionals(
|
||||||
eliminationResults, [](const auto &pair) { return pair.first.first; });
|
eliminationResults,
|
||||||
|
[](const ResultValuePair &pair) { return pair.first.first; });
|
||||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
||||||
discreteSeparator, conditionals);
|
discreteSeparator, conditionals);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue