Refactor to slots

release/4.3a0
Frank Dellaert 2025-01-27 14:01:20 -05:00
parent 35e7acbf16
commit d8ed60aead
2 changed files with 77 additions and 45 deletions

View File

@ -20,6 +20,7 @@
namespace gtsam { namespace gtsam {
using Slot = DiscreteSearch::Slot;
using Solution = DiscreteSearch::Solution; using Solution = DiscreteSearch::Solution;
/** /**
@ -59,12 +60,12 @@ struct SearchNode {
/** /**
* @brief Expands the node by assigning the next variable. * @brief Expands the node by assigning the next variable.
* *
* @param conditional The discrete conditional representing the next variable * @param factor The discrete factor associated with the next variable
* to be assigned. * to be assigned.
* @param fa The frontal assignment for the next variable. * @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state. * @return A new SearchNode representing the expanded state.
*/ */
SearchNode expand(const DiscreteConditional& conditional, SearchNode expand(const DiscreteFactor& factor,
const DiscreteValues& fa) const { const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment // Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment; DiscreteValues newAssignment = assignment;
@ -72,7 +73,7 @@ struct SearchNode {
newAssignment[key] = value; newAssignment[key] = value;
} }
return {newAssignment, error + conditional.error(newAssignment), 0.0, return {newAssignment, error + factor.error(newAssignment), 0.0,
nextConditional - 1}; nextConditional - 1};
} }
@ -150,10 +151,20 @@ class Solutions {
} }
}; };
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) {
slots_.reserve(factorGraph.size());
for (auto& factor : factorGraph) {
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
}
computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
std::vector<DiscreteConditional::shared_ptr> conditionals; slots_.reserve(bayesNet.size());
for (auto& factor : bayesNet) conditionals_.push_back(factor); for (auto& conditional : bayesNet) {
costToGo_ = computeCostToGo(conditionals_); slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0);
}
computeHeuristic();
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
@ -161,22 +172,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
collectConditionals = [&](const auto& clique) { collectConditionals = [&](const auto& clique) {
if (!clique) return; if (!clique) return;
for (const auto& child : clique->children) collectConditionals(child); for (const auto& child : clique->children) collectConditionals(child);
conditionals_.push_back(clique->conditional()); auto conditional = clique->conditional();
slots_.emplace_back(conditional, conditional->frontalAssignments(),
0.0);
}; };
slots_.reserve(bayesTree.size());
for (const auto& root : bayesTree.roots()) collectConditionals(root); for (const auto& root : bayesTree.roots()) collectConditionals(root);
costToGo_ = computeCostToGo(conditionals_); computeHeuristic();
} }
struct SearchNodeQueue struct SearchNodeQueue
: public std::priority_queue<SearchNode, std::vector<SearchNode>, : public std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare> { SearchNode::Compare> {
void expandNextNode( void expandNextNode(const SearchNode& current, const Slot& slot,
const std::vector<DiscreteConditional::shared_ptr>& conditionals, Solutions* solutions) {
const std::vector<double>& costToGo, Solutions* solutions) {
// Pop the partial assignment with the smallest bound
SearchNode current = top();
pop();
// If we already have K solutions, prune if we cannot beat the worst one. // If we already have K solutions, prune if we cannot beat the worst one.
if (solutions->prune(current.bound)) { if (solutions->prune(current.bound)) {
return; return;
@ -188,13 +198,11 @@ struct SearchNodeQueue
return; return;
} }
// Expand on the next factor for (auto& fa : slot.assignments) {
const auto& conditional = conditionals[current.nextConditional]; auto childNode = current.expand(*slot.factor, fa);
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0) if (childNode.nextConditional >= 0)
childNode.bound = childNode.error + costToGo[childNode.nextConditional]; // TODO(frank): this might be wrong !
childNode.bound = childNode.error + slot.heuristic;
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (!solutions->prune(childNode.bound)) { if (!solutions->prune(childNode.bound)) {
@ -207,8 +215,7 @@ struct SearchNodeQueue
std::vector<Solution> DiscreteSearch::run(size_t K) const { std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K); Solutions solutions(K);
SearchNodeQueue expansions; SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(), expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic));
costToGo_.empty() ? 0.0 : costToGo_.back()));
#ifdef DISCRETE_SEARCH_DEBUG #ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0; size_t numExpansions = 0;
@ -216,7 +223,13 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
// Perform the search // Perform the search
while (!expansions.empty()) { while (!expansions.empty()) {
expansions.expandNextNode(conditionals_, costToGo_, &solutions); // Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
// Get the next slot to expand
const auto& slot = slots_[current.nextConditional];
expansions.expandNextNode(current, slot, &solutions);
#ifdef DISCRETE_SEARCH_DEBUG #ifdef DISCRETE_SEARCH_DEBUG
++numExpansions; ++numExpansions;
#endif #endif
@ -230,17 +243,19 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
return solutions.extractSolutions(); return solutions.extractSolutions();
} }
std::vector<double> DiscreteSearch::computeCostToGo( // We have a number of factors, each with a max value, and we want to compute
const std::vector<DiscreteConditional::shared_ptr>& conditionals) { // the a lower-bound on the cost-to-go for each slot. For the first slot, this
std::vector<double> costToGo; // -log(max(factor[0])), as we only have one factor to resolve. For the second
// slot, we need to add -log(max(factor[1])) to it, etc...
void DiscreteSearch::computeHeuristic() {
double error = 0.0; double error = 0.0;
for (const auto& conditional : conditionals) { for (size_t i = 0; i < slots_.size(); ++i) {
Ordering ordering(conditional->begin(), conditional->end()); const auto& factor = slots_[i].factor;
auto maxx = conditional->max(ordering); Ordering ordering(factor->begin(), factor->end());
auto maxx = factor->max(ordering);
error -= std::log(maxx->evaluate({})); error -= std::log(maxx->evaluate({}));
costToGo.push_back(error); slots_[i].heuristic = error;
} }
return costToGo;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -28,9 +28,25 @@ namespace gtsam {
*/ */
class GTSAM_EXPORT DiscreteSearch { class GTSAM_EXPORT DiscreteSearch {
public: public:
/** /// We structure the search as a set of slots, each with a factor and
* @brief A solution to a discrete search problem. /// a set of variable assignments that need to be chosen. In addition, each
*/ /// slot has a heuristic associated with it.
struct Slot {
/// The factors in the search problem,
/// e.g., [P(B|A),P(A)]
DiscreteFactor::shared_ptr factor;
/// The assignments for each factor,
/// e.g., [[B0,B1] [A0,A1]]
std::vector<DiscreteValues> assignments;
/// A lower bound on the cost-to-go for each slot, e.g.,
/// [-log(max_B P(B|A)), -log(max_A P(A))]
double heuristic;
};
/// A solution is then a set of assignments, covering all the slots.
/// as well as an associated error = -log(probability)
struct Solution { struct Solution {
double error; double error;
DiscreteValues assignment; DiscreteValues assignment;
@ -42,13 +58,19 @@ class GTSAM_EXPORT DiscreteSearch {
} }
}; };
public:
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteFactorGraph.
*/
DiscreteSearch(const DiscreteFactorGraph& bayesNet);
/**
* Construct from a DiscreteBayesNet.
*/ */
DiscreteSearch(const DiscreteBayesNet& bayesNet); DiscreteSearch(const DiscreteBayesNet& bayesNet);
/** /**
* Construct from a DiscreteBayesTree and K. * Construct from a DiscreteBayesTree.
*/ */
DiscreteSearch(const DiscreteBayesTree& bayesTree); DiscreteSearch(const DiscreteBayesTree& bayesTree);
@ -65,14 +87,9 @@ class GTSAM_EXPORT DiscreteSearch {
std::vector<Solution> run(size_t K = 1) const; std::vector<Solution> run(size_t K = 1) const;
private: private:
/// Compute the cumulative cost-to-go for each conditional slot. /// Compute the cumulative lower-bound cost-to-go for each slot.
static std::vector<double> computeCostToGo( void computeHeuristic();
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
/// Expand the next node in the search tree. std::vector<Slot> slots_;
void expandNextNode() const;
std::vector<DiscreteConditional::shared_ptr> conditionals_;
std::vector<double> costToGo_;
}; };
} // namespace gtsam } // namespace gtsam