address review comments

release/4.3a0
Varun Agrawal 2025-01-06 14:08:08 -05:00
parent ab90e0b0f3
commit f043ac43a7
4 changed files with 47 additions and 2 deletions

View File

@ -68,11 +68,20 @@ namespace gtsam {
const DiscreteFactor::shared_ptr& f) const { const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result; DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If f is a TableFactor, we convert `this` to a TableFactor since this
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
// then return a TableFactor.
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this)); result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, simply call operator*.
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf)); result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else { } else {
// Simulate double dispatch in C++ // Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated.
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this)); result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
} }
return result; return result;

View File

@ -147,7 +147,20 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition /**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not. Hence we specialize the code to return a TableFactor if
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply( virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override; const DiscreteFactor::shared_ptr& f) const override;

View File

@ -259,11 +259,21 @@ DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const { const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result; DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) { if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If `f` is a TableFactor, we can simply call `operator*`.
result = std::make_shared<TableFactor>(this->operator*(*tf)); result = std::make_shared<TableFactor>(this->operator*(*tf));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
// cheaper than converting `this` to a DecisionTreeFactor.
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf))); result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else { } else {
// Simulate double dispatch in C++ // Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated to know about TableFactor.
// Those classes can be specialized to use TableFactor
// if efficiency is a problem.
result = std::make_shared<DecisionTreeFactor>( result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor())); f->operator*(this->toDecisionTreeFactor()));
} }

View File

@ -178,7 +178,20 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor /// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply factors, DiscreteFactor::shared_ptr edition /**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not.
* Hence we specialize the code to return a TableFactor always.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply( virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override; const DiscreteFactor::shared_ptr& f) const override;