Merge branch 'hybrid-timing' into discrete-table-conditional

release/4.3a0
Varun Agrawal 2025-01-06 20:38:38 -05:00
commit 8658f25edd
18 changed files with 342 additions and 69 deletions

View File

@ -18,9 +18,10 @@
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <utility>
@ -62,6 +63,49 @@ namespace gtsam {
return error(values.discrete());
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If f is a TableFactor, we convert `this` to a TableFactor since this
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
// then return a TableFactor.
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, simply call operator*.
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated.
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
}
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// Check if `f` is a TableFactor. If yes, then
// convert `this` to a TableFactor which is cheaper.
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, divide normally.
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
} else {
// Else, convert `f` to a DecisionTreeFactor so we can divide
return std::make_shared<DecisionTreeFactor>(
this->operator/(f->toDecisionTreeFactor()));
}
}
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum

View File

@ -147,6 +147,23 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not. Hence we specialize the code to return a TableFactor if
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
@ -154,31 +171,43 @@ namespace gtsam {
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
/**
* @brief Divide by factor f (safely).
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
* results in a function which involves all keys
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
*
* @param f The DecisinTreeFactor to divide by.
* @return DecisionTreeFactor
*/
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@ -259,6 +288,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }
/// @}
/// @name Wrapper support
/// @{

View File

@ -44,8 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
const DiscreteFactor& f)
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
if (!dynamic_cast<const BaseFactor*>(&other)) {
return false;
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
const BaseFactor& f(static_cast<const BaseFactor&>(other));
return BaseFactor::equals(f, tol);
}
}
@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}
@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}
@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
return this->operator()(x.discrete());
}
/* ************************************************************************* */

View File

@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>
#include <string>
namespace gtsam {
@ -129,8 +130,40 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;
/// @}
/// @name Wrapper support
/// @{

View File

@ -64,10 +64,17 @@ namespace gtsam {
}
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
}
}
return result;
}
@ -115,21 +122,23 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
* @return DiscreteFactor::shared_ptr
*/
static DecisionTreeFactor DiscreteProduct(
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
DecisionTreeFactor product = factors.product();
gttic(product);
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
@ -142,25 +151,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -220,10 +229,12 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttic(sum);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
@ -233,8 +244,11 @@ namespace gtsam {
sum->keys().end());
// now divide product/sum to get conditional
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttic(divide);
auto conditional = std::make_shared<DiscreteConditional>(
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide);
return {conditional, sum};
}

View File

@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of

View File

@ -18,6 +18,7 @@
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",

View File

@ -254,6 +254,46 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If `f` is a TableFactor, we can simply call `operator*`.
result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
// cheaper than converting `this` to a DecisionTreeFactor.
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated to know about TableFactor.
// Those classes can be specialized to use TableFactor
// if efficiency is a problem.
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
}
return result;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

View File

@ -17,6 +17,7 @@
#pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h>
@ -178,6 +179,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not.
* Hence we specialize the code to return a TableFactor always.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)
@ -185,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
@ -193,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}
@ -313,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
TableFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// @}
/// @name Wrapper support
/// @{

View File

@ -30,6 +30,12 @@
using namespace std;
using namespace gtsam;
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
@ -105,21 +111,45 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual));
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;
DecisionTreeFactor s = joint / pA;
// Factors are not equal due to difference in keys
EXPECT(assert_inequal(pS, s));
// The underlying data should be the same
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));
KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());
EXPECT(assert_inequal(KeySet(pS.keys()), keys));
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
CHECK(actual22);
}
/* ************************************************************************* */
@ -217,12 +247,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {

View File

@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}

View File

@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works
DecisionTreeFactor product = graph.product();
DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
}
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected
auto normalizer = newFactor.max(newFactor.size());
auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
newFactor = newFactor / *normalizer;
newFactor = newFactor / denominator;
// Check Conditional
CHECK(conditional);
@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
normalizer = expectedFactor.max(expectedFactor.size());
// Ensure normalizer is correct.
expectedFactor = expectedFactor / *normalizer;
denominator =
expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
// Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree

View File

@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) {
DiscreteKey X(1, 2);
TableFactor single = *TableFactor({X}, "1 1").sum(1);
auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
single.toDecisionTreeFactor()));
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
single->toDecisionTreeFactor()));
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
empty.toDecisionTreeFactor()));
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
empty->toDecisionTreeFactor()));
}
/* ************************************************************************* */
@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
CHECK(actual22);
}
/* ************************************************************************* */

View File

@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
* @param (in/out) domains all domains, but only domains->at(j) will be checked.
* @param (in/out) domains all domains, but only domains->at(j) will be
* checked.
* @return true if domains->at(j) was changed, false otherwise.
*/
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@ -78,6 +79,39 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override {
return this->toDecisionTreeFactor() / df;
}
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return toDecisionTreeFactor().max(keys);
}
/// @}
/// @name Wrapper support
/// @{

View File

@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); }
size_t nrValues() const { return values_.size(); }
uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; }

View File

@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it
DecisionTreeFactor product = csp.product();
DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product));

View File

@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file
DecisionTreeFactor product = s.product();
DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false);
// Do exact inference