undo merge
parent
dd8de1f300
commit
8c69345937
|
@ -29,12 +29,90 @@ using std::vector;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
void DiscreteLookupTable::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
cout << s << " g( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "; ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
ADT::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
size_t mpe = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||||
const DiscreteBayesNet& bayesNet) {
|
const DiscreteBayesNet& bayesNet) {
|
||||||
DiscreteLookupDAG dag;
|
DiscreteLookupDAG dag;
|
||||||
for (auto&& conditional : bayesNet) {
|
for (auto&& conditional : bayesNet) {
|
||||||
dag.push_back(conditional);
|
if (auto lookupTable =
|
||||||
|
std::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||||
|
dag.push_back(lookupTable);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return dag;
|
return dag;
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,9 +37,41 @@ class DiscreteBayesNet;
|
||||||
* Inherits from discrete conditional for convenience, but is not normalized.
|
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||||
* Is used in the max-product algorithm.
|
* Is used in the max-product algorithm.
|
||||||
*/
|
*/
|
||||||
// Typedef for backwards compatibility
|
class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
|
||||||
// TODO(Varun): Remove
|
public:
|
||||||
using DiscreteLookupTable = DiscreteConditional;
|
using This = DiscreteLookupTable;
|
||||||
|
using shared_ptr = std::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a sorted list of gtsam::Keys
|
||||||
|
* @param potentials the algebraic decision tree with lookup values
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials)
|
||||||
|
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
/** A DAG made from lookup tables, as defined above. */
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||||
|
|
|
@ -262,6 +262,15 @@ class DiscreteBayesTree {
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
|
||||||
|
class DiscreteLookupTable : gtsam::DiscreteConditional{
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
|
||||||
|
const gtsam::DecisionTreeFactor::ADT& potentials);
|
||||||
|
void print(string s = "Discrete Lookup Table: ",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
class DiscreteLookupDAG {
|
class DiscreteLookupDAG {
|
||||||
DiscreteLookupDAG();
|
DiscreteLookupDAG();
|
||||||
void push_back(const gtsam::DiscreteLookupTable* table);
|
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||||
|
|
Loading…
Reference in New Issue