Merge pull request #1055 from borglab/feature/hybrid_base

release/4.3a0
Frank Dellaert 2022-01-22 19:46:16 -05:00 committed by GitHub
commit 9d71c90aff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 189 additions and 24 deletions

13
gtsam/base/utilities.cpp Normal file
View File

@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>
namespace gtsam {
std::string RedirectCout::str() const {
return ssBuffer_.str();
}
RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
}

View File

@ -1,5 +1,9 @@
#pragma once
#include <string>
#include <iostream>
#include <sstream>
namespace gtsam {
/**
* For Python __str__().
@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}
/// return the string
std::string str() const {
return ssBuffer_.str();
}
std::string str() const;
/// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
~RedirectCout();
private:
std::stringstream ssBuffer_;

View File

@ -163,7 +163,7 @@ namespace gtsam {
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
return (boost::format("%4.4g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}

View File

@ -604,7 +604,7 @@ namespace gtsam {
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
"DecisionTree::convertFrom: Invalid NodePtr");
// get new label
const M oldLabel = choice->label();
@ -634,6 +634,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
@ -663,6 +665,8 @@ namespace gtsam {
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!

View File

@ -38,7 +38,7 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double
*/
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
class DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
@ -340,4 +340,11 @@ namespace gtsam {
return f.apply(g, op);
}
/// unzip a DecisionTree if its leaves are `std::pair`
template<typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
}
} // namespace gtsam

View File

@ -72,7 +72,7 @@ namespace gtsam {
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
ADT::print("", formatter);
}
/* ************************************************************************* */
@ -168,6 +168,18 @@ namespace gtsam {
return result;
}
/* ************************************************************************* */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}
/* ************************************************************************* */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();

View File

@ -188,6 +188,9 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// @}
/// @name Wrapper support
/// @{

View File

@ -17,12 +17,59 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <cmath>
#include <sstream>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}
// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);
// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}
// Numerical tolerance for floating point comparisons
double tol = 1e-9;
if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}
} // namespace gtsam

View File

@ -122,4 +122,24 @@ public:
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);
}// namespace gtsam

View File

@ -44,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for(const sharedFactor& factor: *this)
if (factor) keys.insert(factor->begin(), factor->end());
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
return keys;
}
/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}
return result;
}
/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;

View File

@ -115,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;
/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;

View File

@ -28,8 +28,8 @@
namespace gtsam {
/**
* Key type for discrete conditionals
* Includes name and cardinality
* Key type for discrete variables.
* Includes Key and cardinality.
*/
using DiscreteKey = std::pair<Key,size_t>;
@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }
/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}
/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) {

View File

@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {
public:
DiscreteMarginals() {}
/** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables.
*/

View File

@ -375,6 +375,31 @@ TEST(DecisionTree, labels) {
EXPECT_LONGS_EQUAL(2, labels.size());
}
/* ******************************************************************************** */
// Test retrieving all labels.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;
// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B,
DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"})
);
DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;
/// @}
public:
/// @name Advanced Interface
/// @{

View File

@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */
FastVector<sharedFactor> factors_;
/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}
/// @name Standard Constructors
/// @{
@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable
/// @{
/// print out graph
/// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/** Check equality */
/// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const;
/// @}

View File

@ -23,8 +23,8 @@
namespace gtsam {
/* ************************************************************************* */
template<class FACTOR>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
template<class FACTORGRAPH>
void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet;

View File

@ -62,8 +62,8 @@ public:
nKeys_(0) {
}
template<class FG>
MetisIndex(const FG& factorGraph) :
template<class FACTORGRAPH>
MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) {
augment(factorGraph);
}
@ -78,8 +78,8 @@ public:
* Augment the variable index with new factors. This can be used when
* solving problems incrementally.
*/
template<class FACTOR>
void augment(const FactorGraph<FACTOR>& factors);
template<class FACTORGRAPH>
void augment(const FACTORGRAPH& factors);
const std::vector<int32_t>& xadj() const {
return xadj_;

View File

@ -99,6 +99,12 @@ namespace gtsam {
/// @}
/// Check exact equality.
friend bool operator==(const GaussianFactorGraph& lhs,
const GaussianFactorGraph& rhs) {
return lhs.isEqual(rhs);
}
/** Add a factor by value - makes a copy */
void add(const GaussianFactor& factor) { push_back(factor.clone()); }
@ -414,7 +420,7 @@ namespace gtsam {
*/
GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors);
/****** Linear Algebra Opeations ******/
/****** Linear Algebra Operations ******/
///* matrix-vector operations */
//GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r);