diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 235ffc87f..f7b96f694 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -17,7 +17,6 @@ * @date January, 2023 */ -#include #include namespace gtsam { @@ -26,7 +25,7 @@ namespace gtsam { std::set HybridFactorGraph::discreteKeys() const { std::set keys; for (auto& factor : factors_) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { for (const DiscreteKey& key : p->discreteKeys()) { keys.insert(key); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d2ea3d5ef..fb4b69aaf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -48,8 +48,6 @@ #include #include -// #define HYBRID_TIMING - namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { // TODO(dellaert): in C++20, we can use std::visit. continue; } - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; @@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, DiscreteFactorGraph dfg; for (auto &f : factors) { - if (auto dtf = dynamic_pointer_cast(f)) { - dfg.push_back(dtf); + if (auto df = dynamic_pointer_cast(f)) { + dfg.push_back(df); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. @@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; DecisionTree probabilities(eliminationResults, probability); + + auto dtf = + std::make_shared(discreteSeparator, probabilities); + return { std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; + std::make_shared(discreteSeparator, dtf->probabilities())}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. @@ -433,7 +435,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Add the gaussian factor error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If factor at `idx` is discrete-only, we skip. continue; } else { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 421e69aa0..b3f159150 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -40,6 +40,7 @@ class HybridEliminationTree; class HybridBayesTree; class HybridJunctionTree; class DecisionTreeFactor; +class TableFactor; class JacobianFactor; class HybridValues; diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 6f2898bf1..22d3c7dd2 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -66,7 +66,7 @@ struct HybridConstructorTraversalData { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } - } else if (auto hf = std::dynamic_pointer_cast(f)) { + } else if (auto hf = std::dynamic_pointer_cast(f)) { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } @@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree( Data rootData(0); rootData.junctionTreeNode = std::make_shared(); // Make a dummy node to gather - // the junction tree roots + // the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPost); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 260f534e3..2459e4ec9 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -17,6 +17,7 @@ */ #include +#include #include #include #include @@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( } else if (auto nlf = dynamic_pointer_cast(f)) { const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); linearFG->push_back(gf); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); } else if (auto gmf = dynamic_pointer_cast(f)) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 578f5d605..81b257c32 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) { DiscreteFactorGraph dfg; for (auto&& f : *remainingFactorGraph) { - auto discreteFactor = dynamic_pointer_cast(f); + auto discreteFactor = dynamic_pointer_cast(f); assert(discreteFactor); dfg.push_back(discreteFactor); }