rename from DiscreteTableConditional to TableDistribution
parent
0098112f27
commit
9b1918c085
|
@ -10,14 +10,14 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteTableConditional.cpp
|
||||
* @file TableDistribution.cpp
|
||||
* @date Dec 22, 2024
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/discrete/Ring.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
@ -38,42 +38,42 @@ using std::vector;
|
|||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
|
||||
TableDistribution::TableDistribution(const size_t nrFrontals,
|
||||
const TableFactor& f)
|
||||
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
||||
table_(f / (*f.sum(nrFrontals))) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional::DiscreteTableConditional(
|
||||
TableDistribution::TableDistribution(
|
||||
size_t nrFrontals, const DiscreteKeys& keys,
|
||||
const Eigen::SparseVector<double>& potentials)
|
||||
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
||||
table_(TableFactor(keys, potentials)) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||
TableDistribution::TableDistribution(const TableFactor& joint,
|
||||
const TableFactor& marginal)
|
||||
: BaseConditional(joint.size() - marginal.size(),
|
||||
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
||||
table_(joint / marginal) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||
TableDistribution::TableDistribution(const TableFactor& joint,
|
||||
const TableFactor& marginal,
|
||||
const Ordering& orderedKeys)
|
||||
: DiscreteTableConditional(joint, marginal) {
|
||||
: TableDistribution(joint, marginal) {
|
||||
keys_.clear();
|
||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
|
||||
TableDistribution::TableDistribution(const Signature& signature)
|
||||
: BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))),
|
||||
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||
const DiscreteTableConditional& other) const {
|
||||
TableDistribution TableDistribution::operator*(
|
||||
const TableDistribution& other) const {
|
||||
// Take union of frontal keys
|
||||
std::set<Key> newFrontals;
|
||||
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||
|
@ -82,7 +82,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
|
|||
// Check if frontals overlapped
|
||||
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||
throw std::invalid_argument(
|
||||
"DiscreteTableConditional::operator* called with overlapping frontal "
|
||||
"TableDistribution::operator* called with overlapping frontal "
|
||||
"keys.");
|
||||
|
||||
// Now, add cardinalities.
|
||||
|
@ -106,11 +106,11 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
|
|||
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||
|
||||
TableFactor product = this->table_ * other.table();
|
||||
return DiscreteTableConditional(newFrontals.size(), product);
|
||||
return TableDistribution(newFrontals.size(), product);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteTableConditional::print(const string& s,
|
||||
void TableDistribution::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s << " P( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
|
@ -128,9 +128,9 @@ void DiscreteTableConditional::print(const string& s,
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
bool DiscreteTableConditional::equals(const DiscreteFactor& other,
|
||||
bool TableDistribution::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
auto dtc = dynamic_cast<const DiscreteTableConditional*>(&other);
|
||||
auto dtc = dynamic_cast<const TableDistribution*>(&other);
|
||||
if (!dtc) {
|
||||
return false;
|
||||
} else {
|
||||
|
@ -142,17 +142,17 @@ bool DiscreteTableConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteConditional::shared_ptr DiscreteTableConditional::max(
|
||||
DiscreteConditional::shared_ptr TableDistribution::max(
|
||||
const Ordering& keys) const {
|
||||
auto m = *table_.max(keys);
|
||||
|
||||
return std::make_shared<DiscreteTableConditional>(m.discreteKeys().size(), m);
|
||||
return std::make_shared<TableDistribution>(m.discreteKeys().size(), m);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
void DiscreteTableConditional::setData(
|
||||
void TableDistribution::setData(
|
||||
const DiscreteConditional::shared_ptr& dc) {
|
||||
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
|
||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||
this->table_ = dtc->table_;
|
||||
} else {
|
||||
this->table_ = TableFactor(dc->discreteKeys(), *dc);
|
||||
|
@ -160,11 +160,11 @@ void DiscreteTableConditional::setData(
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
DiscreteConditional::shared_ptr DiscreteTableConditional::prune(
|
||||
DiscreteConditional::shared_ptr TableDistribution::prune(
|
||||
size_t maxNrAssignments) const {
|
||||
TableFactor pruned = table_.prune(maxNrAssignments);
|
||||
|
||||
return std::make_shared<DiscreteTableConditional>(
|
||||
return std::make_shared<TableDistribution>(
|
||||
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteTableConditional.h
|
||||
* @file TableDistribution.h
|
||||
* @date Dec 22, 2024
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
@ -34,7 +34,7 @@ namespace gtsam {
|
|||
*
|
||||
* @ingroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||
private:
|
||||
TableFactor table_;
|
||||
|
||||
|
@ -42,7 +42,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DiscreteTableConditional This; ///< Typedef to this class
|
||||
typedef TableDistribution This; ///< Typedef to this class
|
||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DiscreteConditional
|
||||
BaseConditional; ///< Typedef to our conditional base class
|
||||
|
@ -53,42 +53,42 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
/// @{
|
||||
|
||||
/// Default constructor needed for serialization.
|
||||
DiscreteTableConditional() {}
|
||||
TableDistribution() {}
|
||||
|
||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||
DiscreteTableConditional(size_t nFrontals, const TableFactor& f);
|
||||
TableDistribution(size_t nFrontals, const TableFactor& f);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and SparseVector, taking the first
|
||||
* `nFrontals` keys as frontals, in the order given.
|
||||
*/
|
||||
DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||
TableDistribution(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const Eigen::SparseVector<double>& potentials);
|
||||
|
||||
/** Construct from signature */
|
||||
explicit DiscreteTableConditional(const Signature& signature);
|
||||
explicit TableDistribution(const Signature& signature);
|
||||
|
||||
/**
|
||||
* Construct from key, parents, and a Signature::Table specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
||||
* Example: TableDistribution P(D, {B,E}, table);
|
||||
*/
|
||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const Signature::Table& table)
|
||||
: DiscreteTableConditional(Signature(key, parents, table)) {}
|
||||
: TableDistribution(Signature(key, parents, table)) {}
|
||||
|
||||
/**
|
||||
* Construct from key, parents, and a vector<double> specifying the
|
||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||
*
|
||||
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
||||
* Example: TableDistribution P(D, {B,E}, table);
|
||||
*/
|
||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::vector<double>& table)
|
||||
: DiscreteTableConditional(
|
||||
: TableDistribution(
|
||||
1, TableFactor(DiscreteKeys{key} & parents, table)) {}
|
||||
|
||||
/**
|
||||
|
@ -98,21 +98,21 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
*
|
||||
* The string is parsed into a Signature::Table.
|
||||
*
|
||||
* Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
* Example: TableDistribution P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||
*/
|
||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||
const std::string& spec)
|
||||
: DiscreteTableConditional(Signature(key, parents, spec)) {}
|
||||
: TableDistribution(Signature(key, parents, spec)) {}
|
||||
|
||||
/// No-parent specialization; can also use DiscreteDistribution.
|
||||
DiscreteTableConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteTableConditional(Signature(key, {}, spec)) {}
|
||||
TableDistribution(const DiscreteKey& key, const std::string& spec)
|
||||
: TableDistribution(Signature(key, {}, spec)) {}
|
||||
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
*/
|
||||
DiscreteTableConditional(const TableFactor& joint,
|
||||
TableDistribution(const TableFactor& joint,
|
||||
const TableFactor& marginal);
|
||||
|
||||
/**
|
||||
|
@ -120,7 +120,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||
*/
|
||||
DiscreteTableConditional(const TableFactor& joint,
|
||||
TableDistribution(const TableFactor& joint,
|
||||
const TableFactor& marginal,
|
||||
const Ordering& orderedKeys);
|
||||
|
||||
|
@ -139,8 +139,8 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
* P(A|B) * P(B|A) = ?
|
||||
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||
*/
|
||||
DiscreteTableConditional operator*(
|
||||
const DiscreteTableConditional& other) const;
|
||||
TableDistribution operator*(
|
||||
const TableDistribution& other) const;
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
@ -210,11 +210,11 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
|||
}
|
||||
#endif
|
||||
};
|
||||
// DiscreteTableConditional
|
||||
// TableDistribution
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteTableConditional>
|
||||
: public Testable<DiscreteTableConditional> {};
|
||||
struct traits<TableDistribution>
|
||||
: public Testable<TableDistribution> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -52,9 +52,9 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
// The last discrete conditional may be a DiscreteTableConditional
|
||||
// The last discrete conditional may be a TableDistribution
|
||||
if (auto dtc =
|
||||
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
||||
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||
joint = joint * dc;
|
||||
} else {
|
||||
|
@ -133,7 +133,7 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
|
||||
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
|
||||
// The number of keys should be small so should not
|
||||
// be expensive to convert to DiscreteConditional.
|
||||
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
@ -72,7 +72,7 @@ HybridValues HybridBayesTree::optimize() const {
|
|||
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
auto discrete = std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||
auto discrete = std::dynamic_pointer_cast<TableDistribution>(
|
||||
root_conditional->asDiscrete());
|
||||
discrete_fg.push_back(discrete);
|
||||
mpe = discreteMaxProduct(discrete_fg);
|
||||
|
@ -202,7 +202,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto discreteProbs =
|
||||
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
|
||||
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
|
||||
|
||||
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
||||
discreteProbs->prune(maxNrLeaves);
|
||||
|
|
|
@ -265,7 +265,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
|||
for (auto &&factor : factors) {
|
||||
if (factor) {
|
||||
if (auto dtc =
|
||||
std::dynamic_pointer_cast<DiscreteTableConditional>(factor)) {
|
||||
std::dynamic_pointer_cast<TableDistribution>(factor)) {
|
||||
product = product * dtc->table();
|
||||
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||
product = product * (*f);
|
||||
|
@ -323,7 +323,7 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
|||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(ConvertConditionalToTableFactor);
|
||||
#endif
|
||||
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
|
||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||
/// Get the underlying TableFactor
|
||||
dfg.push_back(dtc->table());
|
||||
} else {
|
||||
|
@ -364,7 +364,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
auto conditional =
|
||||
std::make_shared<DiscreteTableConditional>(frontalKeys.size(), product);
|
||||
std::make_shared<TableDistribution>(frontalKeys.size(), product);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
|
@ -80,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
double midway = mu1 - mu0;
|
||||
auto eliminationResult =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||
auto pMid = eliminationResult->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid));
|
||||
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(TableDistribution(m, "60/40"), *pMid));
|
||||
|
||||
// Everywhere else, the result should be a sigmoid.
|
||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||
|
@ -92,7 +92,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
auto eliminationResult1 =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||
auto posterior1 =
|
||||
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
@ -102,7 +102,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
hfg1.push_back(mixing);
|
||||
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||
auto posterior2 =
|
||||
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
@ -142,8 +142,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
eliminationResultMax->discretePosterior(vv)));
|
||||
|
||||
auto pMax =
|
||||
*eliminationResultMax->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4));
|
||||
*eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(TableDistribution(m, "42/58"), pMax, 1e-4));
|
||||
|
||||
// Everywhere else, the result should be a bell curve like function.
|
||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||
|
@ -154,7 +154,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
auto eliminationResult1 =
|
||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||
auto posterior1 =
|
||||
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
@ -164,7 +164,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
hfg.push_back(mixing);
|
||||
auto eliminationResult2 = hfg.eliminateSequential();
|
||||
auto posterior2 =
|
||||
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
|
||||
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -450,9 +450,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
|
||||
DiscreteConditional joint;
|
||||
for (auto&& conditional : posterior->discreteMarginal()) {
|
||||
// The last discrete conditional may be a DiscreteTableConditional
|
||||
// The last discrete conditional may be a TableDistribution
|
||||
if (auto dtc =
|
||||
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
||||
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||
joint = joint * dc;
|
||||
} else {
|
||||
|
|
|
@ -464,14 +464,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
|||
|
||||
// Create expected discrete conditional on m0.
|
||||
DiscreteKey m(M(0), 2);
|
||||
DiscreteTableConditional expected(m % "0.51341712/1"); // regression
|
||||
TableDistribution expected(m % "0.51341712/1"); // regression
|
||||
|
||||
// Eliminate into BN using one ordering
|
||||
const Ordering ordering1{X(0), X(1), M(0)};
|
||||
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc1 = bn1->back()->asDiscrete<DiscreteTableConditional>();
|
||||
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||
|
||||
// Eliminate into BN using a different ordering
|
||||
|
@ -479,7 +479,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
|||
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc2 = bn2->back()->asDiscrete<DiscreteTableConditional>();
|
||||
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||
}
|
||||
|
||||
|
|
|
@ -650,7 +650,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
|||
mode, std::vector{conditional0, conditional1});
|
||||
|
||||
// Add prior on mode.
|
||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "74/26");
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74/26");
|
||||
|
||||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
|
@ -700,7 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
|||
m1, std::vector{conditional0, conditional1});
|
||||
|
||||
// Add prior on m1.
|
||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(
|
||||
m1, "0.188638/0.811362");
|
||||
|
||||
// Test elimination
|
||||
|
@ -738,8 +738,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
|||
|
||||
// Add prior on mode.
|
||||
// Since this is the only discrete conditional, it is added as a
|
||||
// DiscreteTableConditional.
|
||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "23/77");
|
||||
// TableDistribution.
|
||||
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23/77");
|
||||
|
||||
// Test elimination
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
|
|
|
@ -142,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
|
|||
|
||||
// Test the probability values with regression tests.
|
||||
auto discrete =
|
||||
isam[M(1)]->conditional()->asDiscrete<DiscreteTableConditional>();
|
||||
isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
|
||||
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||
|
@ -222,7 +222,7 @@ TEST(HybridGaussianISAM, ApproxInference) {
|
|||
1 1 1 Leaf 0.5
|
||||
*/
|
||||
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||
incrementalHybrid[M(0)]->conditional()->inner());
|
||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||
|
||||
|
@ -474,7 +474,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
DiscreteFactorGraph discreteGraph;
|
||||
// discreteTree is a DiscreteTableConditional, so we convert to
|
||||
// discreteTree is a TableDistribution, so we convert to
|
||||
// DecisionTreeFactor for the DiscreteFactorGraph
|
||||
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/discrete/TableDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
|
@ -144,9 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
|||
// Since no measurement on x1, we hedge our bets
|
||||
// Importance sampling run with 100k samples gives 50.051/49.949
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "50/50");
|
||||
TableDistribution expected(m1, "50/50");
|
||||
EXPECT(assert_equal(expected,
|
||||
*(bn->at(2)->asDiscrete<DiscreteTableConditional>())));
|
||||
*(bn->at(2)->asDiscrete<TableDistribution>())));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -162,9 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
|||
// Since we have a measurement on x1, we get a definite result
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "44.3854/55.6146");
|
||||
TableDistribution expected(m1, "44.3854/55.6146");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
||||
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -251,9 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "48.3158/51.6842");
|
||||
TableDistribution expected(m1, "48.3158/51.6842");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(eliminated->at(2)->asDiscrete<DiscreteTableConditional>()),
|
||||
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()),
|
||||
0.02));
|
||||
}
|
||||
|
||||
|
@ -268,9 +268,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "55.396/44.604");
|
||||
TableDistribution expected(m1, "55.396/44.604");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
||||
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -346,9 +346,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "51.7762/48.2238");
|
||||
TableDistribution expected(m1, "51.7762/48.2238");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
||||
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -362,9 +362,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "49.0762/50.9238");
|
||||
TableDistribution expected(m1, "49.0762/50.9238");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.05));
|
||||
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.05));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -389,9 +389,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
|
|||
|
||||
// Values taken from an importance sampling run with 100k samples:
|
||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||
DiscreteTableConditional expected(m1, "8.91527/91.0847");
|
||||
TableDistribution expected(m1, "8.91527/91.0847");
|
||||
EXPECT(assert_equal(
|
||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.01));
|
||||
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.01));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -512,7 +512,7 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
|
|||
// P(m1)
|
||||
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
||||
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
||||
DiscreteTableConditional dtc = *hybridBayesNet->at(4)->asDiscrete<DiscreteTableConditional>();
|
||||
TableDistribution dtc = *hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
|
||||
EXPECT(
|
||||
DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
|
||||
.equals(*discreteBayesNet.at(1)));
|
||||
|
|
|
@ -265,7 +265,7 @@ TEST(HybridNonlinearISAM, ApproxInference) {
|
|||
1 1 1 Leaf 0.5
|
||||
*/
|
||||
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
|
||||
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||
bayesTree[M(0)]->conditional()->inner());
|
||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||
|
||||
|
@ -517,7 +517,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree =
|
||||
bayesTree[M(3)]->conditional()->asDiscrete<DiscreteTableConditional>();
|
||||
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
|
|
Loading…
Reference in New Issue