Merge pull request #1953 from borglab/discrete-table-conditional

TableFactor and TableDistribution
release/4.3a0
Varun Agrawal 2025-01-08 11:25:27 -05:00 committed by GitHub
commit 8cf2123b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 615 additions and 156 deletions

View File

@ -57,6 +57,9 @@ namespace gtsam {
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
/// Constructor which accepts root pointer
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}
// Explicitly non-explicit constructor // Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {} AlgebraicDecisionTree(const Base& add) : Base(add) {}

View File

@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*( DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const { const DiscreteConditional& other) const {
// If the root is a nullptr, we have a TableDistribution
// TODO(Varun) Revisit this hack after RSS2025 submission
if (!other.root_) {
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
return dc * (*this);
}
// Take union of frontal keys // Take union of frontal keys
std::set<Key> newFrontals; std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key); for (auto&& key : this->frontals()) newFrontals.insert(key);
@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->operator()(x.discrete()); return this->operator()(x.discrete());
} }
/* ************************************************************************* */
DiscreteFactor::shared_ptr DiscreteConditional::max(
const Ordering& keys) const {
return BaseFactor::max(keys);
}
/* ************************************************************************* */
void DiscreteConditional::prune(size_t maxNrAssignments) {
// Get as DiscreteConditional so the probabilities are normalized
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
this->root_ = pruned.root_;
}
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; } double DiscreteConditional::negLogConstant() const { return 0.0; }

View File

@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @return sample from conditional * @return sample from conditional
*/ */
size_t sample(const DiscreteValues& parentsValues) const; virtual size_t sample(const DiscreteValues& parentsValues) const;
/// Single parent version. /// Single parent version.
size_t sample(size_t parent_value) const; size_t sample(size_t parent_value) const;
@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
*/ */
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
/**
* @brief Create new factor by maximizing over all
* values with the same separator.
*
* @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
*/ */
double negLogConstant() const override; double negLogConstant() const override;
/// Prune the conditional
virtual void prune(size_t maxNrAssignments);
/// @} /// @}
protected: protected:

View File

@ -118,30 +118,18 @@ namespace gtsam {
// } // }
// } // }
/** /* ************************************************************************ */
* @brief Multiply all the `factors`. DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DiscreteFactor::shared_ptr product = factors.product(); DiscreteFactor::shared_ptr product = this->product();
gttoc(product); gttoc(product);
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto denominator = product->max(product->size()); auto denominator = product->max(product->size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product->operator/(denominator); product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product; return product;
} }
@ -151,7 +139,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = factors.scaledProduct();
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
@ -229,7 +217,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors); DiscreteFactor::shared_ptr product = factors.scaledProduct();
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);

View File

@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** return product of all factors as a single factor */ /** return product of all factors as a single factor */
DiscreteFactor::shared_ptr product() const; DiscreteFactor::shared_ptr product() const;
/**
* @brief Return product of all `factors` as a single factor,
* which is scaled by the max value to prevent underflow
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
DiscreteFactor::shared_ptr scaledProduct() const;
/** /**
* Evaluates the factor graph given values, returns the joint probability of * Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values * the factor graph given specific instantiation of values

View File

@ -0,0 +1,174 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableDistribution.cpp
* @date Dec 22, 2024
* @author Varun Agrawal
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridValues.h>
#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
namespace gtsam {
/// Normalize sparse_table
static Eigen::SparseVector<double> normalizeSparseTable(
const Eigen::SparseVector<double>& sparse_table) {
return sparse_table / sparse_table.sum();
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::string& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}
/* ************************************************************************** */
void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
cout << "):\n";
table_.print("", formatter);
cout << endl;
}
/* ************************************************************************** */
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) {
return false;
} else {
const DiscreteConditional& f(
static_cast<const DiscreteConditional&>(other));
return table_.equals(dtc->table_, tol) &&
DiscreteConditional::BaseConditional::equals(f, tol);
}
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
return table_.sum(nrFrontals);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
return table_.sum(keys);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
return table_.max(nrFrontals);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
return table_.max(keys);
}
/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator/(
const DiscreteFactor::shared_ptr& f) const {
return table_ / f;
}
/* ************************************************************************ */
DiscreteValues TableDistribution::argmax() const {
uint64_t maxIdx = 0;
double maxValue = 0.0;
Eigen::SparseVector<double> sparseTable = table_.sparseTable();
for (SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
maxValue = it.value();
}
}
return table_.findAssignments(maxIdx);
}
/* ****************************************************************************/
void TableDistribution::prune(size_t maxNrAssignments) {
table_ = table_.prune(maxNrAssignments);
}
/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
}
// Get the correct conditional distribution: P(F|S=parentsValues)
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"TableDistribution::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}
} // namespace gtsam

