From dfc91469bceb135cbec6e0f038be45c89dcde576 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 29 Oct 2024 14:45:19 -0400 Subject: [PATCH] discreteFactors method --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 11 +++++++++++ gtsam/hybrid/HybridGaussianFactorGraph.h | 8 ++++++++ gtsam/hybrid/tests/testHybridBayesTree.cpp | 7 +------ gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp | 9 ++------- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ceabe0871..049e6c38d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -580,4 +580,15 @@ GaussianFactorGraph HybridGaussianFactorGraph::choose( return gfg; } +/* ************************************************************************ */ +DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const { + DiscreteFactorGraph dfg; + for (auto &&f : factors_) { + auto discreteFactor = std::dynamic_pointer_cast(f); + assert(discreteFactor); + dfg.push_back(discreteFactor); + } + return dfg; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 048fd2701..c2e50ace8 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -254,6 +254,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph GaussianFactorGraph operator()(const DiscreteValues& assignment) const { return choose(assignment); } + + /** + * @brief Helper method to get all the discrete factors + * as a DiscreteFactorGraph. + * + * @return DiscreteFactorGraph + */ + DiscreteFactorGraph discreteFactors() const; }; // traits diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index db298e6fc..4f5583bf5 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -443,12 +443,7 @@ TEST(HybridBayesTree, Optimize) { const auto [hybridBayesNet, remainingFactorGraph] = s.linearizedFactorGraph.eliminatePartialSequential(ordering); - DiscreteFactorGraph dfg; - for (auto&& f : *remainingFactorGraph) { - auto discreteFactor = dynamic_pointer_cast(f); - assert(discreteFactor); - dfg.push_back(discreteFactor); - } + DiscreteFactorGraph dfg = remainingFactorGraph->discreteFactors(); // Add the probabilities for each branch DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index dd4128034..100eb024a 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -479,13 +479,8 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) { const auto [hybridBayesNet_partial, remainingFactorGraph_partial] = linearizedFactorGraph.eliminatePartialSequential(ordering); - DiscreteFactorGraph discrete_fg; - // TODO(Varun) Make this a function of HybridGaussianFactorGraph? - for (auto &factor : (*remainingFactorGraph_partial)) { - auto df = dynamic_pointer_cast(factor); - assert(df); - discrete_fg.push_back(df); - } + DiscreteFactorGraph discrete_fg = + remainingFactorGraph_partial->discreteFactors(); ordering.clear(); for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k));