Added more mockups and color output of the elimination process

release/4.3a0
Fan Jiang 2022-03-13 20:09:53 -04:00
parent ee4f9d19f0
commit 5ea614a82a
7 changed files with 151 additions and 57 deletions

View File

@ -18,17 +18,20 @@
#include <gtsam/hybrid/CLGaussianConditional.h>
#include <gtsam/base/utilities.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTree-inl.h>
namespace gtsam {
CLGaussianConditional::CLGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const CLGaussianConditional::Conditionals &factors)
const CLGaussianConditional::Conditionals &conditionals)
: BaseFactor(
CollectKeys(continuousFrontals, continuousParents), discreteParents),
BaseConditional(continuousFrontals.size()) {
BaseConditional(continuousFrontals.size()), conditionals_(conditionals) {
}
@ -47,5 +50,15 @@ void CLGaussianConditional::print(const std::string &s, const KeyFormatter &form
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << "\n";
conditionals_.print(
"",
[&](Key k) {
return formatter(k);
}, [&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty()) gf->print("", formatter);
else return {"nullptr"};
return rd.str();
});
}
}

View File

@ -32,12 +32,14 @@ public:
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
Conditionals conditionals_;
public:
CLGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &factors);
const Conditionals &conditionals);
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

View File

@ -41,56 +41,59 @@ class GTSAM_EXPORT HybridConditional
: public HybridFactor,
public Conditional<HybridFactor, HybridConditional> {
public:
// typedefs needed to play nice with gtsam
typedef HybridConditional This; ///< Typedef to this class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This>
BaseConditional; ///< Typedef to our conditional base class
// typedefs needed to play nice with gtsam
typedef HybridConditional This; ///< Typedef to this class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This>
BaseConditional; ///< Typedef to our conditional base class
private:
// Type-erased pointer to the inner type
std::unique_ptr<Factor> inner;
// Type-erased pointer to the inner type
std::unique_ptr<Factor> inner;
public:
/// @name Standard Constructors
/// @{
/// Default constructor needed for serialization.
HybridConditional() = default;
/// Default constructor needed for serialization.
HybridConditional() = default;
HybridConditional(size_t nFrontals, const KeyVector& keys) : BaseFactor(keys), BaseConditional(nFrontals) {
HybridConditional(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
size_t nFrontals)
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {
}
}
/**
* @brief Combine two conditionals, yielding a new conditional with the union
* of the frontal keys, ordered by gtsam::Key.
*
* The two conditionals must make a valid Bayes net fragment, i.e.,
* the frontal variables cannot overlap, and must be acyclic:
* Example of correct use:
* P(A,B) = P(A|B) * P(B)
* P(A,B|C) = P(A|B) * P(B|C)
* P(A,B,C) = P(A,B|C) * P(C)
* Example of incorrect use:
* P(A|B) * P(A|C) = ?
* P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic.
*/
HybridConditional operator*(const HybridConditional& other) const;
/**
* @brief Combine two conditionals, yielding a new conditional with the union
* of the frontal keys, ordered by gtsam::Key.
*
* The two conditionals must make a valid Bayes net fragment, i.e.,
* the frontal variables cannot overlap, and must be acyclic:
* Example of correct use:
* P(A,B) = P(A|B) * P(B)
* P(A,B|C) = P(A|B) * P(B|C)
* P(A,B,C) = P(A,B|C) * P(C)
* Example of incorrect use:
* P(A|B) * P(A|C) = ?
* P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic.
*/
HybridConditional operator*(const HybridConditional& other) const;
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(
const std::string& s = "Hybrid Conditional: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// GTSAM-style print
void print(
const std::string& s = "Hybrid Conditional: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// GTSAM-style equals
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// GTSAM-style equals
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @}
};

View File

@ -36,6 +36,13 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
return allKeys;
}
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2) {
DiscreteKeys allKeys;
std::copy(key1.begin(), key1.end(), std::back_inserter(allKeys));
std::copy(key2.begin(), key2.end(), std::back_inserter(allKeys));
return allKeys;
}
HybridFactor::HybridFactor() = default;
HybridFactor::HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}

View File

