update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elimination
parent
fb1d52a18e
commit
e9822a70d2
|
@ -64,7 +64,7 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
||||
DiscreteFactor::shared_ptr result;
|
||||
for (auto it = this->begin(); it != this->end(); ++it) {
|
||||
if (*it) {
|
||||
|
@ -76,7 +76,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
}
|
||||
return result->toDecisionTreeFactor();
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -122,20 +122,20 @@ namespace gtsam {
|
|||
* @brief Multiply all the `factors`.
|
||||
*
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DecisionTreeFactor
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
*/
|
||||
static DecisionTreeFactor DiscreteProduct(
|
||||
static DiscreteFactor::shared_ptr DiscreteProduct(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product = factors.product();
|
||||
DiscreteFactor::shared_ptr product = factors.product();
|
||||
gttoc(product);
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto denominator = product.max(product.size());
|
||||
auto denominator = product->max(product->size());
|
||||
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / denominator;
|
||||
product = product->operator/(denominator);
|
||||
|
||||
return product;
|
||||
}
|
||||
|
@ -145,26 +145,25 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor>(product.max(frontalKeys));
|
||||
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
DiscreteKeys orderedKeys;
|
||||
for (auto&& key : frontalKeys)
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||
for (auto&& key : max->keys())
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||
|
||||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup =
|
||||
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
||||
auto lookup = std::make_shared<DiscreteLookupTable>(
|
||||
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
|
||||
gttoc(lookup);
|
||||
|
||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||
|
@ -224,12 +223,11 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
product.sum(frontalKeys));
|
||||
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||
gttoc(sum);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
|
@ -241,8 +239,9 @@ namespace gtsam {
|
|||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
auto conditional = std::make_shared<DiscreteConditional>(
|
||||
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
|
||||
orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
return {conditional, sum};
|
||||
|
|
|
@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** return product of all factors as a single factor */
|
||||
DecisionTreeFactor product() const;
|
||||
DiscreteFactor::shared_ptr product() const;
|
||||
|
||||
/**
|
||||
* Evaluates the factor graph given values, returns the joint probability of
|
||||
|
|
Loading…
Reference in New Issue