Merge pull request #1556 from borglab/hybrid-tablefactor

release/4.3a0
Varun Agrawal 2023-07-17 11:54:20 -04:00 committed by GitHub
commit 016f77ba55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 179 additions and 155 deletions

View File

@ -33,16 +33,13 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) const ADT& potentials)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()), : DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c), AlgebraicDecisionTree<Key>(c) {}
cardinalities_(c.cardinalities_) {}
/* ************************************************************************ */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, bool DecisionTreeFactor::equals(const DiscreteFactor& other,
@ -182,15 +179,12 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const { std::vector<double> DecisionTreeFactor::probabilities() const {
DiscreteKeys result; std::vector<double> probs;
for (auto&& key : keys()) { for (auto&& [key, value] : enumerate()) {
DiscreteKey dkey(key, cardinality(key)); probs.push_back(value);
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
} }
return result; return probs;
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -288,17 +282,15 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table) const vector<double>& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table) const string& table)
: DiscreteFactor(keys.indices()), : DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table), AlgebraicDecisionTree<Key>(keys, table) {}
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
@ -306,11 +298,10 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) { // NOTE(Varun) this is potentially slow due to the cartesian product
size_t nrAssignments = leaf.nrAssignments(); for (auto&& [assignment, prob] : this->enumerate()) {
double prob = leaf.constant(); probabilities.push_back(prob);
probabilities.insert(probabilities.end(), nrAssignments, prob); }
});
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr; typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -119,8 +115,6 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -179,8 +173,8 @@ namespace gtsam {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor. /// Get all the probabilities in order of assignment values
DiscreteKeys discreteKeys() const; std::vector<double> probabilities() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
@ -260,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
} }
#endif #endif
}; };

View File

@ -28,6 +28,18 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const { double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values)); return -std::log((*this)(values));

View File

@ -36,28 +36,35 @@ class HybridValues;
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT DiscreteFactor: public Factor { class GTSAM_EXPORT DiscreteFactor: public Factor {
public:
public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<DiscreteFactor>
typedef Factor Base; ///< Our base class shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
using Values = DiscreteValues; ///< backwards compatibility using Values = DiscreteValues; ///< backwards compatibility
public: protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Default constructor creates empty factor */ /** Default constructor creates empty factor */
DiscreteFactor() {} DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor /**
* constructors, either from a container of keys or from a boost::assign::list_of. */ * Construct from container of keys and map of cardinalities.
template<typename CONTAINER> * This constructor is used internally from derived factor constructors,
DiscreteFactor(const CONTAINER& keys) : Base(keys) {} * either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -77,6 +84,13 @@ public:
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0; virtual double operator()(const DiscreteValues&) const = 0;
@ -124,6 +138,17 @@ public:
const Names& names = {}) const = 0; const Names& names = {}) const = 0;
/// @} /// @}
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
}; };
// DiscreteFactor // DiscreteFactor

View File

