switch between TableFactor and DecisionTreeFactor
parent
3f05d203d3
commit
2138113d05
|
@ -50,6 +50,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define GTSAM_HYBRID_WITH_TABLEFACTOR 0
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
|
@ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors(
|
|||
double min_log = errors.min();
|
||||
AlgebraicDecisionTree<Key> potentials(
|
||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
#else
|
||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
|||
/// Get the underlying TableFactor
|
||||
dfg.push_back(dtc->table());
|
||||
} else {
|
||||
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||
// Convert DiscreteConditional to TableFactor
|
||||
auto tdc = std::make_shared<TableFactor>(*dc);
|
||||
dfg.push_back(tdc);
|
||||
#else
|
||||
dfg.push_back(dc);
|
||||
#endif
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(ConvertConditionalToTableFactor);
|
||||
|
@ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
|||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(CollectDiscreteFactors);
|
||||
#endif
|
||||
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(CollectDiscreteFactors);
|
||||
#endif
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscrete);
|
||||
#endif
|
||||
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||
// Check if separator is empty.
|
||||
// This is the same as checking if the number of frontal variables
|
||||
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||
|
@ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
// Get product factor
|
||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
// Check type of product, and get as TableFactor for efficiency.
|
||||
// Use object instead of pointer since we need it
|
||||
// for the TableDistribution constructor.
|
||||
|
@ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
auto conditional = std::make_shared<TableDistribution>(p);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
||||
DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
|
||||
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
|
||||
} else {
|
||||
#endif
|
||||
// Perform sum-product.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||
}
|
||||
#endif
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
#endif
|
||||
|
@ -411,8 +424,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||
}
|
||||
};
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridCreateGaussianFactor);
|
||||
#endif
|
||||
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
||||
correct);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(HybridCreateGaussianFactor);
|
||||
#endif
|
||||
|
||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue