turns out we can merge DiscreteConditional and DiscreteLookupTable
parent
5550537709
commit
52f1aba10c
|
@ -235,7 +235,10 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::argmax() const {
|
||||
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
size_t maxValue = 0;
|
||||
double maxP = 0;
|
||||
assert(nrFrontals() == 1);
|
||||
|
@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax() const {
|
|||
return maxValue;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
|
@ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteConditional::evaluate(const HybridValues& x) const{
|
||||
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||
return this->evaluate(x.discrete());
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DiscreteConditional This; ///< Typedef to this class
|
||||
typedef DiscreteConditional This; ///< Typedef to this class
|
||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This>
|
||||
|
@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// @{
|
||||
|
||||
/// Log-probability is just -error(x).
|
||||
double logProbability(const DiscreteValues& x) const {
|
||||
return -error(x);
|
||||
}
|
||||
double logProbability(const DiscreteValues& x) const { return -error(x); }
|
||||
|
||||
/// print index signature only
|
||||
void printSignature(
|
||||
|
@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
size_t sample() const;
|
||||
|
||||
/**
|
||||
* @brief Return assignment that maximizes distribution.
|
||||
* @return Optimal assignment (1 frontal variable).
|
||||
* @brief Return assignment for single frontal variable that maximizes value.
|
||||
* @param parentsValues Known assignments for the parents.
|
||||
* @return maximizing assignment for the frontal variable.
|
||||
*/
|
||||
size_t argmax() const;
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const Names& names = {}) const override;
|
||||
|
||||
|
||||
/// @}
|
||||
/// @name HybridValues methods.
|
||||
/// @{
|
||||
|
|
|
@ -29,97 +29,20 @@ using std::vector;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
void DiscreteLookupTable::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
cout << s << " g( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
if (nrParents()) {
|
||||
cout << "; ";
|
||||
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
}
|
||||
cout << "):\n";
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||
const DiscreteBayesNet& bayesNet) {
|
||||
DiscreteLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
if (auto lookupTable =
|
||||
std::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||
dag.push_back(lookupTable);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||
}
|
||||
dag.push_back(conditional);
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) {
|
||||
for (auto it = std::make_reverse_iterator(end());
|
||||
it != std::make_reverse_iterator(begin()); ++it) {
|
||||
// dereference to get the sharedFactor to the lookup table
|
||||
(*it)->argmaxInPlace(&result);
|
||||
}
|
||||
|
|
|
@ -37,41 +37,9 @@ class DiscreteBayesNet;
|
|||
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||
* Is used in the max-product algorithm.
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
|
||||
public:
|
||||
using This = DiscreteLookupTable;
|
||||
using shared_ptr = std::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Discrete Lookup Table object
|
||||
*
|
||||
* @param nFrontals number of frontal variables
|
||||
* @param keys a sorted list of gtsam::Keys
|
||||
* @param potentials the algebraic decision tree with lookup values
|
||||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Lookup Table: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/**
|
||||
* @brief return assignment for single frontal variable that maximizes value.
|
||||
* @param parentsValues Known assignments for the parents.
|
||||
* @return maximizing assignment for the frontal variable.
|
||||
*/
|
||||
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
};
|
||||
// Typedef for backwards compatibility
|
||||
// TODO(Varun): Remove
|
||||
using DiscreteLookupTable = DiscreteConditional;
|
||||
|
||||
/** A DAG made from lookup tables, as defined above. */
|
||||
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||
|
|
Loading…
Reference in New Issue