Merge pull request #1055 from borglab/feature/hybrid_base
commit
9d71c90aff
|
|
@ -0,0 +1,13 @@
|
|||
#include <gtsam/base/utilities.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
std::string RedirectCout::str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
|
||||
RedirectCout::~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace gtsam {
|
||||
/**
|
||||
* For Python __str__().
|
||||
|
|
@ -12,14 +16,10 @@ struct RedirectCout {
|
|||
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
|
||||
|
||||
/// return the string
|
||||
std::string str() const {
|
||||
return ssBuffer_.str();
|
||||
}
|
||||
std::string str() const;
|
||||
|
||||
/// destructor -- redirect stdout buffer to its original buffer
|
||||
~RedirectCout() {
|
||||
std::cout.rdbuf(coutBuffer_);
|
||||
}
|
||||
~RedirectCout();
|
||||
|
||||
private:
|
||||
std::stringstream ssBuffer_;
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ namespace gtsam {
|
|||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
auto valueFormatter = [](const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
return (boost::format("%4.4g") % v).str();
|
||||
};
|
||||
Base::print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -604,7 +604,7 @@ namespace gtsam {
|
|||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::Convert: Invalid NodePtr");
|
||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||
|
||||
// get new label
|
||||
const M oldLabel = choice->label();
|
||||
|
|
@ -634,6 +634,8 @@ namespace gtsam {
|
|||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||
}
|
||||
};
|
||||
|
|
@ -663,6 +665,8 @@ namespace gtsam {
|
|||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
choices[choice->label()] = i; // Set assignment for label to i
|
||||
(*this)(choice->branches()[i]); // recurse!
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ namespace gtsam {
|
|||
* Y = function range (any algebra), e.g., bool, int, double
|
||||
*/
|
||||
template<typename L, typename Y>
|
||||
class GTSAM_EXPORT DecisionTree {
|
||||
class DecisionTree {
|
||||
|
||||
protected:
|
||||
/// Default method for comparison of two objects of type Y.
|
||||
|
|
@ -340,4 +340,11 @@ namespace gtsam {
|
|||
return f.apply(g, op);
|
||||
}
|
||||
|
||||
/// unzip a DecisionTree if its leaves are `std::pair`
|
||||
template<typename L, typename T1, typename T2>
|
||||
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
|
||||
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ namespace gtsam {
|
|||
for (auto&& key : keys())
|
||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||
cout << " ]" << endl;
|
||||
ADT::print("Potentials:", formatter);
|
||||
ADT::print("", formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -168,6 +168,18 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& key : keys()) {
|
||||
DiscreteKey dkey(key, cardinality(key));
|
||||
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
|
||||
result.push_back(dkey);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
static std::string valueFormatter(const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
|
|
|
|||
|
|
@ -188,6 +188,9 @@ namespace gtsam {
|
|||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -17,12 +17,59 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double logProb = logProbs[i];
|
||||
if ((logProb != std::numeric_limits<double>::infinity()) &&
|
||||
logProb > maxLogProb) {
|
||||
maxLogProb = logProb;
|
||||
}
|
||||
}
|
||||
|
||||
// After computing the max = "Z" of the log probabilities L_i, we compute
|
||||
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
|
||||
double total = 0.0;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double probPrime = exp(logProbs[i] - maxLogProb);
|
||||
total += probPrime;
|
||||
}
|
||||
double logTotal = log(total);
|
||||
|
||||
// Now we compute the (normalized) probability (for each i):
|
||||
// p_i = exp(L_i - Z - log S)
|
||||
double checkNormalization = 0.0;
|
||||
std::vector<double> probs;
|
||||
for (size_t i = 0; i < logProbs.size(); i++) {
|
||||
double prob = exp(logProbs[i] - maxLogProb - logTotal);
|
||||
probs.push_back(prob);
|
||||
checkNormalization += prob;
|
||||
}
|
||||
|
||||
// Numerical tolerance for floating point comparisons
|
||||
double tol = 1e-9;
|
||||
|
||||
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
|
||||
std::string errMsg =
|
||||
std::string("expNormalize failed to normalize probabilities. ") +
|
||||
std::string("Expected normalization constant = 1.0. Got value: ") +
|
||||
std::to_string(checkNormalization) +
|
||||
std::string(
|
||||
"\n This could have resulted from numerical overflow/underflow.");
|
||||
throw std::logic_error(errMsg);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -122,4 +122,24 @@ public:
|
|||
// traits
|
||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Normalize a set of log probabilities.
|
||||
*
|
||||
* Normalizing a set of log probabilities in a numerically stable way is
|
||||
* tricky. To avoid overflow/underflow issues, we compute the largest
|
||||
* (finite) log probability and subtract it from each log probability before
|
||||
* normalizing. This comes from the observation that if:
|
||||
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
||||
* Then,
|
||||
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
||||
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
||||
*
|
||||
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
||||
* of the (unnormalized) log probabilities are either very large or very
|
||||
* small.
|
||||
*/
|
||||
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
||||
|
||||
|
||||
}// namespace gtsam
|
||||
|
|
|
|||
|
|
@ -44,11 +44,25 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
KeySet DiscreteFactorGraph::keys() const {
|
||||
KeySet keys;
|
||||
for(const sharedFactor& factor: *this)
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
for (const sharedFactor& factor : *this) {
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& factor : *this) {
|
||||
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
DiscreteKeys factor_keys = p->discreteKeys();
|
||||
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
|
|
|
|||
|
|
@ -115,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
/** Return the set of variables involved in the factors (set union) */
|
||||
KeySet keys() const;
|
||||
|
||||
/// Return the DiscreteKeys in this factor graph.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* Key type for discrete conditionals
|
||||
* Includes name and cardinality
|
||||
* Key type for discrete variables.
|
||||
* Includes Key and cardinality.
|
||||
*/
|
||||
using DiscreteKey = std::pair<Key,size_t>;
|
||||
|
||||
|
|
@ -45,6 +45,11 @@ namespace gtsam {
|
|||
/// Construct from a key
|
||||
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
|
||||
|
||||
/// Construct from cardinalities.
|
||||
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
|
||||
for (auto&& kv : cardinalities) emplace_back(kv);
|
||||
}
|
||||
|
||||
/// Construct from a vector of keys
|
||||
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
|
||||
std::vector<DiscreteKey>(keys) {
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
|
|||
|
||||
public:
|
||||
|
||||
DiscreteMarginals() {}
|
||||
|
||||
/** Construct a marginals class.
|
||||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -375,6 +375,31 @@ TEST(DecisionTree, labels) {
|
|||
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// Test retrieving all labels.
|
||||
TEST(DecisionTree, unzip) {
|
||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||
using DT1 = DecisionTree<string, int>;
|
||||
using DT2 = DecisionTree<string, string>;
|
||||
|
||||
// Create small two-level tree
|
||||
string A("A"), B("B"), C("C");
|
||||
DTP tree(B,
|
||||
DTP(A, {0, "zero"}, {1, "one"}),
|
||||
DTP(A, {2, "two"}, {1337, "l33t"})
|
||||
);
|
||||
|
||||
DT1 dt1;
|
||||
DT2 dt2;
|
||||
std::tie(dt1, dt2) = unzip(tree);
|
||||
|
||||
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||
|
||||
EXPECT(tree1.equals(dt1));
|
||||
EXPECT(tree2.equals(dt2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
|
|||
|
||||
/// @}
|
||||
|
||||
public:
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,11 @@ class FactorGraph {
|
|||
/** Collection of factors */
|
||||
FastVector<sharedFactor> factors_;
|
||||
|
||||
/// Check exact equality of the factor pointers. Useful for derived ==.
|
||||
bool isEqual(const FactorGraph& other) const {
|
||||
return factors_ == other.factors_;
|
||||
}
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
|
|
@ -290,11 +295,11 @@ class FactorGraph {
|
|||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// print out graph
|
||||
/// Print out graph to std::cout, with optional key formatter.
|
||||
virtual void print(const std::string& s = "FactorGraph",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||
|
||||
/** Check equality */
|
||||
/// Check equality up to tolerance.
|
||||
bool equals(const This& fg, double tol = 1e-9) const;
|
||||
/// @}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@
|
|||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class FACTOR>
|
||||
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
|
||||
template<class FACTORGRAPH>
|
||||
void MetisIndex::augment(const FACTORGRAPH& factors) {
|
||||
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
|
||||
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
|
||||
std::set<Key> keySet;
|
||||
|
|
|
|||
|
|
@ -62,8 +62,8 @@ public:
|
|||
nKeys_(0) {
|
||||
}
|
||||
|
||||
template<class FG>
|
||||
MetisIndex(const FG& factorGraph) :
|
||||
template<class FACTORGRAPH>
|
||||
MetisIndex(const FACTORGRAPH& factorGraph) :
|
||||
nKeys_(0) {
|
||||
augment(factorGraph);
|
||||
}
|
||||
|
|
@ -78,8 +78,8 @@ public:
|
|||
* Augment the variable index with new factors. This can be used when
|
||||
* solving problems incrementally.
|
||||
*/
|
||||
template<class FACTOR>
|
||||
void augment(const FactorGraph<FACTOR>& factors);
|
||||
template<class FACTORGRAPH>
|
||||
void augment(const FACTORGRAPH& factors);
|
||||
|
||||
const std::vector<int32_t>& xadj() const {
|
||||
return xadj_;
|
||||
|
|
|
|||
|
|
@ -99,6 +99,12 @@ namespace gtsam {
|
|||
|
||||
/// @}
|
||||
|
||||
/// Check exact equality.
|
||||
friend bool operator==(const GaussianFactorGraph& lhs,
|
||||
const GaussianFactorGraph& rhs) {
|
||||
return lhs.isEqual(rhs);
|
||||
}
|
||||
|
||||
/** Add a factor by value - makes a copy */
|
||||
void add(const GaussianFactor& factor) { push_back(factor.clone()); }
|
||||
|
||||
|
|
@ -414,7 +420,7 @@ namespace gtsam {
|
|||
*/
|
||||
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
|
||||
|
||||
/****** Linear Algebra Opeations ******/
|
||||
/****** Linear Algebra Operations ******/
|
||||
|
||||
///* matrix-vector operations */
|
||||
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);
|
||||
|
|
|
|||
Loading…
Reference in New Issue