Better and more consistent documentation.
parent
8c7e75bb25
commit
1afb089143
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
/**
|
||||
* DiscreteSearch.cpp
|
||||
*
|
||||
* @date January, 2025
|
||||
|
|
@ -25,22 +25,19 @@ namespace gtsam {
|
|||
using Slot = DiscreteSearch::Slot;
|
||||
using Solution = DiscreteSearch::Solution;
|
||||
|
||||
/**
|
||||
* @brief Represents a node in the search tree for discrete search algorithms.
|
||||
*
|
||||
* @details Each SearchNode contains a partial assignment of discrete variables,
|
||||
* the current error, a bound on the final error, and the index of the next
|
||||
* conditional to be assigned.
|
||||
/*
|
||||
* A SearchNode represents a node in the search tree for the search algorithm.
|
||||
* Each SearchNode contains a partial assignment of discrete variables, the
|
||||
* current error, a bound on the final error, and the index of the next
|
||||
* slot to be assigned.
|
||||
*/
|
||||
struct SearchNode {
|
||||
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
||||
double error; ///< Current error for the partial assignment.
|
||||
double bound; ///< Lower bound on the final error
|
||||
std::optional<size_t> next; ///< Index of the next factor to be assigned.
|
||||
DiscreteValues assignment; // Partial assignment of discrete variables.
|
||||
double error; // Current error for the partial assignment.
|
||||
double bound; // Lower bound on the final error
|
||||
std::optional<size_t> next; // Index of the next slot to be assigned.
|
||||
|
||||
/**
|
||||
* @brief Construct the root node for the search.
|
||||
*/
|
||||
// Construct the root node for the search.
|
||||
static SearchNode Root(size_t numSlots, double bound) {
|
||||
return {DiscreteValues(), 0.0, bound, 0};
|
||||
}
|
||||
|
|
@ -51,10 +48,10 @@ struct SearchNode {
|
|||
}
|
||||
};
|
||||
|
||||
/// Checks if the node represents a complete assignment.
|
||||
// Checks if the node represents a complete assignment.
|
||||
inline bool isComplete() const { return !next; }
|
||||
|
||||
/// Expands the node by assigning the next variable(s).
|
||||
// Expands the node by assigning the next variable(s).
|
||||
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
|
||||
std::optional<size_t> nextSlot) const {
|
||||
// Combine the new frontal assignment with the current partial assignment
|
||||
|
|
@ -66,7 +63,7 @@ struct SearchNode {
|
|||
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
|
||||
}
|
||||
|
||||
/// Prints the SearchNode to an output stream.
|
||||
// Prints the SearchNode to an output stream.
|
||||
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
||||
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
||||
return os;
|
||||
|
|
@ -79,17 +76,20 @@ struct CompareSolution {
|
|||
}
|
||||
};
|
||||
|
||||
// Define the Solutions class
|
||||
/*
|
||||
* A Solutions object maintains a priority queue of the best solutions found
|
||||
* during the search. The priority queue is limited to a maximum size, and
|
||||
* solutions are only added if they are better than the worst solution.
|
||||
*/
|
||||
class Solutions {
|
||||
private:
|
||||
size_t maxSize_;
|
||||
size_t maxSize_; // Maximum number of solutions to keep
|
||||
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
||||
|
||||
public:
|
||||
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||
|
||||
/// Add a solution to the priority queue, possibly evicting the worst one.
|
||||
/// Return true if we added the solution.
|
||||
// Add a solution to the priority queue, possibly evicting the worst one.
|
||||
// Return true if we added the solution.
|
||||
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
||||
const bool full = pq_.size() == maxSize_;
|
||||
if (full && error >= pq_.top().error) return false;
|
||||
|
|
@ -98,7 +98,7 @@ class Solutions {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Check if we have any solutions
|
||||
// Check if we have any solutions
|
||||
bool empty() const { return pq_.empty(); }
|
||||
|
||||
// Method to print all solutions
|
||||
|
|
@ -112,9 +112,9 @@ class Solutions {
|
|||
return os;
|
||||
}
|
||||
|
||||
/// Check if (partial) solution with given bound can be pruned. If we have
|
||||
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||
/// than our current worst error.
|
||||
// Check if (partial) solution with given bound can be pruned. If we have
|
||||
// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||
// than our current worst error.
|
||||
bool prune(double bound) const {
|
||||
if (pq_.size() < maxSize_) return false;
|
||||
return bound >= pq_.top().error;
|
||||
|
|
@ -134,9 +134,9 @@ class Solutions {
|
|||
}
|
||||
};
|
||||
|
||||
/// @brief Get the factor associated with a node, possibly product of factors.
|
||||
// Get the factor associated with a node, possibly product of factors.
|
||||
template <typename NodeType>
|
||||
static auto getFactor(const NodeType& node) {
|
||||
static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
|
||||
const auto& factors = node->factors;
|
||||
return factors.size() == 1 ? factors.back()
|
||||
: DiscreteFactorGraph(factors).product();
|
||||
|
|
@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) {
|
|||
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
||||
auto visitor = [this](const NodePtr& node, int data) {
|
||||
const auto factor = getFactor(node);
|
||||
const DiscreteFactor::shared_ptr factor = getFactor(node);
|
||||
const size_t cardinality = factor->cardinality(node->key);
|
||||
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
||||
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
||||
|
|
@ -266,13 +266,14 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
|||
// Extract solutions from bestSolutions in ascending order of error
|
||||
return solutions.extractSolutions();
|
||||
}
|
||||
|
||||
// We have a number of factors, each with a max value, and we want to compute
|
||||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||
// For the last slot, this is 0.0, as the cost after that is zero.
|
||||
// For the second-to-last slot, it is -log(max(factor[0])), because after we
|
||||
// assign slot[1] we still need to assign slot[0], which will cost *at least*
|
||||
// h0. We return the estimated lower bound of the cost for *all* slots.
|
||||
/*
|
||||
* We have a number of factors, each with a max value, and we want to compute
|
||||
* a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||
* For the last slot[n-1], this is 0.0, as the cost after that is zero.
|
||||
* For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
|
||||
* we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
|
||||
* least* h. We return the estimated lower bound of the cost for *all* slots.
|
||||
*/
|
||||
double DiscreteSearch::computeHeuristic() {
|
||||
double error = 0.0;
|
||||
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
/**
|
||||
* @file DiscreteSearch.h
|
||||
* @brief Defines the DiscreteSearch class for discrete search algorithms.
|
||||
*
|
||||
|
|
@ -28,24 +28,40 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* DiscreteSearch: Search for the K best solutions.
|
||||
* @brief DiscreteSearch: Search for the K best solutions.
|
||||
*
|
||||
* This class is used to search for the K best solutions in a DiscreteBayesNet.
|
||||
* This is implemented with a modified A* search algorithm that uses a priority
|
||||
* queue to manage the search nodes. That machinery is defined in the .cpp file.
|
||||
* The heuristic we use is the sum of the log-probabilities of the
|
||||
* maximum-probability assignments for each slot, for all slots to the right of
|
||||
* the current slot.
|
||||
*
|
||||
* TODO: The heuristic could be refined by using the partial assignment in
|
||||
* search node to refine the max-probability assignment for the remaining slots.
|
||||
* This would incur more computation but will lead to fewer expansions.
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteSearch {
|
||||
public:
|
||||
/// We structure the search as a set of slots, each with a factor and
|
||||
/// a set of variable assignments that need to be chosen. In addition, each
|
||||
/// slot has a heuristic associated with it.
|
||||
/**
|
||||
* We structure the search as a set of slots, each with a factor and
|
||||
* a set of variable assignments that need to be chosen. In addition, each
|
||||
* slot has a heuristic associated with it.
|
||||
*
|
||||
* Example:
|
||||
* The factors in the search problem (always parents before descendents!):
|
||||
* [P(A), P(B|A), P(C|A,B)]
|
||||
* The assignments for each factor.
|
||||
* [[A0,A1], [B0,B1], [C0,C1,C2]]
|
||||
* A lower bound on the cost-to-go after each slot, e.g.,
|
||||
* [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0]
|
||||
* Note that these decrease as we move from right to left.
|
||||
* We keep the global lower bound as lowerBound_. In the example, it is:
|
||||
* -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B))
|
||||
*/
|
||||
struct Slot {
|
||||
/// The factors in the search problem,
|
||||
/// e.g., [P(B|A),P(A)]
|
||||
DiscreteFactor::shared_ptr factor;
|
||||
|
||||
/// The assignments for each factor,
|
||||
/// e.g., [[B0,B1] [A0,A1]]
|
||||
std::vector<DiscreteValues> assignments;
|
||||
|
||||
/// A lower bound on the cost-to-go for each slot, e.g.,
|
||||
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
||||
double heuristic;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
|
||||
|
|
@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
}
|
||||
};
|
||||
|
||||
/// A solution is then a set of assignments, covering all the slots.
|
||||
/// as well as an associated error = -log(probability)
|
||||
/**
|
||||
* A solution is a set of assignments, covering all the slots.
|
||||
* as well as an associated error = -log(probability)
|
||||
*/
|
||||
struct Solution {
|
||||
double error;
|
||||
DiscreteValues assignment;
|
||||
|
|
@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
const Ordering& ordering,
|
||||
bool buildJunctionTree = false);
|
||||
|
||||
/**
|
||||
* @brief Constructor from a DiscreteEliminationTree.
|
||||
*
|
||||
* @param etree The DiscreteEliminationTree to initialize from.
|
||||
*/
|
||||
/// Construct from a DiscreteEliminationTree.
|
||||
DiscreteSearch(const DiscreteEliminationTree& etree);
|
||||
|
||||
/**
|
||||
* @brief Constructor from a DiscreteJunctionTree.
|
||||
*
|
||||
* @param junctionTree The DiscreteJunctionTree to initialize from.
|
||||
*/
|
||||
/// Construct from a DiscreteJunctionTree.
|
||||
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
|
||||
|
||||
/**
|
||||
* Construct from a DiscreteBayesNet.
|
||||
*/
|
||||
//// Construct from a DiscreteBayesNet.
|
||||
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
||||
|
||||
/**
|
||||
* Construct from a DiscreteBayesTree.
|
||||
*/
|
||||
/// Construct from a DiscreteBayesTree.
|
||||
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||
|
||||
/// @}
|
||||
|
|
@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch {
|
|||
/// @}
|
||||
|
||||
private:
|
||||
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||
/// @return the estimated lower bound of the cost for *all* slots.
|
||||
/**
|
||||
* Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||
* @return the estimated lower bound of the cost for *all* slots.
|
||||
*/
|
||||
double computeHeuristic();
|
||||
|
||||
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
|
||||
|
|
|
|||
Loading…
Reference in New Issue