Added more Python examples
parent
d2dc620b1e
commit
7f2fa61fb5
|
@ -19,7 +19,10 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
@ -27,11 +30,8 @@
|
|||
#include <string>
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
#include "gtsam/hybrid/GaussianMixture.h"
|
||||
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include "gtsam/hybrid/GaussianMixture.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -44,6 +44,19 @@ class HybridFactorGraph;
|
|||
* - DiscreteConditional
|
||||
* - GaussianConditional
|
||||
* - GaussianMixture
|
||||
*
|
||||
* The reason why this is important is that `Conditional<T>` is a CRTP class.
|
||||
* CRTP is static polymorphism such that all CRTP classes, while bearing the
|
||||
* same name, are different classes not sharing a vtable. This prevents them
|
||||
* from being contained in any container, and thus it is impossible to
|
||||
* dynamically cast between them. A better option, as illustrated here, is
|
||||
* treating them as an implementation detail - such that the hybrid mechanism
|
||||
* does not know what is inside the HybridConditional. This prevents us from
|
||||
* having diamond inheritances, and neutralized the need to change other
|
||||
* components of GTSAM to make hybrid elimination work.
|
||||
*
|
||||
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon
|
||||
* talk.
|
||||
*/
|
||||
class GTSAM_EXPORT HybridConditional
|
||||
: public HybridFactor,
|
||||
|
@ -76,15 +89,20 @@ class GTSAM_EXPORT HybridConditional
|
|||
const KeyVector& continuousParents,
|
||||
const DiscreteKeys& discreteParents);
|
||||
|
||||
HybridConditional(boost::shared_ptr<GaussianConditional> continuousConditional);
|
||||
HybridConditional(
|
||||
boost::shared_ptr<GaussianConditional> continuousConditional);
|
||||
|
||||
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
|
||||
|
||||
|
||||
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||
|
||||
GaussianMixture::shared_ptr asMixture() {
|
||||
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
|
||||
return boost::static_pointer_cast<GaussianMixture>(inner);
|
||||
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
|
||||
return boost::static_pointer_cast<GaussianMixture>(inner);
|
||||
}
|
||||
|
||||
boost::shared_ptr<Factor> getInner() {
|
||||
return inner;
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
|
|
@ -57,7 +57,7 @@ static std::string GREEN = "\033[0;32m";
|
|||
static std::string GREEN_BOLD = "\033[1;32m";
|
||||
static std::string RESET = "\033[0m";
|
||||
|
||||
static bool DEBUG = false;
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
static GaussianMixtureFactor::Sum &addGaussian(
|
||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||
|
@ -123,7 +123,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
// However this is also the case with iSAM2, so no pressure :)
|
||||
|
||||
// PREPROCESS: Identify the nature of the current elimination
|
||||
std::unordered_map<Key, DiscreteKey> discreteCardinalities;
|
||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
||||
std::set<DiscreteKey> discreteSeparatorSet;
|
||||
std::set<DiscreteKey> discreteFrontals;
|
||||
|
||||
|
@ -137,7 +137,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
frontalKeys.print();
|
||||
}
|
||||
|
||||
// This initializes separatorKeys and discreteCardinalities
|
||||
// This initializes separatorKeys and mapFromKeyToDiscreteKey
|
||||
for (auto &&factor : factors) {
|
||||
if (DEBUG) {
|
||||
std::cout << ">>> Adding factor: " << GREEN;
|
||||
|
@ -147,7 +147,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
separatorKeys.insert(factor->begin(), factor->end());
|
||||
if (!factor->isContinuous_) {
|
||||
for (auto &k : factor->discreteKeys_) {
|
||||
discreteCardinalities[k.first] = k;
|
||||
mapFromKeyToDiscreteKey[k.first] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,8 +159,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
|
||||
// Fill in discrete frontals and continuous frontals for the end result
|
||||
for (auto &k : frontalKeys) {
|
||||
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
|
||||
discreteFrontals.insert(discreteCardinalities.at(k));
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousFrontals.insert(k);
|
||||
allContinuousKeys.insert(k);
|
||||
|
@ -169,8 +169,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
|
||||
// Fill in discrete frontals and continuous frontals for the end result
|
||||
for (auto &k : separatorKeys) {
|
||||
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
|
||||
discreteSeparatorSet.insert(discreteCardinalities.at(k));
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousSeparator.insert(k);
|
||||
allContinuousKeys.insert(k);
|
||||
|
@ -181,8 +181,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
if (DEBUG) {
|
||||
std::cout << RED_BOLD << "Keys: " << RESET;
|
||||
for (auto &f : frontalKeys) {
|
||||
if (discreteCardinalities.find(f) != discreteCardinalities.end()) {
|
||||
auto &key = discreteCardinalities.at(f);
|
||||
if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) {
|
||||
auto &key = mapFromKeyToDiscreteKey.at(f);
|
||||
std::cout << boost::format(" (%1%,%2%),") %
|
||||
DefaultKeyFormatter(key.first) % key.second;
|
||||
} else {
|
||||
|
@ -195,8 +195,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
}
|
||||
|
||||
for (auto &f : separatorKeys) {
|
||||
if (discreteCardinalities.find(f) != discreteCardinalities.end()) {
|
||||
auto &key = discreteCardinalities.at(f);
|
||||
if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) {
|
||||
auto &key = mapFromKeyToDiscreteKey.at(f);
|
||||
std::cout << boost::format(" (%1%,%2%),") %
|
||||
DefaultKeyFormatter(key.first) % key.second;
|
||||
} else {
|
||||
|
@ -209,7 +209,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
// NOTE: We should really defer the product here because of pruning
|
||||
|
||||
// Case 1: we are only dealing with continuous
|
||||
if (discreteCardinalities.empty() && !allContinuousKeys.empty()) {
|
||||
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
||||
if (DEBUG) {
|
||||
std::cout << RED_BOLD << "CONT. ONLY" << RESET << "\n";
|
||||
}
|
||||
|
|
|
@ -59,10 +59,15 @@ struct HybridConstructorTraversalData {
|
|||
myData.myJTNode = boost::make_shared<Node>(node->key, node->factors);
|
||||
parentData.myJTNode->addChild(myData.myJTNode);
|
||||
|
||||
#ifndef NDEBUG
|
||||
std::cout << "Getting discrete info: ";
|
||||
#endif
|
||||
for (HybridFactor::shared_ptr& f : node->factors) {
|
||||
for (auto& k : f->discreteKeys_) {
|
||||
#ifndef NDEBUG
|
||||
std::cout << "DK: " << DefaultKeyFormatter(k.first) << "\n";
|
||||
#endif
|
||||
|
||||
myData.discreteKeys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
@ -99,8 +104,10 @@ struct HybridConstructorTraversalData {
|
|||
boost::tie(myConditional, mySeparatorFactor) =
|
||||
internal::EliminateSymbolic(symbolicFactors, keyAsOrdering);
|
||||
|
||||
#ifndef NDEBUG
|
||||
std::cout << "Symbolic: ";
|
||||
myConditional->print();
|
||||
#endif
|
||||
|
||||
// Store symbolic elimination results in the parent
|
||||
myData.parentData->childSymbolicConditionals.push_back(myConditional);
|
||||
|
@ -129,15 +136,19 @@ struct HybridConstructorTraversalData {
|
|||
myData.discreteKeys.exists(myConditional->frontals()[0]);
|
||||
const bool theirType =
|
||||
myData.discreteKeys.exists(childConditionals[i]->frontals()[0]);
|
||||
#ifndef NDEBUG
|
||||
std::cout << "Type: "
|
||||
<< DefaultKeyFormatter(myConditional->frontals()[0]) << " vs "
|
||||
<< DefaultKeyFormatter(childConditionals[i]->frontals()[0])
|
||||
<< "\n";
|
||||
#endif
|
||||
if (myType == theirType) {
|
||||
// Increment number of frontal variables
|
||||
myNrFrontals += nrFrontals[i];
|
||||
#ifndef NDEBUG
|
||||
std::cout << "Merging ";
|
||||
childConditionals[i]->print();
|
||||
#endif
|
||||
merge[i] = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,17 @@ virtual class HybridFactor {
|
|||
gtsam::KeyVector keys() const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
virtual class HybridConditional {
|
||||
void print(string s = "Hybrid Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
Factor* getInner();
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
class GaussianMixtureFactor : gtsam::HybridFactor {
|
||||
static GaussianMixtureFactor FromFactorList(
|
||||
|
|
|
@ -40,9 +40,17 @@ class TestHybridFactorGraph(GtsamTestCase):
|
|||
hfg.add(jf2)
|
||||
hfg.push_back(gmf)
|
||||
|
||||
hfg.eliminateSequential(
|
||||
hbn = hfg.eliminateSequential(
|
||||
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(
|
||||
hfg, [C(0)])).print()
|
||||
hfg, [C(0)]))
|
||||
|
||||
print("hbn = ", hbn)
|
||||
|
||||
mixture = hbn.at(0).getInner()
|
||||
print(mixture)
|
||||
|
||||
discrete_conditional = hbn.at(hbn.size()-1).getInner()
|
||||
print(discrete_conditional)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue