Refactor to slots
parent
35e7acbf16
commit
d8ed60aead
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
using Slot = DiscreteSearch::Slot;
|
||||||
using Solution = DiscreteSearch::Solution;
|
using Solution = DiscreteSearch::Solution;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -59,12 +60,12 @@ struct SearchNode {
|
||||||
/**
|
/**
|
||||||
* @brief Expands the node by assigning the next variable.
|
* @brief Expands the node by assigning the next variable.
|
||||||
*
|
*
|
||||||
* @param conditional The discrete conditional representing the next variable
|
* @param factor The discrete factor associated with the next variable
|
||||||
* to be assigned.
|
* to be assigned.
|
||||||
* @param fa The frontal assignment for the next variable.
|
* @param fa The frontal assignment for the next variable.
|
||||||
* @return A new SearchNode representing the expanded state.
|
* @return A new SearchNode representing the expanded state.
|
||||||
*/
|
*/
|
||||||
SearchNode expand(const DiscreteConditional& conditional,
|
SearchNode expand(const DiscreteFactor& factor,
|
||||||
const DiscreteValues& fa) const {
|
const DiscreteValues& fa) const {
|
||||||
// Combine the new frontal assignment with the current partial assignment
|
// Combine the new frontal assignment with the current partial assignment
|
||||||
DiscreteValues newAssignment = assignment;
|
DiscreteValues newAssignment = assignment;
|
||||||
|
@ -72,7 +73,7 @@ struct SearchNode {
|
||||||
newAssignment[key] = value;
|
newAssignment[key] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {newAssignment, error + conditional.error(newAssignment), 0.0,
|
return {newAssignment, error + factor.error(newAssignment), 0.0,
|
||||||
nextConditional - 1};
|
nextConditional - 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,10 +151,20 @@ class Solutions {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) {
|
||||||
|
slots_.reserve(factorGraph.size());
|
||||||
|
for (auto& factor : factorGraph) {
|
||||||
|
slots_.emplace_back(factor, std::vector<DiscreteValues>{}, 0.0);
|
||||||
|
}
|
||||||
|
computeHeuristic();
|
||||||
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
slots_.reserve(bayesNet.size());
|
||||||
for (auto& factor : bayesNet) conditionals_.push_back(factor);
|
for (auto& conditional : bayesNet) {
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0);
|
||||||
|
}
|
||||||
|
computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
|
@ -161,22 +172,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
collectConditionals = [&](const auto& clique) {
|
collectConditionals = [&](const auto& clique) {
|
||||||
if (!clique) return;
|
if (!clique) return;
|
||||||
for (const auto& child : clique->children) collectConditionals(child);
|
for (const auto& child : clique->children) collectConditionals(child);
|
||||||
conditionals_.push_back(clique->conditional());
|
auto conditional = clique->conditional();
|
||||||
|
slots_.emplace_back(conditional, conditional->frontalAssignments(),
|
||||||
|
0.0);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
slots_.reserve(bayesTree.size());
|
||||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
computeHeuristic();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SearchNodeQueue
|
struct SearchNodeQueue
|
||||||
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
SearchNode::Compare> {
|
SearchNode::Compare> {
|
||||||
void expandNextNode(
|
void expandNextNode(const SearchNode& current, const Slot& slot,
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
|
Solutions* solutions) {
|
||||||
const std::vector<double>& costToGo, Solutions* solutions) {
|
|
||||||
// Pop the partial assignment with the smallest bound
|
|
||||||
SearchNode current = top();
|
|
||||||
pop();
|
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
// If we already have K solutions, prune if we cannot beat the worst one.
|
||||||
if (solutions->prune(current.bound)) {
|
if (solutions->prune(current.bound)) {
|
||||||
return;
|
return;
|
||||||
|
@ -188,13 +198,11 @@ struct SearchNodeQueue
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expand on the next factor
|
for (auto& fa : slot.assignments) {
|
||||||
const auto& conditional = conditionals[current.nextConditional];
|
auto childNode = current.expand(*slot.factor, fa);
|
||||||
|
|
||||||
for (auto& fa : conditional->frontalAssignments()) {
|
|
||||||
auto childNode = current.expand(*conditional, fa);
|
|
||||||
if (childNode.nextConditional >= 0)
|
if (childNode.nextConditional >= 0)
|
||||||
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
|
// TODO(frank): this might be wrong !
|
||||||
|
childNode.bound = childNode.error + slot.heuristic;
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
// Again, prune if we cannot beat the worst solution
|
||||||
if (!solutions->prune(childNode.bound)) {
|
if (!solutions->prune(childNode.bound)) {
|
||||||
|
@ -207,8 +215,7 @@ struct SearchNodeQueue
|
||||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
Solutions solutions(K);
|
Solutions solutions(K);
|
||||||
SearchNodeQueue expansions;
|
SearchNodeQueue expansions;
|
||||||
expansions.push(SearchNode::Root(conditionals_.size(),
|
expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic));
|
||||||
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
|
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
size_t numExpansions = 0;
|
size_t numExpansions = 0;
|
||||||
|
@ -216,7 +223,13 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
|
||||||
// Perform the search
|
// Perform the search
|
||||||
while (!expansions.empty()) {
|
while (!expansions.empty()) {
|
||||||
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
|
// Pop the partial assignment with the smallest bound
|
||||||
|
SearchNode current = expansions.top();
|
||||||
|
expansions.pop();
|
||||||
|
|
||||||
|
// Get the next slot to expand
|
||||||
|
const auto& slot = slots_[current.nextConditional];
|
||||||
|
expansions.expandNextNode(current, slot, &solutions);
|
||||||
#ifdef DISCRETE_SEARCH_DEBUG
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
++numExpansions;
|
++numExpansions;
|
||||||
#endif
|
#endif
|
||||||
|
@ -230,17 +243,19 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
return solutions.extractSolutions();
|
return solutions.extractSolutions();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<double> DiscreteSearch::computeCostToGo(
|
// We have a number of factors, each with a max value, and we want to compute
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
// the a lower-bound on the cost-to-go for each slot. For the first slot, this
|
||||||
std::vector<double> costToGo;
|
// -log(max(factor[0])), as we only have one factor to resolve. For the second
|
||||||
|
// slot, we need to add -log(max(factor[1])) to it, etc...
|
||||||
|
void DiscreteSearch::computeHeuristic() {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (const auto& conditional : conditionals) {
|
for (size_t i = 0; i < slots_.size(); ++i) {
|
||||||
Ordering ordering(conditional->begin(), conditional->end());
|
const auto& factor = slots_[i].factor;
|
||||||
auto maxx = conditional->max(ordering);
|
Ordering ordering(factor->begin(), factor->end());
|
||||||
|
auto maxx = factor->max(ordering);
|
||||||
error -= std::log(maxx->evaluate({}));
|
error -= std::log(maxx->evaluate({}));
|
||||||
costToGo.push_back(error);
|
slots_[i].heuristic = error;
|
||||||
}
|
}
|
||||||
return costToGo;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -28,9 +28,25 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteSearch {
|
class GTSAM_EXPORT DiscreteSearch {
|
||||||
public:
|
public:
|
||||||
/**
|
/// We structure the search as a set of slots, each with a factor and
|
||||||
* @brief A solution to a discrete search problem.
|
/// a set of variable assignments that need to be chosen. In addition, each
|
||||||
*/
|
/// slot has a heuristic associated with it.
|
||||||
|
struct Slot {
|
||||||
|
/// The factors in the search problem,
|
||||||
|
/// e.g., [P(B|A),P(A)]
|
||||||
|
DiscreteFactor::shared_ptr factor;
|
||||||
|
|
||||||
|
/// The assignments for each factor,
|
||||||
|
/// e.g., [[B0,B1] [A0,A1]]
|
||||||
|
std::vector<DiscreteValues> assignments;
|
||||||
|
|
||||||
|
/// A lower bound on the cost-to-go for each slot, e.g.,
|
||||||
|
/// [-log(max_B P(B|A)), -log(max_A P(A))]
|
||||||
|
double heuristic;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A solution is then a set of assignments, covering all the slots.
|
||||||
|
/// as well as an associated error = -log(probability)
|
||||||
struct Solution {
|
struct Solution {
|
||||||
double error;
|
double error;
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
|
@ -42,13 +58,19 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteFactorGraph.
|
||||||
|
*/
|
||||||
|
DiscreteSearch(const DiscreteFactorGraph& bayesNet);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from a DiscreteBayesNet.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesTree and K.
|
* Construct from a DiscreteBayesTree.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||||
|
|
||||||
|
@ -65,14 +87,9 @@ class GTSAM_EXPORT DiscreteSearch {
|
||||||
std::vector<Solution> run(size_t K = 1) const;
|
std::vector<Solution> run(size_t K = 1) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Compute the cumulative cost-to-go for each conditional slot.
|
/// Compute the cumulative lower-bound cost-to-go for each slot.
|
||||||
static std::vector<double> computeCostToGo(
|
void computeHeuristic();
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
|
||||||
|
|
||||||
/// Expand the next node in the search tree.
|
std::vector<Slot> slots_;
|
||||||
void expandNextNode() const;
|
|
||||||
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
|
||||||
std::vector<double> costToGo_;
|
|
||||||
};
|
};
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue