diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index aa7f1d391..ad9a943e3 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DecisionTreeFactor& f, + const Ordering& orderedKeys) + : BaseFactor(f), BaseConditional(nrFrontals) { + keys_.clear(); + keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); +} + /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, const DiscreteKeys& keys, diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 98edcb8c9..89b8d0d56 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * @brief Construct from DecisionTreeFactor, + * taking the first `nrFrontals` from `orderedKeys`. + * + * @param nrFrontals The number of frontal variables. + * @param f The DecisionTreeFactor to construct from. + * @param orderedKeys Ordered list of keys involved in the conditional. + */ + DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, + const Ordering& orderedKeys); + /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * `nFrontals` keys as frontals, in the order given. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index bf9662e34..a833e1c5e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); + // If no keys, then return empty DecisionTreeFactor + if (dkeys.size() == 0) { + AlgebraicDecisionTree tree; + if (sparse_table_.size() != 0) { + tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); + } + return DecisionTreeFactor(dkeys, tree); + } + std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8c37298a7..b9f327021 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -360,12 +360,16 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // All the discrete variables should form a single clique, // so we can sum out on all the variables as frontals. // This should give an empty separator. - Ordering orderedKeys(product.keys()); - TableFactor::shared_ptr sum = product.sum(orderedKeys); + TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif + // Ordering keys for the conditional so that frontalKeys are really in front + Ordering orderedKeys; + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 2de8d15ec..64336ae8d 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -169,6 +169,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } + /* ************************************************************************* */ int main() { TestResult tr;