discreteFactors method

release/4.3a0
Varun Agrawal 2024-10-29 14:45:19 -04:00
parent ae95c6e84a
commit dfc91469bc
4 changed files with 22 additions and 13 deletions

View File

@ -580,4 +580,15 @@ GaussianFactorGraph HybridGaussianFactorGraph::choose(
return gfg; return gfg;
} }
/* ************************************************************************ */
DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const {
DiscreteFactorGraph dfg;
for (auto &&f : factors_) {
auto discreteFactor = std::dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
return dfg;
}
} // namespace gtsam } // namespace gtsam

View File

@ -254,6 +254,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const { GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
return choose(assignment); return choose(assignment);
} }
/**
* @brief Helper method to get all the discrete factors
* as a DiscreteFactorGraph.
*
* @return DiscreteFactorGraph
*/
DiscreteFactorGraph discreteFactors() const;
}; };
// traits // traits

View File

@ -443,12 +443,7 @@ TEST(HybridBayesTree, Optimize) {
const auto [hybridBayesNet, remainingFactorGraph] = const auto [hybridBayesNet, remainingFactorGraph] =
s.linearizedFactorGraph.eliminatePartialSequential(ordering); s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg = remainingFactorGraph->discreteFactors();
for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
// Add the probabilities for each branch // Add the probabilities for each branch
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};

View File

@ -479,13 +479,8 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
const auto [hybridBayesNet_partial, remainingFactorGraph_partial] = const auto [hybridBayesNet_partial, remainingFactorGraph_partial] =
linearizedFactorGraph.eliminatePartialSequential(ordering); linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg =
// TODO(Varun) Make this a function of HybridGaussianFactorGraph? remainingFactorGraph_partial->discreteFactors();
for (auto &factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}
ordering.clear(); ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k)); for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k));