View File

@ -0,0 +1,177 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file TableDistribution.h
* @date Dec 22, 2024
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <memory>
#include <string>
#include <vector>
namespace gtsam {
/**
* Distribution which uses a SparseVector as the internal
* representation, similar to the TableFactor.
*
* This is primarily used in the case when we have a clique in the BayesTree
* which consists of all the discrete variables, e.g. in hybrid elimination.
*
* @ingroup discrete
*/
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
private:
TableFactor table_;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
public:
// typedefs needed to play nice with gtsam
typedef TableDistribution This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DiscreteConditional
BaseConditional; ///< Typedef to our conditional base class
using Values = DiscreteValues; ///< backwards compatibility
/// @name Standard Constructors
/// @{
/// Default constructor needed for serialization.
TableDistribution() {}
/// Construct from TableFactor.
TableDistribution(const TableFactor& f);
/**
* Construct from DiscreteKeys and std::vector.
*/
TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials);
/**
* Construct from single DiscreteKey and std::vector.
*/
TableDistribution(const DiscreteKey& key,
const std::vector<double>& potentials)
: TableDistribution(DiscreteKeys(key), potentials) {}
/**
* Construct from DiscreteKey and std::string.
*/
TableDistribution(const DiscreteKeys& keys, const std::string& potentials);
/**
* Construct from single DiscreteKey and std::string.
*/
TableDistribution(const DiscreteKey& key, const std::string& potentials)
: TableDistribution(DiscreteKeys(key), potentials) {}
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(
const std::string& s = "Table Distribution: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// GTSAM-style equals
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return the underlying TableFactor
TableFactor table() const { return table_; }
using BaseConditional::evaluate; // HybridValues version
/// Evaluate the conditional given the values.
virtual double evaluate(const Assignment<Key>& values) const override {
return table_.evaluate(values);
}
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
/// Create new factor by summing all values with the same separator values
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
/// Create new factor by maximizing over all values with the same separator.
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// divide by DiscreteFactor::shared_ptr f (safely)
DiscreteFactor::shared_ptr operator/(
const DiscreteFactor::shared_ptr& f) const override;
/**
* @brief Return assignment that maximizes value.
*
* @return maximizing assignment for the variables.
*/
DiscreteValues argmax() const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const override;
/// @}
/// @name Advanced Interface
/// @{
/// Prune the conditional
virtual void prune(size_t maxNrAssignments) override;
/// Get a DecisionTreeFactor representation.
DecisionTreeFactor toDecisionTreeFactor() const override {
return table_.toDecisionTreeFactor();
}
/// Get the number of non-zero values.
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
/// @}
private:
#if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(table_);
}
#endif
};
// TableDistribution
// traits
template <>
struct traits<TableDistribution> : public Testable<TableDistribution> {};
} // namespace gtsam

View File

@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
public:
/** /**
* Convert probability table given as doubles to SparseVector. * Convert probability table given as doubles to SparseVector.
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys, static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
const std::string& table); const std::string& table);
public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef TableFactor This; typedef TableFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
/// Create a TableFactor that is a subset of this TableFactor /// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments, TableFactor choose(const DiscreteValues parentAssignments,
DiscreteKeys parent_keys) const; DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values

View File

@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
std::vector<double> pmf() const; std::vector<double> pmf() const;
}; };
#include <gtsam/discrete/TableFactor.h>
virtual class TableFactor : gtsam::DiscreteFactor {
TableFactor();
TableFactor(const gtsam::DiscreteKeys& keys,
const gtsam::TableFactor& potentials);
TableFactor(const gtsam::DiscreteKeys& keys, std::vector<double>& table);
TableFactor(const gtsam::DiscreteKeys& keys, string spec);
TableFactor(const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor& dtf);
TableFactor(const gtsam::DecisionTreeFactor& dtf);
void print(string s = "TableFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/TableDistribution.h>
virtual class TableDistribution : gtsam::DiscreteConditional {
TableDistribution();
TableDistribution(const gtsam::TableFactor& f);
TableDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
TableDistribution(const gtsam::DiscreteKeys& keys, std::vector<double> spec);
TableDistribution(const gtsam::DiscreteKeys& keys, string spec);
TableDistribution(const gtsam::DiscreteKey& key, string spec);
void print(string s = "Table Distribution\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::TableFactor table() const;
double evaluate(const gtsam::DiscreteValues& values) const;
size_t nrValues() const;
};
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
class DiscreteBayesNet { class DiscreteBayesNet {
DiscreteBayesNet(); DiscreteBayesNet();

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
joint = joint * (*conditional); joint = joint * (*conditional);
} }
// Prune the joint. NOTE: again, possibly quite expensive. // Create the result starting with the pruned joint.
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
// Create a the result starting with the pruned joint.
HybridBayesNet result; HybridBayesNet result;
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned); result.emplace_shared<DiscreteConditional>(joint);
// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
result.back()->asDiscrete()->prune(maxNrLeaves);
// Get pruned discrete probabilities so
// we can prune HybridGaussianConditionals.
DiscreteConditional pruned = *result.back()->asDiscrete();
/* To prune, we visitWith every leaf in the HybridGaussianConditional. /* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete()); if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
dtc->toDecisionTreeFactor()));
} else {
discrete_fg.push_back(conditional->asDiscrete());
}
} }
} }

View File

@ -20,6 +20,7 @@
#include <gtsam/base/treeTraversal-inst.h> #include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol); return Base::equals(other, tol);
} }
/* ************************************************************************* */
DiscreteValues HybridBayesTree::discreteMaxProduct(
const DiscreteFactorGraph& dfg) const {
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
// Check type of product, and get as TableFactor for efficiency.
TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
DiscreteValues assignment = TableDistribution(p).argmax();
return assignment;
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const { HybridValues HybridBayesTree::optimize() const {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE // The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) { if (root_conditional->isDiscrete()) {
discrete_fg.push_back(root_conditional->asDiscrete()); auto discrete = root_conditional->asDiscrete<TableDistribution>();
mpe = discrete_fg.optimize(); discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"HybridBayesTree root is not discrete-only. Please check elimination " "HybridBayesTree root is not discrete-only. Please check elimination "
@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); auto prunedDiscreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); // Imperative pruning
discreteProbs->root_ = prunedDiscreteProbs.root_; prunedDiscreteProbs->prune(maxNrLeaves);
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteProbs; DiscreteConditional::shared_ptr prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteProbs(prunedDiscreteProbs) {} : prunedDiscreteProbs(prunedDiscreteProbs) {}
@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (!hybridGaussianCond->pruned()) { if (!hybridGaussianCond->pruned()) {
// Imperative // Imperative
clique->conditional() = std::make_shared<HybridConditional>( clique->conditional() = std::make_shared<HybridConditional>(
hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
} }
} }
return parentData; return parentData;

View File

@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/// @} /// @}
private: private:
/// Helper method to compute the max product assignment
/// given a DiscreteFactorGraph
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const;
#if GTSAM_ENABLE_BOOST_SERIALIZATION #if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional
} }
/** /**
* @brief Return conditional as a DiscreteConditional * @brief Return conditional as a DiscreteConditional or specified type T.
* @return nullptr if not a DiscreteConditional * @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr * @return DiscreteConditional::shared_ptr
*/ */
DiscreteConditional::shared_ptr asDiscrete() const { template <typename T = DiscreteConditional>
return std::dynamic_pointer_cast<DiscreteConditional>(inner_); typename T::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<T>(inner_);
} }
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type

View File

@ -304,7 +304,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DecisionTreeFactor &discreteProbs) const { const DiscreteConditional &discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys(): // Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end()); std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(), std::set<Key> theirs(discreteProbs.keys().begin(),

View File

@ -23,6 +23,7 @@
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @return Shared pointer to possibly a pruned HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional
*/ */
HybridGaussianConditional::shared_ptr prune( HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const; const DiscreteConditional &discreteProbs) const;
/// Return true if the conditional has already been pruned. /// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; } bool pruned() const { return pruned_; }

View File

@ -25,6 +25,7 @@
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
@ -255,46 +256,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials); return std::make_shared<TableFactor>(discreteKeys, potentials);
} }
/**
* @brief Multiply all the `factors` using the machinery of the TableFactor.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = product * TableFactor(*dtf);
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
/* ************************************************************************ */ /* ************************************************************************ */
static DiscreteFactorGraph CollectDiscreteFactors( static DiscreteFactorGraph CollectDiscreteFactors(
const HybridGaussianFactorGraph &factors) { const HybridGaussianFactorGraph &factors) {
@ -325,12 +286,17 @@ static DiscreteFactorGraph CollectDiscreteFactors(
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor); gttic_(ConvertConditionalToTableFactor);
#endif #endif
// Convert DiscreteConditional to TableFactor if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
auto tdc = std::make_shared<TableFactor>(*dc); /// Get the underlying TableFactor
dfg.push_back(dtc->table());
} else {
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
dfg.push_back(tdc);
}
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor); gttoc_(ConvertConditionalToTableFactor);
#endif #endif
dfg.push_back(tdc);
} else { } else {
throwRuntimeError("discreteElimination", f); throwRuntimeError("discreteElimination", f);
} }
@ -355,21 +321,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// so we can use the TableFactor for efficiency. // so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) { if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor // Get product factor
TableFactor product = TableProduct(dfg); DiscreteFactor::shared_ptr product = dfg.scaledProduct();
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional); gttic_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
auto conditional = std::make_shared<DiscreteConditional>( // Check type of product, and get as TableFactor for efficiency.
frontalKeys.size(), product.toDecisionTreeFactor()); TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional); gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
TableFactor::shared_ptr sum = product.sum(frontalKeys); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
return {std::make_shared<HybridConditional>(conditional), sum}; return {std::make_shared<HybridConditional>(conditional), sum};
@ -378,6 +347,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
auto result = EliminateDiscrete(dfg, frontalKeys); auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -15,6 +15,7 @@
* @author Varun Agrawal * @author Varun Agrawal
*/ */
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h> #include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h> #include <gtsam/hybrid/HybridNonlinearISAM.h>
@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() {
// Obtain the new linearization point // Obtain the new linearization point
const Values newLinPoint = estimate(); const Values newLinPoint = estimate();
auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete()); DiscreteConditional::shared_ptr discreteProbabilities;
auto discreteRoot = isam_.roots().at(0)->conditional();
if (discreteRoot->asDiscrete<TableDistribution>()) {
discreteProbabilities = discreteRoot->asDiscrete<TableDistribution>();
} else {
discreteProbabilities = discreteRoot->asDiscrete();
}
isam_.clear(); isam_.clear();
@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() {
HybridNonlinearFactorGraph pruned_factors; HybridNonlinearFactorGraph pruned_factors;
for (auto&& factor : factors_) { for (auto&& factor : factors_) {
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) { if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
pruned_factors.push_back(nf->prune(discreteProbs)); pruned_factors.push_back(nf->prune(*discreteProbabilities));
} else { } else {
pruned_factors.push_back(factor); pruned_factors.push_back(factor);
} }

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double midway = mu1 - mu0; double midway = mu1 - mu0;
auto eliminationResult = auto eliminationResult =
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
auto pMid = *eliminationResult->at(0)->asDiscrete(); auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid)); EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid));
// Everywhere else, the result should be a sigmoid. // Everywhere else, the result should be a sigmoid.
for (const double shift : {-4, -2, 0, 2, 4}) { for (const double shift : {-4, -2, 0, 2, 4}) {
@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
// Workflow 1: convert HBN to HFG and solve // Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 = auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve
@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)}); m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
hfg1.push_back(mixing); hfg1.push_back(mixing);
auto eliminationResult2 = hfg1.eliminateSequential(); auto eliminationResult2 = hfg1.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
} }
} }
@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
// Eliminate the graph! // Eliminate the graph!
auto eliminationResultMax = gfg.eliminateSequential(); auto eliminationResultMax = gfg.eliminateSequential();
// Equality of posteriors asserts that the elimination is correct (same ratios // Equality of posteriors asserts that the elimination is correct
// for all modes) // (same ratios for all modes)
EXPECT(assert_equal(expectedDiscretePosterior, EXPECT(assert_equal(expectedDiscretePosterior,
eliminationResultMax->discretePosterior(vv))); eliminationResultMax->discretePosterior(vv)));
auto pMax = *eliminationResultMax->at(0)->asDiscrete(); auto pMax = *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4)); EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4));
// Everywhere else, the result should be a bell curve like function. // Everywhere else, the result should be a bell curve like function.
for (const double shift : {-4, -2, 0, 2, 4}) { for (const double shift : {-4, -2, 0, 2, 4}) {
@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
// Workflow 1: convert HBN to HFG and solve // Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 = auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve
@ -158,7 +162,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)}); m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
hfg.push_back(mixing); hfg.push_back(mixing);
auto eliminationResult2 = hfg.eliminateSequential(); auto eliminationResult2 = hfg.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
} }
} }

View File

@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
} }
size_t maxNrLeaves = 3; size_t maxNrLeaves = 3;
auto prunedDecisionTree = joint.prune(maxNrLeaves); DiscreteConditional prunedDecisionTree(joint);
prunedDecisionTree.prune(maxNrLeaves);
#ifdef GTSAM_DT_MERGING #ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,

View File

@ -16,6 +16,7 @@
*/ */
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/geometry/Pose3.h> #include <gtsam/geometry/Pose3.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
// Create expected discrete conditional on m0. // Create expected discrete conditional on m0.
DiscreteKey m(M(0), 2); DiscreteKey m(M(0), 2);
DiscreteConditional expected(m % "0.51341712/1"); // regression TableDistribution expected(m, "0.51341712 1"); // regression
// Eliminate into BN using one ordering // Eliminate into BN using one ordering
const Ordering ordering1{X(0), X(1), M(0)}; const Ordering ordering1{X(0), X(1), M(0)};
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1); HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// Check that the discrete conditional matches the expected. // Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete(); auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc1, 1e-9)); EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering // Eliminate into BN using a different ordering
@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2); HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected. // Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete(); auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc2, 1e-9)); EXPECT(assert_equal(expected, *dc2, 1e-9));
} }

View File

@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
potentials[i] = 1; potentials[i] = 1;
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
// Prune the HybridGaussianConditional // Prune the HybridGaussianConditional
const auto pruned = hgc.prune(decisionTreeFactor); const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 1 conditional // Check that the pruned HybridGaussianConditional has 1 conditional
EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
} }
@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0}; 0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor); const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 2 conditionals // Check that the pruned HybridGaussianConditional has 2 conditionals
EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
0, 0, 0.5, 0}; 0, 0, 0.5, 0};
const DecisionTreeFactor decisionTreeFactor(keys, potentials); const DecisionTreeFactor decisionTreeFactor(keys, potentials);
const auto pruned = hgc.prune(decisionTreeFactor); const auto pruned =
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
// Check that the pruned HybridGaussianConditional has 3 conditionals // Check that the pruned HybridGaussianConditional has 3 conditionals
EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); EXPECT_LONGS_EQUAL(3, pruned->nrComponents());

