commit
3f6ae48dfb
|
@ -0,0 +1,246 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* DiscreteSearch.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2025
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteSearch.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
using Solution = DiscreteSearch::Solution;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Represents a node in the search tree for discrete search algorithms.
|
||||||
|
*
|
||||||
|
* @details Each SearchNode contains a partial assignment of discrete variables,
|
||||||
|
* the current error, a bound on the final error, and the index of the next
|
||||||
|
* conditional to be assigned.
|
||||||
|
*/
|
||||||
|
struct SearchNode {
|
||||||
|
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
||||||
|
double error; ///< Current error for the partial assignment.
|
||||||
|
double bound; ///< Lower bound on the final error for unassigned variables.
|
||||||
|
int nextConditional; ///< Index of the next conditional to be assigned.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct the root node for the search.
|
||||||
|
*/
|
||||||
|
static SearchNode Root(size_t numConditionals, double bound) {
|
||||||
|
return {DiscreteValues(), 0.0, bound,
|
||||||
|
static_cast<int>(numConditionals) - 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Compare {
|
||||||
|
bool operator()(const SearchNode& a, const SearchNode& b) const {
|
||||||
|
return a.bound > b.bound; // smallest bound -> highest priority
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if the node represents a complete assignment.
|
||||||
|
*
|
||||||
|
* @return True if all variables have been assigned, false otherwise.
|
||||||
|
*/
|
||||||
|
inline bool isComplete() const { return nextConditional < 0; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Expands the node by assigning the next variable.
|
||||||
|
*
|
||||||
|
* @param conditional The discrete conditional representing the next variable
|
||||||
|
* to be assigned.
|
||||||
|
* @param fa The frontal assignment for the next variable.
|
||||||
|
* @return A new SearchNode representing the expanded state.
|
||||||
|
*/
|
||||||
|
SearchNode expand(const DiscreteConditional& conditional,
|
||||||
|
const DiscreteValues& fa) const {
|
||||||
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
|
DiscreteValues newAssignment = assignment;
|
||||||
|
for (auto& [key, value] : fa) {
|
||||||
|
newAssignment[key] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {newAssignment, error + conditional.error(newAssignment), 0.0,
|
||||||
|
nextConditional - 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prints the SearchNode to an output stream.
|
||||||
|
*
|
||||||
|
* @param os The output stream.
|
||||||
|
* @param node The SearchNode to be printed.
|
||||||
|
* @return The output stream.
|
||||||
|
*/
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
||||||
|
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompareSolution {
|
||||||
|
bool operator()(const Solution& a, const Solution& b) const {
|
||||||
|
return a.error < b.error;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Define the Solutions class
|
||||||
|
class Solutions {
|
||||||
|
private:
|
||||||
|
size_t maxSize_;
|
||||||
|
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Solutions(size_t maxSize) : maxSize_(maxSize) {}
|
||||||
|
|
||||||
|
/// Add a solution to the priority queue, possibly evicting the worst one.
|
||||||
|
/// Return true if we added the solution.
|
||||||
|
bool maybeAdd(double error, const DiscreteValues& assignment) {
|
||||||
|
const bool full = pq_.size() == maxSize_;
|
||||||
|
if (full && error >= pq_.top().error) return false;
|
||||||
|
if (full) pq_.pop();
|
||||||
|
pq_.emplace(error, assignment);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if we have any solutions
|
||||||
|
bool empty() const { return pq_.empty(); }
|
||||||
|
|
||||||
|
// Method to print all solutions
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
|
||||||
|
os << "Solutions (top " << sn.pq_.size() << "):\n";
|
||||||
|
auto pq = sn.pq_;
|
||||||
|
while (!pq.empty()) {
|
||||||
|
os << pq.top() << "\n";
|
||||||
|
pq.pop();
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if (partial) solution with given bound can be pruned. If we have
|
||||||
|
/// room, we never prune. Otherwise, prune if lower bound on error is worse
|
||||||
|
/// than our current worst error.
|
||||||
|
bool prune(double bound) const {
|
||||||
|
if (pq_.size() < maxSize_) return false;
|
||||||
|
return bound >= pq_.top().error;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method to extract solutions in ascending order of error
|
||||||
|
std::vector<Solution> extractSolutions() {
|
||||||
|
std::vector<Solution> result;
|
||||||
|
while (!pq_.empty()) {
|
||||||
|
result.push_back(pq_.top());
|
||||||
|
pq_.pop();
|
||||||
|
}
|
||||||
|
std::sort(
|
||||||
|
result.begin(), result.end(),
|
||||||
|
[](const Solution& a, const Solution& b) { return a.error < b.error; });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
|
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
||||||
|
for (auto& factor : bayesNet) conditionals_.push_back(factor);
|
||||||
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
|
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
||||||
|
collectConditionals = [&](const auto& clique) {
|
||||||
|
if (!clique) return;
|
||||||
|
for (const auto& child : clique->children) collectConditionals(child);
|
||||||
|
conditionals_.push_back(clique->conditional());
|
||||||
|
};
|
||||||
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SearchNodeQueue
|
||||||
|
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
|
SearchNode::Compare> {
|
||||||
|
void expandNextNode(
|
||||||
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
|
||||||
|
const std::vector<double>& costToGo, Solutions* solutions) {
|
||||||
|
// Pop the partial assignment with the smallest bound
|
||||||
|
SearchNode current = top();
|
||||||
|
pop();
|
||||||
|
|
||||||
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
|
if (solutions->prune(current.bound)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a complete assignment
|
||||||
|
if (current.isComplete()) {
|
||||||
|
solutions->maybeAdd(current.error, current.assignment);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand on the next factor
|
||||||
|
const auto& conditional = conditionals[current.nextConditional];
|
||||||
|
|
||||||
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
if (childNode.nextConditional >= 0)
|
||||||
|
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
|
||||||
|
|
||||||
|
// Again, prune if we cannot beat the worst solution
|
||||||
|
if (!solutions->prune(childNode.bound)) {
|
||||||
|
emplace(childNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
Solutions solutions(K);
|
||||||
|
SearchNodeQueue expansions;
|
||||||
|
expansions.push(SearchNode::Root(conditionals_.size(),
|
||||||
|
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
size_t numExpansions = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Perform the search
|
||||||
|
while (!expansions.empty()) {
|
||||||
|
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
++numExpansions;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
std::cout << "Number of expansions: " << numExpansions << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
|
return solutions.extractSolutions();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> DiscreteSearch::computeCostToGo(
|
||||||
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
||||||
|
std::vector<double> costToGo;
|
||||||
|
double error = 0.0;
|
||||||
|
for (const auto& conditional : conditionals) {
|
||||||
|
Ordering ordering(conditional->begin(), conditional->end());
|
||||||
|
auto maxx = conditional->max(ordering);
|
||||||
|
error -= std::log(maxx->evaluate({}));
|
||||||
|
costToGo.push_back(error);
|
||||||
|
}
|
||||||
|
return costToGo;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -0,0 +1,78 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* DiscreteSearch.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2025
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DiscreteSearch: Search for the K best solutions.
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteSearch {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief A solution to a discrete search problem.
|
||||||
|
*/
|
||||||
|
struct Solution {
|
||||||
|
double error;
|
||||||
|
DiscreteValues assignment;
|
||||||
|
Solution(double err, const DiscreteValues& assign)
|
||||||
|
: error(err), assignment(assign) {}
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
|
||||||
|
os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from a DiscreteBayesNet and K.
|
||||||
|
*/
|
||||||
|
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from a DiscreteBayesTree and K.
|
||||||
|
*/
|
||||||
|
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Search for the K best solutions.
|
||||||
|
*
|
||||||
|
* This method performs a search to find the K best solutions for the given
|
||||||
|
* DiscreteBayesNet. It uses a priority queue to manage the search nodes,
|
||||||
|
* expanding nodes with the smallest bound first. The search continues until
|
||||||
|
* all possible nodes have been expanded or pruned.
|
||||||
|
*
|
||||||
|
* @return A vector of the K best solutions found during the search.
|
||||||
|
*/
|
||||||
|
std::vector<Solution> run(size_t K = 1) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Compute the cumulative cost-to-go for each conditional slot.
|
||||||
|
static std::vector<double> computeCostToGo(
|
||||||
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
||||||
|
|
||||||
|
/// Expand the next node in the search tree.
|
||||||
|
void expandNextNode() const;
|
||||||
|
|
||||||
|
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
||||||
|
std::vector<double> costToGo_;
|
||||||
|
};
|
||||||
|
} // namespace gtsam
|
|
@ -26,12 +26,24 @@ using std::stringstream;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static void stream(std::ostream& os, const DiscreteValues& x,
|
||||||
|
const KeyFormatter& keyFormatter) {
|
||||||
|
for (const auto& kv : x)
|
||||||
|
os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) {
|
||||||
|
stream(os, x, DefaultKeyFormatter);
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void DiscreteValues::print(const string& s,
|
void DiscreteValues::print(const string& s,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter) const {
|
||||||
cout << s << ": ";
|
cout << s << ": ";
|
||||||
for (auto&& kv : *this)
|
stream(cout, *this, keyFormatter);
|
||||||
cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// ostream operator:
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x);
|
||||||
|
|
||||||
// insert in base class;
|
// insert in base class;
|
||||||
std::pair<iterator, bool> insert( const value_type& value ){
|
std::pair<iterator, bool> insert( const value_type& value ){
|
||||||
return Base::insert(value);
|
return Base::insert(value);
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* AsiaExample.h
|
||||||
|
*
|
||||||
|
* @date Jan, 2025
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
namespace asia_example {
|
||||||
|
|
||||||
|
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
|
||||||
|
B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
|
||||||
|
S = Symbol('S', 7), A = Symbol('A', 8);
|
||||||
|
|
||||||
|
static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
|
||||||
|
Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
|
||||||
|
Asia(A, 2);
|
||||||
|
|
||||||
|
// Function to construct the Asia priors
|
||||||
|
DiscreteBayesNet createPriors() {
|
||||||
|
DiscreteBayesNet priors;
|
||||||
|
priors.add(Smoking % "50/50");
|
||||||
|
priors.add(Asia, "99/1");
|
||||||
|
return priors;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to construct the incomplete Asia example
|
||||||
|
DiscreteBayesNet createFragment() {
|
||||||
|
DiscreteBayesNet fragment;
|
||||||
|
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
fragment.add(LungCancer | Smoking = "99/1 90/10");
|
||||||
|
fragment.add(Tuberculosis | Asia = "99/1 95/5");
|
||||||
|
for (const auto& factor : createPriors()) fragment.push_back(factor);
|
||||||
|
return fragment;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to construct the Asia example
|
||||||
|
DiscreteBayesNet createAsiaExample() {
|
||||||
|
DiscreteBayesNet asia;
|
||||||
|
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
||||||
|
asia.add(XRay | Either = "95/5 2/98");
|
||||||
|
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
||||||
|
for (const auto& factor : createFragment()) asia.push_back(factor);
|
||||||
|
return asia;
|
||||||
|
}
|
||||||
|
} // namespace asia_example
|
||||||
|
} // namespace gtsam
|
|
@ -23,40 +23,19 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
#include "AsiaExample.h"
|
||||||
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
|
|
||||||
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
|
|
||||||
|
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
|
||||||
|
|
||||||
// Function to construct the Asia example
|
|
||||||
DiscreteBayesNet constructAsiaExample() {
|
|
||||||
DiscreteBayesNet asia;
|
|
||||||
|
|
||||||
asia.add(Asia, "99/1");
|
|
||||||
asia.add(Smoking % "50/50"); // Signature version
|
|
||||||
|
|
||||||
asia.add(Tuberculosis | Asia = "99/1 95/5");
|
|
||||||
asia.add(LungCancer | Smoking = "99/1 90/10");
|
|
||||||
asia.add(Bronchitis | Smoking = "70/30 40/60");
|
|
||||||
|
|
||||||
asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
|
||||||
|
|
||||||
asia.add(XRay | Either = "95/5 2/98");
|
|
||||||
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
|
|
||||||
|
|
||||||
return asia;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, bayesNet) {
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
|
@ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Asia) {
|
TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteBayesNet asia = constructAsiaExample();
|
using namespace asia_example;
|
||||||
|
const DiscreteBayesNet asia = createAsiaExample();
|
||||||
|
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph fg(asia);
|
DiscreteFactorGraph fg(asia);
|
||||||
LONGS_EQUAL(3, fg.back()->size());
|
LONGS_EQUAL(1, fg.back()->size());
|
||||||
|
|
||||||
// Check the marginals we know (of the parent-less nodes)
|
// Check the marginals we know (of the parent-less nodes)
|
||||||
DiscreteMarginals marginals(fg);
|
DiscreteMarginals marginals(fg);
|
||||||
|
@ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
|
EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
|
const Ordering ordering{A, D, T, X, S, E, L, B};
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
@ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Dot) {
|
TEST(DiscreteBayesNet, Dot) {
|
||||||
DiscreteBayesNet fragment;
|
using namespace asia_example;
|
||||||
fragment.add(Asia % "99/1");
|
const DiscreteBayesNet fragment = createFragment();
|
||||||
fragment.add(Smoking % "50/50");
|
|
||||||
|
|
||||||
fragment.add(Tuberculosis | Asia = "99/1 95/5");
|
std::string expected =
|
||||||
fragment.add(LungCancer | Smoking = "99/1 90/10");
|
"digraph {\n"
|
||||||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
string actual = fragment.dot();
|
" var4683743612465315848[label=\"A8\"];\n"
|
||||||
EXPECT(actual ==
|
" var4971973988617027587[label=\"E3\"];\n"
|
||||||
"digraph {\n"
|
" var5476377146882523141[label=\"L5\"];\n"
|
||||||
" size=\"5,5\";\n"
|
" var5980780305148018695[label=\"S7\"];\n"
|
||||||
"\n"
|
" var6052837899185946630[label=\"T6\"];\n"
|
||||||
" var0[label=\"0\"];\n"
|
"\n"
|
||||||
" var3[label=\"3\"];\n"
|
" var4683743612465315848->var6052837899185946630\n"
|
||||||
" var4[label=\"4\"];\n"
|
" var5980780305148018695->var5476377146882523141\n"
|
||||||
" var5[label=\"5\"];\n"
|
" var6052837899185946630->var4971973988617027587\n"
|
||||||
" var6[label=\"6\"];\n"
|
" var5476377146882523141->var4971973988617027587\n"
|
||||||
"\n"
|
"}";
|
||||||
" var3->var5\n"
|
std::string actual = fragment.dot();
|
||||||
" var6->var5\n"
|
EXPECT(actual.compare(expected) == 0);
|
||||||
" var4->var6\n"
|
|
||||||
" var0->var3\n"
|
|
||||||
"}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected.
|
// Check markdown representation looks as expected.
|
||||||
TEST(DiscreteBayesNet, markdown) {
|
TEST(DiscreteBayesNet, markdown) {
|
||||||
DiscreteBayesNet fragment;
|
using namespace asia_example;
|
||||||
fragment.add(Asia % "99/1");
|
DiscreteBayesNet priors = createPriors();
|
||||||
fragment.add(Smoking | Asia = "8/2 7/3");
|
|
||||||
|
|
||||||
string expected =
|
std::string expected =
|
||||||
"`DiscreteBayesNet` of size 2\n"
|
"`DiscreteBayesNet` of size 2\n"
|
||||||
"\n"
|
"\n"
|
||||||
|
" *P(Smoking):*\n\n"
|
||||||
|
"|Smoking|value|\n"
|
||||||
|
"|:-:|:-:|\n"
|
||||||
|
"|0|0.5|\n"
|
||||||
|
"|1|0.5|\n"
|
||||||
|
"\n"
|
||||||
" *P(Asia):*\n\n"
|
" *P(Asia):*\n\n"
|
||||||
"|Asia|value|\n"
|
"|Asia|value|\n"
|
||||||
"|:-:|:-:|\n"
|
"|:-:|:-:|\n"
|
||||||
"|0|0.99|\n"
|
"|0|0.99|\n"
|
||||||
"|1|0.01|\n"
|
"|1|0.01|\n\n";
|
||||||
"\n"
|
auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; };
|
||||||
" *P(Smoking|Asia):*\n\n"
|
std::string actual = priors.markdown(formatter);
|
||||||
"|*Asia*|0|1|\n"
|
|
||||||
"|:-:|:-:|:-:|\n"
|
|
||||||
"|0|0.8|0.2|\n"
|
|
||||||
"|1|0.7|0.3|\n\n";
|
|
||||||
auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
|
|
||||||
string actual = fragment.markdown(formatter);
|
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteSearch.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2025
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteSearch.h>
|
||||||
|
|
||||||
|
#include "AsiaExample.h"
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
// Create Asia Bayes net, FG, and Bayes tree once
|
||||||
|
namespace asia {
|
||||||
|
using namespace asia_example;
|
||||||
|
static const DiscreteBayesNet bayesNet = createAsiaExample();
|
||||||
|
static const DiscreteFactorGraph factorGraph(bayesNet);
|
||||||
|
static const DiscreteValues mpe = factorGraph.optimize();
|
||||||
|
static const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||||
|
static const DiscreteBayesTree bayesTree =
|
||||||
|
*factorGraph.eliminateMultifrontal(ordering);
|
||||||
|
} // namespace asia
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, EmptyKBest) {
|
||||||
|
DiscreteBayesNet net; // no factors
|
||||||
|
DiscreteSearch search(net);
|
||||||
|
auto solutions = search.run(3);
|
||||||
|
// Expect one solution with empty assignment, error=0
|
||||||
|
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
|
const DiscreteSearch search(asia::bayesNet);
|
||||||
|
|
||||||
|
// Ask for the MPE
|
||||||
|
auto mpe = search.run();
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
|
// Regression test: check the MPE solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
|
// Check it is equal to MPE via inference
|
||||||
|
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
|
||||||
|
|
||||||
|
// Ask for top 4 solutions
|
||||||
|
auto solutions = search.run(4);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
|
// Regression test: check the first and last solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesTree, EmptyTree) {
|
||||||
|
DiscreteBayesTree bt;
|
||||||
|
|
||||||
|
DiscreteSearch search(bt);
|
||||||
|
auto solutions = search.run(3);
|
||||||
|
|
||||||
|
// We expect exactly 1 solution with error = 0.0 (the empty assignment).
|
||||||
|
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesTree, AsiaTreeKBest) {
|
||||||
|
DiscreteSearch search(asia::bayesTree);
|
||||||
|
|
||||||
|
// Ask for MPE
|
||||||
|
auto mpe = search.run();
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
|
// Regression test: check the MPE solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
|
// Check it is equal to MPE via inference
|
||||||
|
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
|
||||||
|
|
||||||
|
// Ask for top 4 solutions
|
||||||
|
auto solutions = search.run(4);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
|
// Regression test: check the first and last solution
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
Loading…
Reference in New Issue