rename from DiscreteTableConditional to TableDistribution
parent
0098112f27
commit
9b1918c085
|
@ -10,14 +10,14 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscreteTableConditional.cpp
|
* @file TableDistribution.cpp
|
||||||
* @date Dec 22, 2024
|
* @date Dec 22, 2024
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/discrete/Ring.h>
|
#include <gtsam/discrete/Ring.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
@ -38,42 +38,42 @@ using std::vector;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
|
TableDistribution::TableDistribution(const size_t nrFrontals,
|
||||||
const TableFactor& f)
|
const TableFactor& f)
|
||||||
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
||||||
table_(f / (*f.sum(nrFrontals))) {}
|
table_(f / (*f.sum(nrFrontals))) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(
|
TableDistribution::TableDistribution(
|
||||||
size_t nrFrontals, const DiscreteKeys& keys,
|
size_t nrFrontals, const DiscreteKeys& keys,
|
||||||
const Eigen::SparseVector<double>& potentials)
|
const Eigen::SparseVector<double>& potentials)
|
||||||
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
||||||
table_(TableFactor(keys, potentials)) {}
|
table_(TableFactor(keys, potentials)) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
TableDistribution::TableDistribution(const TableFactor& joint,
|
||||||
const TableFactor& marginal)
|
const TableFactor& marginal)
|
||||||
: BaseConditional(joint.size() - marginal.size(),
|
: BaseConditional(joint.size() - marginal.size(),
|
||||||
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
||||||
table_(joint / marginal) {}
|
table_(joint / marginal) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
TableDistribution::TableDistribution(const TableFactor& joint,
|
||||||
const TableFactor& marginal,
|
const TableFactor& marginal,
|
||||||
const Ordering& orderedKeys)
|
const Ordering& orderedKeys)
|
||||||
: DiscreteTableConditional(joint, marginal) {
|
: TableDistribution(joint, marginal) {
|
||||||
keys_.clear();
|
keys_.clear();
|
||||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
|
TableDistribution::TableDistribution(const Signature& signature)
|
||||||
: BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))),
|
: BaseConditional(1, DecisionTreeFactor(DiscreteKeys{{1, 1}}, ADT(1))),
|
||||||
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
|
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional DiscreteTableConditional::operator*(
|
TableDistribution TableDistribution::operator*(
|
||||||
const DiscreteTableConditional& other) const {
|
const TableDistribution& other) const {
|
||||||
// Take union of frontal keys
|
// Take union of frontal keys
|
||||||
std::set<Key> newFrontals;
|
std::set<Key> newFrontals;
|
||||||
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
@ -82,7 +82,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||||
// Check if frontals overlapped
|
// Check if frontals overlapped
|
||||||
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"DiscreteTableConditional::operator* called with overlapping frontal "
|
"TableDistribution::operator* called with overlapping frontal "
|
||||||
"keys.");
|
"keys.");
|
||||||
|
|
||||||
// Now, add cardinalities.
|
// Now, add cardinalities.
|
||||||
|
@ -106,11 +106,11 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||||
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
TableFactor product = this->table_ * other.table();
|
TableFactor product = this->table_ * other.table();
|
||||||
return DiscreteTableConditional(newFrontals.size(), product);
|
return TableDistribution(newFrontals.size(), product);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteTableConditional::print(const string& s,
|
void TableDistribution::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s << " P( ";
|
cout << s << " P( ";
|
||||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
@ -128,9 +128,9 @@ void DiscreteTableConditional::print(const string& s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteTableConditional::equals(const DiscreteFactor& other,
|
bool TableDistribution::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
auto dtc = dynamic_cast<const DiscreteTableConditional*>(&other);
|
auto dtc = dynamic_cast<const TableDistribution*>(&other);
|
||||||
if (!dtc) {
|
if (!dtc) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
|
@ -142,17 +142,17 @@ bool DiscreteTableConditional::equals(const DiscreteFactor& other,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
DiscreteConditional::shared_ptr DiscreteTableConditional::max(
|
DiscreteConditional::shared_ptr TableDistribution::max(
|
||||||
const Ordering& keys) const {
|
const Ordering& keys) const {
|
||||||
auto m = *table_.max(keys);
|
auto m = *table_.max(keys);
|
||||||
|
|
||||||
return std::make_shared<DiscreteTableConditional>(m.discreteKeys().size(), m);
|
return std::make_shared<TableDistribution>(m.discreteKeys().size(), m);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
void DiscreteTableConditional::setData(
|
void TableDistribution::setData(
|
||||||
const DiscreteConditional::shared_ptr& dc) {
|
const DiscreteConditional::shared_ptr& dc) {
|
||||||
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
|
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||||
this->table_ = dtc->table_;
|
this->table_ = dtc->table_;
|
||||||
} else {
|
} else {
|
||||||
this->table_ = TableFactor(dc->discreteKeys(), *dc);
|
this->table_ = TableFactor(dc->discreteKeys(), *dc);
|
||||||
|
@ -160,11 +160,11 @@ void DiscreteTableConditional::setData(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
DiscreteConditional::shared_ptr DiscreteTableConditional::prune(
|
DiscreteConditional::shared_ptr TableDistribution::prune(
|
||||||
size_t maxNrAssignments) const {
|
size_t maxNrAssignments) const {
|
||||||
TableFactor pruned = table_.prune(maxNrAssignments);
|
TableFactor pruned = table_.prune(maxNrAssignments);
|
||||||
|
|
||||||
return std::make_shared<DiscreteTableConditional>(
|
return std::make_shared<TableDistribution>(
|
||||||
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
|
this->nrFrontals(), this->discreteKeys(), pruned.sparseTable());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file DiscreteTableConditional.h
|
* @file TableDistribution.h
|
||||||
* @date Dec 22, 2024
|
* @date Dec 22, 2024
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
@ -34,7 +34,7 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
private:
|
private:
|
||||||
TableFactor table_;
|
TableFactor table_;
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteTableConditional This; ///< Typedef to this class
|
typedef TableDistribution This; ///< Typedef to this class
|
||||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef DiscreteConditional
|
typedef DiscreteConditional
|
||||||
BaseConditional; ///< Typedef to our conditional base class
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
@ -53,42 +53,42 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Default constructor needed for serialization.
|
/// Default constructor needed for serialization.
|
||||||
DiscreteTableConditional() {}
|
TableDistribution() {}
|
||||||
|
|
||||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteTableConditional(size_t nFrontals, const TableFactor& f);
|
TableDistribution(size_t nFrontals, const TableFactor& f);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from DiscreteKeys and SparseVector, taking the first
|
* Construct from DiscreteKeys and SparseVector, taking the first
|
||||||
* `nFrontals` keys as frontals, in the order given.
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys,
|
TableDistribution(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
const Eigen::SparseVector<double>& potentials);
|
const Eigen::SparseVector<double>& potentials);
|
||||||
|
|
||||||
/** Construct from signature */
|
/** Construct from signature */
|
||||||
explicit DiscreteTableConditional(const Signature& signature);
|
explicit TableDistribution(const Signature& signature);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a Signature::Table specifying the
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
*
|
*
|
||||||
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
* Example: TableDistribution P(D, {B,E}, table);
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const Signature::Table& table)
|
const Signature::Table& table)
|
||||||
: DiscreteTableConditional(Signature(key, parents, table)) {}
|
: TableDistribution(Signature(key, parents, table)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a vector<double> specifying the
|
* Construct from key, parents, and a vector<double> specifying the
|
||||||
* conditional probability table (CPT) in 00 01 10 11 order. For
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
*
|
*
|
||||||
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
* Example: TableDistribution P(D, {B,E}, table);
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const std::vector<double>& table)
|
const std::vector<double>& table)
|
||||||
: DiscreteTableConditional(
|
: TableDistribution(
|
||||||
1, TableFactor(DiscreteKeys{key} & parents, table)) {}
|
1, TableFactor(DiscreteKeys{key} & parents, table)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -98,21 +98,21 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
*
|
*
|
||||||
* The string is parsed into a Signature::Table.
|
* The string is parsed into a Signature::Table.
|
||||||
*
|
*
|
||||||
* Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
* Example: TableDistribution P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
TableDistribution(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
const std::string& spec)
|
const std::string& spec)
|
||||||
: DiscreteTableConditional(Signature(key, parents, spec)) {}
|
: TableDistribution(Signature(key, parents, spec)) {}
|
||||||
|
|
||||||
/// No-parent specialization; can also use DiscreteDistribution.
|
/// No-parent specialization; can also use DiscreteDistribution.
|
||||||
DiscreteTableConditional(const DiscreteKey& key, const std::string& spec)
|
TableDistribution(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteTableConditional(Signature(key, {}, spec)) {}
|
: TableDistribution(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(const TableFactor& joint,
|
TableDistribution(const TableFactor& joint,
|
||||||
const TableFactor& marginal);
|
const TableFactor& marginal);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -120,7 +120,7 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional(const TableFactor& joint,
|
TableDistribution(const TableFactor& joint,
|
||||||
const TableFactor& marginal,
|
const TableFactor& marginal,
|
||||||
const Ordering& orderedKeys);
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
|
@ -139,8 +139,8 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
* P(A|B) * P(B|A) = ?
|
* P(A|B) * P(B|A) = ?
|
||||||
* We check for overlapping frontals, but do *not* check for cyclic.
|
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||||
*/
|
*/
|
||||||
DiscreteTableConditional operator*(
|
TableDistribution operator*(
|
||||||
const DiscreteTableConditional& other) const;
|
const TableDistribution& other) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -210,11 +210,11 @@ class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
// DiscreteTableConditional
|
// TableDistribution
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteTableConditional>
|
struct traits<TableDistribution>
|
||||||
: public Testable<DiscreteTableConditional> {};
|
: public Testable<TableDistribution> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -52,9 +52,9 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint;
|
||||||
for (auto &&conditional : marginal) {
|
for (auto &&conditional : marginal) {
|
||||||
// The last discrete conditional may be a DiscreteTableConditional
|
// The last discrete conditional may be a TableDistribution
|
||||||
if (auto dtc =
|
if (auto dtc =
|
||||||
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||||
joint = joint * dc;
|
joint = joint * dc;
|
||||||
} else {
|
} else {
|
||||||
|
@ -133,7 +133,7 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
|
if (auto dtc = conditional->asDiscrete<TableDistribution>()) {
|
||||||
// The number of keys should be small so should not
|
// The number of keys should be small so should not
|
||||||
// be expensive to convert to DiscreteConditional.
|
// be expensive to convert to DiscreteConditional.
|
||||||
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
|
discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(),
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <gtsam/base/treeTraversal-inst.h>
|
#include <gtsam/base/treeTraversal-inst.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
@ -72,7 +72,7 @@ HybridValues HybridBayesTree::optimize() const {
|
||||||
|
|
||||||
// The root should be discrete only, we compute the MPE
|
// The root should be discrete only, we compute the MPE
|
||||||
if (root_conditional->isDiscrete()) {
|
if (root_conditional->isDiscrete()) {
|
||||||
auto discrete = std::dynamic_pointer_cast<DiscreteTableConditional>(
|
auto discrete = std::dynamic_pointer_cast<TableDistribution>(
|
||||||
root_conditional->asDiscrete());
|
root_conditional->asDiscrete());
|
||||||
discrete_fg.push_back(discrete);
|
discrete_fg.push_back(discrete);
|
||||||
mpe = discreteMaxProduct(discrete_fg);
|
mpe = discreteMaxProduct(discrete_fg);
|
||||||
|
@ -202,7 +202,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto discreteProbs =
|
auto discreteProbs =
|
||||||
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
|
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();
|
||||||
|
|
||||||
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
||||||
discreteProbs->prune(maxNrLeaves);
|
discreteProbs->prune(maxNrLeaves);
|
||||||
|
|
|
@ -265,7 +265,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
if (factor) {
|
if (factor) {
|
||||||
if (auto dtc =
|
if (auto dtc =
|
||||||
std::dynamic_pointer_cast<DiscreteTableConditional>(factor)) {
|
std::dynamic_pointer_cast<TableDistribution>(factor)) {
|
||||||
product = product * dtc->table();
|
product = product * dtc->table();
|
||||||
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
|
||||||
product = product * (*f);
|
product = product * (*f);
|
||||||
|
@ -323,7 +323,7 @@ static DiscreteFactorGraph CollectDiscreteFactors(
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(ConvertConditionalToTableFactor);
|
gttic_(ConvertConditionalToTableFactor);
|
||||||
#endif
|
#endif
|
||||||
if (auto dtc = std::dynamic_pointer_cast<DiscreteTableConditional>(dc)) {
|
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(dc)) {
|
||||||
/// Get the underlying TableFactor
|
/// Get the underlying TableFactor
|
||||||
dfg.push_back(dtc->table());
|
dfg.push_back(dtc->table());
|
||||||
} else {
|
} else {
|
||||||
|
@ -364,7 +364,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
#endif
|
#endif
|
||||||
auto conditional =
|
auto conditional =
|
||||||
std::make_shared<DiscreteTableConditional>(frontalKeys.size(), product);
|
std::make_shared<TableDistribution>(frontalKeys.size(), product);
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
gttoc_(EliminateDiscreteFormDiscreteConditional);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
|
@ -80,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
double midway = mu1 - mu0;
|
double midway = mu1 - mu0;
|
||||||
auto eliminationResult =
|
auto eliminationResult =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||||
auto pMid = eliminationResult->at(0)->asDiscrete<DiscreteTableConditional>();
|
auto pMid = eliminationResult->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid));
|
EXPECT(assert_equal(TableDistribution(m, "60/40"), *pMid));
|
||||||
|
|
||||||
// Everywhere else, the result should be a sigmoid.
|
// Everywhere else, the result should be a sigmoid.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
@ -92,7 +92,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 =
|
auto posterior1 =
|
||||||
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
|
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
@ -102,7 +102,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
hfg1.push_back(mixing);
|
hfg1.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg1.eliminateSequential();
|
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||||
auto posterior2 =
|
auto posterior2 =
|
||||||
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
|
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,8 +142,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
eliminationResultMax->discretePosterior(vv)));
|
eliminationResultMax->discretePosterior(vv)));
|
||||||
|
|
||||||
auto pMax =
|
auto pMax =
|
||||||
*eliminationResultMax->at(0)->asDiscrete<DiscreteTableConditional>();
|
*eliminationResultMax->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4));
|
EXPECT(assert_equal(TableDistribution(m, "42/58"), pMax, 1e-4));
|
||||||
|
|
||||||
// Everywhere else, the result should be a bell curve like function.
|
// Everywhere else, the result should be a bell curve like function.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
@ -154,7 +154,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 =
|
auto posterior1 =
|
||||||
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
|
*eliminationResult1->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
@ -164,7 +164,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
hfg.push_back(mixing);
|
hfg.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg.eliminateSequential();
|
auto eliminationResult2 = hfg.eliminateSequential();
|
||||||
auto posterior2 =
|
auto posterior2 =
|
||||||
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
|
*eliminationResult2->at(0)->asDiscrete<TableDistribution>();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -450,9 +450,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
|
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint;
|
||||||
for (auto&& conditional : posterior->discreteMarginal()) {
|
for (auto&& conditional : posterior->discreteMarginal()) {
|
||||||
// The last discrete conditional may be a DiscreteTableConditional
|
// The last discrete conditional may be a TableDistribution
|
||||||
if (auto dtc =
|
if (auto dtc =
|
||||||
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||||
joint = joint * dc;
|
joint = joint * dc;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -464,14 +464,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
||||||
|
|
||||||
// Create expected discrete conditional on m0.
|
// Create expected discrete conditional on m0.
|
||||||
DiscreteKey m(M(0), 2);
|
DiscreteKey m(M(0), 2);
|
||||||
DiscreteTableConditional expected(m % "0.51341712/1"); // regression
|
TableDistribution expected(m % "0.51341712/1"); // regression
|
||||||
|
|
||||||
// Eliminate into BN using one ordering
|
// Eliminate into BN using one ordering
|
||||||
const Ordering ordering1{X(0), X(1), M(0)};
|
const Ordering ordering1{X(0), X(1), M(0)};
|
||||||
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||||
|
|
||||||
// Check that the discrete conditional matches the expected.
|
// Check that the discrete conditional matches the expected.
|
||||||
auto dc1 = bn1->back()->asDiscrete<DiscreteTableConditional>();
|
auto dc1 = bn1->back()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||||
|
|
||||||
// Eliminate into BN using a different ordering
|
// Eliminate into BN using a different ordering
|
||||||
|
@ -479,7 +479,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) {
|
||||||
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||||
|
|
||||||
// Check that the discrete conditional matches the expected.
|
// Check that the discrete conditional matches the expected.
|
||||||
auto dc2 = bn2->back()->asDiscrete<DiscreteTableConditional>();
|
auto dc2 = bn2->back()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -650,7 +650,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
mode, std::vector{conditional0, conditional1});
|
mode, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
// Add prior on mode.
|
// Add prior on mode.
|
||||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "74/26");
|
expectedBayesNet.emplace_shared<TableDistribution>(mode, "74/26");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
@ -700,7 +700,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
||||||
m1, std::vector{conditional0, conditional1});
|
m1, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
// Add prior on m1.
|
// Add prior on m1.
|
||||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(
|
expectedBayesNet.emplace_shared<TableDistribution>(
|
||||||
m1, "0.188638/0.811362");
|
m1, "0.188638/0.811362");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
|
@ -738,8 +738,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
|
|
||||||
// Add prior on mode.
|
// Add prior on mode.
|
||||||
// Since this is the only discrete conditional, it is added as a
|
// Since this is the only discrete conditional, it is added as a
|
||||||
// DiscreteTableConditional.
|
// TableDistribution.
|
||||||
expectedBayesNet.emplace_shared<DiscreteTableConditional>(mode, "23/77");
|
expectedBayesNet.emplace_shared<TableDistribution>(mode, "23/77");
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
|
|
@ -142,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) {
|
||||||
|
|
||||||
// Test the probability values with regression tests.
|
// Test the probability values with regression tests.
|
||||||
auto discrete =
|
auto discrete =
|
||||||
isam[M(1)]->conditional()->asDiscrete<DiscreteTableConditional>();
|
isam[M(1)]->conditional()->asDiscrete<TableDistribution>();
|
||||||
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||||
|
@ -222,7 +222,7 @@ TEST(HybridGaussianISAM, ApproxInference) {
|
||||||
1 1 1 Leaf 0.5
|
1 1 1 Leaf 0.5
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
|
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||||
incrementalHybrid[M(0)]->conditional()->inner());
|
incrementalHybrid[M(0)]->conditional()->inner());
|
||||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||||
|
|
||||||
|
@ -474,7 +474,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
||||||
|
|
||||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||||
DiscreteFactorGraph discreteGraph;
|
DiscreteFactorGraph discreteGraph;
|
||||||
// discreteTree is a DiscreteTableConditional, so we convert to
|
// discreteTree is a TableDistribution, so we convert to
|
||||||
// DecisionTreeFactor for the DiscreteFactorGraph
|
// DecisionTreeFactor for the DiscreteFactorGraph
|
||||||
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
discreteGraph.push_back(discreteTree->toDecisionTreeFactor());
|
||||||
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
DiscreteValues optimal_assignment = discreteGraph.optimize();
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
#include <gtsam/discrete/TableDistribution.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
|
@ -144,9 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
||||||
// Since no measurement on x1, we hedge our bets
|
// Since no measurement on x1, we hedge our bets
|
||||||
// Importance sampling run with 100k samples gives 50.051/49.949
|
// Importance sampling run with 100k samples gives 50.051/49.949
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "50/50");
|
TableDistribution expected(m1, "50/50");
|
||||||
EXPECT(assert_equal(expected,
|
EXPECT(assert_equal(expected,
|
||||||
*(bn->at(2)->asDiscrete<DiscreteTableConditional>())));
|
*(bn->at(2)->asDiscrete<TableDistribution>())));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -162,9 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) {
|
||||||
// Since we have a measurement on x1, we get a definite result
|
// Since we have a measurement on x1, we get a definite result
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "44.3854/55.6146");
|
TableDistribution expected(m1, "44.3854/55.6146");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,9 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "48.3158/51.6842");
|
TableDistribution expected(m1, "48.3158/51.6842");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(eliminated->at(2)->asDiscrete<DiscreteTableConditional>()),
|
expected, *(eliminated->at(2)->asDiscrete<TableDistribution>()),
|
||||||
0.02));
|
0.02));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,9 +268,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "55.396/44.604");
|
TableDistribution expected(m1, "55.396/44.604");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,9 +346,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "51.7762/48.2238");
|
TableDistribution expected(m1, "51.7762/48.2238");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.02));
|
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.02));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -362,9 +362,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "49.0762/50.9238");
|
TableDistribution expected(m1, "49.0762/50.9238");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.05));
|
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.05));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,9 +389,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) {
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteTableConditional expected(m1, "8.91527/91.0847");
|
TableDistribution expected(m1, "8.91527/91.0847");
|
||||||
EXPECT(assert_equal(
|
EXPECT(assert_equal(
|
||||||
expected, *(bn->at(2)->asDiscrete<DiscreteTableConditional>()), 0.01));
|
expected, *(bn->at(2)->asDiscrete<TableDistribution>()), 0.01));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -512,7 +512,7 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) {
|
||||||
// P(m1)
|
// P(m1)
|
||||||
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)});
|
||||||
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents());
|
||||||
DiscreteTableConditional dtc = *hybridBayesNet->at(4)->asDiscrete<DiscreteTableConditional>();
|
TableDistribution dtc = *hybridBayesNet->at(4)->asDiscrete<TableDistribution>();
|
||||||
EXPECT(
|
EXPECT(
|
||||||
DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
|
DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor())
|
||||||
.equals(*discreteBayesNet.at(1)));
|
.equals(*discreteBayesNet.at(1)));
|
||||||
|
|
|
@ -265,7 +265,7 @@ TEST(HybridNonlinearISAM, ApproxInference) {
|
||||||
1 1 1 Leaf 0.5
|
1 1 1 Leaf 0.5
|
||||||
*/
|
*/
|
||||||
|
|
||||||
auto discreteConditional_m0 = *dynamic_pointer_cast<DiscreteTableConditional>(
|
auto discreteConditional_m0 = *dynamic_pointer_cast<TableDistribution>(
|
||||||
bayesTree[M(0)]->conditional()->inner());
|
bayesTree[M(0)]->conditional()->inner());
|
||||||
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)}));
|
||||||
|
|
||||||
|
@ -517,7 +517,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
||||||
// The final discrete graph should not be empty since we have eliminated
|
// The final discrete graph should not be empty since we have eliminated
|
||||||
// all continuous variables.
|
// all continuous variables.
|
||||||
auto discreteTree =
|
auto discreteTree =
|
||||||
bayesTree[M(3)]->conditional()->asDiscrete<DiscreteTableConditional>();
|
bayesTree[M(3)]->conditional()->asDiscrete<TableDistribution>();
|
||||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||||
|
|
||||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||||
|
|
Loading…
Reference in New Issue