Merge branch 'develop' into hybrid-timing

release/4.3a0
Varun Agrawal 2025-01-06 20:38:10 -05:00
commit edef8c8481
18 changed files with 342 additions and 69 deletions

View File

@ -18,9 +18,10 @@
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <utility> #include <utility>
@ -62,6 +63,49 @@ namespace gtsam {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If f is a TableFactor, we convert `this` to a TableFactor since this
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
// then return a TableFactor.
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, simply call operator*.
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated.
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
}
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// Check if `f` is a TableFactor. If yes, then
// convert `this` to a TableFactor which is cheaper.
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, divide normally.
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
} else {
// Else, convert `f` to a DecisionTreeFactor so we can divide
return std::make_shared<DecisionTreeFactor>(
this->operator/(f->toDecisionTreeFactor()));
}
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -147,6 +147,23 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not. Hence we specialize the code to return a TableFactor if
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul); return apply(f, Ring::mul);
@ -154,31 +171,43 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
/// divide by factor f (safely) /**
* @brief Divide by factor f (safely).
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
* results in a function which involves all keys
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
*
* @param f The DecisinTreeFactor to divide by.
* @return DecisionTreeFactor
*/
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div); return apply(f, safe_div);
} }
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decision tree /// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add); return combine(nrFrontals, Ring::add);
} }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const { DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add); return combine(keys, Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max); return combine(nrFrontals, Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const { DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
@ -259,6 +288,12 @@ namespace gtsam {
*/ */
DecisionTreeFactor prune(size_t maxNrAssignments) const; DecisionTreeFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -44,8 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals, DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) const DiscreteFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
BaseConditional(nrFrontals) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals, DiscreteConditional::DiscreteConditional(size_t nrFrontals,
@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { if (!dynamic_cast<const BaseFactor*>(&other)) {
return false; return false;
} else { } else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other)); const BaseFactor& f(static_cast<const BaseFactor&>(other));
return DecisionTreeFactor::equals(f, tol); return BaseFactor::equals(f, tol);
} }
} }
@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl; ss << "*\n" << std::endl;
if (nrParents() == 0) { if (nrParents() == 0) {
// We have no parents, call factor method. // We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names); ss << BaseFactor::markdown(keyFormatter, names);
return ss.str(); return ss.str();
} }
@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n"; ss << "</i></p>\n";
if (nrParents() == 0) { if (nrParents() == 0) {
// We have no parents, call factor method. // We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names); ss << BaseFactor::html(keyFormatter, names);
return ss.str(); return ss.str();
} }
@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const { double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete()); return this->operator()(x.discrete());
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {} DiscreteConditional() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals. /// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
/** /**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>
#include <string> #include <string>
namespace gtsam { namespace gtsam {
@ -129,8 +130,40 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor /// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -64,10 +64,17 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const { DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DiscreteFactor::shared_ptr result;
for (const sharedFactor& factor : *this) { for (auto it = this->begin(); it != this->end(); ++it) {
if (factor) result = (*factor) * result; if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
}
} }
return result; return result;
} }
@ -115,21 +122,23 @@ namespace gtsam {
* @brief Multiply all the `factors`. * @brief Multiply all the `factors`.
* *
* @param factors The factors to multiply as a DiscreteFactorGraph. * @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor * @return DiscreteFactor::shared_ptr
*/ */
static DecisionTreeFactor DiscreteProduct( static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) { const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
DecisionTreeFactor product = factors.product(); gttic(product);
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize); gttic_(DiscreteNormalize);
#endif #endif
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size()); auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*denominator); product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize); gttoc_(DiscreteNormalize);
#endif #endif
@ -142,25 +151,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max); gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys; DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys) for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key)); orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys()) for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key)); orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product // Make lookup with product
gttic(lookup); gttic(lookup);
size_t nrFrontals = frontalKeys.size(); size_t nrFrontals = frontalKeys.size();
auto lookup = auto lookup = std::make_shared<DiscreteLookupTable>(
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product); nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup); gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -220,10 +229,12 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); gttic(sum);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
@ -233,8 +244,11 @@ namespace gtsam {
sum->keys().end()); sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
auto conditional = gttic(divide);
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); auto conditional = std::make_shared<DiscreteConditional>(
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide);
return {conditional, sum}; return {conditional, sum};
} }

View File

