Remove Potentials

release/4.3a0
Fan Jiang 2022-01-06 20:10:03 -05:00
parent 0683dfaa67
commit f65bd4d90d
10 changed files with 159 additions and 188 deletions

View File

@ -179,5 +179,6 @@ namespace gtsam {
};
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
}
// namespace gtsam

View File

@ -34,12 +34,13 @@ namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}
/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), Potentials(c) {
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}
/* ************************************************************************* */
@ -48,16 +49,25 @@ namespace gtsam {
return false;
}
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}
/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
cout << s;
Potentials::print("Potentials:",formatter);
ADT::print("Potentials:",formatter);
}
/* ************************************************************************* */
@ -162,20 +172,20 @@ namespace gtsam {
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
ADT::dot(os, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero);
return ADT::dot(keyFormatter, valueFormatter, showZero);
}
/* ************************************************************************* */
@ -209,5 +219,15 @@ namespace gtsam {
return ss.str();
}
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -19,7 +19,8 @@
#pragma once
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h>
#include <boost/shared_ptr.hpp>
@ -35,7 +36,7 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
public:
@ -43,6 +44,10 @@ namespace gtsam {
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key,size_t> cardinalities_;
public:
@ -55,11 +60,11 @@ namespace gtsam {
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization
template <class SOURCE>
@ -71,7 +76,7 @@ namespace gtsam {
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);
explicit DecisionTreeFactor(const DiscreteConditional& c);
/// @}
/// @name Testable
@ -90,7 +95,7 @@ namespace gtsam {
/// Value is just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}
/// multiply two factors
@ -98,6 +103,10 @@ namespace gtsam {
return apply(f, ADT::Ring::mul);
}
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);}
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);

View File

@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s,
}
}
cout << ")";
Potentials::print("");
ADT::print("");
cout << endl;
}

View File

@ -128,7 +128,7 @@ public:
/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}
/** Convert to a factor */

View File

@ -1,96 +1,96 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Potentials.cpp
* @date March 24, 2011
* @author Frank Dellaert
*/
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Potentials.h>
#include <boost/format.hpp>
#include <string>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
double Potentials::safe_div(const double& a, const double& b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ********************************************************************************
*/
Potentials::Potentials() : ADT(1.0) {}
/* ********************************************************************************
*/
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
bool Potentials::equals(const Potentials& other, double tol) const {
return ADT::equals(other, tol);
}
/* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl;
ADT::print(" ", formatter);
}
///* ----------------------------------------------------------------------------
//
// * GTSAM Copyright 2010, Georgia Tech Research Corporation,
// * Atlanta, Georgia 30332-0415
// * All Rights Reserved
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
// * See LICENSE for the license information
//
// * -------------------------------------------------------------------------- */
//
///**
// * @file Potentials.cpp
// * @date March 24, 2011
// * @author Frank Dellaert
// */
//
//#include <gtsam/discrete/DecisionTree-inl.h>
//#include <gtsam/discrete/Potentials.h>
//
//#include <boost/format.hpp>
//
//#include <string>
//
//using namespace std;
//
//namespace gtsam {
//
///* ************************************************************************* */
//double Potentials::safe_div(const double& a, const double& b) {
// // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// // The use for safe_div is when we divide the product factor by the sum
// // factor. If the product or sum is zero, we accord zero probability to the
// // event.
// return (a == 0 || b == 0) ? 0 : (a / b);
//}
//
///* ********************************************************************************
// */
//Potentials::Potentials() : ADT(1.0) {}
//
///* ********************************************************************************
// */
//Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
// : ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
//
///* ************************************************************************* */
//bool Potentials::equals(const Potentials& other, double tol) const {
// return ADT::equals(other, tol);
//}
//
///* ************************************************************************* */
//void Potentials::print(const string& s, const KeyFormatter& formatter) const {
// cout << s << "\n Cardinalities: { ";
// for (const std::pair<const Key,size_t>& key : cardinalities_)
// cout << formatter(key.first) << ":" << key.second << ", ";
// cout << "}" << endl;
// ADT::print(" ", formatter);
//}
////
//// /* ************************************************************************* */
//// template<class P>
//// void Potentials::remapIndices(const P& remapping) {
//// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
//// DiscreteKeys keys;
//// map<Key, Key> ordering;
////
//// // Get the original keys from cardinalities_
//// for(const DiscreteKey& key: cardinalities_)
//// keys & key;
////
//// // Perform Permutation
//// for(DiscreteKey& key: keys) {
//// ordering[key.first] = remapping[key.first];
//// key.first = ordering[key.first];
//// }
////
//// // Change *this
//// AlgebraicDecisionTree<Key> permuted((*this), ordering);
//// *this = permuted;
//// cardinalities_ = keys.cardinalities();
//// }
////
//// /* ************************************************************************* */
//// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
//// remapIndices(inversePermutation);
//// }
////
//// /* ************************************************************************* */
//// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
//// remapIndices(inverseReduction);
//// }
//
// /* ************************************************************************* */
// template<class P>
// void Potentials::remapIndices(const P& remapping) {
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
// DiscreteKeys keys;
// map<Key, Key> ordering;
//
// // Get the original keys from cardinalities_
// for(const DiscreteKey& key: cardinalities_)
// keys & key;
//
// // Perform Permutation
// for(DiscreteKey& key: keys) {
// ordering[key.first] = remapping[key.first];
// key.first = ordering[key.first];
// }
//
// // Change *this
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
// *this = permuted;
// cardinalities_ = keys.cardinalities();
// }
//
// /* ************************************************************************* */
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
// remapIndices(inversePermutation);
// }
//
// /* ************************************************************************* */
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
// remapIndices(inverseReduction);
// }
/* ************************************************************************* */
} // namespace gtsam
//} // namespace gtsam

