turns out we can merge DiscreteConditional and DiscreteLookupTable

release/4.3a0
Varun Agrawal 2024-07-14 10:30:23 -04:00
parent 5550537709
commit 52f1aba10c
4 changed files with 50 additions and 125 deletions

View File

@ -235,7 +235,10 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} }
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::argmax() const { size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Initialize
size_t maxValue = 0; size_t maxValue = 0;
double maxP = 0; double maxP = 0;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax() const {
return maxValue; return maxValue;
} }
/* ************************************************************************** */
void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
@ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
} }
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const{ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete()); return this->evaluate(x.discrete());
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <memory> #include <memory>
#include <string> #include <string>
@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional
public Conditional<DecisionTreeFactor, DiscreteConditional> { public Conditional<DecisionTreeFactor, DiscreteConditional> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteConditional This; ///< Typedef to this class typedef DiscreteConditional This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> typedef Conditional<BaseFactor, This>
@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
/// @{ /// @{
/// Log-probability is just -error(x). /// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const { double logProbability(const DiscreteValues& x) const { return -error(x); }
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional
size_t sample() const; size_t sample() const;
/** /**
* @brief Return assignment that maximizes distribution. * @brief Return assignment for single frontal variable that maximizes value.
* @return Optimal assignment (1 frontal variable). * @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/ */
size_t argmax() const; size_t argmax() const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods. /// @name HybridValues methods.
/// @{ /// @{

View File

@ -29,97 +29,20 @@ using std::vector;
namespace gtsam { namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) { const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag; DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) { for (auto&& conditional : bayesNet) {
if (auto lookupTable = dag.push_back(conditional);
std::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
} }
return dag; return dag;
} }
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first). // Argmax each node in turn in topological sort order (parents first).
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
// dereference to get the sharedFactor to the lookup table // dereference to get the sharedFactor to the lookup table
(*it)->argmaxInPlace(&result); (*it)->argmaxInPlace(&result);
} }

View File

@ -37,41 +37,9 @@ class DiscreteBayesNet;
* Inherits from discrete conditional for convenience, but is not normalized. * Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm. * Is used in the max-product algorithm.
*/ */
class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { // Typedef for backwards compatibility
public: // TODO(Varun): Remove
using This = DiscreteLookupTable; using DiscreteLookupTable = DiscreteConditional;
using shared_ptr = std::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */ /** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> { class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {