add type info

release/4.3a0
Varun Agrawal 2024-10-08 12:02:42 -04:00
parent 8e85b68863
commit f39f678c14
3 changed files with 20 additions and 14 deletions

View File

@ -183,10 +183,12 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(
e->conditionals_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
}
/* *******************************************************************************/

View File

@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
[](const auto& f) {
[](const sharedFactor& f) {
return std::pair{f,
f ? 0.0 : std::numeric_limits<double>::infinity()};
});
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
const std::vector<GaussianFactorValuePair>& factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto& pair : factorPairs) {
for (const GaussianFactorValuePair& pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
if (factors_.empty() ^ e->factors_.empty()) return false;
// Check the base and the factors:
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
auto f1 = pair1.first, f2 = pair2.first;
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return match && gtsam::equal(pair1.second, pair2.second, tol);
@ -140,7 +141,7 @@ void HybridGaussianFactor::print(const std::string& s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const auto& pair) -> std::string {
[&](const GaussianFactorValuePair& pair) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (pair.first) {
@ -168,7 +169,8 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
// scalar.
return {{factors_,
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
[](const GaussianFactorValuePair& pair)
-> std::pair<GaussianFactorGraph, double> {
return {GaussianFactorGraph{pair.first}, pair.second};
}}};
}
@ -177,7 +179,7 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const auto& pair) {
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
return pair.first ? pair.first->error(continuousValues) + pair.second
: std::numeric_limits<double>::infinity();
};
@ -188,7 +190,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues& values) const {
// Directly index to get the component, no need to build the whole tree.
const auto pair = factors_(values.discrete());
const GaussianFactorValuePair pair = factors_(values.discrete());
return pair.first ? pair.first->error(values.continuous()) + pair.second
: std::numeric_limits<double>::infinity();
}

View File

@ -57,7 +57,8 @@ using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
using ResultValuePair = std::pair<Result, double>;
using ResultTree = DecisionTree<Key, ResultValuePair>;
static const VectorValues kEmpty;
@ -305,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional
auto correct = [&](const auto &pair) -> GaussianFactorValuePair {
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first;
const double scalar = pair.second;
if (conditional && factor) {
@ -384,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals(
eliminationResults, [](const auto &pair) { return pair.first.first; });
eliminationResults,
[](const ResultValuePair &pair) { return pair.first.first; });
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
discreteSeparator, conditionals);