switch between TableFactor and DecisionTreeFactor
parent
3f05d203d3
commit
2138113d05
|
@ -50,6 +50,8 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#define GTSAM_HYBRID_WITH_TABLEFACTOR 0
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
@ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
double min_log = errors.min();
|
double min_log = errors.min();
|
||||||
AlgebraicDecisionTree<Key> potentials(
|
AlgebraicDecisionTree<Key> potentials(
|
||||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||||
|
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||||
|
#else
|
||||||
|
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
/// Get the underlying TableFactor
|
/// Get the underlying TableFactor
|
||||||
dfg.push_back(dtc->table());
|
dfg.push_back(dtc->table());
|
||||||
} else {
|
} else {
|
||||||
|
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||||
// Convert DiscreteConditional to TableFactor
|
// Convert DiscreteConditional to TableFactor
|
||||||
auto tdc = std::make_shared<TableFactor>(*dc);
|
auto tdc = std::make_shared<TableFactor>(*dc);
|
||||||
dfg.push_back(tdc);
|
dfg.push_back(tdc);
|
||||||
|
#else
|
||||||
|
dfg.push_back(dc);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(ConvertConditionalToTableFactor);
|
gttoc_(ConvertConditionalToTableFactor);
|
||||||
|
@ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(CollectDiscreteFactors);
|
||||||
|
#endif
|
||||||
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(CollectDiscreteFactors);
|
||||||
|
#endif
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscrete);
|
gttic_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
|
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||||
// Check if separator is empty.
|
// Check if separator is empty.
|
||||||
// This is the same as checking if the number of frontal variables
|
// This is the same as checking if the number of frontal variables
|
||||||
// is the same as the number of variables in the DiscreteFactorGraph.
|
// is the same as the number of variables in the DiscreteFactorGraph.
|
||||||
|
@ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// Get product factor
|
// Get product factor
|
||||||
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
// Check type of product, and get as TableFactor for efficiency.
|
// Check type of product, and get as TableFactor for efficiency.
|
||||||
// Use object instead of pointer since we need it
|
// Use object instead of pointer since we need it
|
||||||
// for the TableDistribution constructor.
|
// for the TableDistribution constructor.
|
||||||
|
@ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
auto conditional = std::make_shared<TableDistribution>(p);
|
auto conditional = std::make_shared<TableDistribution>(p);
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
|
||||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
|
DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
#endif
|
||||||
// Perform sum-product.
|
// Perform sum-product.
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||||
|
#if GTSAM_HYBRID_WITH_TABLEFACTOR
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscrete);
|
gttoc_(EliminateDiscrete);
|
||||||
#endif
|
#endif
|
||||||
|
@ -411,8 +424,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttic_(HybridCreateGaussianFactor);
|
||||||
|
#endif
|
||||||
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
||||||
correct);
|
correct);
|
||||||
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
gttoc_(HybridCreateGaussianFactor);
|
||||||
|
#endif
|
||||||
|
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue