Add shorthand for inserting raw JacobianFactor
parent
3aac52c3d3
commit
53551e051d
|
|
@ -5,19 +5,35 @@
|
|||
|
||||
/**
|
||||
* @file CGMixtureFactor.h
|
||||
* @brief Custom hybrid factor graph for discrete + continuous factors
|
||||
* @author Kevin Doherty, kdoherty@mit.edu
|
||||
* @brief A set of Gaussian factors indexed by a set of discrete keys.
|
||||
* @author Varun Agrawal
|
||||
* @author Fan Jiang
|
||||
* @author Frank Dellaert
|
||||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class CGMixtureFactor : HybridFactor {
|
||||
class CGMixtureFactor : public HybridFactor {
|
||||
public:
|
||||
using Base = HybridFactor;
|
||||
using This = CGMixtureFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
||||
|
||||
Factors factors_;
|
||||
|
||||
CGMixtureFactor() = default;
|
||||
|
||||
CGMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors) : Base(continuousKeys, discreteKeys) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
|
|||
}
|
||||
|
||||
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
||||
: Base(dtf.keys()),
|
||||
: Base(dtf.discreteKeys()),
|
||||
inner(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ namespace gtsam {
|
|||
class HybridDiscreteFactor : public HybridFactor {
|
||||
public:
|
||||
using Base = HybridFactor;
|
||||
using This = HybridDiscreteFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
DiscreteFactor::shared_ptr inner;
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
#include <string>
|
||||
|
|
@ -36,6 +37,12 @@ public:
|
|||
typedef boost::shared_ptr<HybridFactor> shared_ptr; ///< shared_ptr to this class
|
||||
typedef Factor Base; ///< Our base class
|
||||
|
||||
bool isDiscrete_ = false;
|
||||
bool isContinuous_ = false;
|
||||
bool isHybrid_ = false;
|
||||
|
||||
DiscreteKeys discreteKeys_;
|
||||
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
|
|
@ -46,8 +53,25 @@ public:
|
|||
|
||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
||||
template<typename CONTAINER>
|
||||
HybridFactor(const CONTAINER &keys) : Base(keys) {}
|
||||
// template<typename CONTAINER>
|
||||
// HybridFactor(const CONTAINER &keys) : Base(keys) {}
|
||||
|
||||
HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
|
||||
|
||||
static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
|
||||
KeyVector allKeys;
|
||||
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
|
||||
std::transform(discreteKeys.begin(),
|
||||
discreteKeys.end(),
|
||||
std::back_inserter(allKeys),
|
||||
[](const DiscreteKey &k) { return k.first; });
|
||||
return allKeys;
|
||||
}
|
||||
|
||||
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(
|
||||
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {}
|
||||
|
||||
HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {}
|
||||
|
||||
/// Virtual destructor
|
||||
virtual ~HybridFactor() {
|
||||
|
|
@ -64,7 +88,11 @@ public:
|
|||
void print(
|
||||
const std::string &s = "HybridFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
std::cout << s;
|
||||
if (isContinuous_) std::cout << "Cont. ";
|
||||
if (isDiscrete_) std::cout << "Disc. ";
|
||||
if (isHybrid_) std::cout << "Hybr. ";
|
||||
this->printKeys("", formatter);
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
|
|
|||
|
|
@ -75,4 +75,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
boost::make_shared<HybridConditional>(std::move(sum)));
|
||||
}
|
||||
|
||||
void HybridFactorGraph::add(JacobianFactor &&factor) {
|
||||
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(std::move(factor)));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ class HybridEliminationTree;
|
|||
class HybridBayesTree;
|
||||
class HybridJunctionTree;
|
||||
|
||||
class JacobianFactor;
|
||||
|
||||
/** Main elimination function for HybridFactorGraph */
|
||||
GTSAM_EXPORT std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
|
||||
EliminateHybrid(const HybridFactorGraph& factors, const Ordering& keys);
|
||||
|
|
@ -77,6 +79,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor>, public Eliminateable
|
|||
template <class DERIVEDFACTOR>
|
||||
HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||
|
||||
using FactorGraph::add;
|
||||
|
||||
/// Add a factor directly using a shared_ptr.
|
||||
void add(JacobianFactor &&factor);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ namespace gtsam {
|
|||
class HybridGaussianFactor : public HybridFactor {
|
||||
public:
|
||||
using Base = HybridFactor;
|
||||
using This = HybridGaussianFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
GaussianFactor::shared_ptr inner;
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@
|
|||
#include <gtsam/linear/JacobianFactor.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
|
@ -34,6 +36,8 @@ using namespace boost::assign;
|
|||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
using gtsam::symbol_shorthand::X;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(HybridFactorGraph, test) {
|
||||
HybridConditional test;
|
||||
|
|
@ -58,12 +62,12 @@ TEST_UNSAFE(HybridFactorGraph, eliminate) {
|
|||
TEST(HybridFactorGraph, eliminateMultifrontal) {
|
||||
HybridFactorGraph hfg;
|
||||
|
||||
DiscreteKey X(1, 2);
|
||||
DiscreteKey x(X(1), 2);
|
||||
|
||||
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
|
||||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(X, {2, 8})));
|
||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8})));
|
||||
|
||||
auto result = hfg.eliminatePartialMultifrontal({0});
|
||||
auto result = hfg.eliminatePartialMultifrontal({X(0)});
|
||||
|
||||
GTSAM_PRINT(*result.first);
|
||||
GTSAM_PRINT(*result.second);
|
||||
|
|
|
|||
Loading…
Reference in New Issue