multiply method for DiscreteFactor
parent
e9e52ad21f
commit
5d865a8cc7
|
@ -62,6 +62,18 @@ namespace gtsam {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override {
|
||||||
|
DiscreteFactor::shared_ptr result;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||||
|
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
|
||||||
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/Ring.h>
|
#include <gtsam/discrete/Ring.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -147,6 +148,10 @@ namespace gtsam {
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const override;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
virtual DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
return apply(f, Ring::mul);
|
return apply(f, Ring::mul);
|
||||||
|
|
|
@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// DecisionTreeFactor
|
/// DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Multiply in a DiscreteFactor and return the result as
|
||||||
|
* DiscreteFactor, both via shared pointers.
|
||||||
|
*
|
||||||
|
* @param df DiscreteFactor shared_ptr
|
||||||
|
* @return DiscreteFactor::shared_ptr
|
||||||
|
*/
|
||||||
|
virtual DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& df) const = 0;
|
||||||
|
|
||||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
return toDecisionTreeFactor() * f;
|
return toDecisionTreeFactor() * f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override {
|
||||||
|
DiscreteFactor::shared_ptr result;
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
|
||||||
|
result = std::make_shared<TableFactor>(this->operator*(*tf));
|
||||||
|
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
|
|
@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
/// multiply with DecisionTreeFactor
|
/// multiply with DecisionTreeFactor
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
|
/// Multiply factors, DiscreteFactor::shared_ptr edition
|
||||||
|
virtual DiscreteFactor::shared_ptr multiply(
|
||||||
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
||||||
/// divide by factor f (safely)
|
/// divide by factor f (safely)
|
||||||
|
|
Loading…
Reference in New Issue