New lookup classes
parent
ec39197cc3
commit
34a3b022d9
|
@ -0,0 +1,153 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupTable.cpp
|
||||||
|
* @date Feb 14, 2011
|
||||||
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using std::pair;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
// Instantiate base class
|
||||||
|
template class GTSAM_EXPORT
|
||||||
|
Conditional<DecisionTreeFactor, DiscreteLookupTable>;
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
void DiscreteLookupTable::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
cout << s << " g( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "; ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
ADT::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
vector<DiscreteValues> DiscreteLookupTable::frontalAssignments() const {
|
||||||
|
vector<pair<Key, size_t>> pairs;
|
||||||
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||||
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||||
|
return DiscreteValues::CartesianProduct(rpairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional,
|
||||||
|
const DiscreteValues& given,
|
||||||
|
bool forceComplete = true) {
|
||||||
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
// branches based on the value of the parent variables.
|
||||||
|
DiscreteLookupTable::ADT adt(conditional);
|
||||||
|
size_t value;
|
||||||
|
for (Key j : conditional.parents()) {
|
||||||
|
try {
|
||||||
|
value = given.at(j);
|
||||||
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
} catch (std::out_of_range&) {
|
||||||
|
if (forceComplete) {
|
||||||
|
given.print("parentsValues: ");
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteLookupTable::Choose: parent value missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return adt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
size_t mpe = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
|
double maxP = 0;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteValues DiscreteLookupDAG::argmax() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return argmax(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||||
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
|
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||||
|
lookupTable->argmaxInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -0,0 +1,138 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.h
|
||||||
|
* @date JAnuary, 2022
|
||||||
|
* @author Frank dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief DiscreteLookupTable table for max-product
|
||||||
|
*/
|
||||||
|
class DiscreteLookupTable
|
||||||
|
: public DecisionTreeFactor,
|
||||||
|
public Conditional<DecisionTreeFactor, DiscreteLookupTable> {
|
||||||
|
public:
|
||||||
|
using This = DiscreteLookupTable;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a orted list of gtsam::Keys
|
||||||
|
* @param potentials the algebraic decision tree with lookup values
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials)
|
||||||
|
: DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {}
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal variables.
|
||||||
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
|
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||||
|
public:
|
||||||
|
using Base = BayesNet<DiscreteLookupTable>;
|
||||||
|
using This = DiscreteLookupDAG;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Construct empty DAG.
|
||||||
|
DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first. If the DAG resulted from
|
||||||
|
* eliminating a factor graph, this is true for the elimination ordering.
|
||||||
|
*
|
||||||
|
* @return optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution, given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted *and* that the
|
||||||
|
* DAG does not contain any conditionals for the given variables.
|
||||||
|
*
|
||||||
|
* @return given assignment extended w. optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax(DiscreteValues given) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
Loading…
Reference in New Issue