@ -27,6 +27,7 @@ namespace gtsam {
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2);
/**
* Base class for hybrid probabilistic factors

View File

@ -5,17 +5,26 @@
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridJunctionTree.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <unordered_map>
namespace gtsam {
template
class EliminateableFactorGraph<HybridFactorGraph>;
static std::string BLACK_BOLD = "\033[1;30m";
static std::string RED_BOLD = "\033[1;31m";
static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m";
/* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridFactorGraph &factors,
@ -33,26 +42,66 @@ EliminateHybrid(const HybridFactorGraph &factors,
// so that the discrete parts will be guaranteed to be eliminated last!
// PREPROCESS: Identify the nature of the current elimination
KeySet allKeys;
std::unordered_map<Key, DiscreteKey> discreteCardinalities;
std::set<DiscreteKey> discreteSeparator;
std::set<DiscreteKey> discreteFrontals;
KeySet separatorKeys;
KeySet allContinuousKeys;
KeySet continuousFrontals;
KeySet continuousSeparator;
// TODO: we do a mock by just doing the correct key thing
std::cout << "Begin Eliminate: ";
// This initializes separatorKeys and discreteCardinalities
for (auto &&factor : factors) {
std::cout << ">>> Adding factor: " << GREEN;
factor->print();
std::cout << RESET;
separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous_) {
for (auto &k : factor->discreteKeys_) {
discreteCardinalities[k.first] = k;
}
}
}
// remove frontals from separator
for (auto &k : frontalKeys) {
separatorKeys.erase(k);
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : frontalKeys) {
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
discreteFrontals.insert(discreteCardinalities.at(k));
} else {
continuousFrontals.insert(k);
}
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : separatorKeys) {
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
discreteSeparator.insert(discreteCardinalities.at(k));
} else {
continuousSeparator.insert(k);
}
}
std::cout << RED_BOLD << "Begin Eliminate: " << RESET;
frontalKeys.print();
for (auto &&factor : factors) {
std::cout << ">>> Adding factor: ";
factor->print();
allKeys.insert(factor->begin(), factor->end());
}
for (auto &k : frontalKeys) {
allKeys.erase(k);
}
std::cout << RED_BOLD << "Discrete Keys: " << RESET;
for (auto &&key : discreteCardinalities)
std::cout << boost::format(" (%1%,%2%),") % DefaultKeyFormatter(key.second.first) % key.second.second;
std::cout << "\n" << RESET;
// PRODUCT: multiply all factors
gttic(product);
HybridConditional sum(allKeys.size(), Ordering(allKeys));
HybridConditional sum(KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
DiscreteKeys(discreteSeparator.begin(), discreteSeparator.end()),
separatorKeys.size());
// HybridDiscreteFactor product(DiscreteConditional());
// for (auto&& factor : factors) product = (*factor) * product;
@ -76,10 +125,22 @@ EliminateHybrid(const HybridFactorGraph &factors,
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
gttoc(divide);
auto conditional = boost::make_shared<HybridConditional>(
CollectKeys({continuousFrontals.begin(), continuousFrontals.end()},
{continuousSeparator.begin(), continuousSeparator.end()}),
CollectDiscreteKeys({discreteFrontals.begin(), discreteFrontals.end()},
{discreteSeparator.begin(), discreteSeparator.end()}),
continuousFrontals.size() + discreteFrontals.size());
std::cout << GREEN_BOLD << "[Conditional]\n" << RESET;
conditional->print();
std::cout << GREEN_BOLD << "[Separator]\n" << RESET;
sum.print();
std::cout << RED_BOLD << "[End Eliminate]\n" << RESET;
// return std::make_pair(conditional, sum);
return std::make_pair(boost::make_shared<HybridConditional>(frontalKeys.size(),
orderedKeys),
boost::make_shared<HybridConditional>(std::move(sum)));
return std::make_pair(
conditional,
boost::make_shared<HybridConditional>(std::move(sum)));
}
void HybridFactorGraph::add(JacobianFactor &&factor) {

View File

@ -42,7 +42,7 @@ using gtsam::symbol_shorthand::X;
using gtsam::symbol_shorthand::C;
/* ************************************************************************* */
TEST_UNSAFE(HybridFactorGraph, creation) {
TEST(HybridFactorGraph, creation) {
HybridConditional test;
HybridFactorGraph hfg;
@ -52,12 +52,19 @@ TEST_UNSAFE(HybridFactorGraph, creation) {
CLGaussianConditional clgc(
{X(0)}, {X(1)},
DiscreteKeys(DiscreteKey{C(0), 2}),
CLGaussianConditional::Conditionals()
CLGaussianConditional::Conditionals(
C(0),
boost::make_shared<GaussianConditional>(
X(0), Z_3x1, I_3x3, X(1), I_3x3),
boost::make_shared<GaussianConditional>(
X(0), Vector3::Ones(),
I_3x3, X(1),
I_3x3))
);
GTSAM_PRINT(clgc);
}
TEST_UNSAFE(HybridFactorGraph, eliminate) {
TEST_DISABLED(HybridFactorGraph, eliminate) {
HybridFactorGraph hfg;
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
@ -67,7 +74,7 @@ TEST_UNSAFE(HybridFactorGraph, eliminate) {
EXPECT_LONGS_EQUAL(result.first->size(), 1);
}
TEST(HybridFactorGraph, eliminateMultifrontal) {
TEST_DISABLED(HybridFactorGraph, eliminateMultifrontal) {
HybridFactorGraph hfg;
DiscreteKey x(X(1), 2);