Works for Bayes trees now as well

release/4.3a0
Frank Dellaert 2025-01-26 23:24:09 -05:00
parent 54f493358d
commit 14eeaf93db
3 changed files with 84 additions and 54 deletions

View File

@ -36,6 +36,16 @@ struct SearchNode {
double bound; ///< Lower bound on the final error for unassigned variables. double bound; ///< Lower bound on the final error for unassigned variables.
int nextConditional; ///< Index of the next conditional to be assigned. int nextConditional; ///< Index of the next conditional to be assigned.
/**
* @brief Construct the root node for the search.
*/
static SearchNode Root(size_t numConditionals, double bound) {
return {.assignment = DiscreteValues(),
.error = 0.0,
.bound = bound,
.nextConditional = static_cast<int>(numConditionals) - 1};
}
struct CompareByBound { struct CompareByBound {
bool operator()(const SearchNode& a, const SearchNode& b) const { bool operator()(const SearchNode& a, const SearchNode& b) const {
return a.bound > b.bound; // smallest bound -> highest priority return a.bound > b.bound; // smallest bound -> highest priority
@ -49,20 +59,6 @@ struct SearchNode {
*/ */
bool isComplete() const { return nextConditional < 0; } bool isComplete() const { return nextConditional < 0; }
/**
* @brief Computes a lower bound on the final error for unassigned variables.
*
* @details This is a stub implementation that returns 0. Real implementations
* might perform partial factor analysis or use heuristics to compute the
* bound.
*
* @return A lower bound on the final error.
*/
double computeBound() const {
// Real code might do partial factor analysis or heuristics.
return 0.0;
}
/** /**
* @brief Expands the node by assigning the next variable. * @brief Expands the node by assigning the next variable.
* *
@ -74,22 +70,15 @@ struct SearchNode {
SearchNode expand(const DiscreteConditional& conditional, SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const { const DiscreteValues& fa) const {
// Combine the new frontal assignment with the current partial assignment // Combine the new frontal assignment with the current partial assignment
SearchNode child; DiscreteValues newAssignment = assignment;
child.assignment = assignment;
for (auto& kv : fa) { for (auto& kv : fa) {
child.assignment[kv.first] = kv.second; newAssignment[kv.first] = kv.second;
} }
// Compute the incremental error for this factor return {.assignment = newAssignment,
child.error = error + conditional.error(child.assignment); .error = error + conditional.error(newAssignment),
.bound = 0.0,
// Compute new bound .nextConditional = nextConditional - 1};
child.bound = computeBound();
// Update the index of the next conditional
child.nextConditional = nextConditional - 1;
return child;
} }
/** /**
@ -184,6 +173,8 @@ class Solutions {
*/ */
class DiscreteSearch { class DiscreteSearch {
public: public:
size_t numExpansions = 0;
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteBayesNet and K.
*/ */
@ -193,33 +184,42 @@ class DiscreteSearch {
conditionals_.push_back(factor); conditionals_.push_back(factor);
} }
// Calculate the cost-to-go for each conditional. If there are n // Calculate the cost-to-go for each conditional
// conditionals, we start with nextConditional = n-1, and the minimum error costToGo_ = computeCostToGo(conditionals_);
// obtainable is the sum of all the minimum errors. We start with
// 0, and that is the minimum error of the conditional with that index:
double error = 0.0;
for (const auto& conditional : conditionals_) {
Ordering ordering(conditional->begin(), conditional->end());
auto maxx = conditional->max(ordering);
assert(maxx->size() == 1);
error -= std::log(maxx->evaluate({}));
costToGo_.push_back(error);
}
// Create the root node: no variables assigned, nextConditional = last. // Create the root node and push it to the expansions queue
SearchNode root{ expansions_.push(SearchNode::Root(
.assignment = DiscreteValues(), conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
.error = 0.0,
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
if (!costToGo_.empty()) root.bound = costToGo_.back();
expansions_.push(root);
} }
/** /**
* Construct from a DiscreteBayesNet and K. * Construct from a DiscreteBayesTree and K.
*/ */
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) {
: solutions_(K) {} using CliquePtr = DiscreteBayesTree::sharedClique;
std::function<void(const CliquePtr&)> collectConditionals =
[&](const CliquePtr& clique) -> void {
if (!clique) return;
// Recursive post-order traversal: process children first
for (const auto& child : clique->children) {
collectConditionals(child);
}
// Then add the current clique's conditional
conditionals_.push_back(clique->conditional());
};
// Start traversal from each root in the tree
for (const auto& root : bayesTree.roots()) collectConditionals(root);
// Calculate the cost-to-go for each conditional
costToGo_ = computeCostToGo(conditionals_);
// Create the root node and push it to the expansions queue
expansions_.push(SearchNode::Root(
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
}
/** /**
* @brief Search for the K best solutions. * @brief Search for the K best solutions.
@ -233,6 +233,7 @@ class DiscreteSearch {
*/ */
std::vector<Solution> run() { std::vector<Solution> run() {
while (!expansions_.empty()) { while (!expansions_.empty()) {
numExpansions++;
expandNextNode(); expandNextNode();
} }
@ -241,6 +242,29 @@ class DiscreteSearch {
} }
private: private:
/**
* @brief Compute the cost-to-go for each conditional.
*
* @param conditionals The conditionals of the DiscreteBayesNet.
* @return A vector of cost-to-go values.
*/
static std::vector<double> computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
std::vector<double> costToGo;
double error = 0.0;
for (const auto& conditional : conditionals) {
Ordering ordering(conditional->begin(), conditional->end());
auto maxx = conditional->max(ordering);
assert(maxx->size() == 1);
error -= std::log(maxx->evaluate({}));
costToGo.push_back(error);
}
return costToGo;
}
/**
* @brief Expand the next node in the search tree.
*/
void expandNextNode() { void expandNextNode() {
// Pop the partial assignment with the smallest bound // Pop the partial assignment with the smallest bound
SearchNode current = expansions_.top(); SearchNode current = expansions_.top();
@ -273,7 +297,7 @@ class DiscreteSearch {
} }
} }
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_; std::vector<DiscreteConditional::shared_ptr> conditionals_;
std::vector<double> costToGo_; std::vector<double> costToGo_;
std::priority_queue<SearchNode, std::vector<SearchNode>, std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::CompareByBound> SearchNode::CompareByBound>

View File

@ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
Asia(A, 2); Asia(A, 2);
// Function to construct the incomplete Asia example // Function to construct the Asia priors
DiscreteBayesNet createPriors() { DiscreteBayesNet createPriors() {
DiscreteBayesNet priors; DiscreteBayesNet priors;
priors.add(Smoking % "50/50"); priors.add(Smoking % "50/50");

View File

@ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) {
DiscreteSearch search(asia, 4); DiscreteSearch search(asia, 4);
auto solutions = search.run(); auto solutions = search.run();
EXPECT(!solutions.empty()); EXPECT(!solutions.empty());
// Regression test: check the first solution // Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) {
TEST(DiscreteBayesTree, testTrivialOneClique) { TEST(DiscreteBayesTree, testTrivialOneClique) {
using namespace asia_example; using namespace asia_example;
DiscreteFactorGraph asia(createAsiaExample()); DiscreteFactorGraph asia(createAsiaExample());
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(); const Ordering ordering{D, X, B, E, L, T, S, A};
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
GTSAM_PRINT(*bt); GTSAM_PRINT(*bt);
// Ask for top 4 solutions // Ask for top 4 solutions
DiscreteSearch search(*bt, 4); DiscreteSearch search(*bt, 4);
auto solutions = search.run(); auto solutions = search.run();
// print numExpansions
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
EXPECT(!solutions.empty()); EXPECT(!solutions.empty());
// Regression test: check the first solution // Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
} }
/* ************************************************************************* */ /* ************************************************************************* */