make functional

release/4.3a0
Varun Agrawal 2022-05-27 15:16:28 -04:00
parent 865d10da9c
commit 573448f126
7 changed files with 35 additions and 19 deletions

View File

@ -35,7 +35,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using sharedConditional = boost::shared_ptr<ConditionalType>; using sharedConditional = boost::shared_ptr<ConditionalType>;
/** Construct empty bayes net */ /** Construct empty bayes net */
HybridBayesNet() : Base() {} HybridBayesNet() = default;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -33,5 +33,4 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol); return Base::equals(other, tol);
} }
/* **************************************************************************/
} // namespace gtsam } // namespace gtsam

View File

@ -70,7 +70,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/// @} /// @}
}; };
/* This does special stuff for the hybrid case */ /**
* @brief Class for Hybrid Bayes tree orphan subtrees.
*
* This does special stuff for the hybrid case
*
* @tparam CLIQUE
*/
template <class CLIQUE> template <class CLIQUE>
class BayesTreeOrphanWrapper< class BayesTreeOrphanWrapper<
CLIQUE, typename std::enable_if< CLIQUE, typename std::enable_if<
@ -82,16 +88,22 @@ class BayesTreeOrphanWrapper<
boost::shared_ptr<CliqueType> clique; boost::shared_ptr<CliqueType> clique;
/**
* @brief Construct a new Bayes Tree Orphan Wrapper object.
*
* @param clique Bayes tree clique.
*/
BayesTreeOrphanWrapper(const boost::shared_ptr<CliqueType>& clique) BayesTreeOrphanWrapper(const boost::shared_ptr<CliqueType>& clique)
: clique(clique) { : clique(clique) {
// Store parent keys in our base type factor so that eliminating those // Store parent keys in our base type factor so that eliminating those
// parent keys will pull this subtree into the elimination. // parent keys will pull this subtree into the elimination.
this->keys_.assign(clique->conditional()->beginParents(), this->keys_.assign(clique->conditional()->beginParents(),
clique->conditional()->endParents()); clique->conditional()->endParents());
this->discreteKeys_.assign(clique->conditional()->discreteKeys_.begin(), this->discreteKeys_.assign(clique->conditional()->discreteKeys().begin(),
clique->conditional()->discreteKeys_.end()); clique->conditional()->discreteKeys().end());
} }
/// print utility
void print( void print(
const std::string& s = "", const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override { const KeyFormatter& formatter = DefaultKeyFormatter) const override {

View File

@ -28,12 +28,18 @@ namespace gtsam {
*/ */
class GTSAM_EXPORT HybridEliminationTree class GTSAM_EXPORT HybridEliminationTree
: public EliminationTree<HybridBayesNet, HybridFactorGraph> { : public EliminationTree<HybridBayesNet, HybridFactorGraph> {
private:
friend class ::EliminationTreeTester;
public: public:
typedef EliminationTree<HybridBayesNet, HybridFactorGraph> typedef EliminationTree<HybridBayesNet, HybridFactorGraph>
Base; ///< Base class Base; ///< Base class
typedef HybridEliminationTree This; ///< This class typedef HybridEliminationTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/// @name Constructors
/// @{
/** /**
* Build the elimination tree of a factor graph using pre-computed column * Build the elimination tree of a factor graph using pre-computed column
* structure. * structure.
@ -54,11 +60,10 @@ class GTSAM_EXPORT HybridEliminationTree
HybridEliminationTree(const HybridFactorGraph& factorGraph, HybridEliminationTree(const HybridFactorGraph& factorGraph,
const Ordering& order); const Ordering& order);
/// @}
/** Test whether the tree is equal to another */ /** Test whether the tree is equal to another */
bool equals(const This& other, double tol = 1e-9) const; bool equals(const This& other, double tol = 1e-9) const;
private:
friend class ::EliminationTreeTester;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -150,8 +150,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
std::cout << RESET; std::cout << RESET;
} }
separatorKeys.insert(factor->begin(), factor->end()); separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous_) { if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys_) { for (auto &k : factor->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k; mapFromKeyToDiscreteKey[k.first] = k;
} }
} }
@ -223,9 +223,9 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
for (auto &fp : factors) { for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp); auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
if (ptr) { if (ptr) {
gfg.push_back(ptr->inner); gfg.push_back(ptr->inner());
} else { } else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner; auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) { if (p) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p)); gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else { } else {
@ -251,9 +251,9 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
for (auto &fp : factors) { for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp); auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
if (ptr) { if (ptr) {
dfg.push_back(ptr->inner); dfg.push_back(ptr->inner());
} else { } else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner; auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) { if (p) {
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p)); dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
} else { } else {
@ -288,7 +288,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
std::vector<GaussianFactor::shared_ptr> deferredFactors; std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors) {
if (f->isHybrid_) { if (f->isHybrid()) {
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f); auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
if (cgmf) { if (cgmf) {
sum = cgmf->add(sum); sum = cgmf->add(sum);
@ -299,9 +299,9 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
sum = gm->asMixture()->add(sum); sum = gm->asMixture()->add(sum);
} }
} else if (f->isContinuous_) { } else if (f->isContinuous()) {
deferredFactors.push_back( deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner); boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
} else { } else {
// We need to handle the case where the object is actually an // We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper! // BayesTreeOrphanWrapper!

View File

@ -58,7 +58,7 @@ void HybridISAM::updateInternal(const HybridFactorGraph& newFactors,
KeySet allDiscrete; KeySet allDiscrete;
for (auto& factor : factors) { for (auto& factor : factors) {
for (auto& k : factor->discreteKeys_) { for (auto& k : factor->discreteKeys()) {
allDiscrete.insert(k.first); allDiscrete.insert(k.first);
} }
} }

View File

@ -63,7 +63,7 @@ struct HybridConstructorTraversalData {
std::cout << "Getting discrete info: "; std::cout << "Getting discrete info: ";
#endif #endif
for (HybridFactor::shared_ptr& f : node->factors) { for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys_) { for (auto& k : f->discreteKeys()) {
#ifdef GTSAM_HYBRID_JUNCTIONTREE_DEBUG #ifdef GTSAM_HYBRID_JUNCTIONTREE_DEBUG
std::cout << "DK: " << DefaultKeyFormatter(k.first) << "\n"; std::cout << "DK: " << DefaultKeyFormatter(k.first) << "\n";
#endif #endif