Merge branch 'develop' into city10000

release/4.3a0
Varun Agrawal 2025-01-29 12:09:54 -05:00
commit fab06a33a0
17 changed files with 967 additions and 177 deletions

View File

@ -166,7 +166,7 @@ int main(int argc, char* argv[]) {
clock_t after_update = clock();
smoother_update_times.push_back({index, after_update - before_update});
size_t key_s, key_t;
size_t key_s, key_t{0};
clock_t start_time = clock();
std::string str;

View File

@ -214,7 +214,10 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = factors.scaledProduct();
gttic(product);
// `product` is scaled later to prevent underflow.
DiscreteFactor::shared_ptr product = factors.product();
gttoc(product);
// sum out frontals, this is the factor on the separator
gttic(sum);
@ -223,6 +226,16 @@ namespace gtsam {
sum = sum->scale();
gttoc(sum);
// Normalize/scale to prevent underflow.
// We divide both `product` and `sum` by `max(sum)`
// since it is faster to compute and when the conditional
// is formed by `product/sum`, the scaling term cancels out.
gttic(scale);
DiscreteFactor::shared_ptr denominator = sum->max(sum->size());
product = product->operator/(denominator);
sum = sum->operator/(denominator);
gttoc(scale);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),

View File

@ -16,19 +16,35 @@
* @author Richard Roberts
*/
#include <gtsam/inference/JunctionTree-inst.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/inference/JunctionTree-inst.h>
namespace gtsam {
// Instantiate base classes
template class EliminatableClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
// Instantiate base classes
template class EliminatableClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree) :
Base(eliminationTree) {}
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree)
: Base(eliminationTree) {}
/* ************************************************************************* */
void DiscreteJunctionTree::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
auto visitor = [&keyFormatter](
const std::shared_ptr<DiscreteJunctionTree::Cluster>& node,
const std::string& parentString) {
// Print the current node
node->print(parentString + "-", keyFormatter);
node->factors.print(parentString + "-", keyFormatter);
std::cout << std::endl;
return parentString + "| "; // Increment the indentation
};
std::string parentString = s;
treeTraversal::DepthFirstForest(*this, parentString, visitor);
}
} // namespace gtsam

View File

@ -18,54 +18,71 @@
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/JunctionTree.h>
namespace gtsam {
// Forward declarations
class DiscreteEliminationTree;
// Forward declarations
class DiscreteEliminationTree;
/**
* An EliminatableClusterTree, i.e., a set of variable clusters with factors,
* arranged in a tree, with the additional property that it represents the
* clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k
* represents a clique (maximal fully connected subset) of an associated chordal
* graph, such as a chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors,
* whereas a BayesTree stores conditionals, that are the product of eliminating
* the factors in the corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analogous to the
* EliminationTree, except that in the JunctionTree, at each node multiple
* variables are eliminated at a time.
*
* \ingroup Multifrontal
* @ingroup discrete
* \nosubgrouping
*/
class GTSAM_EXPORT DiscreteJunctionTree
: public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>
Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/// @name Constructors
/// @{
/**
* An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree,
* with the additional property that it represents the clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k represents
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
* chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
* BayesTree stores conditionals, that are the product of eliminating the factors in the
* corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analogous to the EliminationTree,
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
*
* \ingroup Multifrontal
* @ingroup discrete
* \nosubgrouping
* Build the elimination tree of a factor graph using precomputed column
* structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
class GTSAM_EXPORT DiscreteJunctionTree :
public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
/**
* Build the elimination tree of a factor graph using precomputed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
/// @}
/// @name Testable
/// @{
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}
/** Print the tree to cout */
void print(const std::string& name = "DiscreteJunctionTree: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/// @}
};
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
} // namespace gtsam

View File

@ -0,0 +1,288 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* DiscreteSearch.cpp
*
* @date January, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h>
namespace gtsam {
using Slot = DiscreteSearch::Slot;
using Solution = DiscreteSearch::Solution;
/*
* A SearchNode represents a node in the search tree for the search algorithm.
* Each SearchNode contains a partial assignment of discrete variables, the
* current error, a bound on the final error, and the index of the next
* slot to be assigned.
*/
struct SearchNode {
DiscreteValues assignment; // Partial assignment of discrete variables.
double error; // Current error for the partial assignment.
double bound; // Lower bound on the final error
std::optional<size_t> next; // Index of the next slot to be assigned.
// Construct the root node for the search.
static SearchNode Root(size_t numSlots, double bound) {
return {DiscreteValues(), 0.0, bound, 0};
}
struct Compare {
bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority
}
};
// Checks if the node represents a complete assignment.
inline bool isComplete() const { return !next; }
// Expands the node by assigning the next variable(s).
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
std::optional<size_t> nextSlot) const {
// Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment;
for (auto& [key, value] : fa) {
newAssignment[key] = value;
}
double errorSoFar = error + slot.factor->error(newAssignment);
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
}
// Prints the SearchNode to an output stream.
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
return os;
}
};
struct CompareSolution {
bool operator()(const Solution& a, const Solution& b) const {
return a.error < b.error;
}
};
/*
* A Solutions object maintains a priority queue of the best solutions found
* during the search. The priority queue is limited to a maximum size, and
* solutions are only added if they are better than the worst solution.
*/
class Solutions {
size_t maxSize_; // Maximum number of solutions to keep
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
// Add a solution to the priority queue, possibly evicting the worst one.
// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
if (full) pq_.pop();
pq_.emplace(error, assignment);
return true;
}
// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
os << "Solutions (top " << sn.pq_.size() << "):\n";
auto pq = sn.pq_;
while (!pq.empty()) {
os << pq.top() << "\n";
pq.pop();
}
return os;
}
// Check if (partial) solution with given bound can be pruned. If we have
// room, we never prune. Otherwise, prune if lower bound on error is worse
// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
return bound >= pq_.top().error;
}
// Method to extract solutions in ascending order of error
std::vector<Solution> extractSolutions() {
std::vector<Solution> result;
while (!pq_.empty()) {
result.push_back(pq_.top());
pq_.pop();
}
std::sort(
result.begin(), result.end(),
[](const Solution& a, const Solution& b) { return a.error < b.error; });
return result;
}
};
// Get the factor associated with a node, possibly product of factors.
template <typename NodeType>
static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
const auto& factors = node->factors;
return factors.size() == 1 ? factors.back()
: DiscreteFactorGraph(factors).product();
}
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
auto visitor = [this](const NodePtr& node, int data) {
const DiscreteFactor::shared_ptr factor = getFactor(node);
const size_t cardinality = factor->cardinality(node->key);
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(etree, data, visitor);
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
auto visitor = [this](const NodePtr& cluster, int data) {
const auto factor = getFactor(cluster);
std::vector<std::pair<Key, size_t>> pairs;
for (Key key : cluster->orderedFrontalKeys) {
pairs.emplace_back(key, factor->cardinality(key));
}
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
lowerBound_ = computeHeuristic();
}
DiscreteSearch DiscreteSearch::FromFactorGraph(
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
if (buildJunctionTree) {
const DiscreteJunctionTree junctionTree(etree);
return DiscreteSearch(junctionTree);
} else {
return DiscreteSearch(etree);
}
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
slots_.reserve(bayesNet.size());
for (auto& conditional : bayesNet) {
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
slots_.emplace_back(std::move(slot));
}
std::reverse(slots_.begin(), slots_.end());
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
using NodePtr = DiscreteBayesTree::sharedClique;
auto visitor = [this](const NodePtr& clique, int data) {
auto conditional = clique->conditional();
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(bayesTree, data, visitor);
lowerBound_ = computeHeuristic();
}
void DiscreteSearch::print(const std::string& name,
const KeyFormatter& formatter) const {
std::cout << name << " with " << slots_.size() << " slots:\n";
for (size_t i = 0; i < slots_.size(); ++i) {
std::cout << i << ": " << slots_[i] << std::endl;
}
}
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare>;
std::vector<Solution> DiscreteSearch::run(size_t K) const {
if (slots_.empty()) {
return {Solution(0.0, DiscreteValues())};
}
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
// Perform the search
while (!expansions.empty()) {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions.prune(current.bound)) {
continue;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions.maybeAdd(current.error, current.assignment);
continue;
}
// Get the next slot to expand
const auto& slot = slots_[*current.next];
std::optional<size_t> nextSlot = *current.next + 1;
if (nextSlot == slots_.size()) nextSlot.reset();
for (auto& fa : slot.assignments) {
auto childNode = current.expand(fa, slot, nextSlot);
// Again, prune if we cannot beat the worst solution
if (!solutions.prune(childNode.bound)) {
expansions.emplace(childNode);
}
}
}
// Extract solutions from bestSolutions in ascending order of error
return solutions.extractSolutions();
}
/*
* We have a number of factors, each with a max value, and we want to compute
* a lower-bound on the cost-to-go for each slot, *not* including this factor.
* For the last slot[n-1], this is 0.0, as the cost after that is zero.
* For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
* we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
* least* h. We return the estimated lower bound of the cost for *all* slots.
*/
double DiscreteSearch::computeHeuristic() {
double error = 0.0;
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
it->heuristic = error;
Ordering ordering(it->factor->begin(), it->factor->end());
auto maxx = it->factor->max(ordering);
error -= std::log(maxx->evaluate({}));
}
return error;
}
} // namespace gtsam

View File

@ -0,0 +1,166 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteSearch.h
* @brief Defines the DiscreteSearch class for discrete search algorithms.
*
* @details This file contains the definition of the DiscreteSearch class, which
* is used in discrete search algorithms to find the K best solutions.
*
* @date January, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <queue>
namespace gtsam {
/**
* @brief DiscreteSearch: Search for the K best solutions.
*
* This class is used to search for the K best solutions in a DiscreteBayesNet.
* This is implemented with a modified A* search algorithm that uses a priority
* queue to manage the search nodes. That machinery is defined in the .cpp file.
* The heuristic we use is the sum of the log-probabilities of the
* maximum-probability assignments for each slot, for all slots to the right of
* the current slot.
*
* TODO: The heuristic could be refined by using the partial assignment in
* search node to refine the max-probability assignment for the remaining slots.
* This would incur more computation but will lead to fewer expansions.
*/
class GTSAM_EXPORT DiscreteSearch {
public:
/**
* We structure the search as a set of slots, each with a factor and
* a set of variable assignments that need to be chosen. In addition, each
* slot has a heuristic associated with it.
*
* Example:
* The factors in the search problem (always parents before descendents!):
* [P(A), P(B|A), P(C|A,B)]
* The assignments for each factor.
* [[A0,A1], [B0,B1], [C0,C1,C2]]
* A lower bound on the cost-to-go after each slot, e.g.,
* [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0]
* Note that these decrease as we move from right to left.
* We keep the global lower bound as lowerBound_. In the example, it is:
* -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B))
*/
struct Slot {
DiscreteFactor::shared_ptr factor;
std::vector<DiscreteValues> assignments;
double heuristic;
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
os << "Slot with " << slot.assignments.size()
<< " assignments, heuristic=" << slot.heuristic;
os << ", factor:\n" << slot.factor->markdown() << std::endl;
return os;
}
};
/**
* A solution is a set of assignments, covering all the slots.
* as well as an associated error = -log(probability)
*/
struct Solution {
double error;
DiscreteValues assignment;
Solution(double err, const DiscreteValues& assign)
: error(err), assignment(assign) {}
friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
return os;
}
};
public:
/// @name Standard Constructors
/// @{
/**
* Construct from a DiscreteFactorGraph.
*
* Internally creates either an elimination tree or a junction tree. The
* latter incurs more up-front computation but the search itself might be
* faster. Then again, for the elimination tree, the heuristic will be more
* fine-grained (more slots).
*
* @param factorGraph The factor graph to search over.
* @param ordering The ordering used to create etree (and maybe jtree).
* @param buildJunctionTree Whether to build a junction tree or not.
*/
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree = false);
/// Construct from a DiscreteEliminationTree.
DiscreteSearch(const DiscreteEliminationTree& etree);
/// Construct from a DiscreteJunctionTree.
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
//// Construct from a DiscreteBayesNet.
DiscreteSearch(const DiscreteBayesNet& bayesNet);
/// Construct from a DiscreteBayesTree.
DiscreteSearch(const DiscreteBayesTree& bayesTree);
/// @}
/// @name Testable
/// @{
/** Print the tree to cout */
void print(const std::string& name = "DiscreteSearch: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/// @}
/// @name Standard API
/// @{
/// Return lower bound on the cost-to-go for the entire search
double lowerBound() const { return lowerBound_; }
/// Read access to the slots
const std::vector<Slot>& slots() const { return slots_; }
/**
* @brief Search for the K best solutions.
*
* This method performs a search to find the K best solutions for the given
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
* expanding nodes with the smallest bound first. The search continues until
* all possible nodes have been expanded or pruned.
*
* @return A vector of the K best solutions found during the search.
*/
std::vector<Solution> run(size_t K = 1) const;
/// @}
private:
/**
* Compute the cumulative lower-bound cost-to-go after each slot is filled.
* @return the estimated lower bound of the cost for *all* slots.
*/
double computeHeuristic();
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
std::vector<Slot> slots_; ///< The slots to fill in the search.
};
using DiscreteSearchSolution = DiscreteSearch::Solution; // for wrapping
} // namespace gtsam

View File

@ -26,12 +26,24 @@ using std::stringstream;
namespace gtsam {
/* ************************************************************************ */
static void stream(std::ostream& os, const DiscreteValues& x,
const KeyFormatter& keyFormatter) {
for (const auto& kv : x)
os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
}
/* ************************************************************************ */
std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) {
stream(os, x, DefaultKeyFormatter);
return os;
}
/* ************************************************************************ */
void DiscreteValues::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << ": ";
for (auto&& kv : *this)
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
stream(cout, *this, keyFormatter);
cout << endl;
}

View File

@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @name Standard Interface
/// @{
/// ostream operator:
friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x);
// insert in base class;
std::pair<iterator, bool> insert( const value_type& value ){
return Base::insert(value);

View File

@ -464,4 +464,29 @@ class DiscreteJunctionTree {
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
#include <gtsam/discrete/DiscreteSearch.h>
class DiscreteSearchSolution {
double error;
gtsam::DiscreteValues assignment;
DiscreteSearchSolution(double error, const gtsam::DiscreteValues& assignment);
};
class DiscreteSearch {
static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& ordering,
bool buildJunctionTree = false);
DiscreteSearch(const gtsam::DiscreteEliminationTree& etree);
DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree);
DiscreteSearch(const gtsam::DiscreteBayesNet& bayesNet);
DiscreteSearch(const gtsam::DiscreteBayesTree& bayesTree);
void print(string name = "DiscreteSearch: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double lowerBound() const;
std::vector<gtsam::DiscreteSearchSolution> run(size_t K = 1) const;
};
} // namespace gtsam

View File

@ -0,0 +1,61 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* AsiaExample.h
*
* @date Jan, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/inference/Symbol.h>
namespace gtsam {
namespace asia_example {
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
S = Symbol('S', 7), A = Symbol('A', 8);
static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
Asia(A, 2);
// Function to construct the Asia priors
DiscreteBayesNet createPriors() {
DiscreteBayesNet priors;
priors.add(Smoking % "50/50");
priors.add(Asia, "99/1");
return priors;
}
// Function to construct the incomplete Asia example
DiscreteBayesNet createFragment() {
DiscreteBayesNet fragment;
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add(Tuberculosis | Asia = "99/1 95/5");
for (const auto& factor : createPriors()) fragment.push_back(factor);
return fragment;
}
// Function to construct the Asia example
DiscreteBayesNet createAsiaExample() {
DiscreteBayesNet asia;
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
asia.add(XRay | Either = "95/5 2/98");
asia.add(Bronchitis | Smoking = "70/30 40/60");
for (const auto& factor : createFragment()) asia.push_back(factor);
return asia;
}
} // namespace asia_example
} // namespace gtsam

View File

@ -23,40 +23,19 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/Symbol.h>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
#include "AsiaExample.h"
using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
using ADT = AlgebraicDecisionTree<Key>;
// Function to construct the Asia example
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;
asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version
asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
return asia;
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
using ADT = AlgebraicDecisionTree<Key>;
DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2);
@ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) {
/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia = constructAsiaExample();
using namespace asia_example;
const DiscreteBayesNet asia = createAsiaExample();
// Convert to factor graph
DiscreteFactorGraph fg(asia);
LONGS_EQUAL(3, fg.back()->size());
LONGS_EQUAL(1, fg.back()->size());
// Check the marginals we know (of the parent-less nodes)
DiscreteMarginals marginals(fg);
@ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) {
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
// Create solver and eliminate
const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
const Ordering ordering{A, D, T, X, S, E, L, B};
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back()));
@ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) {
/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50");
using namespace asia_example;
const DiscreteBayesNet fragment = createFragment();
fragment.add(Tuberculosis | Asia = "99/1 95/5");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot();
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
std::string expected =
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var4683743612465315848[label=\"A8\"];\n"
" var4971973988617027587[label=\"E3\"];\n"
" var5476377146882523141[label=\"L5\"];\n"
" var5980780305148018695[label=\"S7\"];\n"
" var6052837899185946630[label=\"T6\"];\n"
"\n"
" var4683743612465315848->var6052837899185946630\n"
" var5980780305148018695->var5476377146882523141\n"
" var6052837899185946630->var4971973988617027587\n"
" var5476377146882523141->var4971973988617027587\n"
"}";
std::string actual = fragment.dot();
EXPECT(actual.compare(expected) == 0);
}
/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DiscreteBayesNet, markdown) {
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking | Asia = "8/2 7/3");
using namespace asia_example;
DiscreteBayesNet priors = createPriors();
string expected =
std::string expected =
"`DiscreteBayesNet` of size 2\n"
"\n"
" *P(Smoking):*\n\n"
"|Smoking|value|\n"
"|:-:|:-:|\n"
"|0|0.5|\n"
"|1|0.5|\n"
"\n"
" *P(Asia):*\n\n"
"|Asia|value|\n"
"|:-:|:-:|\n"
"|0|0.99|\n"
"|1|0.01|\n"
"\n"
" *P(Smoking|Asia):*\n\n"
"|*Asia*|0|1|\n"
"|:-:|:-:|:-:|\n"
"|0|0.8|0.2|\n"
"|1|0.7|0.3|\n\n";
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
string actual = fragment.markdown(formatter);
"|1|0.01|\n\n";
auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; };
std::string actual = priors.markdown(formatter);
EXPECT(actual == expected);
}

View File

@ -0,0 +1,113 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteSearch.cpp
*
* @date January, 2025
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteSearch.h>
#include "AsiaExample.h"
using namespace gtsam;
// Create Asia Bayes net, FG, and Bayes tree once
namespace asia {
using namespace asia_example;
static const DiscreteBayesNet bayesNet = createAsiaExample();
// Create factor graph and optimize with max-product for MPE
static const DiscreteFactorGraph factorGraph(bayesNet);
static const DiscreteValues mpe = factorGraph.optimize();
// Create ordering
static const Ordering ordering{D, X, B, E, L, T, S, A};
// Create Bayes tree
static const DiscreteBayesTree bayesTree =
*factorGraph.eliminateMultifrontal(ordering);
} // namespace asia
/* ************************************************************************* */
TEST(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors
DiscreteSearch search(net);
auto solutions = search.run(3);
// Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, EmptyTree) {
DiscreteBayesTree bt;
DiscreteSearch search(bt);
auto solutions = search.run(3);
// We expect exactly 1 solution with error = 0.0 (the empty assignment).
EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, AsiaKBest) {
auto fromETree =
DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering);
auto fromJunctionTree =
DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true);
const DiscreteSearch fromBayesNet(asia::bayesNet);
const DiscreteSearch fromBayesTree(asia::bayesTree);
for (auto& search :
{fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) {
// Ask for the MPE
auto mpe = search.run();
// Regression on error lower bound
EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5);
// Check that the cost-to-go heuristic decreases from there
auto slots = search.slots();
double previousHeuristic = search.lowerBound();
for (auto&& slot : slots) {
EXPECT(slot.heuristic <= previousHeuristic);
previousHeuristic = slot.heuristic;
}
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
// Check it is equal to MPE via inference
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
// Ask for top 4 solutions
auto solutions = search.run(4);
EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
}
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -202,6 +202,11 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (!this->roots_.at(0)->conditional()->asDiscrete()) {
// Root of the BayesTree is not a discrete clique, so we do nothing.
return;
}
auto prunedDiscreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<TableDistribution>();