@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @} /// @}
//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */ /** Add a decision-tree factor */
template <typename... Args> template <typename... Args>
void add(Args&&... args) { void add(Args&&... args) {
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; DiscreteFactor::shared_ptr product() const;
/** /**
* Evaluates the factor graph given values, returns the joint probability of * Evaluates the factor graph given values, returns the joint probability of

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials) const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {} : DiscreteConditional(nFrontals, keys, potentials) {}
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}
/// GTSAM-style print /// GTSAM-style print
void print( void print(
const std::string& s = "Discrete Lookup Table: ", const std::string& s = "Discrete Lookup Table: ",

View File

@ -254,6 +254,46 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If `f` is a TableFactor, we can simply call `operator*`.
result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
// cheaper than converting `this` to a DecisionTreeFactor.
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated to know about TableFactor.
// Those classes can be specialized to use TableFactor
// if efficiency is a problem.
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
}
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();

View File

@ -17,6 +17,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h> #include <gtsam/discrete/Ring.h>
@ -178,6 +179,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor /// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not.
* Hence we specialize the code to return a TableFactor always.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
/// divide by factor f (safely) /// divide by factor f (safely)
@ -185,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div); return apply(f, safe_div);
} }
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
@ -193,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const; DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add); return combine(nrFrontals, Ring::add);
} }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const { DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add); return combine(keys, Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max); return combine(nrFrontals, Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const { DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
@ -313,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
TableFactor prune(size_t maxNrAssignments) const; TableFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -30,6 +30,12 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) { TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys // Declare two keys
@ -105,21 +111,45 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual)); CHECK(assert_equal(expected2, actual));
} }
/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;
DecisionTreeFactor s = joint / pA;
// Factors are not equal due to difference in keys
EXPECT(assert_inequal(pS, s));
// The underlying data should be the same
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));
KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());
EXPECT(assert_inequal(KeySet(pS.keys()), keys));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) { TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2); DiscreteKey v0(0, 3), v1(1, 2);
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6"); DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
CHECK(actual22);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -217,12 +247,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif #endif
} }
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(DecisionTreeFactor, joint) { TEST(DecisionTreeFactor, joint) {

View File

@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
} }

View File

@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works // Check if graph product works
DecisionTreeFactor product = graph.product(); DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
} }
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr); *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected // Normalize newFactor by max for comparison with expected
auto normalizer = newFactor.max(newFactor.size()); auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
newFactor = newFactor / *normalizer; newFactor = newFactor / denominator;
// Check Conditional // Check Conditional
CHECK(conditional); CHECK(conditional);
@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor); CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max. // Normalize by max.
normalizer = expectedFactor.max(expectedFactor.size()); denominator =
// Ensure normalizer is correct. expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
expectedFactor = expectedFactor / *normalizer; // Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor)); EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree // Test using elimination tree

View File

@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) { TEST(TableFactor, Empty) {
DiscreteKey X(1, 2); DiscreteKey X(1, 2);
TableFactor single = *TableFactor({X}, "1 1").sum(1); auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault // Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
single.toDecisionTreeFactor())); EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
single->toDecisionTreeFactor()));
TableFactor empty = *TableFactor({X}, "0 0").sum(1); auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault // Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
empty.toDecisionTreeFactor())); EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
empty->toDecisionTreeFactor()));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12"); TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6"); TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6"); TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
CHECK(actual22);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/* /*
* Ensure Arc-consistency by checking every possible value of domain j. * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked
* @param (in/out) domains all domains, but only domains->at(j) will be checked. * @param (in/out) domains all domains, but only domains->at(j) will be
* checked.
* @return true if domains->at(j) was changed, false otherwise. * @return true if domains->at(j) was changed, false otherwise.
*/ */
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@ -78,6 +79,39 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version /// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0; virtual shared_ptr partiallyApply(const Domains&) const = 0;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override {
return this->toDecisionTreeFactor() / df;
}
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return toDecisionTreeFactor().max(keys);
}
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-( /// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); } void erase(size_t value) { values_.erase(value); }
size_t nrValues() const { return values_.size(); } uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; } bool isSingleton() const { return nrValues() == 1; }

View File

@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it // Just for fun, create the product and check it
DecisionTreeFactor product = csp.product(); DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product"); // product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product)); EXPECT(assert_equal(expectedProduct, product));

View File

@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file // Do brute force product and output that to file
DecisionTreeFactor product = s.product(); DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false); // product.dot("scheduling", false);
// Do exact inference // Do exact inference