Merge pull request #1963 from borglab/discrete-multiply

DiscreteFactor multiply method
release/4.3a0
Varun Agrawal 2025-01-06 15:21:52 -05:00 committed by GitHub
commit 47074bd0c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 115 additions and 5 deletions

View File

@ -18,9 +18,10 @@
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <utility>
@ -62,6 +63,30 @@ namespace gtsam {
return error(values.discrete());
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If f is a TableFactor, we convert `this` to a TableFactor since this
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
// then return a TableFactor.
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, simply call operator*.
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated.
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
}
return result;
}
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum

View File

@ -147,6 +147,23 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not. Hence we specialize the code to return a TableFactor if
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);

View File

@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @}

View File

@ -65,11 +65,18 @@ namespace gtsam {
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
return result;
}
}
return result->toDecisionTreeFactor();
}
/* ************************************************************************ */

View File

@ -254,6 +254,32 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If `f` is a TableFactor, we can simply call `operator*`.
result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
// cheaper than converting `this` to a DecisionTreeFactor.
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated to know about TableFactor.
// Those classes can be specialized to use TableFactor
// if efficiency is a problem.
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
}
return result;
}
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

View File

@ -178,6 +178,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not.
* Hence we specialize the code to return a TableFactor always.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b);
/// divide by factor f (safely)

View File

@ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0;
/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}
/// @}
/// @name Wrapper support
/// @{