switch between TableFactor and DecisionTreeFactor

release/4.3a0
Varun Agrawal 2025-01-28 16:22:23 -05:00
parent 3f05d203d3
commit 2138113d05
1 changed files with 26 additions and 7 deletions

View File

@ -50,6 +50,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#define GTSAM_HYBRID_WITH_TABLEFACTOR 0
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors(
double min_log = errors.min(); double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials( AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); }); errors, [&min_log](const double x) { return exp(-(x - min_log)); });
#if GTSAM_HYBRID_WITH_TABLEFACTOR
return std::make_shared<TableFactor>(discreteKeys, potentials); return std::make_shared<TableFactor>(discreteKeys, potentials);
#else
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
#endif
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors(
/// Get the underlying TableFactor /// Get the underlying TableFactor
dfg.push_back(dtc->table()); dfg.push_back(dtc->table());
} else { } else {
#if GTSAM_HYBRID_WITH_TABLEFACTOR
// Convert DiscreteConditional to TableFactor // Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc); auto tdc = std::make_shared<TableFactor>(*dc);
dfg.push_back(tdc); dfg.push_back(tdc);
#else
dfg.push_back(dc);
#endif
} }
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor); gttoc_(ConvertConditionalToTableFactor);
@ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors(
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
#if GTSAM_HYBRID_TIMING
gttic_(CollectDiscreteFactors);
#endif
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
#if GTSAM_HYBRID_TIMING
gttoc_(CollectDiscreteFactors);
#endif
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete); gttic_(EliminateDiscrete);
#endif #endif
#if GTSAM_HYBRID_WITH_TABLEFACTOR
// Check if separator is empty. // Check if separator is empty.
// This is the same as checking if the number of frontal variables // This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph. // is the same as the number of variables in the DiscreteFactorGraph.
@ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// Get product factor // Get product factor
DiscreteFactor::shared_ptr product = dfg.scaledProduct(); DiscreteFactor::shared_ptr product = dfg.scaledProduct();
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Check type of product, and get as TableFactor for efficiency. // Check type of product, and get as TableFactor for efficiency.
// Use object instead of pointer since we need it // Use object instead of pointer since we need it
// for the TableDistribution constructor. // for the TableDistribution constructor.
@ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} }
auto conditional = std::make_shared<TableDistribution>(p); auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
DiscreteFactor::shared_ptr sum = p.sum(frontalKeys); DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum}; return {std::make_shared<HybridConditional>(conditional), sum};
} else { } else {
#endif
// Perform sum-product. // Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys); auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
#if GTSAM_HYBRID_WITH_TABLEFACTOR
} }
#endif
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete); gttoc_(EliminateDiscrete);
#endif #endif
@ -411,8 +424,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
} }
}; };
#if GTSAM_HYBRID_TIMING
gttic_(HybridCreateGaussianFactor);
#endif
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults, DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct); correct);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCreateGaussianFactor);
#endif
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors); return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
} }