update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elimination

release/4.3a0
Varun Agrawal 2025-01-05 20:50:24 -05:00
parent fb1d52a18e
commit e9822a70d2
2 changed files with 19 additions and 20 deletions

View File

@ -64,7 +64,7 @@ namespace gtsam {
}
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
@ -76,7 +76,7 @@ namespace gtsam {
}
}
}
return result->toDecisionTreeFactor();
return result;
}
/* ************************************************************************ */
@ -122,20 +122,20 @@ namespace gtsam {
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
* @return DiscreteFactor::shared_ptr
*/
static DecisionTreeFactor DiscreteProduct(
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow.
product = product / denominator;
product = product->operator/(denominator);
return product;
}
@ -145,26 +145,25 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max =
std::dynamic_pointer_cast<DecisionTreeFactor>(product.max(frontalKeys));
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
orderedKeys.emplace_back(key, product->cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -224,12 +223,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast<DecisionTreeFactor>(
product.sum(frontalKeys));
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
@ -241,8 +239,9 @@ namespace gtsam {
// now divide product/sum to get conditional
gttic(divide);
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
auto conditional = std::make_shared<DiscreteConditional>(
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
orderedKeys);
gttoc(divide);
return {conditional, sum};

View File

@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;
DiscreteFactor::shared_ptr product() const;
/**
* Evaluates the factor graph given values, returns the joint probability of