View File

@ -26,72 +26,10 @@
namespace gtsam {
/**
* A base class for both DiscreteFactor and DiscreteConditional
*/
class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree<Key> {
public:
typedef AlgebraicDecisionTree<Key> ADT;
protected:
/// Cardinality for each key, used in combine
std::map<Key,size_t> cardinalities_;
/** Constructor from ColumnIndex, and ADT */
Potentials(const ADT& potentials) :
ADT(potentials) {
}
// Safe division for probabilities
static double safe_div(const double& a, const double& b);
// // Apply either a permutation or a reduction
// template<class P>
// void remapIndices(const P& remapping);
public:
/** Default constructor for I/O */
Potentials();
/** Constructor from Indices and ADT */
Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
Potentials(const DiscreteKeys& keys, SOURCE table) :
ADT(keys, table), cardinalities_(keys.cardinalities()) {
}
// Testable
bool equals(const Potentials& other, double tol = 1e-9) const;
void print(const std::string& s = "Potentials: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
size_t cardinality(Key j) const { return cardinalities_.at(j);}
// /**
// * @brief Permutes the keys in Potentials
// *
// * This permutes the Indices and performs necessary re-ordering of ADD.
// * This is virtual so that derived types e.g. DecisionTreeFactor can
// * re-implement it.
// */
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials
// traits
template<> struct traits<Potentials> : public Testable<Potentials> {};
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
/*
* @deprecated
* @brief Deprecated class for storing an ADT with some convenience methods
* */
typedef GTSAM_DEPRECATED AlgebraicDecisionTree<Key> Potentials;
} // namespace gtsam

View File

@ -41,21 +41,23 @@ using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
using ADT = AlgebraicDecisionTree<Key>;
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
(Potentials::ADT)*prior));
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
(ADT)*prior));
bayesNet.push_back(prior);
auto conditional =
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (ADT)*conditional));
bayesNet.push_back(conditional);
DiscreteFactorGraph fg(bayesNet);

View File

@ -72,8 +72,9 @@ namespace gtsam {
public:
typedef CLIQUE Clique; ///< The clique type, normally BayesTreeClique
typedef boost::shared_ptr<Clique> sharedClique; ///< Shared pointer to a clique
typedef Clique Node; ///< Synonym for Clique (TODO: remove)
typedef sharedClique sharedNode; ///< Synonym for sharedClique (TODO: remove)
typedef GTSAM_DEPRECATED Clique Node; ///< Synonym for Clique (TODO: remove)
typedef GTSAM_DEPRECATED sharedClique sharedNode; ///< Synonym for sharedClique (TODO: remove)
typedef typename CLIQUE::ConditionalType ConditionalType;
typedef boost::shared_ptr<ConditionalType> sharedConditional;
typedef typename CLIQUE::BayesNetType BayesNetType;
@ -143,7 +144,7 @@ namespace gtsam {
const Nodes& nodes() const { return nodes_; }
/** Access node by variable */
const sharedNode operator[](Key j) const { return nodes_.at(j); }
sharedClique operator[](Key j) const { return nodes_.at(j); }
/** return root cliques */
const Roots& roots() const { return roots_; }

View File

@ -130,9 +130,9 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
// get all constraints then specialize to slot
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
DiscreteKey dummy(dummyIndex, nrTimeSlots());
Potentials::ADT p(dummy & areaKey,
AlgebraicDecisionTree<Key> p(dummy & areaKey,
available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot);
auto q = p.choose(dummyIndex, *slot);
CSP::add(areaKey, q);
} else {
DiscreteKeys keys {s.key_, areaKey};