update DiscreteFactorGraph to use DiscreteFactor::shared_ptr for elimination
parent
fb1d52a18e
commit
e9822a70d2
|
@ -64,7 +64,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
||||||
DiscreteFactor::shared_ptr result;
|
DiscreteFactor::shared_ptr result;
|
||||||
for (auto it = this->begin(); it != this->end(); ++it) {
|
for (auto it = this->begin(); it != this->end(); ++it) {
|
||||||
if (*it) {
|
if (*it) {
|
||||||
|
@ -76,7 +76,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result->toDecisionTreeFactor();
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -122,20 +122,20 @@ namespace gtsam {
|
||||||
* @brief Multiply all the `factors`.
|
* @brief Multiply all the `factors`.
|
||||||
*
|
*
|
||||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
* @return DecisionTreeFactor
|
* @return DiscreteFactor::shared_ptr
|
||||||
*/
|
*/
|
||||||
static DecisionTreeFactor DiscreteProduct(
|
static DiscreteFactor::shared_ptr DiscreteProduct(
|
||||||
const DiscreteFactorGraph& factors) {
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product = factors.product();
|
DiscreteFactor::shared_ptr product = factors.product();
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
auto denominator = product.max(product.size());
|
auto denominator = product->max(product->size());
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / denominator;
|
product = product->operator/(denominator);
|
||||||
|
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
@ -145,26 +145,25 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
DecisionTreeFactor::shared_ptr max =
|
DiscreteFactor::shared_ptr max = product->max(frontalKeys);
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor>(product.max(frontalKeys));
|
|
||||||
gttoc(max);
|
gttoc(max);
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
DiscreteKeys orderedKeys;
|
DiscreteKeys orderedKeys;
|
||||||
for (auto&& key : frontalKeys)
|
for (auto&& key : frontalKeys)
|
||||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||||
for (auto&& key : max->keys())
|
for (auto&& key : max->keys())
|
||||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
orderedKeys.emplace_back(key, product->cardinality(key));
|
||||||
|
|
||||||
// Make lookup with product
|
// Make lookup with product
|
||||||
gttic(lookup);
|
gttic(lookup);
|
||||||
size_t nrFrontals = frontalKeys.size();
|
size_t nrFrontals = frontalKeys.size();
|
||||||
auto lookup =
|
auto lookup = std::make_shared<DiscreteLookupTable>(
|
||||||
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
nrFrontals, orderedKeys, product->toDecisionTreeFactor());
|
||||||
gttoc(lookup);
|
gttoc(lookup);
|
||||||
|
|
||||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||||
|
@ -224,12 +223,11 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = DiscreteProduct(factors);
|
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast<DecisionTreeFactor>(
|
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||||
product.sum(frontalKeys));
|
|
||||||
gttoc(sum);
|
gttoc(sum);
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
@ -241,8 +239,9 @@ namespace gtsam {
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
gttic(divide);
|
gttic(divide);
|
||||||
auto conditional =
|
auto conditional = std::make_shared<DiscreteConditional>(
|
||||||
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
product->toDecisionTreeFactor(), sum->toDecisionTreeFactor(),
|
||||||
|
orderedKeys);
|
||||||
gttoc(divide);
|
gttoc(divide);
|
||||||
|
|
||||||
return {conditional, sum};
|
return {conditional, sum};
|
||||||
|
|
|
@ -148,7 +148,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** return product of all factors as a single factor */
|
/** return product of all factors as a single factor */
|
||||||
DecisionTreeFactor product() const;
|
DiscreteFactor::shared_ptr product() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluates the factor graph given values, returns the joint probability of
|
* Evaluates the factor graph given values, returns the joint probability of
|
||||||
|
|
Loading…
Reference in New Issue