560 lines
19 KiB
C++
560 lines
19 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file DiscreteConditional.cpp
|
|
* @date Feb 14, 2011
|
|
* @author Duy-Nguyen Ta
|
|
* @author Frank Dellaert
|
|
*/
|
|
|
|
#include <gtsam/base/Testable.h>
|
|
#include <gtsam/base/debug.h>
|
|
#include <gtsam/discrete/DiscreteConditional.h>
|
|
#include <gtsam/discrete/Ring.h>
|
|
#include <gtsam/discrete/Signature.h>
|
|
#include <gtsam/hybrid/HybridValues.h>
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <random>
|
|
#include <set>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
// In wrappers we can access std::mt19937_64 via gtsam.MT19937
|
|
static std::mt19937_64 kRandomNumberGenerator(2);
|
|
|
|
using namespace std;
|
|
using std::pair;
|
|
using std::stringstream;
|
|
using std::vector;
|
|
namespace gtsam {
|
|
|
|
// Instantiate base class
|
|
template class GTSAM_EXPORT
|
|
Conditional<DecisionTreeFactor, DiscreteConditional>;
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
|
const DiscreteFactor& f)
|
|
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
|
|
BaseConditional(nrFrontals) {}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
|
const DiscreteKeys& keys,
|
|
const ADT& potentials)
|
|
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
|
const DecisionTreeFactor& marginal)
|
|
: BaseFactor(joint / marginal),
|
|
BaseConditional(joint.size() - marginal.size()) {}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
|
const DecisionTreeFactor& marginal,
|
|
const Ordering& orderedKeys)
|
|
: DiscreteConditional(joint, marginal) {
|
|
keys_.clear();
|
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
|
BaseConditional(1) {}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional DiscreteConditional::operator*(
|
|
const DiscreteConditional& other) const {
|
|
// If the root is a nullptr, we have a TableDistribution
|
|
// TODO(Varun) Revisit this hack after RSS2025 submission
|
|
if (!other.root_) {
|
|
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
|
|
return dc * (*this);
|
|
}
|
|
|
|
// Take union of frontal keys
|
|
std::set<Key> newFrontals;
|
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
|
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
|
|
|
// Check if frontals overlapped
|
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
|
throw std::invalid_argument(
|
|
"DiscreteConditional::operator* called with overlapping frontal keys.");
|
|
|
|
// Now, add cardinalities.
|
|
DiscreteKeys discreteKeys;
|
|
for (auto&& key : frontals())
|
|
discreteKeys.emplace_back(key, cardinality(key));
|
|
for (auto&& key : other.frontals())
|
|
discreteKeys.emplace_back(key, other.cardinality(key));
|
|
|
|
// Sort
|
|
std::sort(discreteKeys.begin(), discreteKeys.end());
|
|
|
|
// Add parents to set, to make them unique
|
|
std::set<DiscreteKey> parents;
|
|
for (auto&& key : this->parents())
|
|
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
|
for (auto&& key : other.parents())
|
|
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
|
|
|
// Finally, add parents to keys, in order
|
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
|
|
|
ADT product = ADT::apply(other, Ring::mul);
|
|
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional DiscreteConditional::marginal(Key key) const {
|
|
if (nrParents() > 0)
|
|
throw std::invalid_argument(
|
|
"DiscreteConditional::marginal: single argument version only valid for "
|
|
"fully specified joint distributions (i.e., no parents).");
|
|
|
|
// Calculate the keys as the frontal keys without the given key.
|
|
DiscreteKeys discreteKeys{{key, cardinality(key)}};
|
|
|
|
// Calculate sum
|
|
ADT adt(*this);
|
|
for (auto&& k : frontals())
|
|
if (k != key) adt = adt.sum(k, cardinality(k));
|
|
|
|
// Return new factor
|
|
return DiscreteConditional(1, discreteKeys, adt);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
void DiscreteConditional::print(const string& s,
|
|
const KeyFormatter& formatter) const {
|
|
cout << s << " P( ";
|
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
|
cout << formatter(*it) << " ";
|
|
}
|
|
if (nrParents()) {
|
|
cout << "| ";
|
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
|
cout << formatter(*it) << " ";
|
|
}
|
|
}
|
|
cout << "):\n";
|
|
ADT::print("", formatter);
|
|
cout << endl;
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|
double tol) const {
|
|
if (!dynamic_cast<const BaseFactor*>(&other)) {
|
|
return false;
|
|
} else {
|
|
const BaseFactor& f(static_cast<const BaseFactor&>(other));
|
|
return BaseFactor::equals(f, tol);
|
|
}
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::ADT DiscreteConditional::choose(
|
|
const DiscreteValues& given, bool forceComplete) const {
|
|
// Get the big decision tree with all the levels, and then go down the
|
|
// branches based on the value of the parent variables.
|
|
DiscreteConditional::ADT adt(*this);
|
|
size_t value;
|
|
for (Key j : parents()) {
|
|
try {
|
|
value = given.at(j);
|
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
} catch (std::out_of_range&) {
|
|
if (forceComplete) {
|
|
given.print("parentsValues: ");
|
|
throw runtime_error(
|
|
"DiscreteConditional::choose: parent value missing");
|
|
}
|
|
}
|
|
}
|
|
return adt;
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
|
const DiscreteValues& given) const {
|
|
ADT adt = choose(given, false); // P(F|S=given)
|
|
|
|
// Collect all keys not in given.
|
|
DiscreteKeys dKeys;
|
|
for (Key j : frontals()) {
|
|
dKeys.emplace_back(j, this->cardinality(j));
|
|
}
|
|
for (size_t i = nrFrontals(); i < size(); i++) {
|
|
Key j = keys_[i];
|
|
if (given.count(j) == 0) {
|
|
dKeys.emplace_back(j, this->cardinality(j));
|
|
}
|
|
}
|
|
return std::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|
const DiscreteValues& frontalValues) const {
|
|
// Get the big decision tree with all the levels, and then go down the
|
|
// branches based on the value of the frontal variables.
|
|
ADT adt(*this);
|
|
size_t value;
|
|
for (Key j : frontals()) {
|
|
try {
|
|
value = frontalValues.at(j);
|
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
} catch (exception&) {
|
|
frontalValues.print("frontalValues: ");
|
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
|
}
|
|
}
|
|
|
|
// Convert ADT to factor.
|
|
DiscreteKeys discreteKeys;
|
|
for (Key j : parents()) {
|
|
discreteKeys.emplace_back(j, this->cardinality(j));
|
|
}
|
|
return std::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
|
}
|
|
|
|
/* ****************************************************************************/
|
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|
size_t frontal) const {
|
|
if (nrFrontals() != 1)
|
|
throw std::invalid_argument(
|
|
"Single value likelihood can only be invoked on single-variable "
|
|
"conditional");
|
|
DiscreteValues values;
|
|
values.emplace(keys_[0], frontal);
|
|
return likelihood(values);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
|
|
|
// Initialize
|
|
size_t maxValue = 0;
|
|
double maxP = 0;
|
|
DiscreteValues values = parentsValues;
|
|
|
|
assert(nrFrontals() == 1);
|
|
Key j = firstFrontalKey();
|
|
for (size_t value = 0; value < cardinality(j); value++) {
|
|
values[j] = value;
|
|
double pValueS = (*this)(values);
|
|
// Update MPE solution if better
|
|
if (pValueS > maxP) {
|
|
maxP = pValueS;
|
|
maxValue = value;
|
|
}
|
|
}
|
|
return maxValue;
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|
// throw if more than one frontal:
|
|
if (nrFrontals() != 1) {
|
|
throw std::invalid_argument(
|
|
"DiscreteConditional::sampleInPlace can only be called on single "
|
|
"variable conditionals");
|
|
}
|
|
Key j = firstFrontalKey();
|
|
// throw if values already contains j:
|
|
if (values->count(j) > 0) {
|
|
throw std::invalid_argument(
|
|
"DiscreteConditional::sampleInPlace: values already contains j");
|
|
}
|
|
size_t sampled = sample(*values); // Sample variable given parents
|
|
(*values)[j] = sampled; // store result in partial solution
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|
return sample(parentsValues, &kRandomNumberGenerator);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues,
|
|
std::mt19937_64* rng) const {
|
|
// Get the correct conditional distribution
|
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
|
|
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
|
if (nrFrontals() != 1) {
|
|
throw std::invalid_argument(
|
|
"DiscreteConditional::sample can only be called on single variable "
|
|
"conditionals");
|
|
}
|
|
Key key = firstFrontalKey();
|
|
size_t nj = cardinality(key);
|
|
vector<double> p(nj);
|
|
DiscreteValues frontals;
|
|
for (size_t value = 0; value < nj; value++) {
|
|
frontals[key] = value;
|
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
|
if (p[value] == 1.0) {
|
|
return value; // shortcut exit
|
|
}
|
|
}
|
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
|
return distribution(*rng);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
|
return sample(parent_value, &kRandomNumberGenerator);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample(size_t parent_value,
|
|
std::mt19937_64* rng) const {
|
|
if (nrParents() != 1)
|
|
throw std::invalid_argument(
|
|
"Single value sample() can only be invoked on single-parent "
|
|
"conditional");
|
|
DiscreteValues values;
|
|
values.emplace(keys_.back(), parent_value);
|
|
return sample(values, rng);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample() const {
|
|
return sample(&kRandomNumberGenerator);
|
|
}
|
|
|
|
/* ************************************************************************** */
|
|
size_t DiscreteConditional::sample(std::mt19937_64* rng) const {
|
|
if (nrParents() != 0)
|
|
throw std::invalid_argument(
|
|
"sample() can only be invoked on no-parent prior");
|
|
DiscreteValues values;
|
|
return sample(values, rng);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
|
|
vector<pair<Key, size_t>> pairs;
|
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
|
return DiscreteValues::CartesianProduct(rpairs);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
|
|
vector<pair<Key, size_t>> pairs;
|
|
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
|
|
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
|
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
|
return DiscreteValues::CartesianProduct(rpairs);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
// Print out signature.
|
|
static void streamSignature(const DiscreteConditional& conditional,
|
|
const KeyFormatter& keyFormatter,
|
|
stringstream* ss) {
|
|
*ss << "P(";
|
|
bool first = true;
|
|
for (Key key : conditional.frontals()) {
|
|
if (!first) *ss << ",";
|
|
*ss << keyFormatter(key);
|
|
first = false;
|
|
}
|
|
if (conditional.nrParents() > 0) {
|
|
*ss << "|";
|
|
bool first = true;
|
|
for (Key parent : conditional.parents()) {
|
|
if (!first) *ss << ",";
|
|
*ss << keyFormatter(parent);
|
|
first = false;
|
|
}
|
|
}
|
|
*ss << "):";
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
|
const Names& names) const {
|
|
stringstream ss;
|
|
ss << " *";
|
|
streamSignature(*this, keyFormatter, &ss);
|
|
ss << "*\n" << std::endl;
|
|
if (nrParents() == 0) {
|
|
// We have no parents, call factor method.
|
|
ss << BaseFactor::markdown(keyFormatter, names);
|
|
return ss.str();
|
|
}
|
|
|
|
// Print out header.
|
|
ss << "|";
|
|
for (Key parent : parents()) {
|
|
ss << "*" << keyFormatter(parent) << "*|";
|
|
}
|
|
|
|
auto frontalAssignments = this->frontalAssignments();
|
|
for (const auto& a : frontalAssignments) {
|
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
|
size_t index = a.at(*it);
|
|
ss << DiscreteValues::Translate(names, *it, index);
|
|
}
|
|
ss << "|";
|
|
}
|
|
ss << "\n";
|
|
|
|
// Print out separator with alignment hints.
|
|
ss << "|";
|
|
size_t n = frontalAssignments.size();
|
|
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
|
|
ss << "\n";
|
|
|
|
// Print out all rows.
|
|
size_t count = 0;
|
|
for (const auto& a : allAssignments()) {
|
|
if (count == 0) {
|
|
ss << "|";
|
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
|
size_t index = a.at(*it);
|
|
ss << DiscreteValues::Translate(names, *it, index) << "|";
|
|
}
|
|
}
|
|
ss << operator()(a) << "|";
|
|
count = (count + 1) % n;
|
|
if (count == 0) ss << "\n";
|
|
}
|
|
return ss.str();
|
|
}
|
|
|
|
/* ************************************************************************ */
|
|
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
|
const Names& names) const {
|
|
stringstream ss;
|
|
ss << "<div>\n<p> <i>";
|
|
streamSignature(*this, keyFormatter, &ss);
|
|
ss << "</i></p>\n";
|
|
if (nrParents() == 0) {
|
|
// We have no parents, call factor method.
|
|
ss << BaseFactor::html(keyFormatter, names);
|
|
return ss.str();
|
|
}
|
|
|
|
// Print out preamble.
|
|
ss << "<table class='DiscreteConditional'>\n <thead>\n";
|
|
|
|
// Print out header row.
|
|
ss << " <tr>";
|
|
for (Key parent : parents()) {
|
|
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
|
|
}
|
|
auto frontalAssignments = this->frontalAssignments();
|
|
for (const auto& a : frontalAssignments) {
|
|
ss << "<th>";
|
|
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
|
|
size_t index = a.at(*it);
|
|
ss << DiscreteValues::Translate(names, *it, index);
|
|
}
|
|
ss << "</th>";
|
|
}
|
|
ss << "</tr>\n";
|
|
|
|
// Finish header and start body.
|
|
ss << " </thead>\n <tbody>\n";
|
|
|
|
// Output all rows, one per assignment:
|
|
size_t count = 0, n = frontalAssignments.size();
|
|
for (const auto& a : allAssignments()) {
|
|
if (count == 0) {
|
|
ss << " <tr>";
|
|
for (auto&& it = beginParents(); it != endParents(); ++it) {
|
|
size_t index = a.at(*it);
|
|
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
|
|
}
|
|
}
|
|
ss << "<td>" << operator()(a) << "</td>"; // value
|
|
count = (count + 1) % n;
|
|
if (count == 0) ss << "</tr>\n";
|
|
}
|
|
|
|
// Finish up
|
|
ss << " </tbody>\n</table>\n</div>";
|
|
return ss.str();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
|
return this->operator()(x.discrete());
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
DiscreteFactor::shared_ptr DiscreteConditional::max(
|
|
const Ordering& keys) const {
|
|
return BaseFactor::max(keys);
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
void DiscreteConditional::prune(size_t maxNrAssignments) {
|
|
// Get as DiscreteConditional so the probabilities are normalized
|
|
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
|
|
this->root_ = pruned.root_;
|
|
}
|
|
|
|
/* ************************************************************************ */
|
|
void DiscreteConditional::removeDiscreteModes(const DiscreteValues& given) {
|
|
AlgebraicDecisionTree<Key> tree(*this);
|
|
for (auto [key, value] : given) {
|
|
tree = tree.choose(key, value);
|
|
}
|
|
|
|
// Get the leftover DiscreteKey frontals
|
|
DiscreteKeys frontals;
|
|
std::for_each(this->frontals().begin(), this->frontals().end(), [&](Key key) {
|
|
// Check if frontal key exists in given, if not add to new frontals
|
|
if (given.count(key) == 0) {
|
|
frontals.emplace_back(key, this->cardinalities_.at(key));
|
|
}
|
|
});
|
|
// Get the leftover DiscreteKey parents
|
|
DiscreteKeys parents;
|
|
std::for_each(this->parents().begin(), this->parents().end(), [&](Key key) {
|
|
// Check if parent key exists in given, if not add to new parents
|
|
if (given.count(key) == 0) {
|
|
parents.emplace_back(key, this->cardinalities_.at(key));
|
|
}
|
|
});
|
|
|
|
DiscreteKeys allDkeys(frontals);
|
|
allDkeys.insert(allDkeys.end(), parents.begin(), parents.end());
|
|
|
|
// Update the conditional
|
|
this->keys_ = allDkeys.indices();
|
|
this->cardinalities_ = allDkeys.cardinalities();
|
|
this->root_ = tree.root_;
|
|
this->nrFrontals_ = frontals.size();
|
|
}
|
|
|
|
/* ************************************************************************* */
|
|
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
|
|
|
/* ************************************************************************* */
|
|
|
|
} // namespace gtsam
|