Added more Python examples

release/4.3a0
Fan Jiang 2022-03-25 23:28:40 -06:00
parent d2dc620b1e
commit 7f2fa61fb5
5 changed files with 71 additions and 23 deletions

View File

@ -19,7 +19,10 @@
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/GaussianConditional.h>
#include <boost/make_shared.hpp>
#include <boost/shared_ptr.hpp>
@ -27,11 +30,8 @@
#include <string>
#include <typeinfo>
#include <vector>
#include "gtsam/hybrid/GaussianMixture.h"
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/GaussianConditional.h>
#include "gtsam/hybrid/GaussianMixture.h"
namespace gtsam {
@ -44,6 +44,19 @@ class HybridFactorGraph;
* - DiscreteConditional
* - GaussianConditional
* - GaussianMixture
*
* The reason why this is important is that `Conditional<T>` is a CRTP class.
* CRTP is static polymorphism such that all CRTP classes, while bearing the
* same name, are different classes not sharing a vtable. This prevents them
* from being contained in any container, and thus it is impossible to
* dynamically cast between them. A better option, as illustrated here, is
* treating them as an implementation detail - such that the hybrid mechanism
* does not know what is inside the HybridConditional. This prevents us from
* having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work.
*
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon
* talk.
*/
class GTSAM_EXPORT HybridConditional
: public HybridFactor,
@ -76,15 +89,20 @@ class GTSAM_EXPORT HybridConditional
const KeyVector& continuousParents,
const DiscreteKeys& discreteParents);
HybridConditional(boost::shared_ptr<GaussianConditional> continuousConditional);
HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional);
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
GaussianMixture::shared_ptr asMixture() {
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner);
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner);
}
boost::shared_ptr<Factor> getInner() {
return inner;
}
/// @}

View File

@ -57,7 +57,7 @@ static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m";
static bool DEBUG = false;
constexpr bool DEBUG = false;
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
@ -123,7 +123,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination
std::unordered_map<Key, DiscreteKey> discreteCardinalities;
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
std::set<DiscreteKey> discreteSeparatorSet;
std::set<DiscreteKey> discreteFrontals;
@ -137,7 +137,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
frontalKeys.print();
}
// This initializes separatorKeys and discreteCardinalities
// This initializes separatorKeys and mapFromKeyToDiscreteKey
for (auto &&factor : factors) {
if (DEBUG) {
std::cout << ">>> Adding factor: " << GREEN;
@ -147,7 +147,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous_) {
for (auto &k : factor->discreteKeys_) {
discreteCardinalities[k.first] = k;
mapFromKeyToDiscreteKey[k.first] = k;
}
}
}
@ -159,8 +159,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : frontalKeys) {
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
discreteFrontals.insert(discreteCardinalities.at(k));
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousFrontals.insert(k);
allContinuousKeys.insert(k);
@ -169,8 +169,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : separatorKeys) {
if (discreteCardinalities.find(k) != discreteCardinalities.end()) {
discreteSeparatorSet.insert(discreteCardinalities.at(k));
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.insert(k);
allContinuousKeys.insert(k);
@ -181,8 +181,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
if (DEBUG) {
std::cout << RED_BOLD << "Keys: " << RESET;
for (auto &f : frontalKeys) {
if (discreteCardinalities.find(f) != discreteCardinalities.end()) {
auto &key = discreteCardinalities.at(f);
if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) {
auto &key = mapFromKeyToDiscreteKey.at(f);
std::cout << boost::format(" (%1%,%2%),") %
DefaultKeyFormatter(key.first) % key.second;
} else {
@ -195,8 +195,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
}
for (auto &f : separatorKeys) {
if (discreteCardinalities.find(f) != discreteCardinalities.end()) {
auto &key = discreteCardinalities.at(f);
if (mapFromKeyToDiscreteKey.find(f) != mapFromKeyToDiscreteKey.end()) {
auto &key = mapFromKeyToDiscreteKey.at(f);
std::cout << boost::format(" (%1%,%2%),") %
DefaultKeyFormatter(key.first) % key.second;
} else {
@ -209,7 +209,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// NOTE: We should really defer the product here because of pruning
// Case 1: we are only dealing with continuous
if (discreteCardinalities.empty() && !allContinuousKeys.empty()) {
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
if (DEBUG) {
std::cout << RED_BOLD << "CONT. ONLY" << RESET << "\n";
}

View File

@ -59,10 +59,15 @@ struct HybridConstructorTraversalData {
myData.myJTNode = boost::make_shared<Node>(node->key, node->factors);
parentData.myJTNode->addChild(myData.myJTNode);
#ifndef NDEBUG
std::cout << "Getting discrete info: ";
#endif
for (HybridFactor::shared_ptr& f : node->factors) {
for (auto& k : f->discreteKeys_) {
#ifndef NDEBUG
std::cout << "DK: " << DefaultKeyFormatter(k.first) << "\n";
#endif
myData.discreteKeys.insert(k.first);
}
}
@ -99,8 +104,10 @@ struct HybridConstructorTraversalData {
boost::tie(myConditional, mySeparatorFactor) =
internal::EliminateSymbolic(symbolicFactors, keyAsOrdering);
#ifndef NDEBUG
std::cout << "Symbolic: ";
myConditional->print();
#endif
// Store symbolic elimination results in the parent
myData.parentData->childSymbolicConditionals.push_back(myConditional);
@ -129,15 +136,19 @@ struct HybridConstructorTraversalData {
myData.discreteKeys.exists(myConditional->frontals()[0]);
const bool theirType =
myData.discreteKeys.exists(childConditionals[i]->frontals()[0]);
#ifndef NDEBUG
std::cout << "Type: "
<< DefaultKeyFormatter(myConditional->frontals()[0]) << " vs "
<< DefaultKeyFormatter(childConditionals[i]->frontals()[0])
<< "\n";
#endif
if (myType == theirType) {
// Increment number of frontal variables
myNrFrontals += nrFrontals[i];
#ifndef NDEBUG
std::cout << "Merging ";
childConditionals[i]->print();
#endif
merge[i] = true;
}
}

View File

@ -15,6 +15,17 @@ virtual class HybridFactor {
gtsam::KeyVector keys() const;
};
#include <gtsam/hybrid/HybridConditional.h>
virtual class HybridConditional {
void print(string s = "Hybrid Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const;
size_t nrParents() const;
Factor* getInner();
};
#include <gtsam/hybrid/GaussianMixtureFactor.h>
class GaussianMixtureFactor : gtsam::HybridFactor {
static GaussianMixtureFactor FromFactorList(

View File

@ -40,9 +40,17 @@ class TestHybridFactorGraph(GtsamTestCase):
hfg.add(jf2)
hfg.push_back(gmf)
hfg.eliminateSequential(
hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(
hfg, [C(0)])).print()
hfg, [C(0)]))
print("hbn = ", hbn)
mixture = hbn.at(0).getInner()
print(mixture)
discrete_conditional = hbn.at(hbn.size()-1).getInner()
print(discrete_conditional)
if __name__ == "__main__":