discreteFactors method
parent
ae95c6e84a
commit
dfc91469bc
|
@ -580,4 +580,15 @@ GaussianFactorGraph HybridGaussianFactorGraph::choose(
|
||||||
return gfg;
|
return gfg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const {
|
||||||
|
DiscreteFactorGraph dfg;
|
||||||
|
for (auto &&f : factors_) {
|
||||||
|
auto discreteFactor = std::dynamic_pointer_cast<DiscreteFactor>(f);
|
||||||
|
assert(discreteFactor);
|
||||||
|
dfg.push_back(discreteFactor);
|
||||||
|
}
|
||||||
|
return dfg;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -254,6 +254,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
|
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
|
||||||
return choose(assignment);
|
return choose(assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Helper method to get all the discrete factors
|
||||||
|
* as a DiscreteFactorGraph.
|
||||||
|
*
|
||||||
|
* @return DiscreteFactorGraph
|
||||||
|
*/
|
||||||
|
DiscreteFactorGraph discreteFactors() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
@ -443,12 +443,7 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
const auto [hybridBayesNet, remainingFactorGraph] =
|
const auto [hybridBayesNet, remainingFactorGraph] =
|
||||||
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg = remainingFactorGraph->discreteFactors();
|
||||||
for (auto&& f : *remainingFactorGraph) {
|
|
||||||
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
|
||||||
assert(discreteFactor);
|
|
||||||
dfg.push_back(discreteFactor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the probabilities for each branch
|
// Add the probabilities for each branch
|
||||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
|
|
@ -479,13 +479,8 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
|
||||||
const auto [hybridBayesNet_partial, remainingFactorGraph_partial] =
|
const auto [hybridBayesNet_partial, remainingFactorGraph_partial] =
|
||||||
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg =
|
||||||
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
|
remainingFactorGraph_partial->discreteFactors();
|
||||||
for (auto &factor : (*remainingFactorGraph_partial)) {
|
|
||||||
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
|
|
||||||
assert(df);
|
|
||||||
discrete_fg.push_back(df);
|
|
||||||
}
|
|
||||||
|
|
||||||
ordering.clear();
|
ordering.clear();
|
||||||
for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k));
|
for (size_t k = 0; k < self.K - 1; k++) ordering.push_back(M(k));
|
||||||
|
|
Loading…
Reference in New Issue