Merge pull request #1919 from borglab/discrete-elimination-refactor
commit
82d0ebc8fe
|
@ -87,6 +87,25 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const {
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||||
|
// Check if `f` is a TableFactor. If yes, then
|
||||||
|
// convert `this` to a TableFactor which is cheaper.
|
||||||
|
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
|
||||||
|
|
||||||
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
// If `f` is a DecisionTreeFactor, divide normally.
|
||||||
|
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Else, convert `f` to a DecisionTreeFactor so we can divide
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
this->operator/(f->toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
|
@ -184,26 +184,30 @@ namespace gtsam {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
|
DiscreteFactor::shared_ptr operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
/// Convert into a decision tree
|
/// Convert into a decision tree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||||
return combine(nrFrontals, Ring::add);
|
return combine(nrFrontals, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(const Ordering& keys) const {
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||||
return combine(keys, Ring::add);
|
return combine(keys, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||||
return combine(nrFrontals, Ring::max);
|
return combine(nrFrontals, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(const Ordering& keys) const {
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||||
return combine(keys, Ring::max);
|
return combine(keys, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,6 +288,12 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
DecisionTreeFactor prune(size_t maxNrAssignments) const;
|
DecisionTreeFactor prune(size_t maxNrAssignments) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the number of non-zero values contained in this factor.
|
||||||
|
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||||
|
*/
|
||||||
|
uint64_t nrValues() const override { return nrLeaves(); }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -24,13 +24,13 @@
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::pair;
|
using std::pair;
|
||||||
|
@ -44,8 +44,9 @@ template class GTSAM_EXPORT
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f)
|
const DiscreteFactor& f)
|
||||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
|
||||||
|
BaseConditional(nrFrontals) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
|
@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
if (!dynamic_cast<const BaseFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
const BaseFactor& f(static_cast<const BaseFactor&>(other));
|
||||||
return DecisionTreeFactor::equals(f, tol);
|
return BaseFactor::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
ss << "*\n" << std::endl;
|
ss << "*\n" << std::endl;
|
||||||
if (nrParents() == 0) {
|
if (nrParents() == 0) {
|
||||||
// We have no parents, call factor method.
|
// We have no parents, call factor method.
|
||||||
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
ss << BaseFactor::markdown(keyFormatter, names);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
ss << "</i></p>\n";
|
ss << "</i></p>\n";
|
||||||
if (nrParents() == 0) {
|
if (nrParents() == 0) {
|
||||||
// We have no parents, call factor method.
|
// We have no parents, call factor method.
|
||||||
ss << DecisionTreeFactor::html(keyFormatter, names);
|
ss << BaseFactor::html(keyFormatter, names);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||||
return this->evaluate(x.discrete());
|
return this->operator()(x.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
DiscreteConditional() {}
|
DiscreteConditional() {}
|
||||||
|
|
||||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
virtual DiscreteFactor::shared_ptr multiply(
|
virtual DiscreteFactor::shared_ptr multiply(
|
||||||
const DiscreteFactor::shared_ptr& df) const = 0;
|
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||||
|
|
||||||
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
|
virtual DiscreteFactor::shared_ptr operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||||
|
|
||||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
|
||||||
|
|
||||||
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the number of non-zero values contained in this factor.
|
||||||
|
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||||
|
*/
|
||||||
|
virtual uint64_t nrValues() const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
||||||
DiscreteFactor::shared_ptr result;
|
DiscreteFactor::shared_ptr result;
|
||||||
for (auto it = this->begin(); it != this->end(); ++it) {
|
for (auto it = this->begin(); it != this->end(); ++it) {
|
||||||
if (*it) {
|
if (*it) {
|
||||||
|
@ -76,7 +76,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result->toDecisionTreeFactor();
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -122,20 +122,20 @@ namespace gtsam {
|
||||||
* @brief Multiply all the `factors`.
|
* @brief Multiply all the `factors`.
|
||||||
*
|
*
|
||||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
* @return DecisionTreeFactor
|
* @return DiscreteFactor::shared_ptr
|
||||||
*/
|
*/
|
||||||
static DecisionTreeFactor DiscreteProduct(
|
static DiscreteFactor::shared_ptr DiscreteProduct(
|
||||||
const DiscreteFactorGraph& factors) {
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product = factors.product();
|
DiscreteFactor::shared_ptr product = factors.product();
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
auto denominator = product.max(product.size());
|
auto denominator = product->max(product->size());
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*denominator);
|
product = product->operator/(denominator);
|
||||||
|
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
@ -145,25 +145,25 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
|
||||||
gttoc(max);
|
gttoc(max);
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
DiscreteKeys orderedKeys;
|
DiscreteKeys orderedKeys;
|
||||||
for (auto&& key : frontalKeys)
|
for (auto&& key : frontalKeys)
|
||||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||||
for (auto&& key : max->keys())
|
for (auto&& key : max->keys())
|
||||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||||
|
|
||||||
// Make lookup with product
|
// Make lookup with product
|
||||||
gttic(lookup);
|
gttic(lookup);
|
||||||
size_t nrFrontals = frontalKeys.size();
|
size_t nrFrontals = frontalKeys.size();
|
||||||
auto lookup =
|
auto lookup = std::make_shared<DiscreteLookupTable>(
|
||||||
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
|
||||||
gttoc(lookup);
|
gttoc(lookup);
|
||||||
|
|
||||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||||
|
@ -223,11 +223,11 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||||
gttoc(sum);
|
gttoc(sum);
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
@ -239,8 +239,9 @@ namespace gtsam {
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
gttic(divide);
|
gttic(divide);
|
||||||
auto conditional =
|
auto conditional = std::make_shared<DiscreteConditional>(
|
||||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
|
||||||
|
orderedKeys);
|
||||||
gttoc(divide);
|
gttoc(divide);
|
||||||
|
|
||||||
return {conditional, sum};
|
return {conditional, sum};
|
||||||
|
|
|
@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
//TODO(Varun): Make compatible with TableFactor
|
||||||
/** Add a decision-tree factor */
|
/** Add a decision-tree factor */
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void add(Args&&... args) {
|
void add(Args&&... args) {
|
||||||
|
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** return product of all factors as a single factor */
|
/** return product of all factors as a single factor */
|
||||||
DecisionTreeFactor product() const;
|
DiscreteFactor::shared_ptr product() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluates the factor graph given values, returns the joint probability of
|
* Evaluates the factor graph given values, returns the joint probability of
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
|
||||||
const ADT& potentials)
|
const ADT& potentials)
|
||||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a sorted list of gtsam::Keys
|
||||||
|
* @param potentials Discrete potentials as a TableFactor.
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const TableFactor& potentials)
|
||||||
|
: DiscreteConditional(nFrontals, keys,
|
||||||
|
potentials.toDecisionTreeFactor()) {}
|
||||||
|
|
||||||
/// GTSAM-style print
|
/// GTSAM-style print
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "Discrete Lookup Table: ",
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
|
|
@ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const {
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||||
|
return std::make_shared<TableFactor>(this->operator/(*tf));
|
||||||
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
return std::make_shared<TableFactor>(
|
||||||
|
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
|
||||||
|
} else {
|
||||||
|
TableFactor divisor(f->toDecisionTreeFactor());
|
||||||
|
return std::make_shared<TableFactor>(this->operator/(divisor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/Ring.h>
|
#include <gtsam/discrete/Ring.h>
|
||||||
|
@ -202,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
|
DiscreteFactor::shared_ptr operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
||||||
|
@ -210,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
DiscreteKeys parent_keys) const;
|
DiscreteKeys parent_keys) const;
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||||
return combine(nrFrontals, Ring::add);
|
return combine(nrFrontals, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(const Ordering& keys) const {
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||||
return combine(keys, Ring::add);
|
return combine(keys, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||||
return combine(nrFrontals, Ring::max);
|
return combine(nrFrontals, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(const Ordering& keys) const {
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||||
return combine(keys, Ring::max);
|
return combine(keys, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -330,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
TableFactor prune(size_t maxNrAssignments) const;
|
TableFactor prune(size_t maxNrAssignments) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the number of non-zero values contained in this factor.
|
||||||
|
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||||
|
*/
|
||||||
|
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -138,15 +138,18 @@ TEST(DecisionTreeFactor, sum_max) {
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v1, "9 12");
|
DecisionTreeFactor expected(v1, "9 12");
|
||||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
|
||||||
|
CHECK(actual);
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
|
|
||||||
DecisionTreeFactor expected2(v1, "5 6");
|
DecisionTreeFactor expected2(v1, "5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
|
||||||
|
CHECK(actual2);
|
||||||
CHECK(assert_equal(expected2, *actual2));
|
CHECK(assert_equal(expected2, *actual2));
|
||||||
|
|
||||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
|
||||||
|
CHECK(actual22);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
|
|
||||||
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
||||||
|
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
|
||||||
|
|
||||||
// Check if graph product works
|
// Check if graph product works
|
||||||
DecisionTreeFactor product = graph.product();
|
DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
|
||||||
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto normalizer = newFactor.max(newFactor.size());
|
auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
|
||||||
|
|
||||||
newFactor = newFactor / *normalizer;
|
newFactor = newFactor / denominator;
|
||||||
|
|
||||||
// Check Conditional
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
|
@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
CHECK(&newFactor);
|
CHECK(&newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
// Normalize by max.
|
// Normalize by max.
|
||||||
normalizer = expectedFactor.max(expectedFactor.size());
|
denominator =
|
||||||
// Ensure normalizer is correct.
|
expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
|
||||||
expectedFactor = expectedFactor / *normalizer;
|
// Ensure denominator is correct.
|
||||||
|
expectedFactor = expectedFactor / denominator;
|
||||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
||||||
// Test using elimination tree
|
// Test using elimination tree
|
||||||
|
|
|
@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
|
||||||
TEST(TableFactor, Empty) {
|
TEST(TableFactor, Empty) {
|
||||||
DiscreteKey X(1, 2);
|
DiscreteKey X(1, 2);
|
||||||
|
|
||||||
TableFactor single = *TableFactor({X}, "1 1").sum(1);
|
auto single = TableFactor({X}, "1 1").sum(1);
|
||||||
// Should not throw a segfault
|
// Should not throw a segfault
|
||||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
|
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
|
||||||
single.toDecisionTreeFactor()));
|
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
|
||||||
|
single->toDecisionTreeFactor()));
|
||||||
|
|
||||||
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
|
auto empty = TableFactor({X}, "0 0").sum(1);
|
||||||
// Should not throw a segfault
|
// Should not throw a segfault
|
||||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
|
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
|
||||||
empty.toDecisionTreeFactor()));
|
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
|
||||||
|
empty->toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
|
||||||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
TableFactor expected(v1, "9 12");
|
TableFactor expected(v1, "9 12");
|
||||||
TableFactor::shared_ptr actual = f1.sum(1);
|
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
|
||||||
|
CHECK(actual);
|
||||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||||
|
|
||||||
TableFactor expected2(v1, "5 6");
|
TableFactor expected2(v1, "5 6");
|
||||||
TableFactor::shared_ptr actual2 = f1.max(1);
|
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
|
||||||
|
CHECK(actual2);
|
||||||
CHECK(assert_equal(expected2, *actual2));
|
CHECK(assert_equal(expected2, *actual2));
|
||||||
|
|
||||||
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||||
TableFactor::shared_ptr actual22 = f2.sum(1);
|
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
|
||||||
|
CHECK(actual22);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
* @param (in/out) domains all domains, but only domains->at(j) will be checked.
|
* @param (in/out) domains all domains, but only domains->at(j) will be
|
||||||
|
* checked.
|
||||||
* @return true if domains->at(j) was changed, false otherwise.
|
* @return true if domains->at(j) was changed, false otherwise.
|
||||||
*/
|
*/
|
||||||
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
|
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
|
||||||
|
@ -86,6 +87,31 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
||||||
this->operator*(df->toDecisionTreeFactor()));
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
|
DiscreteFactor::shared_ptr operator/(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
return this->toDecisionTreeFactor() / df;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the number of non-zero values contained in this factor.
|
||||||
|
uint64_t nrValues() const override { return 1; };
|
||||||
|
|
||||||
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||||
|
return toDecisionTreeFactor().sum(nrFrontals);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||||
|
return toDecisionTreeFactor().sum(keys);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||||
|
return toDecisionTreeFactor().max(nrFrontals);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||||
|
return toDecisionTreeFactor().max(keys);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
/// Erase a value, non const :-(
|
/// Erase a value, non const :-(
|
||||||
void erase(size_t value) { values_.erase(value); }
|
void erase(size_t value) { values_.erase(value); }
|
||||||
|
|
||||||
size_t nrValues() const { return values_.size(); }
|
uint64_t nrValues() const override { return values_.size(); }
|
||||||
|
|
||||||
bool isSingleton() const { return nrValues() == 1; }
|
bool isSingleton() const { return nrValues() == 1; }
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Just for fun, create the product and check it
|
// Just for fun, create the product and check it
|
||||||
DecisionTreeFactor product = csp.product();
|
DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
|
||||||
// product.dot("product");
|
// product.dot("product");
|
||||||
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
||||||
EXPECT(assert_equal(expectedProduct, product));
|
EXPECT(assert_equal(expectedProduct, product));
|
||||||
|
|
|
@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
|
||||||
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
|
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
|
||||||
|
|
||||||
// Do brute force product and output that to file
|
// Do brute force product and output that to file
|
||||||
DecisionTreeFactor product = s.product();
|
DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
|
||||||
// product.dot("scheduling", false);
|
// product.dot("scheduling", false);
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
|
|
Loading…
Reference in New Issue