Merge pull request #1284 from borglab/hybrid/misc
commit
7c84020bbc
|
|
@ -69,6 +69,16 @@ namespace gtsam {
|
||||||
push_back(key);
|
push_back(key);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Print the keys and cardinalities.
|
||||||
|
void print(const std::string& s = "",
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
|
for (auto&& dkey : *this) {
|
||||||
|
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}; // DiscreteKeys
|
}; // DiscreteKeys
|
||||||
|
|
||||||
/// Create a list from two keys
|
/// Create a list from two keys
|
||||||
|
|
|
||||||
|
|
@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
|
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
|
||||||
OptionalOrderingType orderingType) const {
|
|
||||||
KeySet discrete_keys;
|
KeySet discrete_keys;
|
||||||
for (auto &factor : factors_) {
|
for (auto &factor : factors_) {
|
||||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||||
discrete_keys.insert(k.first);
|
discrete_keys.insert(k.first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return discrete_keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
|
||||||
|
KeySet keys;
|
||||||
|
for (auto &factor : factors_) {
|
||||||
|
for (const Key &key : factor->continuousKeys()) {
|
||||||
|
keys.insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||||
|
KeySet discrete_keys = getDiscreteKeys();
|
||||||
|
for (auto &factor : factors_) {
|
||||||
|
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||||
|
discrete_keys.insert(k.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const VariableIndex index(factors_);
|
const VariableIndex index(factors_);
|
||||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
|
|
|
||||||
|
|
@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get all the discrete keys in the factor graph.
|
||||||
|
const KeySet getDiscreteKeys() const;
|
||||||
|
|
||||||
|
/// Get all the continuous keys in the factor graph.
|
||||||
|
const KeySet getContinuousKeys() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
|
* eliminated after the continuous keys.
|
||||||
*
|
*
|
||||||
* @param orderingType
|
|
||||||
* @return const Ordering
|
* @return const Ordering
|
||||||
*/
|
*/
|
||||||
const Ordering getHybridOrdering(
|
const Ordering getHybridOrdering() const;
|
||||||
OptionalOrderingType orderingType = boost::none) const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* HybridValues represents a collection of DiscreteValues and VectorValues. It
|
* HybridValues represents a collection of DiscreteValues and VectorValues.
|
||||||
* is typically used to store the variables of a HybridGaussianFactorGraph.
|
* It is typically used to store the variables of a HybridGaussianFactorGraph.
|
||||||
* Optimizing a HybridGaussianBayesNet returns this class.
|
* Optimizing a HybridGaussianBayesNet returns this class.
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridValues {
|
class GTSAM_EXPORT HybridValues {
|
||||||
|
|
@ -47,10 +47,10 @@ class GTSAM_EXPORT HybridValues {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// Default constructor creates an empty HybridValues.
|
/// Default constructor creates an empty HybridValues.
|
||||||
HybridValues() = default;
|
HybridValues() = default;
|
||||||
|
|
||||||
// Construct from DiscreteValues and VectorValues.
|
/// Construct from DiscreteValues and VectorValues.
|
||||||
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
|
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
|
||||||
: discrete_(dv), continuous_(cv){};
|
: discrete_(dv), continuous_(cv){};
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ class GTSAM_EXPORT HybridValues {
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// print required by Testable for unit testing
|
/// print required by Testable for unit testing
|
||||||
void print(const std::string& s = "HybridValues",
|
void print(const std::string& s = "HybridValues",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
std::cout << s << ": \n";
|
std::cout << s << ": \n";
|
||||||
|
|
@ -67,7 +67,7 @@ class GTSAM_EXPORT HybridValues {
|
||||||
keyFormatter); // print continuous components
|
keyFormatter); // print continuous components
|
||||||
};
|
};
|
||||||
|
|
||||||
// equals required by Testable for unit testing
|
/// equals required by Testable for unit testing
|
||||||
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
||||||
return discrete_.equals(other.discrete_, tol) &&
|
return discrete_.equals(other.discrete_, tol) &&
|
||||||
continuous_.equals(other.continuous_, tol);
|
continuous_.equals(other.continuous_, tol);
|
||||||
|
|
@ -83,13 +83,13 @@ class GTSAM_EXPORT HybridValues {
|
||||||
/// Return the delta update for the continuous vectors
|
/// Return the delta update for the continuous vectors
|
||||||
VectorValues continuous() const { return continuous_; }
|
VectorValues continuous() const { return continuous_; }
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists in DiscreteValue.
|
/// Check whether a variable with key \c j exists in DiscreteValue.
|
||||||
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
|
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists in VectorValue.
|
/// Check whether a variable with key \c j exists in VectorValue.
|
||||||
bool existsVector(Key j) { return continuous_.exists(j); };
|
bool existsVector(Key j) { return continuous_.exists(j); };
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists.
|
/// Check whether a variable with key \c j exists.
|
||||||
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
|
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
|
||||||
|
|
||||||
/** Insert a discrete \c value with key \c j. Replaces the existing value if
|
/** Insert a discrete \c value with key \c j. Replaces the existing value if
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,8 @@ class HybridBayesTree {
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
const HybridBayesTreeClique* operator[](size_t j) const;
|
const HybridBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
gtsam::HybridValues optimize() const;
|
||||||
|
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||||
|
|
||||||
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(
|
HybridBayesTree::shared_ptr result =
|
||||||
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
|
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
|
||||||
|
|
||||||
// The bayes tree should have 3 cliques
|
// The bayes tree should have 3 cliques
|
||||||
EXPECT_LONGS_EQUAL(3, result->size());
|
EXPECT_LONGS_EQUAL(3, result->size());
|
||||||
|
|
@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
|
||||||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
|
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
|
||||||
|
|
||||||
// Get a constrained ordering keeping c1 last
|
// Get a constrained ordering keeping c1 last
|
||||||
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)});
|
auto ordering_full = hfg.getHybridOrdering();
|
||||||
|
|
||||||
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
||||||
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
||||||
|
|
@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
||||||
}
|
}
|
||||||
HybridBayesNet::shared_ptr hbn;
|
HybridBayesNet::shared_ptr hbn;
|
||||||
HybridGaussianFactorGraph::shared_ptr remaining;
|
HybridGaussianFactorGraph::shared_ptr remaining;
|
||||||
std::tie(hbn, remaining) =
|
std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);
|
||||||
hfg->eliminatePartialSequential(ordering_partial);
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(14, hbn->size());
|
EXPECT_LONGS_EQUAL(14, hbn->size());
|
||||||
EXPECT_LONGS_EQUAL(11, remaining->size());
|
EXPECT_LONGS_EQUAL(11, remaining->size());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue