Reversed slots so we start from zero
parent
b8f265d69f
commit
9e98b805d6
|
@ -33,16 +33,16 @@ using Solution = DiscreteSearch::Solution;
|
||||||
* conditional to be assigned.
|
* conditional to be assigned.
|
||||||
*/
|
*/
|
||||||
struct SearchNode {
|
struct SearchNode {
|
||||||
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
DiscreteValues assignment; ///< Partial assignment of discrete variables.
|
||||||
double error; ///< Current error for the partial assignment.
|
double error; ///< Current error for the partial assignment.
|
||||||
double bound; ///< Lower bound on the final error for unassigned variables.
|
double bound; ///< Lower bound on the final error
|
||||||
int nextConditional; ///< Index of the next conditional to be assigned.
|
std::optional<size_t> next; ///< Index of the next factor to be assigned.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct the root node for the search.
|
* @brief Construct the root node for the search.
|
||||||
*/
|
*/
|
||||||
static SearchNode Root(size_t numSlots, double bound) {
|
static SearchNode Root(size_t numSlots, double bound) {
|
||||||
return {DiscreteValues(), 0.0, bound, static_cast<int>(numSlots) - 1};
|
return {DiscreteValues(), 0.0, bound, 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Compare {
|
struct Compare {
|
||||||
|
@ -51,38 +51,22 @@ struct SearchNode {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/// Checks if the node represents a complete assignment.
|
||||||
* @brief Checks if the node represents a complete assignment.
|
inline bool isComplete() const { return !next; }
|
||||||
*
|
|
||||||
* @return True if all variables have been assigned, false otherwise.
|
|
||||||
*/
|
|
||||||
inline bool isComplete() const { return nextConditional < 0; }
|
|
||||||
|
|
||||||
/**
|
/// Expands the node by assigning the next variable(s).
|
||||||
* @brief Expands the node by assigning the next variable.
|
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
|
||||||
*
|
std::optional<size_t> nextSlot) const {
|
||||||
* @param slot The slot to be filled.
|
|
||||||
* @param fa The frontal assignment for the next variable.
|
|
||||||
* @return A new SearchNode representing the expanded state.
|
|
||||||
*/
|
|
||||||
SearchNode expand(const Slot& slot, const DiscreteValues& fa) const {
|
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
DiscreteValues newAssignment = assignment;
|
DiscreteValues newAssignment = assignment;
|
||||||
for (auto& [key, value] : fa) {
|
for (auto& [key, value] : fa) {
|
||||||
newAssignment[key] = value;
|
newAssignment[key] = value;
|
||||||
}
|
}
|
||||||
double errorSoFar = error + slot.factor->error(newAssignment);
|
double errorSoFar = error + slot.factor->error(newAssignment);
|
||||||
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic,
|
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
|
||||||
nextConditional - 1};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/// Prints the SearchNode to an output stream.
|
||||||
* @brief Prints the SearchNode to an output stream.
|
|
||||||
*
|
|
||||||
* @param os The output stream.
|
|
||||||
* @param node The SearchNode to be printed.
|
|
||||||
* @return The output stream.
|
|
||||||
*/
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
|
||||||
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
|
||||||
return os;
|
return os;
|
||||||
|
@ -150,13 +134,18 @@ class Solutions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// @brief Get the factor associated with a node, possibly product of factors.
|
||||||
|
template <typename NodeType>
|
||||||
|
static auto getFactor(const NodeType& node) {
|
||||||
|
const auto& factors = node->factors;
|
||||||
|
return factors.size() == 1 ? factors.back()
|
||||||
|
: DiscreteFactorGraph(factors).product();
|
||||||
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||||
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
|
||||||
auto visitor = [this](const NodePtr& node, int data) {
|
auto visitor = [this](const NodePtr& node, int data) {
|
||||||
const auto& factors = node->factors;
|
const auto factor = getFactor(node);
|
||||||
const auto factor = factors.size() == 1
|
|
||||||
? factors.back()
|
|
||||||
: DiscreteFactorGraph(factors).product();
|
|
||||||
const size_t cardinality = factor->cardinality(node->key);
|
const size_t cardinality = factor->cardinality(node->key);
|
||||||
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
|
||||||
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
|
||||||
|
@ -164,19 +153,15 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
|
||||||
return data + 1;
|
return data + 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
const int data = 0; // unused
|
int data = 0; // unused
|
||||||
treeTraversal::DepthFirstForest(etree, data, visitor);
|
treeTraversal::DepthFirstForest(etree, data, visitor);
|
||||||
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
||||||
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
|
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
|
||||||
auto visitor = [this](const NodePtr& cluster, int data) {
|
auto visitor = [this](const NodePtr& cluster, int data) {
|
||||||
const auto& factors = cluster->factors;
|
const auto factor = getFactor(cluster);
|
||||||
const auto factor = factors.size() == 1
|
|
||||||
? factors.back()
|
|
||||||
: DiscreteFactorGraph(factors).product();
|
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
for (Key key : cluster->orderedFrontalKeys) {
|
for (Key key : cluster->orderedFrontalKeys) {
|
||||||
pairs.emplace_back(key, factor->cardinality(key));
|
pairs.emplace_back(key, factor->cardinality(key));
|
||||||
|
@ -186,9 +171,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
|
||||||
return data + 1;
|
return data + 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
const int data = 0; // unused
|
int data = 0; // unused
|
||||||
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
|
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
|
||||||
std::reverse(slots_.begin(), slots_.end()); // reverse slots
|
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,21 +194,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
||||||
slots_.emplace_back(std::move(slot));
|
slots_.emplace_back(std::move(slot));
|
||||||
}
|
}
|
||||||
|
std::reverse(slots_.begin(), slots_.end());
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
using NodePtr = DiscreteBayesTree::sharedClique;
|
||||||
collectConditionals = [&](const auto& clique) {
|
auto visitor = [this](const NodePtr& clique, int data) {
|
||||||
if (!clique) return;
|
auto conditional = clique->conditional();
|
||||||
for (const auto& child : clique->children) collectConditionals(child);
|
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
||||||
auto conditional = clique->conditional();
|
slots_.emplace_back(std::move(slot));
|
||||||
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
|
return data + 1;
|
||||||
slots_.emplace_back(std::move(slot));
|
};
|
||||||
};
|
|
||||||
|
|
||||||
slots_.reserve(bayesTree.size());
|
int data = 0; // unused
|
||||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
treeTraversal::DepthFirstForest(bayesTree, data, visitor);
|
||||||
lowerBound_ = computeHeuristic();
|
lowerBound_ = computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,59 +220,48 @@ void DiscreteSearch::print(const std::string& name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SearchNodeQueue
|
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
SearchNode::Compare>;
|
||||||
SearchNode::Compare> {
|
|
||||||
void expandNextNode(const SearchNode& current, const Slot& slot,
|
|
||||||
Solutions* solutions) {
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
|
||||||
if (solutions->prune(current.bound)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
|
||||||
if (current.isComplete()) {
|
|
||||||
solutions->maybeAdd(current.error, current.assignment);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& fa : slot.assignments) {
|
|
||||||
auto childNode = current.expand(slot, fa);
|
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
|
||||||
if (!solutions->prune(childNode.bound)) {
|
|
||||||
emplace(childNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
if (slots_.empty()) {
|
||||||
|
return {Solution(0.0, DiscreteValues())};
|
||||||
|
}
|
||||||
|
|
||||||
Solutions solutions(K);
|
Solutions solutions(K);
|
||||||
SearchNodeQueue expansions;
|
SearchNodeQueue expansions;
|
||||||
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
|
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
|
||||||
|
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
|
||||||
size_t numExpansions = 0;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Perform the search
|
// Perform the search
|
||||||
while (!expansions.empty()) {
|
while (!expansions.empty()) {
|
||||||
// Pop the partial assignment with the smallest bound
|
// Pop the partial assignment with the smallest bound
|
||||||
SearchNode current = expansions.top();
|
SearchNode current = expansions.top();
|
||||||
expansions.pop();
|
expansions.pop();
|
||||||
|
|
||||||
// Get the next slot to expand
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
const auto& slot = slots_[current.nextConditional];
|
if (solutions.prune(current.bound)) {
|
||||||
expansions.expandNextNode(current, slot, &solutions);
|
continue;
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
}
|
||||||
++numExpansions;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
// Check if we have a complete assignment
|
||||||
std::cout << "Number of expansions: " << numExpansions << std::endl;
|
if (current.isComplete()) {
|
||||||
#endif
|
solutions.maybeAdd(current.error, current.assignment);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next slot to expand
|
||||||
|
const auto& slot = slots_[*current.next];
|
||||||
|
std::optional<size_t> nextSlot = *current.next + 1;
|
||||||
|
if (nextSlot == slots_.size()) nextSlot.reset();
|
||||||
|
for (auto& fa : slot.assignments) {
|
||||||
|
auto childNode = current.expand(fa, slot, nextSlot);
|
||||||
|
|
||||||
|
// Again, prune if we cannot beat the worst solution
|
||||||
|
if (!solutions.prune(childNode.bound)) {
|
||||||
|
expansions.emplace(childNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Extract solutions from bestSolutions in ascending order of error
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
return solutions.extractSolutions();
|
return solutions.extractSolutions();
|
||||||
|
@ -296,17 +269,16 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
|
||||||
// We have a number of factors, each with a max value, and we want to compute
|
// We have a number of factors, each with a max value, and we want to compute
|
||||||
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
|
||||||
// For the first slot, this is 0.0, as this is the last slot to be filled, so
|
// For the last slot, this is 0.0, as the cost after that is zero.
|
||||||
// the cost after that is zero. For the second slot, it is h0 =
|
// For the second-to-last slot, it is -log(max(factor[0])), because after we
|
||||||
// -log(max(factor[0])), because after we assign slot[1] we still need to
|
// assign slot[1] we still need to assign slot[0], which will cost *at least*
|
||||||
// assign slot[0], which will cost *at least* h0. We return the estimated
|
// h0. We return the estimated lower bound of the cost for *all* slots.
|
||||||
// lower bound of the cost for *all* slots.
|
|
||||||
double DiscreteSearch::computeHeuristic() {
|
double DiscreteSearch::computeHeuristic() {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (auto& slot : slots_) {
|
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
|
||||||
slot.heuristic = error;
|
it->heuristic = error;
|
||||||
Ordering ordering(slot.factor->begin(), slot.factor->end());
|
Ordering ordering(it->factor->begin(), it->factor->end());
|
||||||
auto maxx = slot.factor->max(ordering);
|
auto maxx = it->factor->max(ordering);
|
||||||
error -= std::log(maxx->evaluate({}));
|
error -= std::log(maxx->evaluate({}));
|
||||||
}
|
}
|
||||||
return error;
|
return error;
|
||||||
|
|
|
@ -125,6 +125,12 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
/// @name Standard API
|
/// @name Standard API
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Return lower bound on the cost-to-go for the entire search
|
||||||
|
double lowerBound() const { return lowerBound_; }
|
||||||
|
|
||||||
|
/// Read access to the slots
|
||||||
|
const std::vector<Slot>& slots() const { return slots_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
*
|
*
|
||||||
|
|
|
@ -77,6 +77,17 @@ TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
// Ask for the MPE
|
// Ask for the MPE
|
||||||
auto mpe = search.run();
|
auto mpe = search.run();
|
||||||
|
|
||||||
|
// Regression on error lower bound
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5);
|
||||||
|
|
||||||
|
// Check that the cost-to-go heuristic decreases from there
|
||||||
|
auto slots = search.slots();
|
||||||
|
double previousHeuristic = search.lowerBound();
|
||||||
|
for (auto&& slot : slots) {
|
||||||
|
EXPECT(slot.heuristic <= previousHeuristic);
|
||||||
|
previousHeuristic = slot.heuristic;
|
||||||
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(1, mpe.size());
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
// Regression test: check the MPE solution
|
// Regression test: check the MPE solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
Loading…
Reference in New Issue