add type info

release/4.3a0
Varun Agrawal 2024-10-08 12:02:42 -04:00
parent 8e85b68863
commit f39f678c14
3 changed files with 20 additions and 14 deletions

View File

@ -183,9 +183,11 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
// Check the base and the factors: // Check the base and the factors:
return BaseFactor::equals(*e, tol) && return BaseFactor::equals(*e, tol) &&
conditionals_.equals( conditionals_.equals(e->conditionals_,
e->conditionals_, [tol](const auto &f1, const auto &f2) { [tol](const GaussianConditional::shared_ptr &f1,
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); const GaussianConditional::shared_ptr &f2) {
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
}); });
} }

View File

@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
// Build the FactorValuePairs DecisionTree // Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs( pairs = FactorValuePairs(
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors), DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
[](const auto& f) { [](const sharedFactor& f) {
return std::pair{f, return std::pair{f,
f ? 0.0 : std::numeric_limits<double>::infinity()}; f ? 0.0 : std::numeric_limits<double>::infinity()};
}); });
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
const std::vector<GaussianFactorValuePair>& factorPairs) const std::vector<GaussianFactorValuePair>& factorPairs)
: discreteKeys({discreteKey}) { : discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor // Extract continuous keys from the first non-null factor
for (const auto& pair : factorPairs) { for (const GaussianFactorValuePair& pair : factorPairs) {
if (pair.first && continuousKeys.empty()) { if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys(); continuousKeys = pair.first->keys();
break; break;
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
if (factors_.empty() ^ e->factors_.empty()) return false; if (factors_.empty() ^ e->factors_.empty()) return false;
// Check the base and the factors: // Check the base and the factors:
auto compareFunc = [tol](const auto& pair1, const auto& pair2) { auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
auto f1 = pair1.first, f2 = pair2.first; auto f1 = pair1.first, f2 = pair2.first;
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return match && gtsam::equal(pair1.second, pair2.second, tol); return match && gtsam::equal(pair1.second, pair2.second, tol);
@ -140,7 +141,7 @@ void HybridGaussianFactor::print(const std::string& s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const auto& pair) -> std::string { [&](const GaussianFactorValuePair& pair) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (pair.first) { if (pair.first) {
@ -168,7 +169,8 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
// - Each leaf converted to a GaussianFactorGraph with just the factor and its // - Each leaf converted to a GaussianFactorGraph with just the factor and its
// scalar. // scalar.
return {{factors_, return {{factors_,
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> { [](const GaussianFactorValuePair& pair)
-> std::pair<GaussianFactorGraph, double> {
return {GaussianFactorGraph{pair.first}, pair.second}; return {GaussianFactorGraph{pair.first}, pair.second};
}}}; }}};
} }
@ -177,7 +179,7 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues& continuousValues) const { const VectorValues& continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const auto& pair) { auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
return pair.first ? pair.first->error(continuousValues) + pair.second return pair.first ? pair.first->error(continuousValues) + pair.second
: std::numeric_limits<double>::infinity(); : std::numeric_limits<double>::infinity();
}; };
@ -188,7 +190,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues& values) const { double HybridGaussianFactor::error(const HybridValues& values) const {
// Directly index to get the component, no need to build the whole tree. // Directly index to get the component, no need to build the whole tree.
const auto pair = factors_(values.discrete()); const GaussianFactorValuePair pair = factors_(values.discrete());
return pair.first ? pair.first->error(values.continuous()) + pair.second return pair.first ? pair.first->error(values.continuous()) + pair.second
: std::numeric_limits<double>::infinity(); : std::numeric_limits<double>::infinity();
} }

View File

@ -57,7 +57,8 @@ using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>; using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result = using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>; std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultTree = DecisionTree<Key, std::pair<Result, double>>; using ResultValuePair = std::pair<Result, double>;
using ResultTree = DecisionTree<Key, ResultValuePair>;
static const VectorValues kEmpty; static const VectorValues kEmpty;
@ -305,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const auto &pair) -> GaussianFactorValuePair { auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first; const auto &[conditional, factor] = pair.first;
const double scalar = pair.second; const double scalar = pair.second;
if (conditional && factor) { if (conditional && factor) {
@ -384,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::Conditionals conditionals(
eliminationResults, [](const auto &pair) { return pair.first.first; }); eliminationResults,
[](const ResultValuePair &pair) { return pair.first.first; });
auto hybridGaussian = std::make_shared<HybridGaussianConditional>( auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
discreteSeparator, conditionals); discreteSeparator, conditionals);