switch between TableFactor and DecisionTreeFactor

release/4.3a0
Varun Agrawal 2025-01-28 16:22:23 -05:00
parent 3f05d203d3
commit 2138113d05
1 changed files with 26 additions and 7 deletions

View File

@ -50,6 +50,8 @@
#include <utility>
#include <vector>
#define GTSAM_HYBRID_WITH_TABLEFACTOR 0
namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors(
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
#if GTSAM_HYBRID_WITH_TABLEFACTOR
return std::make_shared<TableFactor>(discreteKeys, potentials);
#else
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
#endif
}
/* ************************************************************************ */
@ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors(
/// Get the underlying TableFactor
dfg.push_back(dtc->table());
} else {
#if GTSAM_HYBRID_WITH_TABLEFACTOR
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
dfg.push_back(tdc);
#else
dfg.push_back(dc);
#endif
}
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
@ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors(
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
#if GTSAM_HYBRID_TIMING
gttic_(CollectDiscreteFactors);
#endif
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
#if GTSAM_HYBRID_TIMING
gttoc_(CollectDiscreteFactors);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
#if GTSAM_HYBRID_WITH_TABLEFACTOR
// Check if separator is empty.
// This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph.
@ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// Get product factor
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Check type of product, and get as TableFactor for efficiency.
// Use object instead of pointer since we need it
// for the TableDistribution constructor.
@ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
}
auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum};
} else {
#endif
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
#if GTSAM_HYBRID_WITH_TABLEFACTOR
}
#endif
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
@ -411,8 +424,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
}
};
#if GTSAM_HYBRID_TIMING
gttic_(HybridCreateGaussianFactor);
#endif
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCreateGaussianFactor);
#endif
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
}