Merge branch 'hybrid-timing' into discrete-table-conditional
commit
8658f25edd
|
@ -18,9 +18,10 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
|
@ -62,6 +63,49 @@ namespace gtsam {
|
|||
return error(values.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
DiscreteFactor::shared_ptr result;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
// If f is a TableFactor, we convert `this` to a TableFactor since this
|
||||
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
|
||||
// then return a TableFactor.
|
||||
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
|
||||
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// If `f` is a DecisionTreeFactor, simply call operator*.
|
||||
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
|
||||
|
||||
} else {
|
||||
// Simulate double dispatch in C++
|
||||
// Useful for other classes which inherit from DiscreteFactor and have
|
||||
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
|
||||
// need to be updated.
|
||||
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr DecisionTreeFactor::operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
// Check if `f` is a TableFactor. If yes, then
|
||||
// convert `this` to a TableFactor which is cheaper.
|
||||
return std::make_shared<TableFactor>(tf->operator/(TableFactor(*this)));
|
||||
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// If `f` is a DecisionTreeFactor, divide normally.
|
||||
return std::make_shared<DecisionTreeFactor>(this->operator/(*dtf));
|
||||
|
||||
} else {
|
||||
// Else, convert `f` to a DecisionTreeFactor so we can divide
|
||||
return std::make_shared<DecisionTreeFactor>(
|
||||
this->operator/(f->toDecisionTreeFactor()));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||
// The use for safe_div is when we divide the product factor by the sum
|
||||
|
|
|
@ -147,6 +147,23 @@ namespace gtsam {
|
|||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||
double error(const DiscreteValues& values) const override;
|
||||
|
||||
/**
|
||||
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
|
||||
*
|
||||
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
|
||||
* dispatch and specializations to perform the most efficient
|
||||
* multiplication.
|
||||
*
|
||||
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
|
||||
* reverse is not. Hence we specialize the code to return a TableFactor if
|
||||
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
|
||||
*
|
||||
* @param f The factor to multiply with.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
virtual DiscreteFactor::shared_ptr multiply(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/// multiply two factors
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||
return apply(f, Ring::mul);
|
||||
|
@ -154,31 +171,43 @@ namespace gtsam {
|
|||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
/// divide by factor f (safely)
|
||||
/**
|
||||
* @brief Divide by factor f (safely).
|
||||
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
|
||||
* results in a function which involves all keys
|
||||
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
|
||||
*
|
||||
* @param f The DecisinTreeFactor to divide by.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
|
||||
return apply(f, safe_div);
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/// Convert into a decision tree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(size_t nrFrontals) const {
|
||||
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||
return combine(nrFrontals, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(const Ordering& keys) const {
|
||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||
return combine(keys, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||
return combine(nrFrontals, Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||
return combine(keys, Ring::max);
|
||||
}
|
||||
|
||||
|
@ -259,6 +288,12 @@ namespace gtsam {
|
|||
*/
|
||||
DecisionTreeFactor prune(size_t maxNrAssignments) const;
|
||||
|
||||
/**
|
||||
* Get the number of non-zero values contained in this factor.
|
||||
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||
*/
|
||||
uint64_t nrValues() const override { return nrLeaves(); }
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -44,8 +44,9 @@ template class GTSAM_EXPORT
|
|||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||
const DecisionTreeFactor& f)
|
||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||
const DiscreteFactor& f)
|
||||
: BaseFactor((f / f.sum(nrFrontals))->toDecisionTreeFactor()),
|
||||
BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
|
@ -150,11 +151,11 @@ void DiscreteConditional::print(const string& s,
|
|||
/* ************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
if (!dynamic_cast<const BaseFactor*>(&other)) {
|
||||
return false;
|
||||
} else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return DecisionTreeFactor::equals(f, tol);
|
||||
const BaseFactor& f(static_cast<const BaseFactor&>(other));
|
||||
return BaseFactor::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -375,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
|||
ss << "*\n" << std::endl;
|
||||
if (nrParents() == 0) {
|
||||
// We have no parents, call factor method.
|
||||
ss << DecisionTreeFactor::markdown(keyFormatter, names);
|
||||
ss << BaseFactor::markdown(keyFormatter, names);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
@ -427,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
|||
ss << "</i></p>\n";
|
||||
if (nrParents() == 0) {
|
||||
// We have no parents, call factor method.
|
||||
ss << DecisionTreeFactor::html(keyFormatter, names);
|
||||
ss << BaseFactor::html(keyFormatter, names);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
@ -475,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
|||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||
return this->evaluate(x.discrete());
|
||||
return this->operator()(x.discrete());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -54,7 +54,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
DiscreteConditional() {}
|
||||
|
||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||
DiscreteConditional(size_t nFrontals, const DiscreteFactor& f);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <string>
|
||||
namespace gtsam {
|
||||
|
@ -129,8 +130,40 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
/// DecisionTreeFactor
|
||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Multiply in a DiscreteFactor and return the result as
|
||||
* DiscreteFactor, both via shared pointers.
|
||||
*
|
||||
* @param df DiscreteFactor shared_ptr
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
virtual DiscreteFactor::shared_ptr multiply(
|
||||
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
virtual DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||
|
||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;
|
||||
|
||||
/**
|
||||
* Get the number of non-zero values contained in this factor.
|
||||
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||
*/
|
||||
virtual uint64_t nrValues() const = 0;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -64,10 +64,17 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
for (const sharedFactor& factor : *this) {
|
||||
if (factor) result = (*factor) * result;
|
||||
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
||||
DiscreteFactor::shared_ptr result;
|
||||
for (auto it = this->begin(); it != this->end(); ++it) {
|
||||
if (*it) {
|
||||
if (result) {
|
||||
result = result->multiply(*it);
|
||||
} else {
|
||||
// Assign to the first non-null factor
|
||||
result = *it;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -115,21 +122,23 @@ namespace gtsam {
|
|||
* @brief Multiply all the `factors`.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DecisionTreeFactor
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
static DecisionTreeFactor DiscreteProduct(
|
||||
static DiscreteFactor::shared_ptr DiscreteProduct(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
DecisionTreeFactor product = factors.product();
|
||||
gttic(product);
|
||||
DiscreteFactor::shared_ptr product = factors.product();
|
||||
gttoc(product);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
#endif
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto denominator = product.max(product.size());
|
||||
auto denominator = product->max(product->size());
|
||||
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*denominator);
|
||||
product = product->operator/(denominator);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteNormalize);
|
||||
#endif
|
||||
|
@ -142,25 +151,25 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
DiscreteKeys orderedKeys;
|
||||
for (auto&& key : frontalKeys)
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||
for (auto&& key : max->keys())
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||
|
||||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup =
|
||||
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
||||
auto lookup = std::make_shared<DiscreteLookupTable>(
|
||||
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
|
||||
gttoc(lookup);
|
||||
|
||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||
|
@ -220,10 +229,12 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
gttic(sum);
|
||||
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||
gttoc(sum);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
|
@ -233,8 +244,11 @@ namespace gtsam {
|
|||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
gttic(divide);
|
||||
auto conditional = std::make_shared<DiscreteConditional>(
|
||||
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
|
||||
orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
return {conditional, sum};
|
||||
}
|
||||
|
|
|
@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
|
||||
/// @}
|
||||
|
||||
//TODO(Varun): Make compatible with TableFactor
|
||||
/** Add a decision-tree factor */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
|
@ -147,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
DiscreteFactor::shared_ptr product() const;
|
||||
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
|
@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
|
|||
const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new Discrete Lookup Table object
|
||||
*
|
||||
* @param nFrontals number of frontal variables
|
||||
* @param keys a sorted list of gtsam::Keys
|
||||
* @param potentials Discrete potentials as a TableFactor.
|
||||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const TableFactor& potentials)
|
||||
: DiscreteConditional(nFrontals, keys,
|
||||
potentials.toDecisionTreeFactor()) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Lookup Table: ",
|
||||
|
|
|
@ -254,6 +254,46 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
|||
return toDecisionTreeFactor() * f;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr TableFactor::multiply(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
DiscreteFactor::shared_ptr result;
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
// If `f` is a TableFactor, we can simply call `operator*`.
|
||||
result = std::make_shared<TableFactor>(this->operator*(*tf));
|
||||
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
|
||||
// cheaper than converting `this` to a DecisionTreeFactor.
|
||||
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
|
||||
|
||||
} else {
|
||||
// Simulate double dispatch in C++
|
||||
// Useful for other classes which inherit from DiscreteFactor and have
|
||||
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
|
||||
// need to be updated to know about TableFactor.
|
||||
// Those classes can be specialized to use TableFactor
|
||||
// if efficiency is a problem.
|
||||
result = std::make_shared<DecisionTreeFactor>(
|
||||
f->operator*(this->toDecisionTreeFactor()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteFactor::shared_ptr TableFactor::operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const {
|
||||
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||
return std::make_shared<TableFactor>(this->operator/(*tf));
|
||||
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
return std::make_shared<TableFactor>(
|
||||
this->operator/(TableFactor(f->discreteKeys(), *dtf)));
|
||||
} else {
|
||||
TableFactor divisor(f->toDecisionTreeFactor());
|
||||
return std::make_shared<TableFactor>(this->operator/(divisor));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/Ring.h>
|
||||
|
@ -178,6 +179,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
/// multiply with DecisionTreeFactor
|
||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/**
|
||||
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
|
||||
*
|
||||
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
|
||||
* dispatch and specializations to perform the most efficient
|
||||
* multiplication.
|
||||
*
|
||||
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
|
||||
* reverse is not.
|
||||
* Hence we specialize the code to return a TableFactor always.
|
||||
*
|
||||
* @param f The factor to multiply with.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
virtual DiscreteFactor::shared_ptr multiply(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
static double safe_div(const double& a, const double& b);
|
||||
|
||||
/// divide by factor f (safely)
|
||||
|
@ -185,6 +203,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
return apply(f, safe_div);
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& f) const override;
|
||||
|
||||
/// Convert into a decisiontree
|
||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
|
@ -193,22 +215,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
DiscreteKeys parent_keys) const;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(size_t nrFrontals) const {
|
||||
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||
return combine(nrFrontals, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(const Ordering& keys) const {
|
||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||
return combine(keys, Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||
return combine(nrFrontals, Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||
return combine(keys, Ring::max);
|
||||
}
|
||||
|
||||
|
@ -313,6 +335,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
*/
|
||||
TableFactor prune(size_t maxNrAssignments) const;
|
||||
|
||||
/**
|
||||
* Get the number of non-zero values contained in this factor.
|
||||
* It could be much smaller than `prod_{key}(cardinality(key))`.
|
||||
*/
|
||||
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -30,6 +30,12 @@
|
|||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
DecisionTreeFactor create(const Signature& signature) {
|
||||
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||
return p;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DecisionTreeFactor, ConstructorsMatch) {
|
||||
// Declare two keys
|
||||
|
@ -105,21 +111,45 @@ TEST(DecisionTreeFactor, multiplication) {
|
|||
CHECK(assert_equal(expected2, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DecisionTreeFactor, Divide) {
|
||||
DiscreteKey A(0, 2), S(1, 2);
|
||||
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
|
||||
DecisionTreeFactor joint = pA * pS;
|
||||
|
||||
DecisionTreeFactor s = joint / pA;
|
||||
|
||||
// Factors are not equal due to difference in keys
|
||||
EXPECT(assert_inequal(pS, s));
|
||||
|
||||
// The underlying data should be the same
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
EXPECT(assert_equal(ADT(pS), ADT(s)));
|
||||
|
||||
KeySet keys(joint.keys());
|
||||
keys.insert(pA.keys().begin(), pA.keys().end());
|
||||
EXPECT(assert_inequal(KeySet(pS.keys()), keys));
|
||||
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DecisionTreeFactor, sum_max) {
|
||||
DiscreteKey v0(0, 3), v1(1, 2);
|
||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
DecisionTreeFactor expected(v1, "9 12");
|
||||
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
|
||||
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
|
||||
CHECK(actual);
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
|
||||
DecisionTreeFactor expected2(v1, "5 6");
|
||||
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
|
||||
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
|
||||
CHECK(actual2);
|
||||
CHECK(assert_equal(expected2, *actual2));
|
||||
|
||||
DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
|
||||
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
|
||||
CHECK(actual22);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -217,12 +247,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
|
|||
#endif
|
||||
}
|
||||
|
||||
/** Convert Signature into CPT */
|
||||
DecisionTreeFactor create(const Signature& signature) {
|
||||
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||
return p;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// test Asia Joint
|
||||
TEST(DecisionTreeFactor, joint) {
|
||||
|
|
|
@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
|
|||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
|
||||
std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
|
||||
|
@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
|||
DecisionTreeFactor f2(
|
||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||
DiscreteConditional actual2(1, f2);
|
||||
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||
DecisionTreeFactor expected2 = f2 / f2.sum(1)->toDecisionTreeFactor();
|
||||
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||
}
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
|||
EXPECT_DOUBLES_EQUAL( 1.944, graph(values), 1e-9);
|
||||
|
||||
// Check if graph product works
|
||||
DecisionTreeFactor product = graph.product();
|
||||
DecisionTreeFactor product = graph.product()->toDecisionTreeFactor();
|
||||
EXPECT_DOUBLES_EQUAL( 1.944, product(values), 1e-9);
|
||||
}
|
||||
|
||||
|
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
|
|||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||
|
||||
// Normalize newFactor by max for comparison with expected
|
||||
auto normalizer = newFactor.max(newFactor.size());
|
||||
auto denominator = newFactor.max(newFactor.size())->toDecisionTreeFactor();
|
||||
|
||||
newFactor = newFactor / *normalizer;
|
||||
newFactor = newFactor / denominator;
|
||||
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
|
@ -131,9 +131,10 @@ TEST(DiscreteFactorGraph, test) {
|
|||
CHECK(&newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
// Normalize by max.
|
||||
normalizer = expectedFactor.max(expectedFactor.size());
|
||||
// Ensure normalizer is correct.
|
||||
expectedFactor = expectedFactor / *normalizer;
|
||||
denominator =
|
||||
expectedFactor.max(expectedFactor.size())->toDecisionTreeFactor();
|
||||
// Ensure denominator is correct.
|
||||
expectedFactor = expectedFactor / denominator;
|
||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||
|
||||
// Test using elimination tree
|
||||
|
|
|
@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
|
|||
TEST(TableFactor, Empty) {
|
||||
DiscreteKey X(1, 2);
|
||||
|
||||
TableFactor single = *TableFactor({X}, "1 1").sum(1);
|
||||
auto single = TableFactor({X}, "1 1").sum(1);
|
||||
// Should not throw a segfault
|
||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
|
||||
single.toDecisionTreeFactor()));
|
||||
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
|
||||
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
|
||||
single->toDecisionTreeFactor()));
|
||||
|
||||
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
|
||||
auto empty = TableFactor({X}, "0 0").sum(1);
|
||||
// Should not throw a segfault
|
||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
|
||||
empty.toDecisionTreeFactor()));
|
||||
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
|
||||
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
|
||||
empty->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -303,15 +305,18 @@ TEST(TableFactor, sum_max) {
|
|||
TableFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||
|
||||
TableFactor expected(v1, "9 12");
|
||||
TableFactor::shared_ptr actual = f1.sum(1);
|
||||
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
|
||||
CHECK(actual);
|
||||
CHECK(assert_equal(expected, *actual, 1e-5));
|
||||
|
||||
TableFactor expected2(v1, "5 6");
|
||||
TableFactor::shared_ptr actual2 = f1.max(1);
|
||||
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
|
||||
CHECK(actual2);
|
||||
CHECK(assert_equal(expected2, *actual2));
|
||||
|
||||
TableFactor f2(v1 & v0, "1 2 3 4 5 6");
|
||||
TableFactor::shared_ptr actual22 = f2.sum(1);
|
||||
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
|
||||
CHECK(actual22);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -68,7 +68,8 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
|||
/*
|
||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||
* @param j domain to be checked
|
||||
* @param (in/out) domains all domains, but only domains->at(j) will be checked.
|
||||
* @param (in/out) domains all domains, but only domains->at(j) will be
|
||||
* checked.
|
||||
* @return true if domains->at(j) was changed, false otherwise.
|
||||
*/
|
||||
virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0;
|
||||
|
@ -78,6 +79,39 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
|||
|
||||
/// Partially apply known values, domain version
|
||||
virtual shared_ptr partiallyApply(const Domains&) const = 0;
|
||||
|
||||
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||
DiscreteFactor::shared_ptr multiply(
|
||||
const DiscreteFactor::shared_ptr& df) const override {
|
||||
return std::make_shared<DecisionTreeFactor>(
|
||||
this->operator*(df->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||
DiscreteFactor::shared_ptr operator/(
|
||||
const DiscreteFactor::shared_ptr& df) const override {
|
||||
return this->toDecisionTreeFactor() / df;
|
||||
}
|
||||
|
||||
/// Get the number of non-zero values contained in this factor.
|
||||
uint64_t nrValues() const override { return 1; };
|
||||
|
||||
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
||||
return toDecisionTreeFactor().sum(nrFrontals);
|
||||
}
|
||||
|
||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
||||
return toDecisionTreeFactor().sum(keys);
|
||||
}
|
||||
|
||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||
return toDecisionTreeFactor().max(nrFrontals);
|
||||
}
|
||||
|
||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
||||
return toDecisionTreeFactor().max(keys);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -49,7 +49,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
/// Erase a value, non const :-(
|
||||
void erase(size_t value) { values_.erase(value); }
|
||||
|
||||
size_t nrValues() const { return values_.size(); }
|
||||
uint64_t nrValues() const override { return values_.size(); }
|
||||
|
||||
bool isSingleton() const { return nrValues() == 1; }
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ TEST(CSP, allInOne) {
|
|||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||
|
||||
// Just for fun, create the product and check it
|
||||
DecisionTreeFactor product = csp.product();
|
||||
DecisionTreeFactor product = csp.product()->toDecisionTreeFactor();
|
||||
// product.dot("product");
|
||||
DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0");
|
||||
EXPECT(assert_equal(expectedProduct, product));
|
||||
|
|
|
@ -113,7 +113,7 @@ TEST(schedulingExample, test) {
|
|||
EXPECT(assert_equal(expected, (DiscreteFactorGraph)s));
|
||||
|
||||
// Do brute force product and output that to file
|
||||
DecisionTreeFactor product = s.product();
|
||||
DecisionTreeFactor product = s.product()->toDecisionTreeFactor();
|
||||
// product.dot("scheduling", false);
|
||||
|
||||
// Do exact inference
|
||||
|
|
Loading…
Reference in New Issue