multiply method for DiscreteFactor

release/4.3a0
Varun Agrawal 2025-01-05 13:22:36 -05:00
parent e9e52ad21f
commit 5d865a8cc7
5 changed files with 43 additions and 0 deletions

View File

@ -62,6 +62,18 @@ namespace gtsam {
return error(values.discrete()); return error(values.discrete());
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const override {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
}
return result;
}
/* ************************************************************************ */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h> #include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <algorithm> #include <algorithm>
@ -147,6 +148,10 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul); return apply(f, Ring::mul);

View File

@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor /// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @} /// @}

View File

@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f; return toDecisionTreeFactor() * f;
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const override {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
}
return result;
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();

View File

@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor /// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
/// divide by factor f (safely) /// divide by factor f (safely)