rename from DiscreteTableConditional to TableDistribution

release/4.3a0
Varun Agrawal 2025-01-03 14:51:32 -05:00
parent 0098112f27
commit 9b1918c085
14 changed files with 95 additions and 95 deletions

View File

@ -10,14 +10,14 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file DiscreteTableConditional.cpp * @file TableDistribution.cpp
* @date Dec 22, 2024 * @date Dec 22, 2024
* @author Varun Agrawal * @author Varun Agrawal
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteTableConditional.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/Ring.h> #include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -38,42 +38,42 @@ using std::vector;
namespace gtsam { namespace gtsam {
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals, TableDistribution::TableDistribution(const size_t nrFrontals,
const TableFactor& f) const TableFactor& f)
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())), : BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
table_(f / (*f.sum(nrFrontals))) {} table_(f / (*f.sum(nrFrontals))) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional( TableDistribution::TableDistribution(
size_t nrFrontals, const DiscreteKeys& keys, size_t nrFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials) const Eigen::SparseVector<double>& potentials)
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())), : BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
table_(TableFactor(keys, potentials)) {} table_(TableFactor(keys, potentials)) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, TableDistribution::TableDistribution(const TableFactor& joint,
const TableFactor& marginal) const TableFactor& marginal)
: BaseConditional(joint.size() - marginal.size(), : BaseConditional(joint.size() - marginal.size(),
joint.discreteKeys() & marginal.discreteKeys(), ADT()), joint.discreteKeys() & marginal.discreteKeys(), ADT()),
table_(joint / marginal) {} table_(joint / marginal) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, TableDistribution::TableDistribution(const TableFactor& joint,
const TableFactor& marginal, const TableFactor& marginal,
const Ordering& orderedKeys) const Ordering& orderedKeys)
: DiscreteTableConditional(joint, marginal) { : TableDistribution(joint, marginal) {
keys_.clear(); keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
} }
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature) TableDistribution::TableDistribution(const Signature& signature)
: BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))), : BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))),
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {} table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional DiscreteTableConditional::operator*( TableDistribution TableDistribution::operator*(
const DiscreteTableConditional& other) const { const TableDistribution& other) const {
// Take union of frontal keys // Take union of frontal keys
std::set<Key> newFrontals; std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key); for (auto&& key : this->frontals()) newFrontals.insert(key);
@ -82,7 +82,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
// Check if frontals overlapped // Check if frontals overlapped
if (nrFrontals() + other.nrFrontals() > newFrontals.size()) if (nrFrontals() + other.nrFrontals() > newFrontals.size())
throw std::invalid_argument( throw std::invalid_argument(
"DiscreteTableConditional::operator* called with overlapping frontal " "TableDistribution::operator* called with overlapping frontal "
"keys."); "keys.");
// Now, add cardinalities. // Now, add cardinalities.
@ -106,11 +106,11 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
for (auto&& dk : parents) discreteKeys.push_back(dk); for (auto&& dk : parents) discreteKeys.push_back(dk);
TableFactor product = this->table_ * other.table(); TableFactor product = this->table_ * other.table();
return DiscreteTableConditional(newFrontals.size(), product); return TableDistribution(newFrontals.size(), product);
} }
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteTableConditional::print(const string& s, void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s << " P( "; cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
@ -128,9 +128,9 @@ void DiscreteTableConditional::print(const string& s,
} }
/* ************************************************************************** */ /* ************************************************************************** */
bool DiscreteTableConditional::equals(const DiscreteFactor& other, bool TableDistribution::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
auto dtc = dynamic_cast<const DiscreteTableConditional*>(&other); auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) { if (!dtc) {
return false; return false;
} else { } else {
@ -142,17 +142,17 @@ bool DiscreteTableConditional::equals(const DiscreteFactor& other,
} }
/* ****************************************************************************/ /* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::max( DiscreteConditional::shared_ptr TableDistribution::max(
const Ordering& keys) const { const Ordering& keys) const {
auto m = *table_.max(keys); auto m = *table_.max(keys);
return std::make_shared<DiscreteTableConditional>(m.discreteKeys().size(), m); return std::make_shared<TableDistribution>(m.discreteKeys().size(), m);
} }
/* ****************************************************************************/ /* ****************************************************************************/
void DiscreteTableConditional::setData( void TableDistribution::setData(
const DiscreteConditional::shared_ptr& dc) { const DiscreteConditional::shared_ptr& dc) {
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) { if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
this->table_ = dtc->table_; this->table_ = dtc->table_;
} else { } else {
this->table_ = TableFactor(dc->discreteKeys(), *dc); this->table_ = TableFactor(dc->discreteKeys(), *dc);
@ -160,11 +160,11 @@ void DiscreteTableConditional::setData(
} }
/* ****************************************************************************/ /* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::prune( DiscreteConditional::shared_ptr TableDistribution::prune(
size_t maxNrAssignments) const { size_t maxNrAssignments) const {
TableFactor pruned = table_.prune(maxNrAssignments); TableFactor pruned = table_.prune(maxNrAssignments);
return std::make_shared<DiscreteTableConditional>( return std::make_shared<TableDistribution>(
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable()); this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
} }

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file DiscreteTableConditional.h * @file TableDistribution.h
* @date Dec 22, 2024 * @date Dec 22, 2024
* @author Varun Agrawal * @author Varun Agrawal
*/ */
@ -34,7 +34,7 @@ namespace gtsam {
* *
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
private: private:
TableFactor table_; TableFactor table_;
@ -42,7 +42,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteTableConditional This; ///< Typedef to this class typedef TableDistribution This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DiscreteConditional typedef DiscreteConditional
BaseConditional; ///< Typedef to our conditional base class BaseConditional; ///< Typedef to our conditional base class
@ -53,42 +53,42 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
/// @{ /// @{
/// Default constructor needed for serialization. /// Default constructor needed for serialization.
DiscreteTableConditional() {} TableDistribution() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals. /// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteTableConditional(size_t nFrontals, const TableFactor& f); TableDistribution(size_t nFrontals, const TableFactor& f);
/** /**
* Construct from DiscreteKeys and SparseVector, taking the first * Construct from DiscreteKeys and SparseVector, taking the first
* `nFrontals` keys as frontals, in the order given. * `nFrontals` keys as frontals, in the order given.
*/ */
DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys, TableDistribution(size_t nFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials); const Eigen::SparseVector<double>& potentials);
/** Construct from signature */ /** Construct from signature */
explicit DiscreteTableConditional(const Signature& signature); explicit TableDistribution(const Signature& signature);
/** /**
* Construct from key, parents, and a Signature::Table specifying the * Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For * conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
* *
* Example: DiscreteTableConditional P(D, {B,E}, table); * Example: TableDistribution P(D, {B,E}, table);
*/ */
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table) const Signature::Table& table)
: DiscreteTableConditional(Signature(key, parents, table)) {} : TableDistribution(Signature(key, parents, table)) {}
/** /**
* Construct from key, parents, and a vector<double> specifying the * Construct from key, parents, and a vector<double> specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For * conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
* *
* Example: DiscreteTableConditional P(D, {B,E}, table); * Example: TableDistribution P(D, {B,E}, table);
*/ */
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const std::vector<double>& table) const std::vector<double>& table)
: DiscreteTableConditional( : TableDistribution(
1, TableFactor(DiscreteKeys{key} & parents, table)) {} 1, TableFactor(DiscreteKeys{key} & parents, table)) {}
/** /**
@ -98,21 +98,21 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
* *
* The string is parsed into a Signature::Table. * The string is parsed into a Signature::Table.
* *
* Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); * Example: TableDistribution P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/ */
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec) const std::string& spec)
: DiscreteTableConditional(Signature(key, parents, spec)) {} : TableDistribution(Signature(key, parents, spec)) {}
/// No-parent specialization; can also use DiscreteDistribution. /// No-parent specialization; can also use DiscreteDistribution.
DiscreteTableConditional(const DiscreteKey& key, const std::string& spec) TableDistribution(const DiscreteKey& key, const std::string& spec)
: DiscreteTableConditional(Signature(key, {}, spec)) {} : TableDistribution(Signature(key, {}, spec)) {}
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/ */
DiscreteTableConditional(const TableFactor& joint, TableDistribution(const TableFactor& joint,
const TableFactor& marginal); const TableFactor& marginal);
/** /**
@ -120,7 +120,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys. * Makes sure the keys are ordered as given. Does not check orderedKeys.
*/ */
DiscreteTableConditional(const TableFactor& joint, TableDistribution(const TableFactor& joint,
const TableFactor& marginal, const TableFactor& marginal,
const Ordering& orderedKeys); const Ordering& orderedKeys);
@ -139,8 +139,8 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
* P(A|B) * P(B|A) = ? * P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic. * We check for overlapping frontals, but do *not* check for cyclic.
*/ */
DiscreteTableConditional operator*( TableDistribution operator*(
const DiscreteTableConditional& other) const; const TableDistribution& other) const;
/// @} /// @}
/// @name Testable /// @name Testable
@ -210,11 +210,11 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
} }
#endif #endif
}; };
// DiscreteTableConditional // TableDistribution
// traits // traits
template <> template <>
struct traits<DiscreteTableConditional> struct traits<TableDistribution>
: public Testable<DiscreteTableConditional> {}; : public Testable<TableDistribution> {};
} // namespace gtsam } // namespace gtsam

