diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 92086d143..2f900afbe 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -37,8 +37,7 @@ using std::vector; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT - Conditional; +template class GTSAM_EXPORT Conditional; /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, @@ -54,15 +53,17 @@ DiscreteConditional::DiscreteConditional(size_t nrFrontals, : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) - : BaseFactor(joint / marginal), - BaseConditional(joint.size() - marginal.size()) {} +DiscreteConditional::DiscreteConditional( + const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal) + : BaseFactor(*std::dynamic_pointer_cast( + joint->operator/(marginal))), + BaseConditional(joint->size() - marginal->size()) {} /* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, - const Ordering& orderedKeys) +DiscreteConditional::DiscreteConditional( + const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys) : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ce4fb96e5..ec2c5c38d 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -110,16 +110,16 @@ class GTSAM_EXPORT DiscreteConditional * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal); + DiscreteConditional(const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal); /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, + DiscreteConditional(const DiscreteFactor::shared_ptr& joint, + const DiscreteFactor::shared_ptr& marginal, const Ordering& orderedKeys); /** @@ -173,8 +173,8 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using DiscreteFactor::error; ///< DiscreteValues version + using DiscreteFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. @@ -192,11 +192,11 @@ class GTSAM_EXPORT DiscreteConditional shared_ptr choose(const DiscreteValues& given) const; /** Convert to a likelihood factor by providing value before bar. */ - DecisionTreeFactor::shared_ptr likelihood( + DiscreteFactor::shared_ptr likelihood( const DiscreteValues& frontalValues) const; /** Single variable version of likelihood. */ - DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; + DiscreteFactor::shared_ptr likelihood(size_t frontal) const; /** * sample