formatting and comments

release/4.3a0
Varun Agrawal 2024-10-23 11:37:55 -04:00
parent 977ac0d762
commit 4c74ec113a
2 changed files with 15 additions and 13 deletions

View File

@ -25,12 +25,12 @@
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include "gtsam/linear/GaussianConditional.h"
namespace gtsam { namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
@ -162,7 +162,7 @@ HybridGaussianConditional::HybridGaussianConditional(
/* *******************************************************************************/ /* *******************************************************************************/
const HybridGaussianConditional::Conditionals const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const { HybridGaussianConditional::conditionals() const {
return Conditionals(factors(), [](const auto& pair) { return Conditionals(factors(), [](auto &&pair) {
return std::dynamic_pointer_cast<GaussianConditional>(pair.first); return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
}); });
} }
@ -170,7 +170,7 @@ HybridGaussianConditional::conditionals() const {
/* *******************************************************************************/ /* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const { size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0; size_t total = 0;
factors().visit([&total](const auto& node) { factors().visit([&total](auto &&node) {
if (node.first) total += 1; if (node.first) total += 1;
}); });
return total; return total;
@ -178,8 +178,8 @@ size_t HybridGaussianConditional::nrComponents() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose( GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues& discreteValues) const { const DiscreteValues &discreteValues) const {
auto& [ptr, _] = factors()(discreteValues); auto &[ptr, _] = factors()(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr); auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional) if (conditional)
@ -196,9 +196,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
if (e == nullptr) return false; if (e == nullptr) return false;
// Factors existence and scalar values are checked in BaseFactor::equals. // Factors existence and scalar values are checked in BaseFactor::equals.
// Here we check additionally that the factors *are* conditionals and are equal. // Here we check additionally that the factors *are* conditionals
auto compareFunc = [tol](const GaussianFactorValuePair& pair1, // and are equal.
const GaussianFactorValuePair& pair2) { auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
const GaussianFactorValuePair &pair2) {
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first), auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first); c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol)); return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
@ -222,7 +223,8 @@ void HybridGaussianConditional::print(const std::string &s,
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianFactorValuePair &pair) -> std::string { [&](const GaussianFactorValuePair &pair) -> std::string {
RedirectCout rd; RedirectCout rd;
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) { if (auto gf =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
gf->print("", formatter); gf->print("", formatter);
return rd.str(); return rd.str();
} else { } else {
@ -323,7 +325,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::logProbability( double HybridGaussianConditional::logProbability(
const HybridValues& values) const { const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete()); auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor)) if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->logProbability(values.continuous()); return conditional->logProbability(values.continuous());
@ -333,7 +335,7 @@ double HybridGaussianConditional::logProbability(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues& values) const { double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete()); auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor)) if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->evaluate(values.continuous()); return conditional->evaluate(values.continuous());

View File

@ -20,6 +20,7 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
@ -48,8 +49,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -367,6 +366,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// any difference in noise models used. // any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor(); HybridGaussianProductFactor productFactor = collectProductFactor();
// Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes