Clean up, MPE tests

release/4.3a0
Frank Dellaert 2025-01-27 00:07:22 -05:00
parent 70089a0fd4
commit b10ea06626
3 changed files with 61 additions and 58 deletions

View File

@ -31,8 +31,8 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const { const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment // Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment; DiscreteValues newAssignment = assignment;
for (auto& kv : fa) { for (auto& [key, value] : fa) {
newAssignment[kv.first] = kv.second; newAssignment[key] = value;
} }
return {.assignment = newAssignment, return {.assignment = newAssignment,
@ -50,11 +50,10 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
} }
std::ostream& operator<<(std::ostream& os, const Solutions& sn) { std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
os << "Solutions (top " << sn.pq_.size() << "):\n";
auto pq = sn.pq_; auto pq = sn.pq_;
while (!pq.empty()) { while (!pq.empty()) {
const Solution& best = pq.top(); os << pq.top() << "\n";
os << "Error: " << best.error << ", Values: " << best.assignment
<< std::endl;
pq.pop(); pq.pop();
} }
return os; return os;
@ -62,8 +61,7 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
bool Solutions::prune(double bound) const { bool Solutions::prune(double bound) const {
if (pq_.size() < maxSize_) return false; if (pq_.size() < maxSize_) return false;
double worstError = pq_.top().error; return bound >= pq_.top().error;
return (bound >= worstError);
} }
std::vector<Solution> Solutions::extractSolutions() { std::vector<Solution> Solutions::extractSolutions() {
@ -80,45 +78,23 @@ std::vector<Solution> Solutions::extractSolutions() {
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K) DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
: solutions_(K) { : solutions_(K) {
// Copy out the conditionals std::vector<DiscreteConditional::shared_ptr> conditionals;
for (auto& factor : bayesNet) { for (auto& factor : bayesNet) conditionals.push_back(factor);
conditionals_.push_back(factor); initialize(conditionals);
}
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
} }
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
: solutions_(K) { : solutions_(K) {
using CliquePtr = DiscreteBayesTree::sharedClique; std::vector<DiscreteConditional::shared_ptr> conditionals;
std::function<void(const CliquePtr&)> collectConditionals = std::function<void(const DiscreteBayesTree::sharedClique&)>
[&](const CliquePtr& clique) -> void { collectConditionals = [&](const auto& clique) {
if (!clique) return; if (!clique) return;
for (const auto& child : clique->children) collectConditionals(child);
// Recursive post-order traversal: process children first conditionals.push_back(clique->conditional());
for (const auto& child : clique->children) { };
collectConditionals(child);
}
// Then add the current clique's conditional
conditionals_.push_back(clique->conditional());
};
// Start traversal from each root in the tree
for (const auto& root : bayesTree.roots()) collectConditionals(root); for (const auto& root : bayesTree.roots()) collectConditionals(root);
initialize(conditionals);
// Calculate the cost-to-go for each conditional };
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
std::vector<Solution> DiscreteSearch::run() { std::vector<Solution> DiscreteSearch::run() {
while (!expansions_.empty()) { while (!expansions_.empty()) {
@ -170,7 +146,7 @@ void DiscreteSearch::expandNextNode() {
// Again, prune if we cannot beat the worst solution // Again, prune if we cannot beat the worst solution
if (!solutions_.prune(childNode.bound)) { if (!solutions_.prune(childNode.bound)) {
expansions_.push(childNode); expansions_.emplace(childNode);
} }
} }
} }

View File

@ -19,6 +19,8 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <queue>
namespace gtsam { namespace gtsam {
using Value = size_t; using Value = size_t;
@ -52,7 +54,7 @@ struct SearchNode {
* *
* @return True if all variables have been assigned, false otherwise. * @return True if all variables have been assigned, false otherwise.
*/ */
bool isComplete() const { return nextConditional < 0; } inline bool isComplete() const { return nextConditional < 0; }
/** /**
* @brief Expands the node by assigning the next variable. * @brief Expands the node by assigning the next variable.
@ -133,12 +135,12 @@ class DiscreteSearch {
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteBayesNet and K.
*/ */
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K); DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1);
/** /**
* Construct from a DiscreteBayesTree and K. * Construct from a DiscreteBayesTree and K.
*/ */
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K); DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1);
/** /**
* @brief Search for the K best solutions. * @brief Search for the K best solutions.
@ -153,18 +155,20 @@ class DiscreteSearch {
std::vector<Solution> run(); std::vector<Solution> run();
private: private:
/** /// Initialize the search with the given conditionals.
* @brief Compute the cost-to-go for each conditional. void initialize(
* const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
* @param conditionals The conditionals of the DiscreteBayesNet. conditionals_ = conditionals;
* @return A vector of cost-to-go values. costToGo_ = computeCostToGo(conditionals_);
*/ expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
/// Compute the cumulative cost-to-go for each conditional slot.
static std::vector<double> computeCostToGo( static std::vector<double> computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals); const std::vector<DiscreteConditional::shared_ptr>& conditionals);
/** /// Expand the next node in the search tree.
* @brief Expand the next node in the search tree.
*/
void expandNextNode(); void expandNextNode();
std::vector<DiscreteConditional::shared_ptr> conditionals_; std::vector<DiscreteConditional::shared_ptr> conditionals_;

View File

@ -46,20 +46,32 @@ TEST(DiscreteBayesNet, EmptyKBest) {
TEST(DiscreteBayesNet, AsiaKBest) { TEST(DiscreteBayesNet, AsiaKBest) {
using namespace asia_example; using namespace asia_example;
DiscreteBayesNet asia = createAsiaExample(); DiscreteBayesNet asia = createAsiaExample();
// Ask for the MPE
DiscreteSearch search1(asia);
auto mpe = search1.run();
// print numExpansions
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
DiscreteSearch search(asia, 4); DiscreteSearch search(asia, 4);
auto solutions = search.run(); auto solutions = search.run();
// print numExpansions // print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl; std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT(!solutions.empty()); EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution // Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesTree, testEmptyTree) { TEST(DiscreteBayesTree, EmptyTree) {
DiscreteBayesTree bt; DiscreteBayesTree bt;
DiscreteSearch search(bt, 3); DiscreteSearch search(bt, 3);
@ -72,12 +84,23 @@ TEST(DiscreteBayesTree, testEmptyTree) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesTree, testTrivialOneClique) { TEST(DiscreteBayesTree, AsiaTreeKBest) {
using namespace asia_example; using namespace asia_example;
DiscreteFactorGraph asia(createAsiaExample()); DiscreteFactorGraph asia(createAsiaExample());
const Ordering ordering{D, X, B, E, L, T, S, A}; const Ordering ordering{D, X, B, E, L, T, S, A};
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering); DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
// Ask for top 4 solutions
DiscreteSearch search1(*bt);
auto mpe = search1.run();
// print numExpansions
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
// Ask for top 4 solutions // Ask for top 4 solutions
DiscreteSearch search(*bt, 4); DiscreteSearch search(*bt, 4);
auto solutions = search.run(); auto solutions = search.run();
@ -85,7 +108,7 @@ TEST(DiscreteBayesTree, testTrivialOneClique) {
// print numExpansions // print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl; std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT(!solutions.empty()); EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution // Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);