View File

@ -99,7 +99,8 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
auto errorFunc =
[continuousValues](const std::pair<sharedFactor, double>& f) {
auto [factor, val] = f;
return factor->error(continuousValues) + val;
return factor ? factor->error(continuousValues) + val
: std::numeric_limits<double>::infinity();
};
return {factors_, errorFunc};
}

View File

@ -0,0 +1,35 @@
import numpy as np
from gtsam import Symbol
def make_key(character, index, cardinality):
"""
Helper function to mimic the behavior of gtbook.Variables discrete_series function.
"""
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
def generate_transition_cpt(num_states, transitions=None):
"""
Generate a row-wise CPT for a transition matrix.
"""
if transitions is None:
# Default to identity matrix with slight regularization
transitions = np.eye(num_states) + 0.1 / num_states
# Ensure transitions sum to 1 if not already normalized
transitions /= np.sum(transitions, axis=1, keepdims=True)
return " ".join(["/".join(map(str, row)) for row in transitions])
def generate_observation_cpt(num_states, num_obs, desired_state):
"""
Generate a row-wise CPT for observations with contrived probabilities.
"""
obs = np.zeros((num_states, num_obs + 1))
obs[:, -1] = 1 # All states default to measurement num_obs
obs[desired_state, 0:-1] = 1
obs[desired_state, -1] = 0
return " ".join(["/".join(map(str, row)) for row in obs])

View File

