Clean up, MPE tests
parent
70089a0fd4
commit
b10ea06626
|
@ -31,8 +31,8 @@ SearchNode SearchNode::expand(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& fa) const {
|
const DiscreteValues& fa) const {
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
DiscreteValues newAssignment = assignment;
|
DiscreteValues newAssignment = assignment;
|
||||||
for (auto& kv : fa) {
|
for (auto& [key, value] : fa) {
|
||||||
newAssignment[kv.first] = kv.second;
|
newAssignment[key] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {.assignment = newAssignment,
|
return {.assignment = newAssignment,
|
||||||
|
@ -50,11 +50,10 @@ bool Solutions::maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
||||||
|
os << "Solutions (top " << sn.pq_.size() << "):\n";
|
||||||
auto pq = sn.pq_;
|
auto pq = sn.pq_;
|
||||||
while (!pq.empty()) {
|
while (!pq.empty()) {
|
||||||
const Solution& best = pq.top();
|
os << pq.top() << "\n";
|
||||||
os << "Error: " << best.error << ", Values: " << best.assignment
|
|
||||||
<< std::endl;
|
|
||||||
pq.pop();
|
pq.pop();
|
||||||
}
|
}
|
||||||
return os;
|
return os;
|
||||||
|
@ -62,8 +61,7 @@ std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
||||||
|
|
||||||
bool Solutions::prune(double bound) const {
|
bool Solutions::prune(double bound) const {
|
||||||
if (pq_.size() < maxSize_) return false;
|
if (pq_.size() < maxSize_) return false;
|
||||||
double worstError = pq_.top().error;
|
return bound >= pq_.top().error;
|
||||||
return (bound >= worstError);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Solution> Solutions::extractSolutions() {
|
std::vector<Solution> Solutions::extractSolutions() {
|
||||||
|
@ -80,45 +78,23 @@ std::vector<Solution> Solutions::extractSolutions() {
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
||||||
: solutions_(K) {
|
: solutions_(K) {
|
||||||
// Copy out the conditionals
|
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
||||||
for (auto& factor : bayesNet) {
|
for (auto& factor : bayesNet) conditionals.push_back(factor);
|
||||||
conditionals_.push_back(factor);
|
initialize(conditionals);
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the cost-to-go for each conditional
|
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
|
||||||
|
|
||||||
// Create the root node and push it to the expansions queue
|
|
||||||
expansions_.push(SearchNode::Root(
|
|
||||||
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
|
||||||
: solutions_(K) {
|
: solutions_(K) {
|
||||||
using CliquePtr = DiscreteBayesTree::sharedClique;
|
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
||||||
std::function<void(const CliquePtr&)> collectConditionals =
|
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
||||||
[&](const CliquePtr& clique) -> void {
|
collectConditionals = [&](const auto& clique) {
|
||||||
if (!clique) return;
|
if (!clique) return;
|
||||||
|
for (const auto& child : clique->children) collectConditionals(child);
|
||||||
// Recursive post-order traversal: process children first
|
conditionals.push_back(clique->conditional());
|
||||||
for (const auto& child : clique->children) {
|
};
|
||||||
collectConditionals(child);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then add the current clique's conditional
|
|
||||||
conditionals_.push_back(clique->conditional());
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start traversal from each root in the tree
|
|
||||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
|
initialize(conditionals);
|
||||||
// Calculate the cost-to-go for each conditional
|
};
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
|
||||||
|
|
||||||
// Create the root node and push it to the expansions queue
|
|
||||||
expansions_.push(SearchNode::Root(
|
|
||||||
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Solution> DiscreteSearch::run() {
|
std::vector<Solution> DiscreteSearch::run() {
|
||||||
while (!expansions_.empty()) {
|
while (!expansions_.empty()) {
|
||||||
|
@ -170,7 +146,7 @@ void DiscreteSearch::expandNextNode() {
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
// Again, prune if we cannot beat the worst solution
|
||||||
if (!solutions_.prune(childNode.bound)) {
|
if (!solutions_.prune(childNode.bound)) {
|
||||||
expansions_.push(childNode);
|
expansions_.emplace(childNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
using Value = size_t;
|
using Value = size_t;
|
||||||
|
@ -52,7 +54,7 @@ struct SearchNode {
|
||||||
*
|
*
|
||||||
* @return True if all variables have been assigned, false otherwise.
|
* @return True if all variables have been assigned, false otherwise.
|
||||||
*/
|
*/
|
||||||
bool isComplete() const { return nextConditional < 0; }
|
inline bool isComplete() const { return nextConditional < 0; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Expands the node by assigning the next variable.
|
* @brief Expands the node by assigning the next variable.
|
||||||
|
@ -133,12 +135,12 @@ class DiscreteSearch {
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteBayesNet and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K);
|
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesTree and K.
|
* Construct from a DiscreteBayesTree and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K);
|
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
|
@ -153,18 +155,20 @@ class DiscreteSearch {
|
||||||
std::vector<Solution> run();
|
std::vector<Solution> run();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/// Initialize the search with the given conditionals.
|
||||||
* @brief Compute the cost-to-go for each conditional.
|
void initialize(
|
||||||
*
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
||||||
* @param conditionals The conditionals of the DiscreteBayesNet.
|
conditionals_ = conditionals;
|
||||||
* @return A vector of cost-to-go values.
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
*/
|
expansions_.push(SearchNode::Root(
|
||||||
|
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the cumulative cost-to-go for each conditional slot.
|
||||||
static std::vector<double> computeCostToGo(
|
static std::vector<double> computeCostToGo(
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
||||||
|
|
||||||
/**
|
/// Expand the next node in the search tree.
|
||||||
* @brief Expand the next node in the search tree.
|
|
||||||
*/
|
|
||||||
void expandNextNode();
|
void expandNextNode();
|
||||||
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
||||||
|
|
|
@ -46,20 +46,32 @@ TEST(DiscreteBayesNet, EmptyKBest) {
|
||||||
TEST(DiscreteBayesNet, AsiaKBest) {
|
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
using namespace asia_example;
|
using namespace asia_example;
|
||||||
DiscreteBayesNet asia = createAsiaExample();
|
DiscreteBayesNet asia = createAsiaExample();
|
||||||
|
|
||||||
|
// Ask for the MPE
|
||||||
|
DiscreteSearch search1(asia);
|
||||||
|
auto mpe = search1.run();
|
||||||
|
|
||||||
|
// print numExpansions
|
||||||
|
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
|
// Regression test: check the MPE solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
DiscreteSearch search(asia, 4);
|
DiscreteSearch search(asia, 4);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run();
|
||||||
|
|
||||||
// print numExpansions
|
// print numExpansions
|
||||||
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
||||||
|
|
||||||
EXPECT(!solutions.empty());
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
// Regression test: check the first and last solution
|
// Regression test: check the first and last solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesTree, testEmptyTree) {
|
TEST(DiscreteBayesTree, EmptyTree) {
|
||||||
DiscreteBayesTree bt;
|
DiscreteBayesTree bt;
|
||||||
|
|
||||||
DiscreteSearch search(bt, 3);
|
DiscreteSearch search(bt, 3);
|
||||||
|
@ -72,12 +84,23 @@ TEST(DiscreteBayesTree, testEmptyTree) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesTree, testTrivialOneClique) {
|
TEST(DiscreteBayesTree, AsiaTreeKBest) {
|
||||||
using namespace asia_example;
|
using namespace asia_example;
|
||||||
DiscreteFactorGraph asia(createAsiaExample());
|
DiscreteFactorGraph asia(createAsiaExample());
|
||||||
const Ordering ordering{D, X, B, E, L, T, S, A};
|
const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||||
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
||||||
|
|
||||||
|
// Ask for top 4 solutions
|
||||||
|
DiscreteSearch search1(*bt);
|
||||||
|
auto mpe = search1.run();
|
||||||
|
|
||||||
|
// print numExpansions
|
||||||
|
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
|
// Regression test: check the MPE solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
// Ask for top 4 solutions
|
// Ask for top 4 solutions
|
||||||
DiscreteSearch search(*bt, 4);
|
DiscreteSearch search(*bt, 4);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run();
|
||||||
|
@ -85,7 +108,7 @@ TEST(DiscreteBayesTree, testTrivialOneClique) {
|
||||||
// print numExpansions
|
// print numExpansions
|
||||||
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
||||||
|
|
||||||
EXPECT(!solutions.empty());
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
// Regression test: check the first and last solution
|
// Regression test: check the first and last solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
|
|
Loading…
Reference in New Issue