View File

@ -25,6 +25,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
mode, std::vector{conditional0, conditional1}); mode, std::vector{conditional0, conditional1});
// Add prior on mode. // Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "74/26"); expectedBayesNet.emplace_shared<TableDistribution>(mode, "74 26");
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
m1, std::vector{conditional0, conditional1}); m1, std::vector{conditional0, conditional1});
// Add prior on m1. // Add prior on m1.
expectedBayesNet.emplace_shared<DiscreteConditional>(m1, "1/1"); expectedBayesNet.emplace_shared<TableDistribution>(m1, "0.188638 0.811362");
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
mode, std::vector{conditional0, conditional1}); mode, std::vector{conditional0, conditional1});
// Add prior on mode. // Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteConditional>(mode, "23/77"); // Since this is the only discrete conditional, it is added as a
// TableDistribution.
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23 77");
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridGaussianISAM.h> #include <gtsam/hybrid/HybridGaussianISAM.h>
@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering); expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
// Test the probability values with regression tests. // Test the probability values with regression tests.
auto discrete = isam[M(1)]->conditional()->asDiscrete(); auto discrete = isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5)); EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) {
1 1 1 Leaf 0.5 1 1 1 Leaf 0.5
*/ */
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>( auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
incrementalHybrid[M(0)]->conditional()->inner()); incrementalHybrid[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
// Get the number of elements which are greater than 0.
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
// Check that the number of leaves after pruning is 5. // Check that the number of leaves after pruning is 5.
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
// Check that the hybrid nodes of the bayes net match those of the pre-pruning // Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions. // bayes net, at the same positions.
@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) {
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph; DiscreteFactorGraph discreteGraph;
discreteGraph.push_back(discreteTree); // discreteTree is a TableDistribution, so we convert to
// DecisionTreeFactor for the DiscreteFactorGraph
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues optimal_assignment = discreteGraph.optimize();
DiscreteValues expected_assignment; DiscreteValues expected_assignment;

View File

@ -22,6 +22,7 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since no measurement on x1, we hedge our bets // Since no measurement on x1, we hedge our bets
// Importance sampling run with 100k samples gives 50.051/49.949 // Importance sampling run with 100k samples gives 50.051/49.949
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "50/50"); TableDistribution expected(m1, "50 50");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()))); EXPECT(
assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>())));
} }
{ {
@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since we have a measurement on x1, we get a definite result // Since we have a measurement on x1, we get a definite result
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "44.3854/55.6146"); TableDistribution expected(m1, "44.3854 55.6146");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
} }
} }
@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "48.3158/51.6842"); TableDistribution expected(m1, "48.3158 51.6842");
EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002)); EXPECT(assert_equal(
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()), 0.02));
} }
{ {
@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "55.396/44.604"); TableDistribution expected(m1, "55.396 44.604");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
} }
} }
@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "51.7762/48.2238"); TableDistribution expected(m1, "51.7762 48.2238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.02));
} }
{ {
@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "49.0762/50.9238"); TableDistribution expected(m1, "49.0762 50.9238");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005)); EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.05));
} }
} }
@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteConditional expected(m1, "8.91527/91.0847"); TableDistribution expected(m1, "8.91527 91.0847");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete<TableDistribution>()),
0.01));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) {
DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv0{{M(1), 0}};
DiscreteValues dv1{{M(1), 1}}; DiscreteValues dv1{{M(1), 1}};
DiscreteConditional expected_m1(m1, "0.5/0.5"); TableDistribution expected_m1(m1, "0.5 0.5");
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
EXPECT(assert_equal(expected_m1, actual_m1)); EXPECT(assert_equal(expected_m1, actual_m1));
} }

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
@ -512,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
// P(m1) // P(m1)
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)}); EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents()); EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
EXPECT( TableDistribution dtc =
dynamic_pointer_cast<DiscreteConditional>(hybridBayesNet->at(4)->inner()) *hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
->equals(*discreteBayesNet.at(1))); EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
.equals(*discreteBayesNet.at(1)));
} }
/**************************************************************************** /****************************************************************************
@ -1061,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv0{{M(1), 0}};
DiscreteValues dv1{{M(1), 1}}; DiscreteValues dv1{{M(1), 1}};
DiscreteConditional expected_m1(m1, "0.5/0.5"); TableDistribution expected_m1(m1, "0.5 0.5");
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete<TableDistribution>());
EXPECT(assert_equal(expected_m1, actual_m1)); EXPECT(assert_equal(expected_m1, actual_m1));
} }

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h> #include <gtsam/hybrid/HybridNonlinearISAM.h>
@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) {
1 1 1 Leaf 0.5 1 1 1 Leaf 0.5
*/ */
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteConditional>( auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
bayesTree[M(0)]->conditional()->inner()); bayesTree[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
// Get the number of elements which are greater than 0.
auto count = [](const double &value, int count) {
return value > 0 ? count + 1 : count;
};
// Check that the number of leaves after pruning is 5. // Check that the number of leaves after pruning is 5.
EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues());
// Check that the hybrid nodes of the bayes net match those of the pre-pruning // Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions. // bayes net, at the same positions.
@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated // The final discrete graph should not be empty since we have eliminated
// all continuous variables. // all continuous variables.
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete(); auto discreteTree =
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
EXPECT_LONGS_EQUAL(3, discreteTree->size()); EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph; DiscreteFactorGraph discreteGraph;
discreteGraph.push_back(discreteTree); discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues optimal_assignment = discreteGraph.optimize();
DiscreteValues expected_assignment; DiscreteValues expected_assignment;

View File

@ -18,6 +18,7 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>; using ADT = AlgebraicDecisionTree<Key>;

View File

@ -13,14 +13,14 @@ Author: Fan Jiang, Varun Agrawal, Frank Dellaert
import unittest import unittest
import numpy as np import numpy as np
from gtsam.symbol_shorthand import C, M, X, Z
from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, GaussianConditional, from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet,
HybridBayesNet, HybridGaussianConditional, HybridGaussianConditional, HybridGaussianFactor,
HybridGaussianFactor, HybridGaussianFactorGraph, HybridGaussianFactorGraph, HybridValues, JacobianFactor,
HybridValues, JacobianFactor, noiseModel) TableDistribution, noiseModel)
from gtsam.symbol_shorthand import C, M, X, Z
from gtsam.utils.test_case import GtsamTestCase
DEBUG_MARGINALS = False DEBUG_MARGINALS = False
@ -51,7 +51,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(len(hybridCond.keys()), 2) self.assertEqual(len(hybridCond.keys()), 2)
discrete_conditional = hbn.at(hbn.size() - 1).inner() discrete_conditional = hbn.at(hbn.size() - 1).inner()
self.assertIsInstance(discrete_conditional, DiscreteConditional) self.assertIsInstance(discrete_conditional, TableDistribution)
def test_optimize(self): def test_optimize(self):
"""Test construction of hybrid factor graph.""" """Test construction of hybrid factor graph."""