use TableFactor for discrete elimination

release/4.3a0
Varun Agrawal 2024-12-31 00:27:04 -05:00
parent 214043d60d
commit dfec8409fe
1 changed files with 14 additions and 7 deletions

View File

@ -25,7 +25,9 @@
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -241,18 +243,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
/**
* @brief Take negative log-values, shift them so that the minimum value is 0,
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
* and then exponentiate to create a TableFactor (not normalized yet!).
*
* @param errors DecisionTree of (unnormalized) errors.
* @return DecisionTreeFactor::shared_ptr
* @return TableFactor::shared_ptr
*/
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
static TableFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/* ************************************************************************ */
@ -285,12 +287,17 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
/// Get the underlying TableFactor
dfg.push_back(dtc->table());
} else {
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
dfg.push_back(tdc);
}
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
#endif
dfg.push_back(tdc);
} else {
throwRuntimeError("discreteElimination", f);
}