Merge pull request #1919 from borglab/discrete-elimination-refactor

release/4.3a0
Varun Agrawal 2025-01-06 20:33:59 -05:00 committed by GitHub
commit 82d0ebc8fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 188 additions and 60 deletions

View File

@ -87,6 +87,25 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// Check if `f` is a TableFactor. If yes, then
// convert `this` to a TableFactor which is cheaper.
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, divide normally.
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
} else {
// Else, convert `f` to a DecisionTreeFactor so we can divide
return std::make_shared<DecisionTreeFactor>(
this->operator/(f->toDecisionTreeFactor()));
}
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -184,26 +184,30 @@ namespace gtsam {
return apply(f, safe_div); return apply(f, safe_div);
} }
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decision tree /// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add); return combine(nrFrontals, Ring::add);
} }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const { DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add); return combine(keys, Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max); return combine(nrFrontals, Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const { DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
@ -284,6 +288,12 @@ namespace gtsam {
*/ */
DecisionTreeFactor prune(size_t maxNrAssignments) const; DecisionTreeFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <algorithm> #include <algorithm>
#include <cassert>
#include <random> #include <random>
#include <set> #include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <cassert>
using namespace std; using namespace std;
using std::pair; using std::pair;
@ -44,8 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals, DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) const DiscreteFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} : BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
BaseConditional(nrFrontals) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals, DiscreteConditional::DiscreteConditional(size_t nrFrontals,
@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { if (!dynamic_cast<const BaseFactor*>(&other)) {
return false; return false;
} else { } else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other)); const BaseFactor& f(static_cast<const BaseFactor&>(other));
return DecisionTreeFactor::equals(f, tol); return BaseFactor::equals(f, tol);
} }
} }
@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl; ss << "*\n" << std::endl;
if (nrParents() == 0) { if (nrParents() == 0) {
// We have no parents, call factor method. // We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names); ss << BaseFactor::markdown(keyFormatter, names);
return ss.str(); return ss.str();
} }
@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n"; ss << "</i></p>\n";
if (nrParents() == 0) { if (nrParents() == 0) {
// We have no parents, call factor method. // We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names); ss << BaseFactor::html(keyFormatter, names);
return ss.str(); return ss.str();
} }
@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const { double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete()); return this->operator()(x.discrete());
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional() {} DiscreteConditional() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals. /// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
/** /**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>
#include <string> #include <string>
namespace gtsam { namespace gtsam {
@ -139,8 +140,30 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
virtual DiscreteFactor::shared_ptr multiply( virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0; const DiscreteFactor::shared_ptr& df) const = 0;
/// divide by DiscreteFactor::shared_ptr f (safely)
virtual DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -64,7 +64,7 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const { DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result; DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) { for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) { if (*it) {
@ -76,7 +76,7 @@ namespace gtsam {
} }
} }
} }
return result->toDecisionTreeFactor(); return result;
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`. * @brief Multiply all the `factors`.
* *
* @param factors The factors to multiply as a DiscreteFactorGraph. * @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor * @return DiscreteFactor::shared_ptr
*/ */
static DecisionTreeFactor DiscreteProduct( static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) { const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product = factors.product(); DiscreteFactor::shared_ptr product = factors.product();
gttoc(product); gttoc(product);
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size()); auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*denominator); product = product->operator/(denominator);
return product; return product;
} }
@ -145,25 +145,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max); gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys; DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys) for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key)); orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys()) for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key)); orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product // Make lookup with product
gttic(lookup); gttic(lookup);
size_t nrFrontals = frontalKeys.size(); size_t nrFrontals = frontalKeys.size();
auto lookup = auto lookup = std::make_shared<DiscreteLookupTable>(
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product); nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup); gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -223,11 +223,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum); gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
@ -239,8 +239,9 @@ namespace gtsam {
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
auto conditional = auto conditional = std::make_shared<DiscreteConditional>(
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys); product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide); gttoc(divide);
return {conditional, sum}; return {conditional, sum};

View File