@ -15,10 +15,16 @@ import unittest
import numpy as np
from gtsam.utils.test_case import GtsamTestCase
from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt
from gtsam import (DecisionTreeFactor, DiscreteConditional,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
Symbol)
from gtsam import (
DecisionTreeFactor,
DiscreteConditional,
DiscreteFactorGraph,
DiscreteKeys,
DiscreteValues,
Ordering,
)
OrderingType = Ordering.OrderingType
@ -50,7 +56,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
assignment[1] = 1
# Check if graph evaluation works ( 0.3*0.6*4 )
self.assertAlmostEqual(.72, graph(assignment))
self.assertAlmostEqual(0.72, graph(assignment))
# Create a new test with third node and adding unary and ternary factor
graph.add(P3, "0.9 0.2 0.5")
@ -100,8 +106,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
expectedValues[1] = 0
expectedValues[2] = 0
actualValues = graph.optimize()
self.assertEqual(list(actualValues.items()),
list(expectedValues.items()))
self.assertEqual(list(actualValues.items()), list(expectedValues.items()))
def test_MPE(self):
"""Test maximum probable explanation (MPE): same as optimize."""
@ -123,13 +128,11 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Use maxProduct
dag = graph.maxProduct(OrderingType.COLAMD)
actualMPE = dag.argmax()
self.assertEqual(list(actualMPE.items()),
list(mpe.items()))
self.assertEqual(list(actualMPE.items()), list(mpe.items()))
# All in one
actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()),
list(mpe.items()))
self.assertEqual(list(actualMPE2.items()), list(mpe.items()))
def test_sumProduct(self):
"""Test sumProduct."""
@ -154,11 +157,17 @@ class TestDiscreteFactorGraph(GtsamTestCase):
self.assertAlmostEqual(mpeProbability, 0.36) # regression
# Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
OrderingType.CUSTOM]:
for ordering_type in [
OrderingType.COLAMD,
OrderingType.METIS,
OrderingType.NATURAL,
OrderingType.CUSTOM,
]:
bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability)
class TestChains(GtsamTestCase):
def test_MPE_chain(self):
"""
Test for numerical underflow in EliminateMPE on long chains.
@ -170,46 +179,22 @@ class TestDiscreteFactorGraph(GtsamTestCase):
desired_state = 1
states = list(range(num_states))
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
graph = DiscreteFactorGraph()
# Mostly identity transition matrix
transitions = np.eye(num_states)
# Needed otherwise mpe is always state 0?
transitions += 0.1/(num_states)
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
transition_cpt = generate_transition_cpt(num_states)
for i in reversed(range(1, num_obs)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)
# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
# but all other states always give measurement num_obs
obs = np.zeros((num_states, num_obs+1))
obs[:,-1] = 1
obs[desired_state,0: -1] = 1
obs[desired_state,-1] = 0
obs_cpt_list = []
for i in range(0, num_states):
obs_row = "/".join([str(z) for z in obs[i]])
obs_cpt_list.append(obs_row)
obs_cpt = " ".join(obs_cpt_list)
obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
# Contrived example where each measurement is its own index
for i in range(0, num_obs):
for i in range(num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i)
graph.push_back(factor)
@ -217,7 +202,7 @@ class TestDiscreteFactorGraph(GtsamTestCase):
mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state]*num_obs)
self.assertEqual(vals, [desired_state] * num_obs)
def test_sumProduct_chain(self):
"""
@ -227,15 +212,8 @@ class TestDiscreteFactorGraph(GtsamTestCase):
"""
num_states = 3
chain_length = 400
desired_state = 1
states = list(range(num_states))
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
graph = DiscreteFactorGraph()
@ -253,18 +231,15 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel())
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
transition_cpt = generate_transition_cpt(num_states, transitions.T)
for i in reversed(range(1, chain_length)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)
# Run sum product using natural ordering so the resulting Bayes net has the form:
@ -277,5 +252,6 @@ class TestDiscreteFactorGraph(GtsamTestCase):
# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, last_marginal)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,84 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Search.
Author: Frank Dellaert
"""
# pylint: disable=no-name-in-module, invalid-name
import unittest
from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key
from gtsam.utils.test_case import GtsamTestCase
from gtsam import (
DiscreteConditional,
DiscreteFactorGraph,
DiscreteSearch,
Ordering,
DefaultKeyFormatter,
)
OrderingType = Ordering.OrderingType
class TestDiscreteSearch(GtsamTestCase):
"""Tests for Discrete Factor Graphs."""
def test_MPE_chain(self):
"""
Test for numerical underflow in EliminateMPE on long chains.
Adapted from the toy problem of @pcl15423
Ref: https://github.com/borglab/gtsam/issues/1448
"""
num_states = 3
num_obs = 200
desired_state = 1
states = list(range(num_states))
X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
graph = DiscreteFactorGraph()
transition_cpt = generate_transition_cpt(num_states)
for i in reversed(range(1, num_obs)):
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)
# Contrived example such that the desired state gives measurements [0, num_obs) with equal
# probability but all other states always give measurement num_obs
obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
# Contrived example where each measurement is its own index
for i in range(num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i)
graph.push_back(factor)
# Check MPE
mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state] * num_obs)
# Create an ordering:
ordering = Ordering()
for i in reversed(range(num_obs)):
ordering.push_back(X[i][0])
# Now do Search
search = DiscreteSearch.FromFactorGraph(graph, ordering)
solutions = search.run(K=1)
mpe2 = solutions[0].assignment
# print({DefaultKeyFormatter(key): value for key, value in mpe2.items()})
vals = [mpe2[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state] * num_obs)
if __name__ == "__main__":
unittest.main()