Remove Potentials

release/4.3a0
Fan Jiang 2022-01-06 20:10:03 -05:00
parent 0683dfaa67
commit f65bd4d90d
10 changed files with 159 additions and 188 deletions

View File

@ -179,5 +179,6 @@ namespace gtsam {
}; };
// AlgebraicDecisionTree // AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
} }
// namespace gtsam // namespace gtsam

View File

@ -34,12 +34,13 @@ namespace gtsam {
/* ******************************************************************************** */ /* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) : const ADT& potentials) :
DiscreteFactor(keys.indices()), Potentials(keys, potentials) { DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
} }
/* *************************************************************************/ /* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), Potentials(c) { DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -48,16 +49,25 @@ namespace gtsam {
return false; return false;
} }
else { else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol); return ADT::equals(f, tol);
} }
} }
/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************* */ /* ************************************************************************* */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
Potentials::print("Potentials:",formatter); ADT::print("Potentials:",formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -162,20 +172,20 @@ namespace gtsam {
void DecisionTreeFactor::dot(std::ostream& os, void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero); ADT::dot(os, keyFormatter, valueFormatter, showZero);
} }
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name, void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero); ADT::dot(name, keyFormatter, valueFormatter, showZero);
} }
/** output to graphviz format string */ /** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero); return ADT::dot(keyFormatter, valueFormatter, showZero);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -209,5 +219,15 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -19,7 +19,8 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
@ -35,7 +36,7 @@ namespace gtsam {
/** /**
* A discrete probabilistic factor * A discrete probabilistic factor
*/ */
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials { class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
public: public:
@ -43,6 +44,10 @@ namespace gtsam {
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key,size_t> cardinalities_;
public: public:
@ -55,11 +60,11 @@ namespace gtsam {
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from Indices and (string or doubles) */ /** Constructor from doubles */
template<class SOURCE> DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
DiscreteFactor(keys.indices()), Potentials(keys, table) { /** Constructor from string */
} DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization /// Single-key specialization
template <class SOURCE> template <class SOURCE>
@ -71,7 +76,7 @@ namespace gtsam {
: DecisionTreeFactor(DiscreteKeys{key}, row) {} : DecisionTreeFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */ /** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c); explicit DecisionTreeFactor(const DiscreteConditional& c);
/// @} /// @}
/// @name Testable /// @name Testable
@ -90,7 +95,7 @@ namespace gtsam {
/// Value is just look up in AlgebraicDecisonTree /// Value is just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override { double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values); return ADT::operator()(values);
} }
/// multiply two factors /// multiply two factors
@ -98,6 +103,10 @@ namespace gtsam {
return apply(f, ADT::Ring::mul); return apply(f, ADT::Ring::mul);
} }
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);}
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);

View File

@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s,
} }
} }
cout << ")"; cout << ")";
Potentials::print(""); ADT::print("");
cout << endl; cout << endl;
} }

View File

@ -128,7 +128,7 @@ public:
/// Evaluate, just look up in AlgebraicDecisonTree /// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override { double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values); return ADT::operator()(values);
} }
/** Convert to a factor */ /** Convert to a factor */

View File

@ -1,96 +1,96 @@
/* ---------------------------------------------------------------------------- ///* ----------------------------------------------------------------------------
//
* GTSAM Copyright 2010, Georgia Tech Research Corporation, // * GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415 // * Atlanta, Georgia 30332-0415
* All Rights Reserved // * All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) // * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
* See LICENSE for the license information // * See LICENSE for the license information
//
* -------------------------------------------------------------------------- */ // * -------------------------------------------------------------------------- */
//
/** ///**
* @file Potentials.cpp // * @file Potentials.cpp
* @date March 24, 2011 // * @date March 24, 2011
* @author Frank Dellaert // * @author Frank Dellaert
*/ // */
//
#include <gtsam/discrete/DecisionTree-inl.h> //#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Potentials.h> //#include <gtsam/discrete/Potentials.h>
//
#include <boost/format.hpp> //#include <boost/format.hpp>
//
#include <string> //#include <string>
//
using namespace std; //using namespace std;
//
namespace gtsam { //namespace gtsam {
//
/* ************************************************************************* */ ///* ************************************************************************* */
double Potentials::safe_div(const double& a, const double& b) { //double Potentials::safe_div(const double& a, const double& b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); // // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum // // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // // factor. If the product or sum is zero, we accord zero probability to the
// event. // // event.
return (a == 0 || b == 0) ? 0 : (a / b); // return (a == 0 || b == 0) ? 0 : (a / b);
} //}
//
/* ******************************************************************************** ///* ********************************************************************************
*/ // */
Potentials::Potentials() : ADT(1.0) {} //Potentials::Potentials() : ADT(1.0) {}
//
/* ******************************************************************************** ///* ********************************************************************************
*/ // */
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) //Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {} // : ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
//
/* ************************************************************************* */ ///* ************************************************************************* */
bool Potentials::equals(const Potentials& other, double tol) const { //bool Potentials::equals(const Potentials& other, double tol) const {
return ADT::equals(other, tol); // return ADT::equals(other, tol);
} //}
//
/* ************************************************************************* */ ///* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const { //void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: { "; // cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_) // for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", "; // cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl; // cout << "}" << endl;
ADT::print(" ", formatter); // ADT::print(" ", formatter);
} //}
////
//// /* ************************************************************************* */
//// template<class P>
//// void Potentials::remapIndices(const P& remapping) {
//// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
//// DiscreteKeys keys;
//// map<Key, Key> ordering;
////
//// // Get the original keys from cardinalities_
//// for(const DiscreteKey& key: cardinalities_)
//// keys & key;
////
//// // Perform Permutation
//// for(DiscreteKey& key: keys) {
//// ordering[key.first] = remapping[key.first];
//// key.first = ordering[key.first];
//// }
////
//// // Change *this
//// AlgebraicDecisionTree<Key> permuted((*this), ordering);
//// *this = permuted;
//// cardinalities_ = keys.cardinalities();
//// }
////
//// /* ************************************************************************* */
//// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
//// remapIndices(inversePermutation);
//// }
////
//// /* ************************************************************************* */
//// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
//// remapIndices(inverseReduction);
//// }
// //
// /* ************************************************************************* */ // /* ************************************************************************* */
// template<class P>
// void Potentials::remapIndices(const P& remapping) {
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
// DiscreteKeys keys;
// map<Key, Key> ordering;
// //
// // Get the original keys from cardinalities_ //} // namespace gtsam
// for(const DiscreteKey& key: cardinalities_)
// keys & key;
//
// // Perform Permutation
// for(DiscreteKey& key: keys) {
// ordering[key.first] = remapping[key.first];
// key.first = ordering[key.first];
// }
//
// // Change *this
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
// *this = permuted;
// cardinalities_ = keys.cardinalities();
// }
//
// /* ************************************************************************* */
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
// remapIndices(inversePermutation);
// }
//
// /* ************************************************************************* */
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
// remapIndices(inverseReduction);
// }
/* ************************************************************************* */
} // namespace gtsam

View File

@ -26,72 +26,10 @@
namespace gtsam { namespace gtsam {
/** /*
* A base class for both DiscreteFactor and DiscreteConditional * @deprecated
*/ * @brief Deprecated class for storing an ADT with some convenience methods
class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree<Key> { * */
typedef GTSAM_DEPRECATED AlgebraicDecisionTree<Key> Potentials;
public:
typedef AlgebraicDecisionTree<Key> ADT;
protected:
/// Cardinality for each key, used in combine
std::map<Key,size_t> cardinalities_;
/** Constructor from ColumnIndex, and ADT */
Potentials(const ADT& potentials) :
ADT(potentials) {
}
// Safe division for probabilities
static double safe_div(const double& a, const double& b);
// // Apply either a permutation or a reduction
// template<class P>
// void remapIndices(const P& remapping);
public:
/** Default constructor for I/O */
Potentials();
/** Constructor from Indices and ADT */
Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
Potentials(const DiscreteKeys& keys, SOURCE table) :
ADT(keys, table), cardinalities_(keys.cardinalities()) {
}
// Testable
bool equals(const Potentials& other, double tol = 1e-9) const;
void print(const std::string& s = "Potentials: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
size_t cardinality(Key j) const { return cardinalities_.at(j);}
// /**
// * @brief Permutes the keys in Potentials
// *
// * This permutes the Indices and performs necessary re-ordering of ADD.
// * This is virtual so that derived types e.g. DecisionTreeFactor can
// * re-implement it.
// */
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials
// traits
template<> struct traits<Potentials> : public Testable<Potentials> {};
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
} // namespace gtsam } // namespace gtsam

View File

@ -41,21 +41,23 @@ using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
using ADT = AlgebraicDecisionTree<Key>;
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2); DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4"); auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
(Potentials::ADT)*prior)); (ADT)*prior));
bayesNet.push_back(prior); bayesNet.push_back(prior);
auto conditional = auto conditional =
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2"); boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); CHECK(assert_equal(expected, (ADT)*conditional));
bayesNet.push_back(conditional); bayesNet.push_back(conditional);
DiscreteFactorGraph fg(bayesNet); DiscreteFactorGraph fg(bayesNet);

View File

@ -72,8 +72,9 @@ namespace gtsam {
public: public:
typedef CLIQUE Clique; ///< The clique type, normally BayesTreeClique typedef CLIQUE Clique; ///< The clique type, normally BayesTreeClique
typedef boost::shared_ptr<Clique> sharedClique; ///< Shared pointer to a clique typedef boost::shared_ptr<Clique> sharedClique; ///< Shared pointer to a clique
typedef Clique Node; ///< Synonym for Clique (TODO: remove)
typedef sharedClique sharedNode; ///< Synonym for sharedClique (TODO: remove) typedef GTSAM_DEPRECATED Clique Node; ///< Synonym for Clique (TODO: remove)
typedef GTSAM_DEPRECATED sharedClique sharedNode; ///< Synonym for sharedClique (TODO: remove)
typedef typename CLIQUE::ConditionalType ConditionalType; typedef typename CLIQUE::ConditionalType ConditionalType;
typedef boost::shared_ptr<ConditionalType> sharedConditional; typedef boost::shared_ptr<ConditionalType> sharedConditional;
typedef typename CLIQUE::BayesNetType BayesNetType; typedef typename CLIQUE::BayesNetType BayesNetType;
@ -143,7 +144,7 @@ namespace gtsam {
const Nodes& nodes() const { return nodes_; } const Nodes& nodes() const { return nodes_; }
/** Access node by variable */ /** Access node by variable */
const sharedNode operator[](Key j) const { return nodes_.at(j); } sharedClique operator[](Key j) const { return nodes_.at(j); }
/** return root cliques */ /** return root cliques */
const Roots& roots() const { return roots_; } const Roots& roots() const { return roots_; }

View File

@ -130,9 +130,9 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
// get all constraints then specialize to slot // get all constraints then specialize to slot
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_; size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
DiscreteKey dummy(dummyIndex, nrTimeSlots()); DiscreteKey dummy(dummyIndex, nrTimeSlots());
Potentials::ADT p(dummy & areaKey, AlgebraicDecisionTree<Key> p(dummy & areaKey,
available_); // available_ is Doodle string available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot); auto q = p.choose(dummyIndex, *slot);
CSP::add(areaKey, q); CSP::add(areaKey, q);
} else { } else {
DiscreteKeys keys {s.key_, areaKey}; DiscreteKeys keys {s.key_, areaKey};