update DiscreteConditional

release/4.3a0
Varun Agrawal 2024-12-07 19:18:42 -05:00
parent 20d6d09e06
commit 32b6bc0a37
2 changed files with 18 additions and 17 deletions

View File

@ -37,8 +37,7 @@ using std::vector;
namespace gtsam {
// Instantiate base class
template class GTSAM_EXPORT
Conditional<DecisionTreeFactor, DiscreteConditional>;
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional>;
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
@ -54,15 +53,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals,
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal)
: BaseFactor(joint / marginal),
BaseConditional(joint.size() - marginal.size()) {}
DiscreteConditional::DiscreteConditional(
const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal)
: BaseFactor(*std::dynamic_pointer_cast<DecisionTreeFactor>(
joint->operator/(marginal))),
BaseConditional(joint->size() - marginal->size()) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys)
DiscreteConditional::DiscreteConditional(
const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());

View File

@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal);
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
DiscreteConditional(const DiscreteFactor::shared_ptr& joint,
const DiscreteFactor::shared_ptr& marginal,
const Ordering& orderedKeys);
/**
@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}
using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
using DiscreteFactor::error; ///< DiscreteValues version
using DiscreteFactor::operator(); ///< DiscreteValues version
/**
* @brief restrict to given *parent* values.
@ -192,11 +192,11 @@ class GTSAM_EXPORT DiscreteConditional
shared_ptr choose(const DiscreteValues& given) const;
/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
DiscreteFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;
/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
DiscreteFactor::shared_ptr likelihood(size_t frontal) const;
/**
* sample