Added more mockups and color output of the elimination process
parent
ee4f9d19f0
commit
5ea614a82a
|
|
@ -18,17 +18,20 @@
|
|||
|
||||
#include <gtsam/hybrid/CLGaussianConditional.h>
|
||||
|
||||
#include <gtsam/base/utilities.h>
|
||||
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
CLGaussianConditional::CLGaussianConditional(const KeyVector &continuousFrontals,
|
||||
const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const CLGaussianConditional::Conditionals &factors)
|
||||
const CLGaussianConditional::Conditionals &conditionals)
|
||||
: BaseFactor(
|
||||
CollectKeys(continuousFrontals, continuousParents), discreteParents),
|
||||
BaseConditional(continuousFrontals.size()) {
|
||||
BaseConditional(continuousFrontals.size()), conditionals_(conditionals) {
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -47,5 +50,15 @@ void CLGaussianConditional::print(const std::string &s, const KeyFormatter &form
|
|||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
conditionals_.print(
|
||||
"",
|
||||
[&](Key k) {
|
||||
return formatter(k);
|
||||
}, [&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
if (!gf->empty()) gf->print("", formatter);
|
||||
else return {"nullptr"};
|
||||
return rd.str();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -32,12 +32,14 @@ public:
|
|||
|
||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
Conditionals conditionals_;
|
||||
|
||||
public:
|
||||
|
||||
CLGaussianConditional(const KeyVector &continuousFrontals,
|
||||
const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const Conditionals &factors);
|
||||
const Conditionals &conditionals);
|
||||
|
||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -41,56 +41,59 @@ class GTSAM_EXPORT HybridConditional
|
|||
: public HybridFactor,
|
||||
public Conditional<HybridFactor, HybridConditional> {
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef HybridConditional This; ///< Typedef to this class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This>
|
||||
BaseConditional; ///< Typedef to our conditional base class
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef HybridConditional This; ///< Typedef to this class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This>
|
||||
BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
private:
|
||||
// Type-erased pointer to the inner type
|
||||
std::unique_ptr<Factor> inner;
|
||||
// Type-erased pointer to the inner type
|
||||
std::unique_ptr<Factor> inner;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
HybridConditional() = default;
|
||||
/// Default constructor needed for serialization.
|
||||
HybridConditional() = default;
|
||||
|
||||
HybridConditional(size_t nFrontals, const KeyVector& keys) : BaseFactor(keys), BaseConditional(nFrontals) {
|
||||
HybridConditional(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
size_t nFrontals)
|
||||
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||
* of the frontal keys, ordered by gtsam::Key.
|
||||
*
|
||||
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||
* the frontal variables cannot overlap, and must be acyclic:
|
||||
* Example of correct use:
|
||||
* P(A,B) = P(A|B) * P(B)
|
||||
* P(A,B|C) = P(A|B) * P(B|C)
|
||||
* P(A,B,C) = P(A,B|C) * P(C)
|
||||
* Example of incorrect use:
|
||||
* P(A|B) * P(A|C) = ?
|
||||
* P(A|B) * P(B|A) = ?
|
||||
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||
*/
|
||||
HybridConditional operator*(const HybridConditional& other) const;
|
||||
/**
|
||||
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||
* of the frontal keys, ordered by gtsam::Key.
|
||||
*
|
||||
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||
* the frontal variables cannot overlap, and must be acyclic:
|
||||
* Example of correct use:
|
||||
* P(A,B) = P(A|B) * P(B)
|
||||
* P(A,B|C) = P(A|B) * P(B|C)
|
||||
* P(A,B,C) = P(A,B|C) * P(C)
|
||||
* Example of incorrect use:
|
||||
* P(A|B) * P(A|C) = ?
|
||||
* P(A|B) * P(B|A) = ?
|
||||
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||
*/
|
||||
HybridConditional operator*(const HybridConditional& other) const;
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Hybrid Conditional: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Hybrid Conditional: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// GTSAM-style equals
|
||||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||
/// GTSAM-style equals
|
||||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -36,6 +36,13 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
|
|||
return allKeys;
|
||||
}
|
||||
|
||||
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2) {
|
||||
DiscreteKeys allKeys;
|
||||
std::copy(key1.begin(), key1.end(), std::back_inserter(allKeys));
|
||||
std::copy(key2.begin(), key2.end(), std::back_inserter(allKeys));
|
||||
return allKeys;
|
||||
}
|
||||
|
||||
HybridFactor::HybridFactor() = default;
|
||||
|
||||
HybridFactor::HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ namespace gtsam {
|
|||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2);
|
||||
|
||||
/**
|
||||
* Base class for hybrid probabilistic factors
|
||||
|
|
|
|||
|
|
@ -5,17 +5,26 @@
|
|||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridJunctionTree.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
template
|
||||
class EliminateableFactorGraph<HybridFactorGraph>;
|
||||
|
||||
static std::string BLACK_BOLD = "\033[1;30m";
|
||||
static std::string RED_BOLD = "\033[1;31m";
|
||||
static std::string GREEN = "\033[0;32m";
|
||||
static std::string GREEN_BOLD = "\033[1;32m";
|
||||
static std::string RESET = "\033[0m";
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
||||
EliminateHybrid(const HybridFactorGraph &factors,
|
||||
|
|
@ -33,26 +42,66 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
// so that the discrete parts will be guaranteed to be eliminated last!
|
||||
|
||||
// PREPROCESS: Identify the nature of the current elimination
|
||||
KeySet allKeys;
|
||||
std::unordered_map<Key, DiscreteKey> discreteCardinalities;
|
||||
std::set<DiscreteKey> discreteSeparator;
|
||||
std::set<DiscreteKey> discreteFrontals;
|
||||
|
||||
KeySet separatorKeys;
|
||||
KeySet allContinuousKeys;
|
||||
KeySet continuousFrontals;
|
||||
KeySet continuousSeparator;
|
||||
|
||||
// TODO: we do a mock by just doing the correct key thing
|
||||
std::cout << "Begin Eliminate: ";
|
||||
|
||||
// This initializes separatorKeys and discreteCardinalities
|
||||
for (auto &&factor : factors) {
|
||||
std::cout << ">>> Adding factor: " << GREEN;
|
||||
factor->print();
|
||||
std::cout << RESET;
|
||||
separatorKeys.insert(factor->begin(), factor->end());
|
||||
if (!factor->isContinuous_) {
|
||||
for (auto &k : factor->discreteKeys_) {
|
||||
discreteCardinalities[k.first] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove frontals from separator
|
||||
for (auto &k : frontalKeys) {
|
||||
separatorKeys.erase(k);
|
||||
}
|
||||
|
||||
// Fill in discrete frontals and continuous frontals for the end result
|
||||
for (auto &k : frontalKeys) {
|
||||
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
|
||||
discreteFrontals.insert(discreteCardinalities.at(k));
|
||||
} else {
|
||||
continuousFrontals.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
// Fill in discrete frontals and continuous frontals for the end result
|
||||
for (auto &k : separatorKeys) {
|
||||
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
|
||||
discreteSeparator.insert(discreteCardinalities.at(k));
|
||||
} else {
|
||||
continuousSeparator.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << RED_BOLD << "Begin Eliminate: " << RESET;
|
||||
frontalKeys.print();
|
||||
|
||||
for (auto &&factor : factors) {
|
||||
std::cout << ">>> Adding factor: ";
|
||||
factor->print();
|
||||
allKeys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
|
||||
for (auto &k : frontalKeys) {
|
||||
allKeys.erase(k);
|
||||
}
|
||||
|
||||
std::cout << RED_BOLD << "Discrete Keys: " << RESET;
|
||||
for (auto &&key : discreteCardinalities)
|
||||
std::cout << boost::format(" (%1%,%2%),") % DefaultKeyFormatter(key.second.first) % key.second.second;
|
||||
std::cout << "\n" << RESET;
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
|
||||
HybridConditional sum(allKeys.size(), Ordering(allKeys));
|
||||
HybridConditional sum(KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||
DiscreteKeys(discreteSeparator.begin(), discreteSeparator.end()),
|
||||
separatorKeys.size());
|
||||
|
||||
// HybridDiscreteFactor product(DiscreteConditional());
|
||||
// for (auto&& factor : factors) product = (*factor) * product;
|
||||
|
|
@ -76,10 +125,22 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
auto conditional = boost::make_shared<HybridConditional>(
|
||||
CollectKeys({continuousFrontals.begin(), continuousFrontals.end()},
|
||||
{continuousSeparator.begin(), continuousSeparator.end()}),
|
||||
CollectDiscreteKeys({discreteFrontals.begin(), discreteFrontals.end()},
|
||||
{discreteSeparator.begin(), discreteSeparator.end()}),
|
||||
continuousFrontals.size() + discreteFrontals.size());
|
||||
std::cout << GREEN_BOLD << "[Conditional]\n" << RESET;
|
||||
conditional->print();
|
||||
std::cout << GREEN_BOLD << "[Separator]\n" << RESET;
|
||||
sum.print();
|
||||
std::cout << RED_BOLD << "[End Eliminate]\n" << RESET;
|
||||
|
||||
// return std::make_pair(conditional, sum);
|
||||
return std::make_pair(boost::make_shared<HybridConditional>(frontalKeys.size(),
|
||||
orderedKeys),
|
||||
boost::make_shared<HybridConditional>(std::move(sum)));
|
||||
return std::make_pair(
|
||||
conditional,
|
||||
boost::make_shared<HybridConditional>(std::move(sum)));
|
||||
}
|
||||
|
||||
void HybridFactorGraph::add(JacobianFactor &&factor) {
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ using gtsam::symbol_shorthand::X;
|
|||
using gtsam::symbol_shorthand::C;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(HybridFactorGraph, creation) {
|
||||
TEST(HybridFactorGraph, creation) {
|
||||
HybridConditional test;
|
||||
|
||||
HybridFactorGraph hfg;
|
||||
|
|
@ -52,12 +52,19 @@ TEST_UNSAFE(HybridFactorGraph, creation) {
|
|||
CLGaussianConditional clgc(
|
||||
{X(0)}, {X(1)},
|
||||
DiscreteKeys(DiscreteKey{C(0), 2}),
|
||||
CLGaussianConditional::Conditionals()
|
||||
CLGaussianConditional::Conditionals(
|
||||
C(0),
|
||||
boost::make_shared<GaussianConditional>(
|
||||
X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
||||
boost::make_shared<GaussianConditional>(
|
||||
X(0), Vector3::Ones(),
|
||||
I_3x3, X(1),
|
||||
I_3x3))
|
||||
);
|
||||
GTSAM_PRINT(clgc);
|
||||
}
|
||||
|
||||
TEST_UNSAFE(HybridFactorGraph, eliminate) {
|
||||
TEST_DISABLED(HybridFactorGraph, eliminate) {
|
||||
HybridFactorGraph hfg;
|
||||
|
||||
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
|
||||
|
|
@ -67,7 +74,7 @@ TEST_UNSAFE(HybridFactorGraph, eliminate) {
|
|||
EXPECT_LONGS_EQUAL(result.first->size(), 1);
|
||||
}
|
||||
|
||||
TEST(HybridFactorGraph, eliminateMultifrontal) {
|
||||
TEST_DISABLED(HybridFactorGraph, eliminateMultifrontal) {
|
||||
HybridFactorGraph hfg;
|
||||
|
||||
DiscreteKey x(X(1), 2);
|
||||
|
|
|
|||
Loading…
Reference in New Issue