Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
commit
cc237a2f43
|
|
@ -47,15 +47,6 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
|||
const DecisionTreeFactor& f)
|
||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
const DecisionTreeFactor& f,
|
||||
const Ordering& orderedKeys)
|
||||
: BaseFactor(f), BaseConditional(nrFrontals) {
|
||||
keys_.clear();
|
||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
const DiscreteKeys& keys,
|
||||
|
|
|
|||
|
|
@ -56,17 +56,6 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||
|
||||
/**
|
||||
* @brief Construct from DecisionTreeFactor,
|
||||
* taking the first `nrFrontals` from `orderedKeys`.
|
||||
*
|
||||
* @param nrFrontals The number of frontal variables.
|
||||
* @param f The DecisionTreeFactor to construct from.
|
||||
* @param orderedKeys Ordered list of keys involved in the conditional.
|
||||
*/
|
||||
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
|
||||
const Ordering& orderedKeys);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||
* `nFrontals` keys as frontals, in the order given.
|
||||
|
|
|
|||
|
|
@ -252,15 +252,6 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
|||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
|
||||
// If no keys, then return empty DecisionTreeFactor
|
||||
if (dkeys.size() == 0) {
|
||||
AlgebraicDecisionTree<Key> tree;
|
||||
if (sparse_table_.size() != 0) {
|
||||
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
|
||||
}
|
||||
return DecisionTreeFactor(dkeys, tree);
|
||||
}
|
||||
|
||||
std::vector<double> table;
|
||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||
table.push_back(sparse_table_.coeff(i));
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) {
|
||||
TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteProduct);
|
||||
|
|
@ -279,14 +279,13 @@ TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) {
|
|||
gttoc_(DiscreteProduct);
|
||||
#endif
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalizer = product.max(product.size());
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(DiscreteNormalize);
|
||||
#endif
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto denominator = product.max(product.size());
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*normalizer);
|
||||
product = product / (*denominator);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(DiscreteNormalize);
|
||||
#endif
|
||||
|
|
@ -343,41 +342,40 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscrete);
|
||||
#endif
|
||||
/**** NOTE: This does sum-product. ****/
|
||||
// Get product factor
|
||||
TableFactor product = TableProductAndNormalize(dfg);
|
||||
// Check if separator is empty
|
||||
Ordering allKeys(dfg.keyVector());
|
||||
Ordering separator;
|
||||
std::set_difference(allKeys.begin(), allKeys.end(), frontalKeys.begin(),
|
||||
frontalKeys.end(),
|
||||
std::inserter(separator, separator.begin()));
|
||||
|
||||
// If the separator is empty, we have a clique of all the discrete variables
|
||||
// so we can use the TableFactor for efficiency.
|
||||
if (separator.size() == 0) {
|
||||
// Get product factor
|
||||
TableFactor product = TableProduct(dfg);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteSum);
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
// All the discrete variables should form a single clique,
|
||||
// so we can sum out on all the variables as frontals.
|
||||
// This should give an empty separator.
|
||||
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
auto conditional = std::make_shared<DiscreteConditional>(
|
||||
frontalKeys.size(), product.toDecisionTreeFactor());
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteSum);
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
// Finally, get the conditional
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteTableConditional>(product, *sum, orderedKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
|
||||
} else {
|
||||
// Perform sum-product.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscrete);
|
||||
#endif
|
||||
|
||||
return {std::make_shared<HybridConditional>(conditional), sum};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
Loading…
Reference in New Issue