discreteFactors method

release/4.3a0
Varun Agrawal 2024-10-29 14:45:19 -04:00
parent ae95c6e84a
commit dfc91469bc
4 changed files with 22 additions and 13 deletions

View File

@ -580,4 +580,15 @@ GaussianFactorGraph HybridGaussianFactorGraph::choose(
return gfg;
}
/* ************************************************************************ */
DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const {
DiscreteFactorGraph dfg;
for (auto &&f : factors_) {
auto discreteFactor = std::dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
return dfg;
}
} // namespace gtsam

View File

@ -254,6 +254,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
return choose(assignment);
}
/**
* @brief Helper method to get all the discrete factors
* as a DiscreteFactorGraph.
*
* @return DiscreteFactorGraph
*/
DiscreteFactorGraph discreteFactors() const;
};
// traits

View File

@ -443,12 +443,7 @@ TEST(HybridBayesTree, Optimize) {
const auto [hybridBayesNet, remainingFactorGraph] =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
DiscreteFactorGraph dfg = remainingFactorGraph->discreteFactors();
// Add the probabilities for each branch
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};

View File

@ -479,13 +479,8 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
const auto [hybridBayesNet_partial, remainingFactorGraph_partial] =
linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto &factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}
DiscreteFactorGraph discrete_fg =
remainingFactorGraph_partial->discreteFactors();
ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k));