Merge pull request #1556 from borglab/hybrid-tablefactor
commit
016f77ba55
|
@ -34,15 +34,12 @@ namespace gtsam {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const ADT& potentials)
|
const ADT& potentials)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}
|
||||||
ADT(potentials),
|
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||||
: DiscreteFactor(c.keys()),
|
: DiscreteFactor(c.keys(), c.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(c),
|
AlgebraicDecisionTree<Key>(c) {}
|
||||||
cardinalities_(c.cardinalities_) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||||
|
@ -182,15 +179,12 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
std::vector<double> DecisionTreeFactor::probabilities() const {
|
||||||
DiscreteKeys result;
|
std::vector<double> probs;
|
||||||
for (auto&& key : keys()) {
|
for (auto&& [key, value] : enumerate()) {
|
||||||
DiscreteKey dkey(key, cardinality(key));
|
probs.push_back(value);
|
||||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
|
||||||
result.push_back(dkey);
|
|
||||||
}
|
}
|
||||||
}
|
return probs;
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -289,16 +283,14 @@ namespace gtsam {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const vector<double>& table)
|
const vector<double>& table)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(keys, table),
|
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const string& table)
|
const string& table)
|
||||||
: DiscreteFactor(keys.indices()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(keys, table),
|
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||||
cardinalities_(keys.cardinalities()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||||
|
@ -306,11 +298,10 @@ namespace gtsam {
|
||||||
|
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Get the probabilities in the decision tree so we can threshold.
|
||||||
std::vector<double> probabilities;
|
std::vector<double> probabilities;
|
||||||
this->visitLeaf([&](const Leaf& leaf) {
|
// NOTE(Varun) this is potentially slow due to the cartesian product
|
||||||
size_t nrAssignments = leaf.nrAssignments();
|
for (auto&& [assignment, prob] : this->enumerate()) {
|
||||||
double prob = leaf.constant();
|
probabilities.push_back(prob);
|
||||||
probabilities.insert(probabilities.end(), nrAssignments, prob);
|
}
|
||||||
});
|
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves
|
// The number of probabilities can be lower than max_leaves
|
||||||
if (probabilities.size() <= N) {
|
if (probabilities.size() <= N) {
|
||||||
|
|
|
@ -50,10 +50,6 @@ namespace gtsam {
|
||||||
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
|
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
protected:
|
|
||||||
std::map<Key, size_t> cardinalities_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -119,8 +115,6 @@ namespace gtsam {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
|
@ -179,8 +173,8 @@ namespace gtsam {
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
/// Get all the probabilities in order of assignment values
|
||||||
DiscreteKeys discreteKeys() const;
|
std::vector<double> probabilities() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
|
@ -260,7 +254,6 @@ namespace gtsam {
|
||||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
@ -28,6 +28,18 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||||
|
DiscreteKeys result;
|
||||||
|
for (auto&& key : keys()) {
|
||||||
|
DiscreteKey dkey(key, cardinality(key));
|
||||||
|
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||||
|
result.push_back(dkey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactor::error(const DiscreteValues& values) const {
|
double DiscreteFactor::error(const DiscreteValues& values) const {
|
||||||
return -std::log((*this)(values));
|
return -std::log((*this)(values));
|
||||||
|
|
|
@ -36,28 +36,35 @@ class HybridValues;
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteFactor: public Factor {
|
class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteFactor This; ///< This class
|
typedef DiscreteFactor This; ///< This class
|
||||||
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
typedef std::shared_ptr<DiscreteFactor>
|
||||||
|
shared_ptr; ///< shared_ptr to this class
|
||||||
typedef Factor Base; ///< Our base class
|
typedef Factor Base; ///< Our base class
|
||||||
|
|
||||||
using Values = DiscreteValues; ///< backwards compatibility
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
public:
|
protected:
|
||||||
|
/// Map of Keys and their cardinalities.
|
||||||
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
|
||||||
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Default constructor creates empty factor */
|
/** Default constructor creates empty factor */
|
||||||
DiscreteFactor() {}
|
DiscreteFactor() {}
|
||||||
|
|
||||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
/**
|
||||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
* Construct from container of keys and map of cardinalities.
|
||||||
|
* This constructor is used internally from derived factor constructors,
|
||||||
|
* either from a container of keys or from a boost::assign::list_of.
|
||||||
|
*/
|
||||||
template <typename CONTAINER>
|
template <typename CONTAINER>
|
||||||
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
|
DiscreteFactor(const CONTAINER& keys,
|
||||||
|
const std::map<Key, size_t> cardinalities = {})
|
||||||
|
: Base(keys), cardinalities_(cardinalities) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -77,6 +84,13 @@ public:
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Return all the discrete keys associated with this factor.
|
||||||
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
|
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
||||||
|
|
||||||
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
||||||
|
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
|
@ -124,6 +138,17 @@ public:
|
||||||
const Names& names = {}) const = 0;
|
const Names& names = {}) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
||||||
|
|
|
@ -13,11 +13,12 @@
|
||||||
* @file TableFactor.cpp
|
* @file TableFactor.cpp
|
||||||
* @brief discrete factor
|
* @brief discrete factor
|
||||||
* @date May 4, 2023
|
* @date May 4, 2023
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim, Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const TableFactor& potentials)
|
const TableFactor& potentials)
|
||||||
: DiscreteFactor(dkeys.indices()),
|
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
|
||||||
cardinalities_(potentials.cardinalities_) {
|
|
||||||
sparse_table_ = potentials.sparse_table_;
|
sparse_table_ = potentials.sparse_table_;
|
||||||
denominators_ = potentials.denominators_;
|
denominators_ = potentials.denominators_;
|
||||||
sorted_dkeys_ = discreteKeys();
|
sorted_dkeys_ = discreteKeys();
|
||||||
|
@ -44,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const Eigen::SparseVector<double>& table)
|
const Eigen::SparseVector<double>& table)
|
||||||
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
|
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
|
||||||
|
sparse_table_(table.size()) {
|
||||||
sparse_table_ = table;
|
sparse_table_ = table;
|
||||||
double denom = table.size();
|
double denom = table.size();
|
||||||
for (const DiscreteKey& dkey : dkeys) {
|
for (const DiscreteKey& dkey : dkeys) {
|
||||||
cardinalities_.insert(dkey);
|
|
||||||
denom /= dkey.second;
|
denom /= dkey.second;
|
||||||
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
|
||||||
}
|
}
|
||||||
|
@ -56,6 +56,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
|
: TableFactor(c.discreteKeys(), c.probabilities()) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
const std::vector<double>& table) {
|
const std::vector<double>& table) {
|
||||||
|
@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
DiscreteKeys TableFactor::discreteKeys() const {
|
|
||||||
DiscreteKeys result;
|
|
||||||
for (auto&& key : keys()) {
|
|
||||||
DiscreteKey dkey(key, cardinality(key));
|
|
||||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
|
||||||
result.push_back(dkey);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print out header.
|
// Print out header.
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
string TableFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
/**
|
/**
|
||||||
* @file TableFactor.h
|
* @file TableFactor.h
|
||||||
* @date May 4, 2023
|
* @date May 4, 2023
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim, Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
@ -32,6 +32,7 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteConditional;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -44,8 +45,6 @@ class HybridValues;
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
protected:
|
protected:
|
||||||
/// Map of Keys and their cardinalities.
|
|
||||||
std::map<Key, size_t> cardinalities_;
|
|
||||||
/// SparseVector of nonzero probabilities.
|
/// SparseVector of nonzero probabilities.
|
||||||
Eigen::SparseVector<double> sparse_table_;
|
Eigen::SparseVector<double> sparse_table_;
|
||||||
|
|
||||||
|
@ -75,7 +74,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
* @brief Return ith key in keys_ as a DiscreteKey
|
* @brief Return ith key in keys_ as a DiscreteKey
|
||||||
* @param i ith key in keys_
|
* @param i ith key in keys_
|
||||||
* @return DiscreteKey
|
* @return DiscreteKey
|
||||||
* */
|
*/
|
||||||
DiscreteKey discreteKey(size_t i) const {
|
DiscreteKey discreteKey(size_t i) const {
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
}
|
}
|
||||||
|
@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
: TableFactor(DiscreteKeys{key}, row) {}
|
: TableFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
|
/** Construct from a DiscreteConditional type */
|
||||||
|
explicit TableFactor(const DiscreteConditional& c);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
TableFactor operator/(const TableFactor& f) const {
|
TableFactor operator/(const TableFactor& f) const {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
|
@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
|
||||||
DiscreteKeys discreteKeys() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
|
|
|
@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
|
|
||||||
// Assert that error = -log(value)
|
// Assert that error = -log(value)
|
||||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||||
|
|
||||||
|
// Construct from DiscreteConditional
|
||||||
|
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
||||||
|
DecisionTreeFactor f4(conditional);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
for (auto&& kv : measured_time) {
|
for (auto&& kv : measured_time) {
|
||||||
cout << "dropout: " << kv.first
|
cout << "dropout: " << kv.first
|
||||||
<< " | TableFactor time: " << kv.second.first.count()
|
<< " | TableFactor time: " << kv.second.first.count()
|
||||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
|
||||||
|
endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
|
||||||
|
|
||||||
// Assert that error = -log(value)
|
// Assert that error = -log(value)
|
||||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||||
|
|
||||||
|
// Construct from DiscreteConditional
|
||||||
|
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
||||||
|
TableFactor f4(conditional);
|
||||||
|
// Manually constructed via inspection and comparison to DecisionTreeFactor
|
||||||
|
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
|
EXPECT(assert_equal(expected, f4));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Benchmark which compares runtime of multiplication of two TableFactors
|
// Benchmark which compares runtime of multiplication of two TableFactors
|
||||||
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
||||||
TEST(TableFactor, benchmark) {
|
// NOTE: Enable to run.
|
||||||
|
TEST_DISABLED(TableFactor, benchmark) {
|
||||||
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
||||||
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
||||||
|
|
||||||
|
|
|
@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
*
|
*
|
||||||
* @param decisionTree The probability decision tree of only discrete keys.
|
* @param discreteProbs The probabilities of only discrete keys.
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
* @return std::function<GaussianConditional::shared_ptr(
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
*/
|
*/
|
||||||
std::function<GaussianConditional::shared_ptr(
|
std::function<GaussianConditional::shared_ptr(
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the gaussian mixture.
|
// and the gaussian mixture.
|
||||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||||
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||||
|
|
||||||
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
|
auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
|
||||||
const Assignment<Key> &choices,
|
const Assignment<Key> &choices,
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
|
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
|
|
||||||
// Case where the gaussian mixture has the same
|
// Case where the gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (gaussianMixtureKeySet == decisionTreeKeySet) {
|
if (gaussianMixtureKeySet == discreteProbsKeySet) {
|
||||||
if (decisionTree(values) == 0.0) {
|
if (discreteProbs(values) == 0.0) {
|
||||||
// empty aka null pointer
|
// empty aka null pointer
|
||||||
std::shared_ptr<GaussianConditional> null;
|
std::shared_ptr<GaussianConditional> null;
|
||||||
return null;
|
return null;
|
||||||
|
@ -259,9 +259,9 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::vector<DiscreteKey> set_diff;
|
std::vector<DiscreteKey> set_diff;
|
||||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
std::set_difference(
|
||||||
gaussianMixtureKeySet.begin(),
|
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||||
gaussianMixtureKeySet.end(),
|
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
|
||||||
std::back_inserter(set_diff));
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
const std::vector<DiscreteValues> assignments =
|
const std::vector<DiscreteValues> assignments =
|
||||||
|
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
// If any one of the sub-branches are non-zero,
|
||||||
// we need this conditional.
|
// we need this conditional.
|
||||||
if (decisionTree(augmented_values) > 0.0) {
|
if (discreteProbs(augmented_values) > 0.0) {
|
||||||
return conditional;
|
return conditional;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
|
||||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||||
// Functional which loops over all assignments and create a set of
|
// Functional which loops over all assignments and create a set of
|
||||||
// GaussianConditionals
|
// GaussianConditionals
|
||||||
auto pruner = prunerFunc(decisionTree);
|
auto pruner = prunerFunc(discreteProbs);
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
conditionals_.root_ = pruned_conditionals.root_;
|
||||||
|
|
|
@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functor.
|
* @brief Helper function to get the pruner functor.
|
||||||
*
|
*
|
||||||
* @param decisionTree The pruned discrete probability decision tree.
|
* @param discreteProbs The pruned discrete probabilities.
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
* @return std::function<GaussianConditional::shared_ptr(
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
*/
|
*/
|
||||||
std::function<GaussianConditional::shared_ptr(
|
std::function<GaussianConditional::shared_ptr(
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
prunerFunc(const DecisionTreeFactor &decisionTree);
|
prunerFunc(const DecisionTreeFactor &discreteProbs);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
* `decisionTree`.
|
* `discreteProbs`.
|
||||||
*
|
*
|
||||||
* @param decisionTree A pruned decision tree of discrete keys where the
|
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||||
* leaves are probabilities.
|
|
||||||
*/
|
*/
|
||||||
void prune(const DecisionTreeFactor &decisionTree);
|
void prune(const DecisionTreeFactor &discreteProbs);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
||||||
|
|
|
@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
AlgebraicDecisionTree<Key> decisionTree;
|
AlgebraicDecisionTree<Key> discreteProbs;
|
||||||
|
|
||||||
// The canonical decision tree factor which will get
|
// The canonical decision tree factor which will get
|
||||||
// the discrete conditionals added to it.
|
// the discrete conditionals added to it.
|
||||||
DecisionTreeFactor dtFactor;
|
DecisionTreeFactor discreteProbsFactor;
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
DecisionTreeFactor f(*conditional->asDiscrete());
|
||||||
dtFactor = dtFactor * f;
|
discreteProbsFactor = discreteProbsFactor * f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_shared<DecisionTreeFactor>(dtFactor);
|
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
*
|
*
|
||||||
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
|
||||||
* @param conditional Conditional to prune. Used to get full assignment.
|
* @param conditional Conditional to prune. Used to get full assignment.
|
||||||
* @return std::function<double(const Assignment<Key> &, double)>
|
* @return std::function<double(const Assignment<Key> &, double)>
|
||||||
*/
|
*/
|
||||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
const DecisionTreeFactor &prunedDecisionTree,
|
const DecisionTreeFactor &prunedDiscreteProbs,
|
||||||
const HybridConditional &conditional) {
|
const HybridConditional &conditional) {
|
||||||
// Get the discrete keys as sets for the decision tree
|
// Get the discrete keys as sets for the decision tree
|
||||||
// and the Gaussian mixture.
|
// and the Gaussian mixture.
|
||||||
std::set<DiscreteKey> decisionTreeKeySet =
|
std::set<DiscreteKey> discreteProbsKeySet =
|
||||||
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
|
||||||
std::set<DiscreteKey> conditionalKeySet =
|
std::set<DiscreteKey> conditionalKeySet =
|
||||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||||
|
|
||||||
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
|
||||||
const Assignment<Key> &choices,
|
const Assignment<Key> &choices,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// This corresponds to 0 probability
|
// This corresponds to 0 probability
|
||||||
|
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
// Case where the Gaussian mixture has the same
|
// Case where the Gaussian mixture has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (conditionalKeySet == decisionTreeKeySet) {
|
if (conditionalKeySet == discreteProbsKeySet) {
|
||||||
if (prunedDecisionTree(values) == 0) {
|
if (prunedDiscreteProbs(values) == 0) {
|
||||||
return pruned_prob;
|
return pruned_prob;
|
||||||
} else {
|
} else {
|
||||||
return probability;
|
return probability;
|
||||||
|
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we generate the full assignment by enumerating
|
// Now we generate the full assignment by enumerating
|
||||||
// over all keys in the prunedDecisionTree.
|
// over all keys in the prunedDiscreteProbs.
|
||||||
// First we find the differing keys
|
// First we find the differing keys
|
||||||
std::vector<DiscreteKey> set_diff;
|
std::vector<DiscreteKey> set_diff;
|
||||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
std::set_difference(discreteProbsKeySet.begin(),
|
||||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
discreteProbsKeySet.end(), conditionalKeySet.begin(),
|
||||||
|
conditionalKeySet.end(),
|
||||||
std::back_inserter(set_diff));
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
// Now enumerate over all assignments of the differing keys
|
// Now enumerate over all assignments of the differing keys
|
||||||
|
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
// If any one of the sub-branches are non-zero,
|
||||||
// we need this probability.
|
// we need this probability.
|
||||||
if (prunedDecisionTree(augmented_values) > 0.0) {
|
if (prunedDiscreteProbs(augmented_values) > 0.0) {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
const DecisionTreeFactor &prunedDecisionTree) {
|
const DecisionTreeFactor &prunedDiscreteProbs) {
|
||||||
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
|
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
|
||||||
|
|
||||||
// Loop with index since we need it later.
|
// Loop with index since we need it later.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
@ -153,18 +154,21 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
auto discrete = conditional->asDiscrete();
|
auto discrete = conditional->asDiscrete();
|
||||||
|
|
||||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
// Convert pointer from conditional to factor
|
||||||
auto discreteTree =
|
auto discreteTree =
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||||
|
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||||
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
|
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
||||||
|
|
||||||
|
gttic_(HybridBayesNet_MakeConditional);
|
||||||
// Create the new (hybrid) conditional
|
// Create the new (hybrid) conditional
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
KeyVector frontals(discrete->frontals().begin(),
|
||||||
discrete->frontals().end());
|
discrete->frontals().end());
|
||||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
||||||
|
gttoc_(HybridBayesNet_MakeConditional);
|
||||||
|
|
||||||
// Add it back to the BayesNet
|
// Add it back to the BayesNet
|
||||||
this->at(i) = conditional;
|
this->at(i) = conditional;
|
||||||
|
@ -175,10 +179,16 @@ void HybridBayesNet::updateDiscreteConditionals(
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
// Get the decision tree of only the discrete keys
|
||||||
auto discreteConditionals = this->discreteConditionals();
|
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
|
DecisionTreeFactor::shared_ptr discreteConditionals =
|
||||||
|
this->discreteConditionals();
|
||||||
|
const DecisionTreeFactor prunedDiscreteProbs =
|
||||||
|
discreteConditionals->prune(maxNrLeaves);
|
||||||
|
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
|
||||||
this->updateDiscreteConditionals(decisionTree);
|
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
||||||
|
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
|
@ -189,13 +199,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
|
|
||||||
HybridBayesNet prunedBayesNetFragment;
|
HybridBayesNet prunedBayesNetFragment;
|
||||||
|
|
||||||
|
gttic_(HybridBayesNet_PruneMixtures);
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per decisionTree.
|
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// Make a copy of the Gaussian mixture and prune it!
|
// Make a copy of the Gaussian mixture and prune it!
|
||||||
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
|
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
|
||||||
prunedGaussianMixture->prune(decisionTree); // imperative :-(
|
prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
||||||
|
@ -205,6 +216,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
prunedBayesNetFragment.push_back(conditional);
|
prunedBayesNetFragment.push_back(conditional);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
gttoc_(HybridBayesNet_PruneMixtures);
|
||||||
|
|
||||||
return prunedBayesNetFragment;
|
return prunedBayesNetFragment;
|
||||||
}
|
}
|
||||||
|
|
|
@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/**
|
/**
|
||||||
* @brief Update the discrete conditionals with the pruned versions.
|
* @brief Update the discrete conditionals with the pruned versions.
|
||||||
*
|
*
|
||||||
* @param prunedDecisionTree
|
* @param prunedDiscreteProbs
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
|
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto decisionTree =
|
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||||
this->roots_.at(0)->conditional()->asDiscrete();
|
|
||||||
|
|
||||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||||
decisionTree->root_ = prunedDecisionTree.root_;
|
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||||
|
|
||||||
/// Helper struct for pruning the hybrid bayes tree.
|
/// Helper struct for pruning the hybrid bayes tree.
|
||||||
struct HybridPrunerData {
|
struct HybridPrunerData {
|
||||||
/// The discrete decision tree after pruning.
|
/// The discrete decision tree after pruning.
|
||||||
DecisionTreeFactor prunedDecisionTree;
|
DecisionTreeFactor prunedDiscreteProbs;
|
||||||
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
|
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
||||||
const HybridBayesTree::sharedNode& parentClique)
|
const HybridBayesTree::sharedNode& parentClique)
|
||||||
: prunedDecisionTree(prunedDecisionTree) {}
|
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A function used during tree traversal that operates on each node
|
* @brief A function used during tree traversal that operates on each node
|
||||||
|
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
auto gaussianMixture = conditional->asMixture();
|
auto gaussianMixture = conditional->asMixture();
|
||||||
|
|
||||||
gaussianMixture->prune(parentData.prunedDecisionTree);
|
gaussianMixture->prune(parentData.prunedDiscreteProbs);
|
||||||
}
|
}
|
||||||
return parentData;
|
return parentData;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
HybridPrunerData rootData(prunedDecisionTree, 0);
|
HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
||||||
{
|
{
|
||||||
treeTraversal::no_op visitorPost;
|
treeTraversal::no_op visitorPost;
|
||||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||||
|
|
|
@ -98,7 +98,7 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
gttic(assembleGraphTree);
|
gttic_(assembleGraphTree);
|
||||||
|
|
||||||
GaussianFactorGraphTree result;
|
GaussianFactorGraphTree result;
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gttoc(assembleGraphTree);
|
gttoc_(assembleGraphTree);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||||
// otherwise create a GFG with a single (null) factor, which doesn't register as null.
|
// otherwise create a GFG with a single (null) factor,
|
||||||
|
// which doesn't register as null.
|
||||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||||
bool hasNull =
|
bool hasNull =
|
||||||
|
@ -230,26 +231,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
|
||||||
gttic_(hybrid_eliminate);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
|
||||||
gttoc_(hybrid_eliminate);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
|
||||||
tictoc_print_();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Separate out decision tree into conditionals and remaining factors.
|
// Separate out decision tree into conditionals and remaining factors.
|
||||||
const auto [conditionals, newFactors] = unzip(eliminationResults);
|
const auto [conditionals, newFactors] = unzip(eliminationResults);
|
||||||
|
|
||||||
|
|
|
@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
public:
|
public:
|
||||||
using Base = HybridFactorGraph;
|
using Base = HybridFactorGraph;
|
||||||
using This = HybridGaussianFactorGraph; ///< this class
|
using This = HybridGaussianFactorGraph; ///< this class
|
||||||
using BaseEliminateable =
|
///< for elimination
|
||||||
EliminateableFactorGraph<This>; ///< for elimination
|
using BaseEliminateable = EliminateableFactorGraph<This>;
|
||||||
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
|
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
using Values = gtsam::Values; ///< backwards compatibility
|
using Values = gtsam::Values; ///< backwards compatibility
|
||||||
|
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
using Base::error; // Expose error(const HybridValues&) method..
|
/// Expose error(const HybridValues&) method.
|
||||||
|
using Base::error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute error for each discrete assignment,
|
* @brief Compute error for each discrete assignment,
|
||||||
|
|
Loading…
Reference in New Issue