Made run const, take K
parent
b10ea06626
commit
f174a38eed
|
@ -76,34 +76,84 @@ std::vector<Solution> Solutions::extractSolutions() {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K)
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
|
||||||
: solutions_(K) {
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
||||||
for (auto& factor : bayesNet) conditionals.push_back(factor);
|
for (auto& factor : bayesNet) conditionals_.push_back(factor);
|
||||||
initialize(conditionals);
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K)
|
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
|
||||||
: solutions_(K) {
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals;
|
|
||||||
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
std::function<void(const DiscreteBayesTree::sharedClique&)>
|
||||||
collectConditionals = [&](const auto& clique) {
|
collectConditionals = [&](const auto& clique) {
|
||||||
if (!clique) return;
|
if (!clique) return;
|
||||||
for (const auto& child : clique->children) collectConditionals(child);
|
for (const auto& child : clique->children) collectConditionals(child);
|
||||||
conditionals.push_back(clique->conditional());
|
conditionals_.push_back(clique->conditional());
|
||||||
};
|
};
|
||||||
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
for (const auto& root : bayesTree.roots()) collectConditionals(root);
|
||||||
initialize(conditionals);
|
costToGo_ = computeCostToGo(conditionals_);
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<Solution> DiscreteSearch::run() {
|
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
|
||||||
while (!expansions_.empty()) {
|
SearchNode::Compare>;
|
||||||
numExpansions++;
|
|
||||||
|
std::vector<Solution> DiscreteSearch::run(size_t K) const {
|
||||||
|
SearchNodeQueue expansions;
|
||||||
|
expansions.push(SearchNode::Root(conditionals_.size(),
|
||||||
|
costToGo_.empty() ? 0.0 : costToGo_.back()));
|
||||||
|
|
||||||
|
Solutions solutions(K);
|
||||||
|
|
||||||
|
auto expandNextNode = [&] {
|
||||||
|
// Pop the partial assignment with the smallest bound
|
||||||
|
SearchNode current = expansions.top();
|
||||||
|
expansions.pop();
|
||||||
|
|
||||||
|
// If we already have K solutions, prune if we cannot beat the worst
|
||||||
|
// one.
|
||||||
|
if (solutions.prune(current.bound)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a complete assignment
|
||||||
|
if (current.isComplete()) {
|
||||||
|
solutions.maybeAdd(current.error, current.assignment);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand on the next factor
|
||||||
|
const auto& conditional = conditionals_[current.nextConditional];
|
||||||
|
|
||||||
|
for (auto& fa : conditional->frontalAssignments()) {
|
||||||
|
auto childNode = current.expand(*conditional, fa);
|
||||||
|
if (childNode.nextConditional >= 0)
|
||||||
|
childNode.bound =
|
||||||
|
childNode.error + costToGo_[childNode.nextConditional];
|
||||||
|
|
||||||
|
// Again, prune if we cannot beat the worst solution
|
||||||
|
if (!solutions.prune(childNode.bound)) {
|
||||||
|
expansions.emplace(childNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
size_t numExpansions = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Perform the search
|
||||||
|
while (!expansions.empty()) {
|
||||||
expandNextNode();
|
expandNextNode();
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
++numExpansions;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DISCRETE_SEARCH_DEBUG
|
||||||
|
std::cout << "Number of expansions: " << numExpansions << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Extract solutions from bestSolutions in ascending order of error
|
// Extract solutions from bestSolutions in ascending order of error
|
||||||
return solutions_.extractSolutions();
|
return solutions.extractSolutions();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<double> DiscreteSearch::computeCostToGo(
|
std::vector<double> DiscreteSearch::computeCostToGo(
|
||||||
|
@ -120,35 +170,4 @@ std::vector<double> DiscreteSearch::computeCostToGo(
|
||||||
return costToGo;
|
return costToGo;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DiscreteSearch::expandNextNode() {
|
|
||||||
// Pop the partial assignment with the smallest bound
|
|
||||||
SearchNode current = expansions_.top();
|
|
||||||
expansions_.pop();
|
|
||||||
|
|
||||||
// If we already have K solutions, prune if we cannot beat the worst one.
|
|
||||||
if (solutions_.prune(current.bound)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have a complete assignment
|
|
||||||
if (current.isComplete()) {
|
|
||||||
solutions_.maybeAdd(current.error, current.assignment);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand on the next factor
|
|
||||||
const auto& conditional = conditionals_[current.nextConditional];
|
|
||||||
|
|
||||||
for (auto& fa : conditional->frontalAssignments()) {
|
|
||||||
auto childNode = current.expand(*conditional, fa);
|
|
||||||
if (childNode.nextConditional >= 0)
|
|
||||||
childNode.bound = childNode.error + costToGo_[childNode.nextConditional];
|
|
||||||
|
|
||||||
// Again, prune if we cannot beat the worst solution
|
|
||||||
if (!solutions_.prune(childNode.bound)) {
|
|
||||||
expansions_.emplace(childNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -130,17 +130,15 @@ class Solutions {
|
||||||
*/
|
*/
|
||||||
class DiscreteSearch {
|
class DiscreteSearch {
|
||||||
public:
|
public:
|
||||||
size_t numExpansions = 0;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesNet and K.
|
* Construct from a DiscreteBayesNet and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesNet& bayesNet, size_t K = 1);
|
DiscreteSearch(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from a DiscreteBayesTree and K.
|
* Construct from a DiscreteBayesTree and K.
|
||||||
*/
|
*/
|
||||||
DiscreteSearch(const DiscreteBayesTree& bayesTree, size_t K = 1);
|
DiscreteSearch(const DiscreteBayesTree& bayesTree);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Search for the K best solutions.
|
* @brief Search for the K best solutions.
|
||||||
|
@ -152,29 +150,17 @@ class DiscreteSearch {
|
||||||
*
|
*
|
||||||
* @return A vector of the K best solutions found during the search.
|
* @return A vector of the K best solutions found during the search.
|
||||||
*/
|
*/
|
||||||
std::vector<Solution> run();
|
std::vector<Solution> run(size_t K = 1) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Initialize the search with the given conditionals.
|
|
||||||
void initialize(
|
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
|
|
||||||
conditionals_ = conditionals;
|
|
||||||
costToGo_ = computeCostToGo(conditionals_);
|
|
||||||
expansions_.push(SearchNode::Root(
|
|
||||||
conditionals_.size(), costToGo_.empty() ? 0.0 : costToGo_.back()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compute the cumulative cost-to-go for each conditional slot.
|
/// Compute the cumulative cost-to-go for each conditional slot.
|
||||||
static std::vector<double> computeCostToGo(
|
static std::vector<double> computeCostToGo(
|
||||||
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
|
||||||
|
|
||||||
/// Expand the next node in the search tree.
|
/// Expand the next node in the search tree.
|
||||||
void expandNextNode();
|
void expandNextNode() const;
|
||||||
|
|
||||||
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
std::vector<DiscreteConditional::shared_ptr> conditionals_;
|
||||||
std::vector<double> costToGo_;
|
std::vector<double> costToGo_;
|
||||||
std::priority_queue<SearchNode, std::vector<SearchNode>, SearchNode::Compare>
|
|
||||||
expansions_;
|
|
||||||
Solutions solutions_;
|
|
||||||
};
|
};
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -35,8 +35,8 @@ using namespace gtsam;
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, EmptyKBest) {
|
TEST(DiscreteBayesNet, EmptyKBest) {
|
||||||
DiscreteBayesNet net; // no factors
|
DiscreteBayesNet net; // no factors
|
||||||
DiscreteSearch search(net, 3);
|
DiscreteSearch search(net);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run(3);
|
||||||
// Expect one solution with empty assignment, error=0
|
// Expect one solution with empty assignment, error=0
|
||||||
EXPECT_LONGS_EQUAL(1, solutions.size());
|
EXPECT_LONGS_EQUAL(1, solutions.size());
|
||||||
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
|
||||||
|
@ -46,23 +46,17 @@ TEST(DiscreteBayesNet, EmptyKBest) {
|
||||||
TEST(DiscreteBayesNet, AsiaKBest) {
|
TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
using namespace asia_example;
|
using namespace asia_example;
|
||||||
DiscreteBayesNet asia = createAsiaExample();
|
DiscreteBayesNet asia = createAsiaExample();
|
||||||
|
DiscreteSearch search(asia);
|
||||||
|
|
||||||
// Ask for the MPE
|
// Ask for the MPE
|
||||||
DiscreteSearch search1(asia);
|
auto mpe = search.run();
|
||||||
auto mpe = search1.run();
|
|
||||||
|
|
||||||
// print numExpansions
|
|
||||||
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(1, mpe.size());
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
// Regression test: check the MPE solution
|
// Regression test: check the MPE solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
DiscreteSearch search(asia, 4);
|
// Ask for top 4 solutions
|
||||||
auto solutions = search.run();
|
auto solutions = search.run(4);
|
||||||
|
|
||||||
// print numExpansions
|
|
||||||
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(4, solutions.size());
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
// Regression test: check the first and last solution
|
// Regression test: check the first and last solution
|
||||||
|
@ -74,8 +68,8 @@ TEST(DiscreteBayesNet, AsiaKBest) {
|
||||||
TEST(DiscreteBayesTree, EmptyTree) {
|
TEST(DiscreteBayesTree, EmptyTree) {
|
||||||
DiscreteBayesTree bt;
|
DiscreteBayesTree bt;
|
||||||
|
|
||||||
DiscreteSearch search(bt, 3);
|
DiscreteSearch search(bt);
|
||||||
auto solutions = search.run();
|
auto solutions = search.run(3);
|
||||||
|
|
||||||
// We expect exactly 1 solution with error = 0.0 (the empty assignment).
|
// We expect exactly 1 solution with error = 0.0 (the empty assignment).
|
||||||
assert(solutions.size() == 1 && "There should be exactly one empty solution");
|
assert(solutions.size() == 1 && "There should be exactly one empty solution");
|
||||||
|
@ -89,24 +83,17 @@ TEST(DiscreteBayesTree, AsiaTreeKBest) {
|
||||||
DiscreteFactorGraph asia(createAsiaExample());
|
DiscreteFactorGraph asia(createAsiaExample());
|
||||||
const Ordering ordering{D, X, B, E, L, T, S, A};
|
const Ordering ordering{D, X, B, E, L, T, S, A};
|
||||||
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
DiscreteBayesTree::shared_ptr bt = asia.eliminateMultifrontal(ordering);
|
||||||
|
DiscreteSearch search(*bt);
|
||||||
|
|
||||||
// Ask for top 4 solutions
|
// Ask for MPE
|
||||||
DiscreteSearch search1(*bt);
|
auto mpe = search.run();
|
||||||
auto mpe = search1.run();
|
|
||||||
|
|
||||||
// print numExpansions
|
|
||||||
std::cout << "Number of expansions: " << search1.numExpansions << std::endl;
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(1, mpe.size());
|
EXPECT_LONGS_EQUAL(1, mpe.size());
|
||||||
// Regression test: check the MPE solution
|
// Regression test: check the MPE solution
|
||||||
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
|
||||||
|
|
||||||
// Ask for top 4 solutions
|
// Ask for top 4 solutions
|
||||||
DiscreteSearch search(*bt, 4);
|
auto solutions = search.run(4);
|
||||||
auto solutions = search.run();
|
|
||||||
|
|
||||||
// print numExpansions
|
|
||||||
std::cout << "Number of expansions: " << search.numExpansions << std::endl;
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(4, solutions.size());
|
EXPECT_LONGS_EQUAL(4, solutions.size());
|
||||||
// Regression test: check the first and last solution
|
// Regression test: check the first and last solution
|
||||||
|
|
Loading…
Reference in New Issue