Merge pull request #1927 from borglab/discrete-improvements
Various Discrete Improvementsrelease/4.3a0
commit
2c9e315a2c
|
@ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisionTree
|
using BaseFactor::error; ///< DiscreteValues version
|
||||||
virtual double evaluate(const Assignment<Key>& values) const override {
|
using BaseFactor::evaluate; ///< DiscreteValues version
|
||||||
return ADT::operator()(values);
|
using BaseFactor::operator(); ///< DiscreteValues version
|
||||||
}
|
|
||||||
|
|
||||||
using DecisionTreeFactor::error; ///< DiscreteValues version
|
|
||||||
using DiscreteFactor::operator(); ///< DiscreteValues version
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
|
|
|
@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||||
/// Default constructor needed for serialization.
|
/// Default constructor needed for serialization.
|
||||||
DiscreteDistribution() {}
|
DiscreteDistribution() {}
|
||||||
|
|
||||||
/// Constructor from factor.
|
/// Constructor from DecisionTreeFactor.
|
||||||
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
explicit DiscreteDistribution(const DecisionTreeFactor& f)
|
||||||
: Base(f.size(), f) {}
|
: Base(f.size(), f) {}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
@ -35,13 +36,12 @@ namespace gtsam {
|
||||||
template class FactorGraph<DiscreteFactor>;
|
template class FactorGraph<DiscreteFactor>;
|
||||||
template class EliminateableFactorGraph<DiscreteFactorGraph>;
|
template class EliminateableFactorGraph<DiscreteFactorGraph>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
|
bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
|
||||||
{
|
|
||||||
return Base::equals(fg, tol);
|
return Base::equals(fg, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
KeySet DiscreteFactorGraph::keys() const {
|
KeySet DiscreteFactorGraph::keys() const {
|
||||||
KeySet keys;
|
KeySet keys;
|
||||||
for (const sharedFactor& factor : *this) {
|
for (const sharedFactor& factor : *this) {
|
||||||
|
@ -50,11 +50,11 @@ namespace gtsam {
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
for (auto&& factor : *this) {
|
for (auto&& factor : *this) {
|
||||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
DiscreteKeys factor_keys = p->discreteKeys();
|
DiscreteKeys factor_keys = p->discreteKeys();
|
||||||
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
|
||||||
}
|
}
|
||||||
|
@ -63,24 +63,25 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||||
DecisionTreeFactor result;
|
DecisionTreeFactor result;
|
||||||
for(const sharedFactor& factor: *this)
|
for (const sharedFactor& factor : *this) {
|
||||||
if (factor) result = (*factor) * result;
|
if (factor) result = (*factor) * result;
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
double DiscreteFactorGraph::operator()(
|
double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
|
||||||
const DiscreteValues &values) const {
|
|
||||||
double product = 1.0;
|
double product = 1.0;
|
||||||
for( const sharedFactor& factor: factors_ )
|
for (const sharedFactor& factor : factors_) {
|
||||||
product *= (*factor)(values);
|
if (factor) product *= (*factor)(values);
|
||||||
|
}
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
void DiscreteFactorGraph::print(const string& s,
|
void DiscreteFactorGraph::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << s << std::endl;
|
std::cout << s << std::endl;
|
||||||
|
@ -110,15 +111,18 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/**
|
||||||
// Alternate eliminate function for MPE
|
* @brief Multiply all the `factors` and normalize the
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
* product to prevent underflow.
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
*
|
||||||
const Ordering& frontalKeys) {
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
static DecisionTreeFactor ProductAndNormalize(
|
||||||
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product = factors.product();
|
||||||
for (auto&& factor : factors) product = (*factor) * product;
|
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
@ -127,6 +131,16 @@ namespace gtsam {
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*normalization);
|
product = product / (*normalization);
|
||||||
|
|
||||||
|
return product;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// Alternate eliminate function for MPE
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
|
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||||
|
@ -142,8 +156,8 @@ namespace gtsam {
|
||||||
// Make lookup with product
|
// Make lookup with product
|
||||||
gttic(lookup);
|
gttic(lookup);
|
||||||
size_t nrFrontals = frontalKeys.size();
|
size_t nrFrontals = frontalKeys.size();
|
||||||
auto lookup = std::make_shared<DiscreteLookupTable>(nrFrontals,
|
auto lookup =
|
||||||
orderedKeys, product);
|
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
|
||||||
gttoc(lookup);
|
gttoc(lookup);
|
||||||
|
|
||||||
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
|
||||||
|
@ -201,20 +215,10 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
// PRODUCT: multiply all factors
|
DecisionTreeFactor product = ProductAndNormalize(factors);
|
||||||
gttic(product);
|
|
||||||
DecisionTreeFactor product;
|
|
||||||
for (auto&& factor : factors) product = (*factor) * product;
|
|
||||||
gttoc(product);
|
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
|
||||||
auto normalization = product.max(product.size());
|
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
|
||||||
product = product / (*normalization);
|
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
@ -48,7 +49,7 @@ class DiscreteJunctionTree;
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT
|
GTSAM_EXPORT
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys);
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
|
@ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT
|
GTSAM_EXPORT
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys);
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
const Ordering frontalKeys{0};
|
const Ordering frontalKeys{0};
|
||||||
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
DecisionTreeFactor newFactor = *newFactorPtr;
|
DecisionTreeFactor newFactor =
|
||||||
|
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto normalization = newFactor.max(newFactor.size());
|
auto normalization = newFactor.max(newFactor.size());
|
||||||
|
|
Loading…
Reference in New Issue