use DiscreteTableConditional in EliminateDiscrete

release/4.3a0
Varun Agrawal 2024-12-31 00:20:09 -05:00
parent 60945c8e32
commit e46e9d67c5
1 changed files with 6 additions and 11 deletions

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
@ -70,8 +71,7 @@ namespace gtsam {
if (factor) { if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) { if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
result = result * (*f); result = result * (*f);
} } else if (auto dtf =
else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
result = TableFactor(result * (*dtf)); result = TableFactor(result * (*dtf));
} }
@ -253,18 +253,13 @@ namespace gtsam {
sum->keys().end()); sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteDivide);
#endif
auto c = product / (*sum);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteDivide);
#endif
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional); gttic_(EliminateDiscreteToDiscreteConditional);
#endif #endif
auto conditional = std::make_shared<DiscreteConditional>( // auto conditional = std::make_shared<DiscreteConditional>(
orderedKeys.size(), c.toDecisionTreeFactor()); // orderedKeys.size(), (product / (*sum)).toDecisionTreeFactor());
auto conditional =
std::make_shared<DiscreteTableConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional); gttoc_(EliminateDiscreteToDiscreteConditional);
#endif #endif