Made run const, take K

release/4.3a0
Frank Dellaert 2025-01-27 00:26:35 -05:00
parent b10ea06626
commit f174a38eed
3 changed files with 79 additions and 87 deletions

View File

@ -76,34 +76,84 @@ std::vector<Solution> Solutions::extractSolutions() {
return result; return result;
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
: solutions_(K) {
std::vector<DiscreteConditional::shared_ptr> conditionals; std::vector<DiscreteConditional::shared_ptr> conditionals;
for (auto& factor : bayesNet) conditionals.push_back(factor); for (auto& factor : bayesNet) conditionals_.push_back(factor);
initialize(conditionals); costToGo_ = computeCostToGo(conditionals_);
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
: solutions_(K) {
std::vector<DiscreteConditional::shared_ptr> conditionals;
std::function<void(const DiscreteBayesTree::sharedClique&)> std::function<void(const DiscreteBayesTree::sharedClique&)>
collectConditionals = [&](const auto& clique) { collectConditionals = [&](const auto& clique) {
if (!clique) return; if (!clique) return;
for (const auto& child : clique->children) collectConditionals(child); for (const auto& child : clique->children) collectConditionals(child);
conditionals.push_back(clique->conditional()); conditionals_.push_back(clique->conditional());
}; };
for (const auto& root : bayesTree.roots()) collectConditionals(root); for (const auto& root : bayesTree.roots()) collectConditionals(root);
initialize(conditionals); costToGo_ = computeCostToGo(conditionals_);
}; };
std::vector<Solution> DiscreteSearch::run() { using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
while (!expansions_.empty()) { SearchNode::Compare>;
numExpansions++;
std::vector<Solution> DiscreteSearch::run(size_t K) const {
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(),
costToGo_.empty() ? 0.0 : costToGo_.back()));
Solutions solutions(K);
auto expandNextNode = [&] {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions.top();
expansions.pop();
// If we already have K solutions, prune if we cannot beat the worst
// one.
if (solutions.prune(current.bound)) {
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions.maybeAdd(current.error, current.assignment);
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound =
childNode.error + costToGo_[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution
if (!solutions.prune(childNode.bound)) {
expansions.emplace(childNode);
}
}
};
#ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0;
#endif
// Perform the search
while (!expansions.empty()) {
expandNextNode(); expandNextNode();
#ifdef DISCRETE_SEARCH_DEBUG
++numExpansions;
#endif
} }
#ifdef DISCRETE_SEARCH_DEBUG
std::cout << "Number of expansions: " << numExpansions << std::endl;
#endif
// Extract solutions from bestSolutions in ascending order of error // Extract solutions from bestSolutions in ascending order of error
return solutions_.extractSolutions(); return solutions.extractSolutions();
} }
std::vector<double> DiscreteSearch::computeCostToGo( std::vector<double> DiscreteSearch::computeCostToGo(
@ -120,35 +170,4 @@ std::vector<double> DiscreteSearch::computeCostToGo(
return costToGo; return costToGo;
} }
void DiscreteSearch::expandNextNode() {
// Pop the partial assignment with the smallest bound
SearchNode current = expansions_.top();
expansions_.pop();
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions_.prune(current.bound)) {
return;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions_.maybeAdd(current.error, current.assignment);
return;
}
// Expand on the next factor
const auto& conditional = conditionals_[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound = childNode.error + costToGo_[childNode.nextConditional];
// Again, prune if we cannot beat the worst solution
if (!solutions_.prune(childNode.bound)) {
expansions_.emplace(childNode);
}
}
}
} // namespace gtsam } // namespace gtsam

View File

