Merge branch 'develop' into discrete-elimination-refactor
commit
0b3f0587c8
|
|
@ -14,6 +14,7 @@
|
|||
* @date Feb 14, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
|
|
@ -35,13 +36,12 @@ namespace gtsam {
|
|||
template class FactorGraph<DiscreteFactor>;
|
||||
template class EliminateableFactorGraph<DiscreteFactorGraph>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
|
||||
{
|
||||
/* ************************************************************************ */
|
||||
bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
|
||||
return Base::equals(fg, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
KeySet DiscreteFactorGraph::keys() const {
|
||||
KeySet keys;
|
||||
for (const sharedFactor& factor : *this) {
|
||||
|
|
@ -50,7 +50,7 @@ namespace gtsam {
|
|||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
for (auto&& factor : *this) {
|
||||
|
|
@ -63,7 +63,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
for (const sharedFactor& factor : *this) {
|
||||
|
|
@ -72,18 +72,18 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactorGraph::operator()(
|
||||
const DiscreteValues &values) const {
|
||||
/* ************************************************************************ */
|
||||
double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
|
||||
double product = 1.0;
|
||||
for( const sharedFactor& factor: factors_ )
|
||||
product *= (*factor)(values);
|
||||
for (const sharedFactor& factor : factors_) {
|
||||
if (factor) product *= (*factor)(values);
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void DiscreteFactorGraph::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
const KeyFormatter& formatter) const {
|
||||
std::cout << s << std::endl;
|
||||
std::cout << "size: " << size() << std::endl;
|
||||
for (size_t i = 0; i < factors_.size(); i++) {
|
||||
|
|
@ -112,43 +112,36 @@ namespace gtsam {
|
|||
// }
|
||||
|
||||
/**
|
||||
* @brief Helper method to normalize the product factor by
|
||||
* the max value to prevent underflow
|
||||
* @brief Multiply all the `factors` and normalize the
|
||||
* product to prevent underflow.
|
||||
*
|
||||
* @param product The product discrete factor.
|
||||
* @return DiscreteFactor::shared_ptr
|
||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||
* @return DecisionTreeFactor
|
||||
*/
|
||||
static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) {
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
gttic_(DiscreteFindMax);
|
||||
auto normalization = product.max(product.size());
|
||||
gttoc_(DiscreteFindMax);
|
||||
static DecisionTreeFactor ProductAndNormalize(
|
||||
const DiscreteFactorGraph& factors) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product = factors.product();
|
||||
gttoc(product);
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalization = product.max(product.size());
|
||||
|
||||
gttic_(DiscreteNormalization);
|
||||
// Normalize the product factor to prevent underflow.
|
||||
auto normalized_product =
|
||||
product /
|
||||
(*std::dynamic_pointer_cast<DecisionTreeFactor>(normalization));
|
||||
gttoc_(DiscreteNormalization);
|
||||
|
||||
return normalized_product;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// Alternate eliminate function for MPE
|
||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
|
||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic_(MPEProduct);
|
||||
DecisionTreeFactor product = factors.product();
|
||||
gttoc_(MPEProduct);
|
||||
|
||||
gttic_(Normalize);
|
||||
|
||||
// Normalize the product
|
||||
product = Normalize(product);
|
||||
gttoc_(Normalize);
|
||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
|
|
@ -225,18 +218,10 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
|
||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic_(product);
|
||||
DecisionTreeFactor product = factors.product();
|
||||
gttoc_(product);
|
||||
|
||||
gttic_(Normalize);
|
||||
// Normalize the product
|
||||
product = Normalize(product);
|
||||
gttoc_(Normalize);
|
||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic_(sum);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
* @date Feb 14, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
|
|
|||
|
|
@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) {
|
|||
const Ordering frontalKeys{0};
|
||||
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||
DecisionTreeFactor newFactor =
|
||||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||
|
||||
// Normalize newFactor by max for comparison with expected
|
||||
auto normalization = newFactor.max(newFactor.size());
|
||||
|
|
|
|||
Loading…
Reference in New Issue