View File

@ -52,9 +52,9 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Multiply into one big conditional. NOTE: possibly quite expensive. // Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint; DiscreteConditional joint;
for (auto &&conditional : marginal) { for (auto &&conditional : marginal) {
// The last discrete conditional may be a DiscreteTableConditional // The last discrete conditional may be a TableDistribution
if (auto dtc = if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) { std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc; joint = joint * dc;
} else { } else {
@ -133,7 +133,7 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) { if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
// The number of keys should be small so should not // The number of keys should be small so should not
// be expensive to convert to DiscreteConditional. // be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(), discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),

View File

@ -20,7 +20,7 @@
#include <gtsam/base/treeTraversal-inst.h> #include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteTableConditional.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -72,7 +72,7 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE // The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) { if (root_conditional->isDiscrete()) {
auto discrete = std::dynamic_pointer_cast<DiscreteTableConditional>( auto discrete = std::dynamic_pointer_cast<TableDistribution>(
root_conditional->asDiscrete()); root_conditional->asDiscrete());
discrete_fg.push_back(discrete); discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg); mpe = discreteMaxProduct(discrete_fg);
@ -202,7 +202,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = auto discreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>(); this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DiscreteConditional::shared_ptr prunedDiscreteProbs = DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves); discreteProbs->prune(maxNrLeaves);

View File

@ -265,7 +265,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
for (auto &&factor : factors) { for (auto &&factor : factors) {
if (factor) { if (factor) {
if (auto dtc = if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(factor)) { std::dynamic_pointer_cast<TableDistribution>(factor)) {
product = product * dtc->table(); product = product * dtc->table();
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) { } else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f); product = product * (*f);
@ -323,7 +323,7 @@ static DiscreteFactorGraph CollectDiscreteFactors(
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor); gttic_(ConvertConditionalToTableFactor);
#endif #endif
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) { if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
/// Get the underlying TableFactor /// Get the underlying TableFactor
dfg.push_back(dtc->table()); dfg.push_back(dtc->table());
} else { } else {
@ -364,7 +364,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
gttic_(EliminateDiscreteFormDiscreteConditional); gttic_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
auto conditional = auto conditional =
std::make_shared<DiscreteTableConditional>(frontalKeys.size(), product); std::make_shared<TableDistribution>(frontalKeys.size(), product);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional); gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif #endif

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -80,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double midway = mu1 - mu0; double midway = mu1 - mu0;
auto eliminationResult = auto eliminationResult =
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
auto pMid = eliminationResult->at(0)->asDiscrete<DiscreteTableConditional>(); auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid)); EXPECT(assert_equal(TableDistribution(m, "60/40"), *pMid));
// Everywhere else, the result should be a sigmoid. // Everywhere else, the result should be a sigmoid.
for (const double shift : {-4, -2, 0, 2, 4}) { for (const double shift : {-4, -2, 0, 2, 4}) {
@ -92,7 +92,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
auto eliminationResult1 = auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>(); *eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve
@ -102,7 +102,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
hfg1.push_back(mixing); hfg1.push_back(mixing);
auto eliminationResult2 = hfg1.eliminateSequential(); auto eliminationResult2 = hfg1.eliminateSequential();
auto posterior2 = auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>(); *eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
} }
} }
@ -142,8 +142,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
eliminationResultMax->discretePosterior(vv))); eliminationResultMax->discretePosterior(vv)));
auto pMax = auto pMax =
*eliminationResultMax->at(0)->asDiscrete<DiscreteTableConditional>(); *eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4)); EXPECT(assert_equal(TableDistribution(m, "42/58"), pMax, 1e-4));
// Everywhere else, the result should be a bell curve like function. // Everywhere else, the result should be a bell curve like function.
for (const double shift : {-4, -2, 0, 2, 4}) { for (const double shift : {-4, -2, 0, 2, 4}) {
@ -154,7 +154,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
auto eliminationResult1 = auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>(); *eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve
@ -164,7 +164,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
hfg.push_back(mixing); hfg.push_back(mixing);
auto eliminationResult2 = hfg.eliminateSequential(); auto eliminationResult2 = hfg.eliminateSequential();
auto posterior2 = auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>(); *eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
} }
} }

View File

@ -450,9 +450,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
DiscreteConditional joint; DiscreteConditional joint;
for (auto&& conditional : posterior->discreteMarginal()) { for (auto&& conditional : posterior->discreteMarginal()) {
// The last discrete conditional may be a DiscreteTableConditional // The last discrete conditional may be a TableDistribution
if (auto dtc = if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) { std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc; joint = joint * dc;
} else { } else {

View File

@ -464,14 +464,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
// Create expected discrete conditional on m0. // Create expected discrete conditional on m0.
DiscreteKey m(M(0), 2); DiscreteKey m(M(0), 2);
DiscreteTableConditional expected(m % "0.51341712/1"); // regression TableDistribution expected(m % "0.51341712/1"); // regression
// Eliminate into BN using one ordering // Eliminate into BN using one ordering
const Ordering ordering1{X(0), X(1), M(0)}; const Ordering ordering1{X(0), X(1), M(0)};
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1); HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// Check that the discrete conditional matches the expected. // Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete<DiscreteTableConditional>(); auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc1, 1e-9)); EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering // Eliminate into BN using a different ordering
@ -479,7 +479,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2); HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected. // Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete<DiscreteTableConditional>(); auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc2, 1e-9)); EXPECT(assert_equal(expected, *dc2, 1e-9));
} }

View File

@ -650,7 +650,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
mode, std::vector{conditional0, conditional1}); mode, std::vector{conditional0, conditional1});
// Add prior on mode. // Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "74/26"); expectedBayesNet.emplace_shared<TableDistribution>(mode, "74/26");
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();
@ -700,7 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
m1, std::vector{conditional0, conditional1}); m1, std::vector{conditional0, conditional1});
// Add prior on m1. // Add prior on m1.
expectedBayesNet.emplace_shared<DiscreteTableConditional>( expectedBayesNet.emplace_shared<TableDistribution>(
m1, "0.188638/0.811362"); m1, "0.188638/0.811362");
// Test elimination // Test elimination
@ -738,8 +738,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Add prior on mode. // Add prior on mode.
// Since this is the only discrete conditional, it is added as a // Since this is the only discrete conditional, it is added as a
// DiscreteTableConditional. // TableDistribution.
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "23/77"); expectedBayesNet.emplace_shared<TableDistribution>(mode, "23/77");
// Test elimination // Test elimination
const auto posterior = fg.eliminateSequential(); const auto posterior = fg.eliminateSequential();

View File

@ -142,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
auto discrete = auto discrete =
isam[M(1)]->conditional()->asDiscrete<DiscreteTableConditional>(); isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5)); EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
@ -222,7 +222,7 @@ TEST(HybridGaussianISAM, ApproxInference) {
1 1 1 Leaf 0.5 1 1 1 Leaf 0.5
*/ */
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>( auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
incrementalHybrid[M(0)]->conditional()->inner()); incrementalHybrid[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
@ -474,7 +474,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph; DiscreteFactorGraph discreteGraph;
// discreteTree is a DiscreteTableConditional, so we convert to // discreteTree is a TableDistribution, so we convert to
// DecisionTreeFactor for the DiscreteFactorGraph // DecisionTreeFactor for the DiscreteFactorGraph
discreteGraph.push_back(discreteTree->toDecisionTreeFactor()); discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues optimal_assignment = discreteGraph.optimize();

View File

@ -21,7 +21,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteTableConditional.h> #include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
@ -144,9 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since no measurement on x1, we hedge our bets // Since no measurement on x1, we hedge our bets
// Importance sampling run with 100k samples gives 50.051/49.949 // Importance sampling run with 100k samples gives 50.051/49.949
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "50/50"); TableDistribution expected(m1, "50/50");
EXPECT(assert_equal(expected, EXPECT(assert_equal(expected,
*(bn->at(2)->asDiscrete<DiscreteTableConditional>()))); *(bn->at(2)->asDiscrete<TableDistribution>())));
} }
{ {
@ -162,9 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since we have a measurement on x1, we get a definite result // Since we have a measurement on x1, we get a definite result
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "44.3854/55.6146"); TableDistribution expected(m1, "44.3854/55.6146");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02)); expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
} }
} }
@ -251,9 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "48.3158/51.6842"); TableDistribution expected(m1, "48.3158/51.6842");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(eliminated->at(2)->asDiscrete<DiscreteTableConditional>()), expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()),
0.02)); 0.02));
} }
@ -268,9 +268,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "55.396/44.604"); TableDistribution expected(m1, "55.396/44.604");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02)); expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
} }
} }
@ -346,9 +346,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "51.7762/48.2238"); TableDistribution expected(m1, "51.7762/48.2238");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02)); expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
} }
{ {
@ -362,9 +362,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "49.0762/50.9238"); TableDistribution expected(m1, "49.0762/50.9238");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.05)); expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.05));
} }
} }
@ -389,9 +389,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
// Values taken from an importance sampling run with 100k samples: // Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given); // approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "8.91527/91.0847"); TableDistribution expected(m1, "8.91527/91.0847");
EXPECT(assert_equal( EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.01)); expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.01));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -512,7 +512,7 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
// P(m1) // P(m1)
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)}); EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents()); EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
DiscreteTableConditional dtc = *hybridBayesNet->at(4)->asDiscrete<DiscreteTableConditional>(); TableDistribution dtc = *hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
EXPECT( EXPECT(
DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor()) DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
.equals(*discreteBayesNet.at(1))); .equals(*discreteBayesNet.at(1)));

View File

@ -265,7 +265,7 @@ TEST(HybridNonlinearISAM, ApproxInference) {
1 1 1 Leaf 0.5 1 1 1 Leaf 0.5
*/ */
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>( auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
bayesTree[M(0)]->conditional()->inner()); bayesTree[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
@ -517,7 +517,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated // The final discrete graph should not be empty since we have eliminated
// all continuous variables. // all continuous variables.
auto discreteTree = auto discreteTree =
bayesTree[M(3)]->conditional()->asDiscrete<DiscreteTableConditional>(); bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
EXPECT_LONGS_EQUAL(3, discreteTree->size()); EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).