@ -13,11 +13,12 @@
* @file TableFactor.cpp * @file TableFactor.cpp
* @brief discrete factor * @brief discrete factor
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials) const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()), : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
cardinalities_(potentials.cardinalities_) {
sparse_table_ = potentials.sparse_table_; sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_; denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys(); sorted_dkeys_ = discreteKeys();
@ -44,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table) const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table; sparse_table_ = table;
double denom = table.size(); double denom = table.size();
for (const DiscreteKey& dkey : dkeys) { for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second; denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom)); denominators_.insert(std::pair<Key, double>(dkey.first, denom));
} }
@ -56,6 +56,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {
@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result; return result;
} }
/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
// Print out header. // Print out header.
/* ************************************************************************ */ /* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter, string TableFactor::markdown(const KeyFormatter& keyFormatter,

View File

@ -12,7 +12,7 @@
/** /**
* @file TableFactor.h * @file TableFactor.h
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#pragma once #pragma once
@ -32,6 +32,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional;
class HybridValues; class HybridValues;
/** /**
@ -44,8 +45,6 @@ class HybridValues;
*/ */
class GTSAM_EXPORT TableFactor : public DiscreteFactor { class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected: protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities. /// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_; Eigen::SparseVector<double> sparse_table_;
@ -57,10 +56,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/** /**
* @brief Uses lazy cartesian product to find nth entry in the cartesian * @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1) * product of arrays in O(1)
* Example) * Example)
* v0 | v1 | val * v0 | v1 | val
* 0 | 0 | 10 * 0 | 0 | 10
* 0 | 1 | 21 * 0 | 1 | 21
* 1 | 0 | 32 * 1 | 0 | 32
* 1 | 1 | 43 * 1 | 1 | 43
@ -75,13 +74,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey * @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_ * @param i ith key in keys_
* @return DiscreteKey * @return DiscreteKey
* */ */
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
/// Convert probability table given as doubles to SparseVector. /// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table); static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
/// Convert probability table given as string to SparseVector. /// Convert probability table given as string to SparseVector.
@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row) TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {} : TableFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const { TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/** /**
* @brief Prune the decision tree of discrete variables. * @brief Prune the decision tree of discrete variables.
* *

View File

@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors)
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count() << " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl; << " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
} }
} }
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
TableFactor f4(conditional);
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
/* ************************************************************************* */ /* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors // Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) { // NOTE: Enable to run.
TEST_DISABLED(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);

View File

@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param decisionTree The probability decision tree of only discrete keys. * @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) { if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (decisionTree(values) == 0.0) { if (discreteProbs(values) == 0.0) {
// empty aka null pointer // empty aka null pointer
std::shared_ptr<GaussianConditional> null; std::shared_ptr<GaussianConditional> null;
return null; return null;
@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
} else { } else {
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(
gaussianMixtureKeySet.begin(), discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.end(), gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this conditional. // we need this conditional.
if (decisionTree(augmented_values) > 0.0) { if (discreteProbs(augmented_values) > 0.0) {
return conditional; return conditional;
} }
} }
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(decisionTree); auto pruner = prunerFunc(discreteProbs);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;

View File

@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
* *
* @param decisionTree The pruned discrete probability decision tree. * @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree); prunerFunc(const DecisionTreeFactor &discreteProbs);
public: public:
/// @name Constructors /// @name Constructors
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `discreteProbs`.
* *
* @param decisionTree A pruned decision tree of discrete keys where the * @param discreteProbs A pruned set of probabilities for the discrete keys.
* leaves are probabilities.
*/ */
void prune(const DecisionTreeFactor &decisionTree); void prune(const DecisionTreeFactor &discreteProbs);
/** /**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while

View File

@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get // The canonical decision tree factor which will get
// the discrete conditionals added to it. // the discrete conditionals added to it.
DecisionTreeFactor dtFactor; DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. // Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete()); DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f; discreteProbsFactor = discreteProbsFactor * f;
} }
} }
return std::make_shared<DecisionTreeFactor>(dtFactor); return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
} }
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param prunedDecisionTree The prob. decision tree of only discrete keys. * @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment. * @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDecisionTree, const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the Gaussian mixture. // and the Gaussian mixture.
std::set<DiscreteKey> decisionTreeKeySet = std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet = std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys()); DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// This corresponds to 0 probability // This corresponds to 0 probability
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the Gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDecisionTree(values) == 0) { if (prunedDiscreteProbs(values) == 0) {
return pruned_prob; return pruned_prob;
} else { } else {
return probability; return probability;
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
// Now we generate the full assignment by enumerating // Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree. // over all keys in the prunedDiscreteProbs.
// First we find the differing keys // First we find the differing keys
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(discreteProbsKeySet.begin(),
conditionalKeySet.begin(), conditionalKeySet.end(), discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys // Now enumerate over all assignments of the differing keys
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this probability. // we need this probability.
if (prunedDecisionTree(augmented_values) > 0.0) { if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability; return probability;
} }
} }
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDecisionTree) { const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys(); KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
// Loop with index since we need it later. // Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
@ -153,18 +154,21 @@ void HybridBayesNet::updateDiscreteConditionals(
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); auto discrete = conditional->asDiscrete();
// Apply prunerFunc to the underlying AlgebraicDecisionTree // Convert pointer from conditional to factor
auto discreteTree = auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree = DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional // Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(), KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end()); discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>( auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete); conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet // Add it back to the BayesNet
this->at(i) = conditional; this->at(i) = conditional;
@ -175,10 +179,16 @@ void HybridBayesNet::updateDiscreteConditionals(
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals(); gttic_(HybridBayesNet_PruneDiscreteConditionals);
const auto decisionTree = discreteConditionals->prune(maxNrLeaves); DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
this->updateDiscreteConditionals(decisionTree); gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
@ -189,13 +199,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree. // Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm); auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(decisionTree); // imperative :-( prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture); prunedBayesNetFragment.push_back(prunedGaussianMixture);
@ -205,6 +216,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional); prunedBayesNetFragment.push_back(conditional);
} }
} }
gttoc_(HybridBayesNet_PruneMixtures);
return prunedBayesNetFragment; return prunedBayesNetFragment;
} }

View File

@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Update the discrete conditionals with the pruned versions.
* *
* @param prunedDecisionTree * @param prunedDiscreteProbs
*/ */
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; discreteProbs->root_ = prunedDiscreteProbs.root_;
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree; DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {} : prunedDiscreteProbs(prunedDiscreteProbs) {}
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture(); auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree); gaussianMixture->prune(parentData.prunedDiscreteProbs);
} }
return parentData; return parentData;
} }
}; };
HybridPrunerData rootData(prunedDecisionTree, 0); HybridPrunerData rootData(prunedDiscreteProbs, 0);
{ {
treeTraversal::no_op visitorPost; treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP // Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -98,7 +98,7 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete // TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector. // keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree); gttic_(assembleGraphTree);
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
@ -131,7 +131,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
} }
gttoc(assembleGraphTree); gttoc_(assembleGraphTree);
return result; return result;
} }
@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */ /* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor, which doesn't register as null. // otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) { auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull = bool hasNull =
@ -230,26 +231,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
#ifdef HYBRID_TIMING
gttic_(hybrid_eliminate);
#endif
auto result = EliminatePreferCholesky(graph, frontalKeys); auto result = EliminatePreferCholesky(graph, frontalKeys);
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return result; return result;
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate); DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
#ifdef HYBRID_TIMING
tictoc_print_();
#endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
const auto [conditionals, newFactors] = unzip(eliminationResults); const auto [conditionals, newFactors] = unzip(eliminationResults);

View File

@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
public: public:
using Base = HybridFactorGraph; using Base = HybridFactorGraph;
using This = HybridGaussianFactorGraph; ///< this class using This = HybridGaussianFactorGraph; ///< this class
using BaseEliminateable = ///< for elimination
EliminateableFactorGraph<This>; ///< for elimination using BaseEliminateable = EliminateableFactorGraph<This>;
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
using Base::error; // Expose error(const HybridValues&) method.. /// Expose error(const HybridValues&) method.
using Base::error;
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,