Address comments

release/4.3a0
Fan Jiang 2022-03-23 20:36:18 -04:00
parent 1e8aae3f06
commit b4f8eea231
10 changed files with 53 additions and 38 deletions

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */
/**
* @file CGMixtureFactor.cpp
* @file GaussianMixtureFactor.cpp
* @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
@ -18,7 +18,7 @@
* @date Mar 12, 2022
*/
#include <gtsam/hybrid/CGMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h>
@ -27,15 +27,15 @@
namespace gtsam {
CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys,
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors) : Base(continuousKeys, discreteKeys),
factors_(factors) {}
bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const {
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false;
}
void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
factors_.print(
"mixture = ",
@ -49,12 +49,12 @@ void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter)
});
}
const CGMixtureFactor::Factors& CGMixtureFactor::factors() {
const GaussianMixtureFactor::Factors& GaussianMixtureFactor::factors() {
return factors_;
}
/* *******************************************************************************/
CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) const {
GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
@ -66,7 +66,7 @@ CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) con
}
/* *******************************************************************************/
CGMixtureFactor::Sum CGMixtureFactor::wrappedFactors() const {
GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */
/**
* @file CGMixtureFactor.h
* @file GaussianMixtureFactor.h
* @brief A set of Gaussian factors indexed by a set of discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
@ -29,19 +29,19 @@ namespace gtsam {
class GaussianFactorGraph;
class CGMixtureFactor : public HybridFactor {
class GaussianMixtureFactor : public HybridFactor {
public:
using Base = HybridFactor;
using This = CGMixtureFactor;
using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>;
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
Factors factors_;
CGMixtureFactor() = default;
GaussianMixtureFactor() = default;
CGMixtureFactor(const KeyVector &continuousKeys,
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const Factors &factors);
using Sum = DecisionTree<Key, GaussianFactorGraph>;

View File

@ -25,8 +25,8 @@
namespace gtsam {
/**
* A hybrid Bayes net can have discrete conditionals, Gaussian mixtures,
* or pure Gaussian conditionals.
* A hybrid Bayes net is a collection of HybridConditionals, which can have
* discrete conditionals, Gaussian mixtures, or pure Gaussian conditionals.
*/
class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
public:

View File

@ -33,7 +33,9 @@ class HybridConditional;
class VectorValues;
/* ************************************************************************* */
/** A clique in a HybridBayesTree */
/** A clique in a HybridBayesTree
* which is a HybridConditional internally.
*/
class GTSAM_EXPORT HybridBayesTreeClique
: public BayesTreeCliqueBase<HybridBayesTreeClique, HybridFactorGraph> {
public:

View File

@ -23,6 +23,10 @@
namespace gtsam {
/**
* A HybridDiscreteFactor is a wrapper for DiscreteFactor, so we hide the
* implementation of DiscreteFactor, and thus avoiding diamond inheritance.
*/
class HybridDiscreteFactor : public HybridFactor {
public:
using Base = HybridFactor;

View File

@ -34,6 +34,11 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/**
* Base class for hybrid probabilistic factors
* Examples:
* - HybridGaussianFactor
* - HybridDiscreteFactor
* - GaussianMixtureFactor
* - GaussianMixture
*/
class GTSAM_EXPORT HybridFactor : public Factor {
public:

View File

@ -21,7 +21,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/CGMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
@ -56,14 +56,14 @@ static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m";
static CGMixtureFactor::Sum &addGaussian(
CGMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
sum = CGMixtureFactor::Sum(result);
sum = GaussianMixtureFactor::Sum(result);
} else {
auto add = [&factor](const Y &graph) {
@ -307,7 +307,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// continue;
// }
// auto ptr_mf = boost::dynamic_pointer_cast<CGMixtureFactor>(factor);
// auto ptr_mf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor);
// if (ptr_mf) gf.push_back(ptr_mf->factors_(new_assignment));
// auto ptr_gm = boost::dynamic_pointer_cast<GaussianMixture>(factor);
@ -329,13 +329,13 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
std::cout << RED_BOLD << "HYBRID ELIM." << RESET << "\n";
CGMixtureFactor::Sum sum;
GaussianMixtureFactor::Sum sum;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) {
if (f->isHybrid_) {
auto cgmf = boost::dynamic_pointer_cast<CGMixtureFactor>(f);
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
if (cgmf) {
sum = cgmf->addTo(sum);
}
@ -395,7 +395,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
auto pair = unzip(eliminationResults);
const GaussianMixture::Conditionals &conditionals = pair.first;
const CGMixtureFactor::Factors &separatorFactors = pair.second;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>(
@ -429,7 +429,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
} else {
// Create a resulting DCGaussianMixture on the separator.
auto factor = boost::make_shared<CGMixtureFactor>(
auto factor = boost::make_shared<GaussianMixtureFactor>(
frontalKeys, discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}

View File

@ -23,6 +23,10 @@
namespace gtsam {
/**
* A HybridGaussianFactor is a wrapper for GaussianFactor so that we do not have
* a diamond inheritance.
*/
class HybridGaussianFactor : public HybridFactor {
public:
using Base = HybridFactor;

View File

@ -17,7 +17,7 @@
#include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/hybrid/CGMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
@ -108,8 +108,8 @@ TEST(HybridFactorGraph, eliminateFullSequentialSimple) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(CGMixtureFactor({X(0)}, {c1}, dt));
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8})));
hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
@ -137,8 +137,8 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(CGMixtureFactor({X(0)}, {c1}, dt));
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8})));
hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
@ -167,7 +167,7 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalCLG) {
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, {c}, dt));
hfg.add(GaussianMixtureFactor({X(1)}, {c}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8})));
// hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1
// 2 3 4")));
@ -203,13 +203,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) {
C(0), boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(0)}, {{C(0), 2}}, dt));
hfg.add(GaussianMixtureFactor({X(0)}, {{C(0), 2}}, dt));
DecisionTree<Key, GaussianFactor::shared_ptr> dt1(
C(1), boost::make_shared<JacobianFactor>(X(2), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(2), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(2)}, {{C(1), 2}}, dt1));
hfg.add(GaussianMixtureFactor({X(2)}, {{C(1), 2}}, dt1));
}
// hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8})));
@ -224,13 +224,13 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) {
C(3), boost::make_shared<JacobianFactor>(X(3), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(3), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(3)}, {{C(3), 2}}, dt));
hfg.add(GaussianMixtureFactor({X(3)}, {{C(3), 2}}, dt));
DecisionTree<Key, GaussianFactor::shared_ptr> dt1(
C(2), boost::make_shared<JacobianFactor>(X(5), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(5), I_3x3, Vector3::Ones()));
hfg.add(CGMixtureFactor({X(5)}, {{C(2), 2}}, dt1));
hfg.add(GaussianMixtureFactor({X(5)}, {{C(2), 2}}, dt1));
}
auto ordering_full =