Address comments

release/4.3a0
Fan Jiang 2022-03-23 20:36:18 -04:00
parent 1e8aae3f06
commit b4f8eea231
10 changed files with 53 additions and 38 deletions

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file CGMixtureFactor.cpp * @file GaussianMixtureFactor.cpp
* @brief A set of Gaussian factors indexed by a set of discrete keys. * @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -18,7 +18,7 @@
* @date Mar 12, 2022 * @date Mar 12, 2022
*/ */
#include <gtsam/hybrid/CGMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
@ -27,15 +27,15 @@
namespace gtsam { namespace gtsam {
CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors) : Base(continuousKeys, discreteKeys), const Factors &factors) : Base(continuousKeys, discreteKeys),
factors_(factors) {} factors_(factors) {}
bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false; return false;
} }
void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
factors_.print( factors_.print(
"mixture = ", "mixture = ",
@ -49,12 +49,12 @@ void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter)
}); });
} }
const CGMixtureFactor::Factors& CGMixtureFactor::factors() { const GaussianMixtureFactor::Factors& GaussianMixtureFactor::factors() {
return factors_; return factors_;
} }
/* *******************************************************************************/ /* *******************************************************************************/
CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) const { GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1;
@ -66,7 +66,7 @@ CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) con
} }
/* *******************************************************************************/ /* *******************************************************************************/
CGMixtureFactor::Sum CGMixtureFactor::wrappedFactors() const { GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) { auto wrap = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file CGMixtureFactor.h * @file GaussianMixtureFactor.h
* @brief A set of Gaussian factors indexed by a set of discrete keys. * @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -29,19 +29,19 @@ namespace gtsam {
class GaussianFactorGraph; class GaussianFactorGraph;
class CGMixtureFactor : public HybridFactor { class GaussianMixtureFactor : public HybridFactor {
public: public:
using Base = HybridFactor; using Base = HybridFactor;
using This = CGMixtureFactor; using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>; using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
Factors factors_; Factors factors_;
CGMixtureFactor() = default; GaussianMixtureFactor() = default;
CGMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const Factors &factors); const DiscreteKeys &discreteKeys, const Factors &factors);
using Sum = DecisionTree<Key, GaussianFactorGraph>; using Sum = DecisionTree<Key, GaussianFactorGraph>;

View File

@ -25,8 +25,8 @@
namespace gtsam { namespace gtsam {
/** /**
* A hybrid Bayes net can have discrete conditionals, Gaussian mixtures, * A hybrid Bayes net is a collection of HybridConditionals, which can have
* or pure Gaussian conditionals. * discrete conditionals, Gaussian mixtures, or pure Gaussian conditionals.
*/ */
class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
public: public:

View File

@ -33,7 +33,9 @@ class HybridConditional;
class VectorValues; class VectorValues;
/* ************************************************************************* */ /* ************************************************************************* */
/** A clique in a HybridBayesTree */ /** A clique in a HybridBayesTree
* which is a HybridConditional internally.
*/
class GTSAM_EXPORT HybridBayesTreeClique class GTSAM_EXPORT HybridBayesTreeClique
: public BayesTreeCliqueBase<HybridBayesTreeClique, HybridFactorGraph> { : public BayesTreeCliqueBase<HybridBayesTreeClique, HybridFactorGraph> {
public: public:

View File

@ -23,6 +23,10 @@
namespace gtsam { namespace gtsam {
/**
* A HybridDiscreteFactor is a wrapper for DiscreteFactor, so we hide the
* implementation of DiscreteFactor, and thus avoiding diamond inheritance.
*/
class HybridDiscreteFactor : public HybridFactor { class HybridDiscreteFactor : public HybridFactor {
public: public:
using Base = HybridFactor; using Base = HybridFactor;

View File

@ -34,6 +34,11 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/** /**
* Base class for hybrid probabilistic factors * Base class for hybrid probabilistic factors
* Examples:
* - HybridGaussianFactor
* - HybridDiscreteFactor
* - GaussianMixtureFactor
* - GaussianMixture
*/ */
class GTSAM_EXPORT HybridFactor : public Factor { class GTSAM_EXPORT HybridFactor : public Factor {
public: public:

View File

@ -21,7 +21,7 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/CGMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
@ -56,14 +56,14 @@ static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m"; static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m"; static std::string RESET = "\033[0m";
static CGMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
CGMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it. // If the decision tree is not intiialized, then intialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
sum = CGMixtureFactor::Sum(result); sum = GaussianMixtureFactor::Sum(result);
} else { } else {
auto add = [&factor](const Y &graph) { auto add = [&factor](const Y &graph) {
@ -307,7 +307,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// continue; // continue;
// } // }
// auto ptr_mf = boost::dynamic_pointer_cast<CGMixtureFactor>(factor); // auto ptr_mf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor);
// if (ptr_mf) gf.push_back(ptr_mf->factors_(new_assignment)); // if (ptr_mf) gf.push_back(ptr_mf->factors_(new_assignment));
// auto ptr_gm = boost::dynamic_pointer_cast<GaussianMixture>(factor); // auto ptr_gm = boost::dynamic_pointer_cast<GaussianMixture>(factor);
@ -329,13 +329,13 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
std::cout << RED_BOLD << "HYBRID ELIM." << RESET << "\n"; std::cout << RED_BOLD << "HYBRID ELIM." << RESET << "\n";
CGMixtureFactor::Sum sum; GaussianMixtureFactor::Sum sum;
std::vector<GaussianFactor::shared_ptr> deferredFactors; std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors) {
if (f->isHybrid_) { if (f->isHybrid_) {
auto cgmf = boost::dynamic_pointer_cast<CGMixtureFactor>(f); auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
if (cgmf) { if (cgmf) {
sum = cgmf->addTo(sum); sum = cgmf->addTo(sum);
} }
@ -395,7 +395,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const GaussianMixture::Conditionals &conditionals = pair.first; const GaussianMixture::Conditionals &conditionals = pair.first;
const CGMixtureFactor::Factors &separatorFactors = pair.second; const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto conditional = boost::make_shared<GaussianMixture>(
@ -429,7 +429,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
} else { } else {
// Create a resulting DCGaussianMixture on the separator. // Create a resulting DCGaussianMixture on the separator.
auto factor = boost::make_shared<CGMixtureFactor>( auto factor = boost::make_shared<GaussianMixtureFactor>(
frontalKeys, discreteSeparator, separatorFactors); frontalKeys, discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor}; return {boost::make_shared<HybridConditional>(conditional), factor};
} }

View File

@ -23,6 +23,10 @@
namespace gtsam { namespace gtsam {
/**
* A HybridGaussianFactor is a wrapper for GaussianFactor so that we do not have
* a diamond inheritance.
*/
class HybridGaussianFactor : public HybridFactor { class HybridGaussianFactor : public HybridFactor {
public: public:
using Base = HybridFactor; using Base = HybridFactor;

View File

@ -17,7 +17,7 @@
#include <CppUnitLite/Test.h> #include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/hybrid/CGMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
@ -108,8 +108,8 @@ TEST(HybridFactorGraph, eliminateFullSequentialSimple) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1), C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c1}, dt)); hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8})));
hfg.add(HybridDiscreteFactor( hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
@ -137,8 +137,8 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1), C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c1}, dt)); hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8})));
hfg.add(HybridDiscreteFactor( hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
@ -167,7 +167,7 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalCLG) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1), C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c}, dt)); hfg.add(GaussianMixtureFactor({X(1)}, {c}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8})));
// hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1
// 2 3 4"))); // 2 3 4")));
@ -203,13 +203,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) {
C(0), boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1), C(0), boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(0)}, {{C(0), 2}}, dt)); hfg.add(GaussianMixtureFactor({X(0)}, {{C(0), 2}}, dt));
DecisionTree<Key, GaussianFactor::shared_ptr> dt1( DecisionTree<Key, GaussianFactor::shared_ptr> dt1(
C(1), boost::make_shared<JacobianFactor>(X(2), I_3x3, Z_3x1), C(1), boost::make_shared<JacobianFactor>(X(2), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(2), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(2), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(2)}, {{C(1), 2}}, dt1)); hfg.add(GaussianMixtureFactor({X(2)}, {{C(1), 2}}, dt1));
} }
// hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); // hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8})));
@ -224,13 +224,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) {
C(3), boost::make_shared<JacobianFactor>(X(3), I_3x3, Z_3x1), C(3), boost::make_shared<JacobianFactor>(X(3), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(3), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(3), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(3)}, {{C(3), 2}}, dt)); hfg.add(GaussianMixtureFactor({X(3)}, {{C(3), 2}}, dt));
DecisionTree<Key, GaussianFactor::shared_ptr> dt1( DecisionTree<Key, GaussianFactor::shared_ptr> dt1(
C(2), boost::make_shared<JacobianFactor>(X(5), I_3x3, Z_3x1), C(2), boost::make_shared<JacobianFactor>(X(5), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(5), I_3x3, Vector3::Ones())); boost::make_shared<JacobianFactor>(X(5), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(5)}, {{C(2), 2}}, dt1)); hfg.add(GaussianMixtureFactor({X(5)}, {{C(2), 2}}, dt1));
} }
auto ordering_full = auto ordering_full =