update DiscreteConditional

release/4.3a0
Varun Agrawal 2024-12-07 19:18:42 -05:00
parent 20d6d09e06
commit 32b6bc0a37
2 changed files with 18 additions and 17 deletions

View File

@ -37,8 +37,7 @@ using std::vector;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
template class GTSAM_EXPORT template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional>;
Conditional<DecisionTreeFactor, DiscreteConditional>;
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals, DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
@ -54,15 +53,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional::DiscreteConditional(
const DecisionTreeFactor& marginal) const DiscreteFactor::shared_ptr& joint,
: BaseFactor(joint / marginal), const DiscreteFactor::shared_ptr& marginal)
BaseConditional(joint.size() - marginal.size()) {} : BaseFactor(*std::dynamic_pointer_cast<DecisionTreeFactor>(
joint->operator/(marginal))),
BaseConditional(joint->size() - marginal->size()) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional::DiscreteConditional(
const DecisionTreeFactor& marginal, const DiscreteFactor::shared_ptr& joint,
const Ordering& orderedKeys) const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) { : DiscreteConditional(joint, marginal) {
keys_.clear(); keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());

View File

@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/ */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DecisionTreeFactor& marginal); const DiscreteFactor::shared_ptr& marginal);
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys. * Makes sure the keys are ordered as given. Does not check orderedKeys.
*/ */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DecisionTreeFactor& marginal, const DiscreteFactor::shared_ptr& marginal,
const Ordering& orderedKeys); const Ordering& orderedKeys);
/** /**
@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
using DecisionTreeFactor::error; ///< DiscreteValues version using DiscreteFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version using DiscreteFactor::operator(); ///< DiscreteValues version
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.
@ -192,11 +192,11 @@ class GTSAM_EXPORT DiscreteConditional
shared_ptr choose(const DiscreteValues& given) const; shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */ /** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood( DiscreteFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const; const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; DiscreteFactor::shared_ptr likelihood(size_t frontal) const;
/** /**
* sample * sample