Merge pull request #1998 from borglab/fix/refactor_cartesian_product

Fix/refactor cartesian product
release/4.3a0
Frank Dellaert 2025-01-24 15:42:52 -05:00 committed by GitHub
commit 99c4f7f258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 19 deletions

View File

@ -85,25 +85,31 @@ class Assignment : public std::map<L, size_t> {
* variables with each having cardinalities 4, we get 4096 possible * variables with each having cardinalities 4, we get 4096 possible
* configurations!! * configurations!!
*/ */
template <typename Derived = Assignment<L>> template <typename AssignmentType = Assignment<L>>
static std::vector<Derived> CartesianProduct( static std::vector<AssignmentType> CartesianProduct(
const std::vector<std::pair<L, size_t>>& keys) { const std::vector<std::pair<L, size_t>>& keys) {
std::vector<Derived> allPossValues; std::vector<AssignmentType> allPossValues;
Derived values; AssignmentType assignment;
typedef std::pair<L, size_t> DiscreteKey; for (const auto [idx, _] : keys) assignment[idx] = 0; // Initialize from 0
for (const DiscreteKey& key : keys)
values[key.first] = 0; // Initialize from 0 const size_t nrKeys = keys.size();
while (1) { while (true) {
allPossValues.push_back(values); allPossValues.push_back(assignment);
// Increment the assignment. This generalizes incrementing a binary number
size_t j = 0; size_t j = 0;
for (j = 0; j < keys.size(); j++) { for (j = 0; j < nrKeys; j++) {
L idx = keys[j].first; auto [idx, cardinality] = keys[j];
values[idx]++; // Most of the time, we just increment the value for the first key, j=0:
if (values[idx] < keys[j].second) break; assignment[idx]++;
// Wrap condition // But if this key is done, we increment next key.
values[idx] = 0; const bool carry = (assignment[idx] == cardinality);
if (!carry) break;
assignment[idx] = 0; // wrap on carry, and continue to next variable
} }
if (j == keys.size()) break;
// If we propagated carry past the last key, exit:
if (j == nrKeys) break;
} }
return allPossValues; return allPossValues;
} }

View File

@ -338,6 +338,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
// Find the lowest common ancestor of two cliques // Find the lowest common ancestor of two cliques
// TODO(Varun): consider implementing this as a Range Minimum Query
template <class CLIQUE> template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor( static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) { const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
@ -360,7 +361,7 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B // Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B) // Return the Bayes tree P(S\B | S \cap B), where \cap is intersection
template <class CLIQUE> template <class CLIQUE>
static auto factorInto( static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B, const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,

View File

@ -107,7 +107,7 @@ namespace gtsam {
// The shortcut density is a conditional P(S|B) of the separator of this // The shortcut density is a conditional P(S|B) of the separator of this
// clique on the root or common ancestor B. We can compute it recursively from // clique on the root or common ancestor B. We can compute it recursively from
// the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the // the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the
// frontal nodes in p // frontal nodes in the parent p, and Sp the separator of the parent.
/* ************************************************************************* /* *************************************************************************
*/ */
template <class DERIVED, class FACTORGRAPH> template <class DERIVED, class FACTORGRAPH>
@ -141,7 +141,8 @@ namespace gtsam {
/* *********************************************************************** */ /* *********************************************************************** */
// Separator marginal, uses separator marginal of parent recursively // Separator marginal, uses separator marginal of parent recursively
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// if P(Sp) is not cached, it will call separatorMarginal on the parent // if P(Sp) is not cached, it will call separatorMarginal on the parent.
// Here again, Fp and Sp are the frontal nodes and separator in the parent p.
/* *********************************************************************** */ /* *********************************************************************** */
template <class DERIVED, class FACTORGRAPH> template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType