Add shorthand for inserting raw JacobianFactor

release/4.3a0
Fan Jiang 2022-03-12 11:51:48 -05:00
parent 3aac52c3d3
commit 53551e051d
8 changed files with 72 additions and 11 deletions

View File

@ -5,19 +5,35 @@
/**
* @file CGMixtureFactor.h
* @brief Custom hybrid factor graph for discrete + continuous factors
* @author Kevin Doherty, kdoherty@mit.edu
* @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
* @date December 2021
*/
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
namespace gtsam {
class CGMixtureFactor : HybridFactor {
class CGMixtureFactor : public HybridFactor {
public:
using Base = HybridFactor;
using This = CGMixtureFactor;
using shared_ptr = boost::shared_ptr<This>;
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
Factors factors_;
CGMixtureFactor() = default;
CGMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors) : Base(continuousKeys, discreteKeys) {
}
};
}

View File

@ -14,7 +14,7 @@ HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
}
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
: Base(dtf.keys()),
: Base(dtf.discreteKeys()),
inner(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {
}

View File

@ -24,6 +24,8 @@ namespace gtsam {
class HybridDiscreteFactor : public HybridFactor {
public:
using Base = HybridFactor;
using This = HybridDiscreteFactor;
using shared_ptr = boost::shared_ptr<This>;
DiscreteFactor::shared_ptr inner;

View File

@ -19,6 +19,7 @@
#include <gtsam/nonlinear/Values.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/base/Testable.h>
#include <string>
@ -36,6 +37,12 @@ public:
typedef boost::shared_ptr<HybridFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
bool isDiscrete_ = false;
bool isContinuous_ = false;
bool isHybrid_ = false;
DiscreteKeys discreteKeys_;
public:
/// @name Standard Constructors
@ -46,8 +53,25 @@ public:
/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
HybridFactor(const CONTAINER &keys) : Base(keys) {}
// template<typename CONTAINER>
// HybridFactor(const CONTAINER &keys) : Base(keys) {}
HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
KeyVector allKeys;
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
std::transform(discreteKeys.begin(),
discreteKeys.end(),
std::back_inserter(allKeys),
[](const DiscreteKey &k) { return k.first; });
return allKeys;
}
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {}
HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {}
/// Virtual destructor
virtual ~HybridFactor() {
@ -64,7 +88,11 @@ public:
void print(
const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
std::cout << s;
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. ";
this->printKeys("", formatter);
}
/// @}

View File

@ -75,4 +75,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
boost::make_shared<HybridConditional>(std::move(sum)));
}
void HybridFactorGraph::add(JacobianFactor &&factor) {
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(std::move(factor)));
}
}

View File

@ -33,6 +33,8 @@ class HybridEliminationTree;
class HybridBayesTree;
class HybridJunctionTree;
class JacobianFactor;
/** Main elimination function for HybridFactorGraph */
GTSAM_EXPORT std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
EliminateHybrid(const HybridFactorGraph& factors, const Ordering& keys);
@ -77,6 +79,10 @@ class HybridFactorGraph : public FactorGraph<HybridFactor>, public Eliminateable
template <class DERIVEDFACTOR>
HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
using FactorGraph::add;
/// Add a factor directly using a shared_ptr.
void add(JacobianFactor &&factor);
};
}

View File

@ -24,6 +24,8 @@ namespace gtsam {
class HybridGaussianFactor : public HybridFactor {
public:
using Base = HybridFactor;
using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>;
GaussianFactor::shared_ptr inner;

View File

@ -26,6 +26,8 @@
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/Symbol.h>
#include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/map.hpp>
@ -34,6 +36,8 @@ using namespace boost::assign;
using namespace std;
using namespace gtsam;
using gtsam::symbol_shorthand::X;
/* ************************************************************************* */
TEST_UNSAFE(HybridFactorGraph, test) {
HybridConditional test;
@ -58,12 +62,12 @@ TEST_UNSAFE(HybridFactorGraph, eliminate) {
TEST(HybridFactorGraph, eliminateMultifrontal) {
HybridFactorGraph hfg;
DiscreteKey X(1, 2);
DiscreteKey x(X(1), 2);
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(X, {2, 8})));
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8})));
auto result = hfg.eliminatePartialMultifrontal({0});
auto result = hfg.eliminatePartialMultifrontal({X(0)});
GTSAM_PRINT(*result.first);
GTSAM_PRINT(*result.second);