Works for Bayes trees now as well
parent
54f493358d
commit
14eeaf93db
|
@ -36,6 +36,16 @@ struct SearchNode {
|
||||||
double bound; ///< Lower bound on the final error for unassigned variables.
|
double bound; ///< Lower bound on the final error for unassigned variables.
|
||||||
int nextConditional; ///< Index of the next conditional to be assigned.
|
int nextConditional; ///< Index of the next conditional to be assigned.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct the root node for the search.
|
||||||
|
*/
|
||||||
|
static SearchNode Root(size_t numConditionals, double bound) {
|
||||||
|
return {.assignment = DiscreteValues(),
|
||||||
|
.error = 0.0,
|
||||||
|
.bound = bound,
|
||||||
|
.nextConditional = static_cast<int>(numConditionals) - 1};
|
||||||
|
}
|
||||||
|
|
||||||
struct CompareByBound {
|
struct CompareByBound {
|
||||||
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
||||||
return a.bound > b.bound; // smallest bound -> highest priority
|
return a.bound > b.bound; // smallest bound -> highest priority
|
||||||
|
@ -49,20 +59,6 @@ struct SearchNode {
|
||||||
*/
|
*/
|
||||||
bool isComplete() const { return nextConditional < 0; }
|
bool isComplete() const { return nextConditional < 0; }
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Computes a lower bound on the final error for unassigned variables.
|
|
||||||
*
|
|
||||||
* @details This is a stub implementation that returns 0. Real implementations
|
|
||||||
* might perform partial factor analysis or use heuristics to compute the
|
|
||||||
* bound.
|
|
||||||
*
|
|
||||||
* @return A lower bound on the final error.
|
|
||||||
*/
|
|
||||||
double computeBound() const {
|
|
||||||
// Real code might do partial factor analysis or heuristics.
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Expands the node by assigning the next variable.
|
* @brief Expands the node by assigning the next variable.
|
||||||
*
|
*
|
||||||
|
@ -74,22 +70,15 @@ struct SearchNode {
|
||||||
SearchNode expand(const DiscreteConditional& conditional,
|
SearchNode expand(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& fa) const {
|
const DiscreteValues& fa) const {
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
SearchNode child;
|
DiscreteValues newAssignment = assignment;
|
||||||
child.assignment = assignment;
|
|
||||||
for (auto& kv : fa) {
|
for (auto& kv : fa) {
|
||||||
child.assignment[kv.first] = kv.second;
|
newAssignment[kv.first] = kv.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the incremental error for this factor
|
return {.assignment = newAssignment,
|
||||||
child.error = error + conditional.error(child.assignment);
|
.error = error + conditional.error(newAssignment),
|
||||||
|
.bound = 0.0,
|
||||||
// Compute new bound
|
.nextConditional = nextConditional - 1};
|
||||||
child.bound = computeBound();
|
|
||||||
|
|
||||||
// Update the index of the next conditional
|
|
||||||
child.nextConditional = nextConditional - 1;
|
|
||||||
|
|
||||||
return child;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -184,6 +173,8 @@ class Solutions {
|
||||||
*/
|
*/
|
||||||
class DiscreteSearch {
|
class DiscreteSearch {
|
||||||
public:
|
public:
|
||||||
|
size_t numExpansions = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteBayesNet and K.
|
||||||
*/
|
*/
|
||||||
|
@ -193,33 +184,42 @@ class DiscreteSearch {
|
||||||
conditionals_.push_back(factor);
|
conditionals_.push_back(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the cost-to-go for each conditional. If there are n
|
// Calculate the cost-to-go for each conditional
|
||||||
// conditionals, we start with nextConditional = n-1, and the minimum error
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
// obtainable is the sum of all the minimum errors. We start with
|
|
||||||
// 0, and that is the minimum error of the conditional with that index:
|
|
||||||
double error = 0.0;
|
|
||||||
for (const auto& conditional : conditionals_) {
|
|
||||||
Ordering ordering(conditional->begin(), conditional->end());
|
|
||||||
auto maxx = conditional->max(ordering);
|
|
||||||
assert(maxx->size() == 1);
|
|
||||||
error -= std::log(maxx->evaluate({}));
|
|
||||||
costToGo_.push_back(error);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the root node: no variables assigned, nextConditional = last.
|
// Create the root node and push it to the expansions queue
|
||||||
SearchNode root{
|
expansions_.push(SearchNode::Root(
|
||||||
.assignment = DiscreteValues(),
|
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
.error = 0.0,
|
|
||||||
.nextConditional = static_cast<int>(conditionals_.size()) - 1};
|
|
||||||
if (!costToGo_.empty()) root.bound = costToGo_.back();
|
|
||||||
expansions_.push(root);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteBayesTree and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
|
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K) : solutions_(K) {
|
||||||
: solutions_(K) {}
|
using CliquePtr = DiscreteBayesTree::sharedClique;
|
||||||
|
std::function<void(const CliquePtr&)> collectConditionals =
|
||||||
|
[&](const CliquePtr& clique) -> void {
|
||||||
|
if (!clique) return;
|
||||||
|
|
||||||
|
// Recursive post-order traversal: process children first
|
||||||
|
for (const auto& child : clique->children) {
|
||||||
|
collectConditionals(child);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then add the current clique's conditional
|
||||||
|
conditionals_.push_back(clique->conditional());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start traversal from each root in the tree
|
||||||
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
|
|
||||||
|
// Calculate the cost-to-go for each conditional
|
||||||
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
|
|
||||||
|
// Create the root node and push it to the expansions queue
|
||||||
|
expansions_.push(SearchNode::Root(
|
||||||
|
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
|
@ -233,6 +233,7 @@ class DiscreteSearch {
|
||||||
*/
|
*/
|
||||||
std::vector<Solution> run() {
|
std::vector<Solution> run() {
|
||||||
while (!expansions_.empty()) {
|
while (!expansions_.empty()) {
|
||||||
|
numExpansions++;
|
||||||
expandNextNode();
|
expandNextNode();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +242,29 @@ class DiscreteSearch {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/**
|
||||||
|
* @brief Compute the cost-to-go for each conditional.
|
||||||
|
*
|
||||||
|
* @param conditionals The conditionals of the DiscreteBayesNet.
|
||||||
|
* @return A vector of cost-to-go values.
|
||||||
|
*/
|
||||||
|
static std::vector<double> computeCostToGo(
|
||||||
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
||||||
|
std::vector<double> costToGo;
|
||||||
|
double error = 0.0;
|
||||||
|
for (const auto& conditional : conditionals) {
|
||||||
|
Ordering ordering(conditional->begin(), conditional->end());
|
||||||
|
auto maxx = conditional->max(ordering);
|
||||||
|
assert(maxx->size() == 1);
|
||||||
|
error -= std::log(maxx->evaluate({}));
|
||||||
|
costToGo.push_back(error);
|
||||||
|
}
|
||||||
|
return costToGo;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Expand the next node in the search tree.
|
||||||
|
*/
|
||||||
void expandNextNode() {
|
void expandNextNode() {
|
||||||
// Pop the partial assignment with the smallest bound
|
// Pop the partial assignment with the smallest bound
|
||||||
SearchNode current = expansions_.top();
|
SearchNode current = expansions_.top();
|
||||||
|
@ -273,7 +297,7 @@ class DiscreteSearch {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DiscreteConditional>> conditionals_;
|
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
||||||
std::vector<double> costToGo_;
|
std::vector<double> costToGo_;
|
||||||
std::priority_queue<SearchNode, std::vector<SearchNode>,
|
std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
SearchNode::CompareByBound>
|
SearchNode::CompareByBound>
|
||||||
|
|
|
@ -30,7 +30,7 @@ static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
|
||||||
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
|
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
|
||||||
Asia(A, 2);
|
Asia(A, 2);
|
||||||
|
|
||||||
// Function to construct the incomplete Asia example
|
// Function to construct the Asia priors
|
||||||
DiscreteBayesNet createPriors() {
|
DiscreteBayesNet createPriors() {
|
||||||
DiscreteBayesNet priors;
|
DiscreteBayesNet priors;
|
||||||
priors.add(Smoking % "50/50");
|
priors.add(Smoking % "50/50");
|
||||||
|
|
|
@ -49,8 +49,9 @@ TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
DiscreteSearch search(asia, 4);
|
DiscreteSearch search(asia, 4);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run();
|
||||||
EXPECT(!solutions.empty());
|
EXPECT(!solutions.empty());
|
||||||
// Regression test: check the first solution
|
// Regression test: check the first and last solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -70,16 +71,21 @@ TEST(DiscreteBayesTree, testEmptyTree) {
|
||||||
TEST(DiscreteBayesTree, testTrivialOneClique) {
|
TEST(DiscreteBayesTree, testTrivialOneClique) {
|
||||||
using namespace asia_example;
|
using namespace asia_example;
|
||||||
DiscreteFactorGraph asia(createAsiaExample());
|
DiscreteFactorGraph asia(createAsiaExample());
|
||||||
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal();
|
const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||||
|
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
||||||
GTSAM_PRINT(*bt);
|
GTSAM_PRINT(*bt);
|
||||||
|
|
||||||
// Ask for top 4 solutions
|
// Ask for top 4 solutions
|
||||||
DiscreteSearch search(*bt, 4);
|
DiscreteSearch search(*bt, 4);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run();
|
||||||
|
|
||||||
|
// print numExpansions
|
||||||
|
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
||||||
|
|
||||||
EXPECT(!solutions.empty());
|
EXPECT(!solutions.empty());
|
||||||
// Regression test: check the first solution
|
// Regression test: check the first and last solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue