rename from DiscreteTableConditional to TableDistribution

release/4.3a0
Varun Agrawal 2025-01-03 14:51:32 -05:00
parent 0098112f27
commit 9b1918c085
14 changed files with 95 additions and 95 deletions

View File

@ -10,14 +10,14 @@
* -------------------------------------------------------------------------- */
/**
* @file DiscreteTableConditional.cpp
* @file TableDistribution.cpp
* @date Dec 22, 2024
* @author Varun Agrawal
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/hybrid/HybridValues.h>
@ -38,42 +38,42 @@ using std::vector;
namespace gtsam {
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
TableDistribution::TableDistribution(const size_t nrFrontals,
const TableFactor& f)
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
table_(f / (*f.sum(nrFrontals))) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(
TableDistribution::TableDistribution(
size_t nrFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials)
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
table_(TableFactor(keys, potentials)) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
TableDistribution::TableDistribution(const TableFactor& joint,
const TableFactor& marginal)
: BaseConditional(joint.size() - marginal.size(),
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
table_(joint / marginal) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
TableDistribution::TableDistribution(const TableFactor& joint,
const TableFactor& marginal,
const Ordering& orderedKeys)
: DiscreteTableConditional(joint, marginal) {
: TableDistribution(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
TableDistribution::TableDistribution(const Signature& signature)
: BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))),
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
/* ************************************************************************** */
DiscreteTableConditional DiscreteTableConditional::operator*(
const DiscreteTableConditional& other) const {
TableDistribution TableDistribution::operator*(
const TableDistribution& other) const {
// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
@ -82,7 +82,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
// Check if frontals overlapped
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
throw std::invalid_argument(
"DiscreteTableConditional::operator* called with overlapping frontal "
"TableDistribution::operator* called with overlapping frontal "
"keys.");
// Now, add cardinalities.
@ -106,11 +106,11 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
for (auto&& dk : parents) discreteKeys.push_back(dk);
TableFactor product = this->table_ * other.table();
return DiscreteTableConditional(newFrontals.size(), product);
return TableDistribution(newFrontals.size(), product);
}
/* ************************************************************************** */
void DiscreteTableConditional::print(const string& s,
void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
@ -128,9 +128,9 @@ void DiscreteTableConditional::print(const string& s,
}
/* ************************************************************************** */
bool DiscreteTableConditional::equals(const DiscreteFactor& other,
bool TableDistribution::equals(const DiscreteFactor& other,
double tol) const {
auto dtc = dynamic_cast<const DiscreteTableConditional*>(&other);
auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) {
return false;
} else {
@ -142,17 +142,17 @@ bool DiscreteTableConditional::equals(const DiscreteFactor& other,
}
/* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::max(
DiscreteConditional::shared_ptr TableDistribution::max(
const Ordering& keys) const {
auto m = *table_.max(keys);
return std::make_shared<DiscreteTableConditional>(m.discreteKeys().size(), m);
return std::make_shared<TableDistribution>(m.discreteKeys().size(), m);
}
/* ****************************************************************************/
void DiscreteTableConditional::setData(
void TableDistribution::setData(
const DiscreteConditional::shared_ptr& dc) {
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
this->table_ = dtc->table_;
} else {
this->table_ = TableFactor(dc->discreteKeys(), *dc);
@ -160,11 +160,11 @@ void DiscreteTableConditional::setData(
}
/* ****************************************************************************/
DiscreteConditional::shared_ptr DiscreteTableConditional::prune(
DiscreteConditional::shared_ptr TableDistribution::prune(
size_t maxNrAssignments) const {
TableFactor pruned = table_.prune(maxNrAssignments);
return std::make_shared<DiscreteTableConditional>(
return std::make_shared<TableDistribution>(
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
}

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */
/**
* @file DiscreteTableConditional.h
* @file TableDistribution.h
* @date Dec 22, 2024
* @author Varun Agrawal
*/
@ -34,7 +34,7 @@ namespace gtsam {
*
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
private:
TableFactor table_;
@ -42,7 +42,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteTableConditional This; ///< Typedef to this class
typedef TableDistribution This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DiscreteConditional
BaseConditional; ///< Typedef to our conditional base class
@ -53,42 +53,42 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
/// @{
/// Default constructor needed for serialization.
DiscreteTableConditional() {}
TableDistribution() {}
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteTableConditional(size_t nFrontals, const TableFactor& f);
TableDistribution(size_t nFrontals, const TableFactor& f);
/**
* Construct from DiscreteKeys and SparseVector, taking the first
* `nFrontals` keys as frontals, in the order given.
*/
DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys,
TableDistribution(size_t nFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials);
/** Construct from signature */
explicit DiscreteTableConditional(const Signature& signature);
explicit TableDistribution(const Signature& signature);
/**
* Construct from key, parents, and a Signature::Table specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* Example: DiscreteTableConditional P(D, {B,E}, table);
* Example: TableDistribution P(D, {B,E}, table);
*/
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table)
: DiscreteTableConditional(Signature(key, parents, table)) {}
: TableDistribution(Signature(key, parents, table)) {}
/**
* Construct from key, parents, and a vector<double> specifying the
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* Example: DiscreteTableConditional P(D, {B,E}, table);
* Example: TableDistribution P(D, {B,E}, table);
*/
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const std::vector<double>& table)
: DiscreteTableConditional(
: TableDistribution(
1, TableFactor(DiscreteKeys{key} & parents, table)) {}
/**
@ -98,21 +98,21 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
*
* The string is parsed into a Signature::Table.
*
* Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
* Example: TableDistribution P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteTableConditional(Signature(key, parents, spec)) {}
: TableDistribution(Signature(key, parents, spec)) {}
/// No-parent specialization; can also use DiscreteDistribution.
DiscreteTableConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteTableConditional(Signature(key, {}, spec)) {}
TableDistribution(const DiscreteKey& key, const std::string& spec)
: TableDistribution(Signature(key, {}, spec)) {}
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteTableConditional(const TableFactor& joint,
TableDistribution(const TableFactor& joint,
const TableFactor& marginal);
/**
@ -120,7 +120,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
DiscreteTableConditional(const TableFactor& joint,
TableDistribution(const TableFactor& joint,
const TableFactor& marginal,
const Ordering& orderedKeys);
@ -139,8 +139,8 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
* P(A|B) * P(B|A) = ?
* We check for overlapping frontals, but do *not* check for cyclic.
*/
DiscreteTableConditional operator*(
const DiscreteTableConditional& other) const;
TableDistribution operator*(
const TableDistribution& other) const;
/// @}
/// @name Testable
@ -210,11 +210,11 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
}
#endif
};
// DiscreteTableConditional
// TableDistribution
// traits
template <>
struct traits<DiscreteTableConditional>
: public Testable<DiscreteTableConditional> {};
struct traits<TableDistribution>
: public Testable<TableDistribution> {};
} // namespace gtsam

View File

@ -52,9 +52,9 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
// The last discrete conditional may be a DiscreteTableConditional
// The last discrete conditional may be a TableDistribution
if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc;
} else {
@ -133,7 +133,7 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),

View File

@ -20,7 +20,7 @@
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
@ -72,7 +72,7 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
auto discrete = std::dynamic_pointer_cast<DiscreteTableConditional>(
auto discrete = std::dynamic_pointer_cast<TableDistribution>(
root_conditional->asDiscrete());
discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg);
@ -202,7 +202,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves);

View File

@ -265,7 +265,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
for (auto &&factor : factors) {
if (factor) {
if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(factor)) {
std::dynamic_pointer_cast<TableDistribution>(factor)) {
product = product * dtc->table();
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
@ -323,7 +323,7 @@ static DiscreteFactorGraph CollectDiscreteFactors(
#if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
/// Get the underlying TableFactor
dfg.push_back(dtc->table());
} else {
@ -364,7 +364,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteTableConditional>(frontalKeys.size(), product);
std::make_shared<TableDistribution>(frontalKeys.size(), product);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -80,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double midway = mu1 - mu0;
auto eliminationResult =
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
auto pMid = eliminationResult->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid));
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(TableDistribution(m, "60/40"), *pMid));
// Everywhere else, the result should be a sigmoid.
for (const double shift : {-4, -2, 0, 2, 4}) {
@ -92,7 +92,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve
@ -102,7 +102,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
hfg1.push_back(mixing);
auto eliminationResult2 = hfg1.eliminateSequential();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
@ -142,8 +142,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
eliminationResultMax->discretePosterior(vv)));
auto pMax =
*eliminationResultMax->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4));
*eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
EXPECT(assert_equal(TableDistribution(m, "42/58"), pMax, 1e-4));
// Everywhere else, the result should be a bell curve like function.
for (const double shift : {-4, -2, 0, 2, 4}) {
@ -154,7 +154,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve
@ -164,7 +164,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
hfg.push_back(mixing);
auto eliminationResult2 = hfg.eliminateSequential();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}

View File

@ -450,9 +450,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
DiscreteConditional joint;
for (auto&& conditional : posterior->discreteMarginal()) {
// The last discrete conditional may be a DiscreteTableConditional
// The last discrete conditional may be a TableDistribution
if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
joint = joint * dc;
} else {

View File

@ -464,14 +464,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
// Create expected discrete conditional on m0.
DiscreteKey m(M(0), 2);
DiscreteTableConditional expected(m % "0.51341712/1"); // regression
TableDistribution expected(m % "0.51341712/1"); // regression
// Eliminate into BN using one ordering
const Ordering ordering1{X(0), X(1), M(0)};
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete<DiscreteTableConditional>();
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering
@ -479,7 +479,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete<DiscreteTableConditional>();
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(expected, *dc2, 1e-9));
}

View File

@ -650,7 +650,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
mode, std::vector{conditional0, conditional1});
// Add prior on mode.
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "74/26");
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74/26");
// Test elimination
const auto posterior = fg.eliminateSequential();
@ -700,7 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
m1, std::vector{conditional0, conditional1});
// Add prior on m1.
expectedBayesNet.emplace_shared<DiscreteTableConditional>(
expectedBayesNet.emplace_shared<TableDistribution>(
m1, "0.188638/0.811362");
// Test elimination
@ -738,8 +738,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Add prior on mode.
// Since this is the only discrete conditional, it is added as a
// DiscreteTableConditional.
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "23/77");
// TableDistribution.
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23/77");
// Test elimination
const auto posterior = fg.eliminateSequential();

View File

@ -142,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
// Test the probability values with regression tests.
auto discrete =
isam[M(1)]->conditional()->asDiscrete<DiscreteTableConditional>();
isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
@ -222,7 +222,7 @@ TEST(HybridGaussianISAM, ApproxInference) {
1 1 1 Leaf 0.5
*/
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
incrementalHybrid[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
@ -474,7 +474,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// Test if the optimal discrete mode assignment is (1, 1, 1).
DiscreteFactorGraph discreteGraph;
// discreteTree is a DiscreteTableConditional, so we convert to
// discreteTree is a TableDistribution, so we convert to
// DecisionTreeFactor for the DiscreteFactorGraph
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
DiscreteValues optimal_assignment = discreteGraph.optimize();

View File

@ -21,7 +21,7 @@
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
@ -144,9 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since no measurement on x1, we hedge our bets
// Importance sampling run with 100k samples gives 50.051/49.949
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "50/50");
TableDistribution expected(m1, "50/50");
EXPECT(assert_equal(expected,
*(bn->at(2)->asDiscrete<DiscreteTableConditional>())));
*(bn->at(2)->asDiscrete<TableDistribution>())));
}
{
@ -162,9 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
// Since we have a measurement on x1, we get a definite result
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "44.3854/55.6146");
TableDistribution expected(m1, "44.3854/55.6146");
EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
}
}
@ -251,9 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "48.3158/51.6842");
TableDistribution expected(m1, "48.3158/51.6842");
EXPECT(assert_equal(
expected, *(eliminated->at(2)->asDiscrete<DiscreteTableConditional>()),
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()),
0.02));
}
@ -268,9 +268,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "55.396/44.604");
TableDistribution expected(m1, "55.396/44.604");
EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
}
}
@ -346,9 +346,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "51.7762/48.2238");
TableDistribution expected(m1, "51.7762/48.2238");
EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
}
{
@ -362,9 +362,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "49.0762/50.9238");
TableDistribution expected(m1, "49.0762/50.9238");
EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.05));
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.05));
}
}
@ -389,9 +389,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
// Values taken from an importance sampling run with 100k samples:
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
DiscreteTableConditional expected(m1, "8.91527/91.0847");
TableDistribution expected(m1, "8.91527/91.0847");
EXPECT(assert_equal(
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.01));
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.01));
}
/* ************************************************************************* */

View File

@ -512,7 +512,7 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
// P(m1)
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
DiscreteTableConditional dtc = *hybridBayesNet->at(4)->asDiscrete<DiscreteTableConditional>();
TableDistribution dtc = *hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
EXPECT(
DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
.equals(*discreteBayesNet.at(1)));

View File

@ -265,7 +265,7 @@ TEST(HybridNonlinearISAM, ApproxInference) {
1 1 1 Leaf 0.5
*/
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
bayesTree[M(0)]->conditional()->inner());
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
@ -517,7 +517,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree =
bayesTree[M(3)]->conditional()->asDiscrete<DiscreteTableConditional>();
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).