@ -130,17 +130,15 @@ class Solutions {
*/ */
class DiscreteSearch { class DiscreteSearch {
public: public:
size_t numExpansions = 0;
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteBayesNet and K.
*/ */
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1); DiscreteSearch(const DiscreteBayesNet& bayesNet);
/** /**
* Construct from a DiscreteBayesTree and K. * Construct from a DiscreteBayesTree and K.
*/ */
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1); DiscreteSearch(const DiscreteBayesTree& bayesTree);
/** /**
* @brief Search for the K best solutions. * @brief Search for the K best solutions.
@ -152,29 +150,17 @@ class DiscreteSearch {
* *
* @return A vector of the K best solutions found during the search. * @return A vector of the K best solutions found during the search.
*/ */
std::vector<Solution> run(); std::vector<Solution> run(size_t K = 1) const;
private: private:
/// Initialize the search with the given conditionals.
void initialize(
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
conditionals_ = conditionals;
costToGo_ = computeCostToGo(conditionals_);
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
/// Compute the cumulative cost-to-go for each conditional slot. /// Compute the cumulative cost-to-go for each conditional slot.
static std::vector<double> computeCostToGo( static std::vector<double> computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals); const std::vector<DiscreteConditional::shared_ptr>& conditionals);
/// Expand the next node in the search tree. /// Expand the next node in the search tree.
void expandNextNode(); void expandNextNode() const;
std::vector<DiscreteConditional::shared_ptr> conditionals_; std::vector<DiscreteConditional::shared_ptr> conditionals_;
std::vector<double> costToGo_; std::vector<double> costToGo_;
std::priority_queue<SearchNode, std::vector<SearchNode>, SearchNode::Compare>
expansions_;
Solutions solutions_;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -35,8 +35,8 @@ using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesNet, EmptyKBest) { TEST(DiscreteBayesNet, EmptyKBest) {
DiscreteBayesNet net; // no factors DiscreteBayesNet net; // no factors
DiscreteSearch search(net, 3); DiscreteSearch search(net);
auto solutions = search.run(); auto solutions = search.run(3);
// Expect one solution with empty assignment, error=0 // Expect one solution with empty assignment, error=0
EXPECT_LONGS_EQUAL(1, solutions.size()); EXPECT_LONGS_EQUAL(1, solutions.size());
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
@ -46,23 +46,17 @@ TEST(DiscreteBayesNet, EmptyKBest) {
TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesNet, AsiaKBest) {
using namespace asia_example; using namespace asia_example;
DiscreteBayesNet asia = createAsiaExample(); DiscreteBayesNet asia = createAsiaExample();
DiscreteSearch search(asia);
// Ask for the MPE // Ask for the MPE
DiscreteSearch search1(asia); auto mpe = search.run();
auto mpe = search1.run();
// print numExpansions
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(1, mpe.size()); EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution // Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
DiscreteSearch search(asia, 4); // Ask for top 4 solutions
auto solutions = search.run(); auto solutions = search.run(4);
// print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(4, solutions.size()); EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution // Regression test: check the first and last solution
@ -74,8 +68,8 @@ TEST(DiscreteBayesNet, AsiaKBest) {
TEST(DiscreteBayesTree, EmptyTree) { TEST(DiscreteBayesTree, EmptyTree) {
DiscreteBayesTree bt; DiscreteBayesTree bt;
DiscreteSearch search(bt, 3); DiscreteSearch search(bt);
auto solutions = search.run(); auto solutions = search.run(3);
// We expect exactly 1 solution with error = 0.0 (the empty assignment). // We expect exactly 1 solution with error = 0.0 (the empty assignment).
assert(solutions.size() == 1 && "There should be exactly one empty solution"); assert(solutions.size() == 1 && "There should be exactly one empty solution");
@ -89,24 +83,17 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) {
DiscreteFactorGraph asia(createAsiaExample()); DiscreteFactorGraph asia(createAsiaExample());
const Ordering ordering{D, X, B, E, L, T, S, A}; const Ordering ordering{D, X, B, E, L, T, S, A};
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
DiscreteSearch search(*bt);
// Ask for top 4 solutions // Ask for MPE
DiscreteSearch search1(*bt); auto mpe = search.run();
auto mpe = search1.run();
// print numExpansions
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(1, mpe.size()); EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution // Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
// Ask for top 4 solutions // Ask for top 4 solutions
DiscreteSearch search(*bt, 4); auto solutions = search.run(4);
auto solutions = search.run();
// print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(4, solutions.size()); EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution // Regression test: check the first and last solution