@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @} /// @}
//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */ /** Add a decision-tree factor */
template <typename... Args> template <typename... Args>
void add(Args&&... args) { void add(Args&&... args) {
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DecisionTreeFactor product() const; DiscreteFactor::shared_ptr product() const;
/** /**
* Evaluates the factor graph given values, returns the joint probability of * Evaluates the factor graph given values, returns the joint probability of

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials) const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {} : DiscreteConditional(nFrontals, keys, potentials) {}
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}
/// GTSAM-style print /// GTSAM-style print
void print( void print(
const std::string& s = "Discrete Lookup Table: ", const std::string& s = "Discrete Lookup Table: ",

View File

@ -280,6 +280,20 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
return result; return result;
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::operator/(
const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator/(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<TableFactor>(
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
} else {
TableFactor divisor(f->toDecisionTreeFactor());
return std::make_shared<TableFactor>(this->operator/(divisor));
}
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();

View File

@ -17,6 +17,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h> #include <gtsam/discrete/Ring.h>
@ -202,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div); return apply(f, safe_div);
} }
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
@ -210,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const; DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add); return combine(nrFrontals, Ring::add);
} }
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const { DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add); return combine(keys, Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max); return combine(nrFrontals, Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator. /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const { DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
@ -330,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
TableFactor prune(size_t maxNrAssignments) const; TableFactor prune(size_t maxNrAssignments) const;
/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -138,15 +138,18 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
DecisionTreeFactor expected2(v1, "5 6"); DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6"); DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
CHECK(actual22);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2( DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2); DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1); DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2))); EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
} }

View File

@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9); EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
// Check if graph product works // Check if graph product works
DecisionTreeFactor product = graph.product(); DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9); EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
} }
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr); *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected // Normalize newFactor by max for comparison with expected
auto normalizer = newFactor.max(newFactor.size()); auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
newFactor = newFactor / *normalizer; newFactor = newFactor / denominator;
// Check Conditional // Check Conditional
CHECK(conditional); CHECK(conditional);
@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor); CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max. // Normalize by max.
normalizer = expectedFactor.max(expectedFactor.size()); denominator =
// Ensure normalizer is correct. expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
expectedFactor = expectedFactor / *normalizer; // Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor)); EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree // Test using elimination tree

View File

@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) { TEST(TableFactor, Empty) {
DiscreteKey X(1, 2); DiscreteKey X(1, 2);
TableFactor single = *TableFactor({X}, "1 1").sum(1); auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault // Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
single.toDecisionTreeFactor())); EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
single->toDecisionTreeFactor()));
TableFactor empty = *TableFactor({X}, "0 0").sum(1); auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault // Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
empty.toDecisionTreeFactor())); EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
empty->toDecisionTreeFactor()));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6"); TableFactor f1(v0 & v1, "1 2 3 4 5 6");
TableFactor expected(v1, "9 12"); TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1); auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(actual);
CHECK(assert_equal(expected, *actual, 1e-5)); CHECK(assert_equal(expected, *actual, 1e-5));
TableFactor expected2(v1, "5 6"); TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1); auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(actual2);
CHECK(assert_equal(expected2, *actual2)); CHECK(assert_equal(expected2, *actual2));
TableFactor f2(v1 & v0, "1 2 3 4 5 6"); TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1); auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
CHECK(actual22);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/* /*
* Ensure Arc-consistency by checking every possible value of domain j. * Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked * @param j domain to be checked
* @param (in/out) domains all domains, but only domains->at(j) will be checked. * @param (in/out) domains all domains, but only domains->at(j) will be
* checked.
* @return true if domains->at(j) was changed, false otherwise. * @return true if domains->at(j) was changed, false otherwise.
*/ */
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
@ -86,6 +87,31 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
this->operator*(df->toDecisionTreeFactor())); this->operator*(df->toDecisionTreeFactor()));
} }
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& df) const override {
return this->toDecisionTreeFactor() / df;
}
/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return toDecisionTreeFactor().sum(nrFrontals);
}
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return toDecisionTreeFactor().sum(keys);
}
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return toDecisionTreeFactor().max(nrFrontals);
}
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return toDecisionTreeFactor().max(keys);
}
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Erase a value, non const :-( /// Erase a value, non const :-(
void erase(size_t value) { values_.erase(value); } void erase(size_t value) { values_.erase(value); }
size_t nrValues() const { return values_.size(); } uint64_t nrValues() const override { return values_.size(); }
bool isSingleton() const { return nrValues() == 1; } bool isSingleton() const { return nrValues() == 1; }

View File

@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Just for fun, create the product and check it // Just for fun, create the product and check it
DecisionTreeFactor product = csp.product(); DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
// product.dot("product"); // product.dot("product");
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
EXPECT(assert_equal(expectedProduct, product)); EXPECT(assert_equal(expectedProduct, product));

View File

@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s)); EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
// Do brute force product and output that to file // Do brute force product and output that to file
DecisionTreeFactor product = s.product(); DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
// product.dot("scheduling", false); // product.dot("scheduling", false);
// Do exact inference // Do exact inference