Merge pull request #1927 from borglab/discrete-improvements

Various Discrete Improvements
release/4.3a0
Varun Agrawal 2024-12-10 18:30:41 -05:00 committed by GitHub
commit 2c9e315a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 47 deletions

View File

@ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter); static_cast<const BaseConditional*>(this)->print(s, formatter);
} }
/// Evaluate, just look up in AlgebraicDecisionTree using BaseFactor::error; ///< DiscreteValues version
virtual double evaluate(const Assignment<Key>& values) const override { using BaseFactor::evaluate; ///< DiscreteValues version
return ADT::operator()(values); using BaseFactor::operator(); ///< DiscreteValues version
}
using DecisionTreeFactor::error; ///< DiscreteValues version
using DiscreteFactor::operator(); ///< DiscreteValues version
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.

View File

@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Default constructor needed for serialization. /// Default constructor needed for serialization.
DiscreteDistribution() {} DiscreteDistribution() {}
/// Constructor from factor. /// Constructor from DecisionTreeFactor.
explicit DiscreteDistribution(const DecisionTreeFactor& f) explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {} : Base(f.size(), f) {}

View File

@ -14,6 +14,7 @@
* @date Feb 14, 2011 * @date Feb 14, 2011
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
* @author Frank Dellaert * @author Frank Dellaert
* @author Varun Agrawal
*/ */
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
@ -35,13 +36,12 @@ namespace gtsam {
template class FactorGraph<DiscreteFactor>; template class FactorGraph<DiscreteFactor>;
template class EliminateableFactorGraph<DiscreteFactorGraph>; template class EliminateableFactorGraph<DiscreteFactorGraph>;
/* ************************************************************************* */ /* ************************************************************************ */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
{
return Base::equals(fg, tol); return Base::equals(fg, tol);
} }
/* ************************************************************************* */ /* ************************************************************************ */
KeySet DiscreteFactorGraph::keys() const { KeySet DiscreteFactorGraph::keys() const {
KeySet keys; KeySet keys;
for (const sharedFactor& factor : *this) { for (const sharedFactor& factor : *this) {
@ -50,11 +50,11 @@ namespace gtsam {
return keys; return keys;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;
for (auto&& factor : *this) { for (auto&& factor : *this) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys(); DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end()); result.insert(result.end(), factor_keys.begin(), factor_keys.end());
} }
@ -63,26 +63,27 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result; DecisionTreeFactor result;
for(const sharedFactor& factor: *this) for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result; if (factor) result = (*factor) * result;
}
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
double DiscreteFactorGraph::operator()( double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
const DiscreteValues &values) const {
double product = 1.0; double product = 1.0;
for( const sharedFactor& factor: factors_ ) for (const sharedFactor& factor : factors_) {
product *= (*factor)(values); if (factor) product *= (*factor)(values);
}
return product; return product;
} }
/* ************************************************************************* */ /* ************************************************************************ */
void DiscreteFactorGraph::print(const string& s, void DiscreteFactorGraph::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
std::cout << s << std::endl; std::cout << s << std::endl;
std::cout << "size: " << size() << std::endl; std::cout << "size: " << size() << std::endl;
for (size_t i = 0; i < factors_.size(); i++) { for (size_t i = 0; i < factors_.size(); i++) {
@ -110,15 +111,18 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************ */ /**
// Alternate eliminate function for MPE * @brief Multiply all the `factors` and normalize the
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // * product to prevent underflow.
EliminateForMPE(const DiscreteFactorGraph& factors, *
const Ordering& frontalKeys) { * @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product = factors.product();
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
@ -127,6 +131,16 @@ namespace gtsam {
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*normalization); product = product / (*normalization);
return product;
}
/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
@ -142,8 +156,8 @@ namespace gtsam {
// Make lookup with product // Make lookup with product
gttic(lookup); gttic(lookup);
size_t nrFrontals = frontalKeys.size(); size_t nrFrontals = frontalKeys.size();
auto lookup = std::make_shared<DiscreteLookupTable>(nrFrontals, auto lookup =
orderedKeys, product); std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
gttoc(lookup); gttoc(lookup);
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max}; return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
@ -201,20 +215,10 @@ namespace gtsam {
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
// PRODUCT: multiply all factors DecisionTreeFactor product = ProductAndNormalize(factors);
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*normalization);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);

View File

@ -14,6 +14,7 @@
* @date Feb 14, 2011 * @date Feb 14, 2011
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
* @author Frank Dellaert * @author Frank Dellaert
* @author Varun Agrawal
*/ */
#pragma once #pragma once
@ -48,7 +49,7 @@ class DiscreteJunctionTree;
* @ingroup discrete * @ingroup discrete
*/ */
GTSAM_EXPORT GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys); const Ordering& frontalKeys);
@ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors,
* @ingroup discrete * @ingroup discrete
*/ */
GTSAM_EXPORT GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys); const Ordering& frontalKeys);

View File

@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0}; const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
DecisionTreeFactor newFactor = *newFactorPtr; DecisionTreeFactor newFactor =
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected // Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size()); auto normalization = newFactor.max(newFactor.size());