Better and more consistent documentation.
parent
8c7e75bb25
commit
1afb089143
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* DiscreteSearch.cpp
|
* DiscreteSearch.cpp
|
||||||
*
|
*
|
||||||
* @date January, 2025
|
* @date January, 2025
|
||||||
|
|
@ -25,22 +25,19 @@ namespace gtsam {
|
||||||
using Slot = DiscreteSearch::Slot;
|
using Slot = DiscreteSearch::Slot;
|
||||||
using Solution = DiscreteSearch::Solution;
|
using Solution = DiscreteSearch::Solution;
|
||||||
|
|
||||||
/**
|
/*
|
||||||
* @brief Represents a node in the search tree for discrete search algorithms.
|
* A SearchNode represents a node in the search tree for the search algorithm.
|
||||||
*
|
* Each SearchNode contains a partial assignment of discrete variables, the
|
||||||
* @details Each SearchNode contains a partial assignment of discrete variables,
|
* current error, a bound on the final error, and the index of the next
|
||||||
* the current error, a bound on the final error, and the index of the next
|
* slot to be assigned.
|
||||||
* conditional to be assigned.
|
|
||||||
*/
|
*/
|
||||||
struct SearchNode {
|
struct SearchNode {
|
||||||
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
DiscreteValues assignment; // Partial assignment of discrete variables.
|
||||||
double error; ///< Current error for the partial assignment.
|
double error; // Current error for the partial assignment.
|
||||||
double bound; ///< Lower bound on the final error
|
double bound; // Lower bound on the final error
|
||||||
std::optional<size_t> next; ///< Index of the next factor to be assigned.
|
std::optional<size_t> next; // Index of the next slot to be assigned.
|
||||||
|
|
||||||
/**
|
// Construct the root node for the search.
|
||||||
* @brief Construct the root node for the search.
|
|
||||||
*/
|
|
||||||
static SearchNode Root(size_t numSlots, double bound) {
|
static SearchNode Root(size_t numSlots, double bound) {
|
||||||
return {DiscreteValues(), 0.0, bound, 0};
|
return {DiscreteValues(), 0.0, bound, 0};
|
||||||
}
|
}
|
||||||
|
|
@ -51,10 +48,10 @@ struct SearchNode {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Checks if the node represents a complete assignment.
|
// Checks if the node represents a complete assignment.
|
||||||
inline bool isComplete() const { return !next; }
|
inline bool isComplete() const { return !next; }
|
||||||
|
|
||||||
/// Expands the node by assigning the next variable(s).
|
// Expands the node by assigning the next variable(s).
|
||||||
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
|
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
|
||||||
std::optional<size_t> nextSlot) const {
|
std::optional<size_t> nextSlot) const {
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
|
|
@ -66,7 +63,7 @@ struct SearchNode {
|
||||||
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
|
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prints the SearchNode to an output stream.
|
// Prints the SearchNode to an output stream.
|
||||||
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
||||||
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
||||||
return os;
|
return os;
|
||||||
|
|
@ -79,17 +76,20 @@ struct CompareSolution {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Define the Solutions class
|
/*
|
||||||
|
* A Solutions object maintains a priority queue of the best solutions found
|
||||||
|
* during the search. The priority queue is limited to a maximum size, and
|
||||||
|
* solutions are only added if they are better than the worst solution.
|
||||||
|
*/
|
||||||
class Solutions {
|
class Solutions {
|
||||||
private:
|
size_t maxSize_; // Maximum number of solutions to keep
|
||||||
size_t maxSize_;
|
|
||||||
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||||
|
|
||||||
/// Add a solution to the priority queue, possibly evicting the worst one.
|
// Add a solution to the priority queue, possibly evicting the worst one.
|
||||||
/// Return true if we added the solution.
|
// Return true if we added the solution.
|
||||||
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
const bool full = pq_.size() == maxSize_;
|
const bool full = pq_.size() == maxSize_;
|
||||||
if (full && error >= pq_.top().error) return false;
|
if (full && error >= pq_.top().error) return false;
|
||||||
|
|
@ -98,7 +98,7 @@ class Solutions {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if we have any solutions
|
// Check if we have any solutions
|
||||||
bool empty() const { return pq_.empty(); }
|
bool empty() const { return pq_.empty(); }
|
||||||
|
|
||||||
// Method to print all solutions
|
// Method to print all solutions
|
||||||
|
|
@ -112,9 +112,9 @@ class Solutions {
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if (partial) solution with given bound can be pruned. If we have
|
// Check if (partial) solution with given bound can be pruned. If we have
|
||||||
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||||
/// than our current worst error.
|
// than our current worst error.
|
||||||
bool prune(double bound) const {
|
bool prune(double bound) const {
|
||||||
if (pq_.size() < maxSize_) return false;
|
if (pq_.size() < maxSize_) return false;
|
||||||
return bound >= pq_.top().error;
|
return bound >= pq_.top().error;
|
||||||
|
|
@ -134,9 +134,9 @@ class Solutions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// @brief Get the factor associated with a node, possibly product of factors.
|
// Get the factor associated with a node, possibly product of factors.
|
||||||
template <typename NodeType>
|
template <typename NodeType>
|
||||||
static auto getFactor(const NodeType& node) {
|
static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
|
||||||
const auto& factors = node->factors;
|
const auto& factors = node->factors;
|
||||||
return factors.size() == 1 ? factors.back()
|
return factors.size() == 1 ? factors.back()
|
||||||
: DiscreteFactorGraph(factors).product();
|
: DiscreteFactorGraph(factors).product();
|
||||||
|
|
@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) {
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||||
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
||||||
auto visitor = [this](const NodePtr& node, int data) {
|
auto visitor = [this](const NodePtr& node, int data) {
|
||||||
const auto factor = getFactor(node);
|
const DiscreteFactor::shared_ptr factor = getFactor(node);
|
||||||
const size_t cardinality = factor->cardinality(node->key);
|
const size_t cardinality = factor->cardinality(node->key);
|
||||||
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
||||||
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
||||||
|
|
@ -266,13 +266,14 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
// Extract solutions from bestSolutions in ascending order of error
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
return solutions.extractSolutions();
|
return solutions.extractSolutions();
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
// We have a number of factors, each with a max value, and we want to compute
|
* We have a number of factors, each with a max value, and we want to compute
|
||||||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
* a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||||
// For the last slot, this is 0.0, as the cost after that is zero.
|
* For the last slot[n-1], this is 0.0, as the cost after that is zero.
|
||||||
// For the second-to-last slot, it is -log(max(factor[0])), because after we
|
* For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
|
||||||
// assign slot[1] we still need to assign slot[0], which will cost *at least*
|
* we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
|
||||||
// h0. We return the estimated lower bound of the cost for *all* slots.
|
* least* h. We return the estimated lower bound of the cost for *all* slots.
|
||||||
|
*/
|
||||||
double DiscreteSearch::computeHeuristic() {
|
double DiscreteSearch::computeHeuristic() {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
|
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* @file DiscreteSearch.h
|
* @file DiscreteSearch.h
|
||||||
* @brief Defines the DiscreteSearch class for discrete search algorithms.
|
* @brief Defines the DiscreteSearch class for discrete search algorithms.
|
||||||
*
|
*
|
||||||
|
|
@ -28,24 +28,40 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DiscreteSearch: Search for the K best solutions.
|
* @brief DiscreteSearch: Search for the K best solutions.
|
||||||
|
*
|
||||||
|
* This class is used to search for the K best solutions in a DiscreteBayesNet.
|
||||||
|
* This is implemented with a modified A* search algorithm that uses a priority
|
||||||
|
* queue to manage the search nodes. That machinery is defined in the .cpp file.
|
||||||
|
* The heuristic we use is the sum of the log-probabilities of the
|
||||||
|
* maximum-probability assignments for each slot, for all slots to the right of
|
||||||
|
* the current slot.
|
||||||
|
*
|
||||||
|
* TODO: The heuristic could be refined by using the partial assignment in
|
||||||
|
* search node to refine the max-probability assignment for the remaining slots.
|
||||||
|
* This would incur more computation but will lead to fewer expansions.
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteSearch {
|
class GTSAM_EXPORT DiscreteSearch {
|
||||||
public:
|
public:
|
||||||
/// We structure the search as a set of slots, each with a factor and
|
/**
|
||||||
/// a set of variable assignments that need to be chosen. In addition, each
|
* We structure the search as a set of slots, each with a factor and
|
||||||
/// slot has a heuristic associated with it.
|
* a set of variable assignments that need to be chosen. In addition, each
|
||||||
|
* slot has a heuristic associated with it.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* The factors in the search problem (always parents before descendents!):
|
||||||
|
* [P(A), P(B|A), P(C|A,B)]
|
||||||
|
* The assignments for each factor.
|
||||||
|
* [[A0,A1], [B0,B1], [C0,C1,C2]]
|
||||||
|
* A lower bound on the cost-to-go after each slot, e.g.,
|
||||||
|
* [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0]
|
||||||
|
* Note that these decrease as we move from right to left.
|
||||||
|
* We keep the global lower bound as lowerBound_. In the example, it is:
|
||||||
|
* -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B))
|
||||||
|
*/
|
||||||
struct Slot {
|
struct Slot {
|
||||||
/// The factors in the search problem,
|
|
||||||
/// e.g., [P(B|A),P(A)]
|
|
||||||
DiscreteFactor::shared_ptr factor;
|
DiscreteFactor::shared_ptr factor;
|
||||||
|
|
||||||
/// The assignments for each factor,
|
|
||||||
/// e.g., [[B0,B1] [A0,A1]]
|
|
||||||
std::vector<DiscreteValues> assignments;
|
std::vector<DiscreteValues> assignments;
|
||||||
|
|
||||||
/// A lower bound on the cost-to-go for each slot, e.g.,
|
|
||||||
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
|
||||||
double heuristic;
|
double heuristic;
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
|
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
|
||||||
|
|
@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A solution is then a set of assignments, covering all the slots.
|
/**
|
||||||
/// as well as an associated error = -log(probability)
|
* A solution is a set of assignments, covering all the slots.
|
||||||
|
* as well as an associated error = -log(probability)
|
||||||
|
*/
|
||||||
struct Solution {
|
struct Solution {
|
||||||
double error;
|
double error;
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
|
|
@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
const Ordering& ordering,
|
const Ordering& ordering,
|
||||||
bool buildJunctionTree = false);
|
bool buildJunctionTree = false);
|
||||||
|
|
||||||
/**
|
/// Construct from a DiscreteEliminationTree.
|
||||||
* @brief Constructor from a DiscreteEliminationTree.
|
|
||||||
*
|
|
||||||
* @param etree The DiscreteEliminationTree to initialize from.
|
|
||||||
*/
|
|
||||||
DiscreteSearch(const DiscreteEliminationTree& etree);
|
DiscreteSearch(const DiscreteEliminationTree& etree);
|
||||||
|
|
||||||
/**
|
/// Construct from a DiscreteJunctionTree.
|
||||||
* @brief Constructor from a DiscreteJunctionTree.
|
|
||||||
*
|
|
||||||
* @param junctionTree The DiscreteJunctionTree to initialize from.
|
|
||||||
*/
|
|
||||||
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
|
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
|
||||||
|
|
||||||
/**
|
//// Construct from a DiscreteBayesNet.
|
||||||
* Construct from a DiscreteBayesNet.
|
|
||||||
*/
|
|
||||||
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
/**
|
/// Construct from a DiscreteBayesTree.
|
||||||
* Construct from a DiscreteBayesTree.
|
|
||||||
*/
|
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
/**
|
||||||
/// @return the estimated lower bound of the cost for *all* slots.
|
* Compute the cumulative lower-bound cost-to-go after each slot is filled.
|
||||||
|
* @return the estimated lower bound of the cost for *all* slots.
|
||||||
|
*/
|
||||||
double computeHeuristic();
|
double computeHeuristic();
|
||||||
|
|
||||||
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
|
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue