Merge branch 'hybrid-custom-discrete' into discrete-table-conditional

release/4.3a0
Varun Agrawal 2025-01-01 19:26:41 -05:00
commit cc237a2f43
4 changed files with 27 additions and 58 deletions

View File

@ -47,15 +47,6 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DecisionTreeFactor& f,
const Ordering& orderedKeys)
: BaseFactor(f), BaseConditional(nrFrontals) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,

View File

@ -56,17 +56,6 @@ class GTSAM_EXPORT DiscreteConditional
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
/**
* @brief Construct from DecisionTreeFactor,
* taking the first `nrFrontals` from `orderedKeys`.
*
* @param nrFrontals The number of frontal variables.
* @param f The DecisionTreeFactor to construct from.
* @param orderedKeys Ordered list of keys involved in the conditional.
*/
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
const Ordering& orderedKeys);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
* `nFrontals` keys as frontals, in the order given.

View File

@ -252,15 +252,6 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
// If no keys, then return empty DecisionTreeFactor
if (dkeys.size() == 0) {
AlgebraicDecisionTree<Key> tree;
if (sparse_table_.size() != 0) {
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
}
return DecisionTreeFactor(dkeys, tree);
}
std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));

View File

@ -256,7 +256,7 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
}
/* ************************************************************************ */
TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) {
TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
@ -279,14 +279,13 @@ TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) {
gttoc_(DiscreteProduct);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
product = product / (*denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
@ -343,41 +342,40 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
/**** NOTE: This does sum-product. ****/
// Check if separator is empty
Ordering allKeys(dfg.keyVector());
Ordering separator;
std::set_difference(allKeys.begin(), allKeys.end(), frontalKeys.begin(),
frontalKeys.end(),
std::inserter(separator, separator.begin()));
// If the separator is empty, we have a clique of all the discrete variables
// so we can use the TableFactor for efficiency.
if (separator.size() == 0) {
// Get product factor
TableFactor product = TableProductAndNormalize(dfg);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
TableFactor product = TableProduct(dfg);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Finally, get the conditional
auto conditional =
std::make_shared<DiscreteTableConditional>(product, *sum, orderedKeys);
auto conditional = std::make_shared<DiscreteConditional>(
frontalKeys.size(), product.toDecisionTreeFactor());
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
TableFactor::shared_ptr sum = product.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum};
} else {
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
}
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(conditional), sum};
}
/* ************************************************************************ */