first converged loopy belief test

release/4.3a0
Duy-Nguyen Ta 2013-10-12 20:06:02 +00:00
parent 9ad033fc45
commit 939f694b33
1 changed files with 100 additions and 31 deletions

View File

@ -32,12 +32,25 @@ class LoopyBelief {
typedef std::map<Key, size_t> CorrectedBeliefIndices; typedef std::map<Key, size_t> CorrectedBeliefIndices;
struct StarGraph { struct StarGraph {
DiscreteFactorGraph::shared_ptr star; DiscreteFactorGraph::shared_ptr star;
DecisionTreeFactor::shared_ptr unary;
CorrectedBeliefIndices correctedBeliefIndices; CorrectedBeliefIndices correctedBeliefIndices;
DecisionTreeFactor::shared_ptr unary;
VariableIndex varIndex_;
StarGraph(const DiscreteFactorGraph::shared_ptr& _star, StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
const DecisionTreeFactor::shared_ptr& _unary, const CorrectedBeliefIndices& _beliefIndices,
const CorrectedBeliefIndices& _beliefIndices) : const DecisionTreeFactor::shared_ptr& _unary) :
star(_star), unary(_unary), correctedBeliefIndices(_beliefIndices) { star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
*_star) {
}
void print(const std::string& s = "") const {
cout << s << ":" << endl;
star->print("Star graph: ");
BOOST_FOREACH(Key key, correctedBeliefIndices | boost::adaptors::map_keys) {
cout << "Belief factor index for " << key << ": "
<< correctedBeliefIndices.at(key) << endl;
}
if (unary)
unary->print("Unary: ");
} }
}; };
@ -51,48 +64,96 @@ public:
*/ */
LoopyBelief(const DiscreteFactorGraph& graph, LoopyBelief(const DiscreteFactorGraph& graph,
const std::map<Key, DiscreteKey>& allDiscreteKeys) : const std::map<Key, DiscreteKey>& allDiscreteKeys) :
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
}
/// print
void print(const std::string& s = "") const {
cout << s << ":" << endl;
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
starGraphs_.at(key).print((boost::format("Node %d:") % key).str());
}
} }
/// One step of belief propagation /// One step of belief propagation
DiscreteFactorGraph::shared_ptr iterate() { DiscreteFactorGraph::shared_ptr iterate(
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
// Eliminate each star graph // Eliminate each star graph
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) { BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
cout << "***** Node " << key << "*****" << endl;
// initialize belief to the unary factor from the original graph // initialize belief to the unary factor from the original graph
DecisionTreeFactor beliefAtKey = *starGraphs_.at(key).unary; DecisionTreeFactor::shared_ptr beliefAtKey;
// keep intermediate messages to divide later // keep intermediate messages to divide later
std::map<Key, DiscreteFactor::shared_ptr> messages; std::map<Key, DiscreteFactor::shared_ptr> messages;
// eliminate each neighbor in this star graph one by one // eliminate each neighbor in this star graph one by one
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DiscreteFactor::shared_ptr factor; DiscreteFactorGraph subGraph;
boost::tie(dummyCond, factor) = EliminateDiscrete( BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
*starGraphs_.at(key).star, Ordering(list_of(neighbor))); subGraph.push_back(starGraphs_.at(key).star->at(factor));
}
subGraph.print("------- Subgraph:");
DiscreteFactor::shared_ptr message;
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
Ordering(list_of(neighbor)));
// store the new factor into messages // store the new factor into messages
messages.insert(make_pair(neighbor, factor)); messages.insert(make_pair(neighbor, message));
message->print("------- Message: ");
// Belief is the product of all messages and the unary factor // Belief is the product of all messages and the unary factor
// Incorporate new the factor to belief // Incorporate new the factor to belief
beliefAtKey = beliefAtKey if (!beliefAtKey)
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)); beliefAtKey = boost::dynamic_pointer_cast<DecisionTreeFactor>(
message);
else
beliefAtKey =
make_shared<DecisionTreeFactor>(
(*beliefAtKey)
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
message)));
} }
if (starGraphs_.at(key).unary)
beliefAtKey = make_shared<DecisionTreeFactor>(
(*beliefAtKey) * (*starGraphs_.at(key).unary));
beliefAtKey->print("New belief at key: ");
// normalize belief
double sum = 0.0;
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) {
DiscreteFactor::Values val;
val[key] = v;
sum += (*beliefAtKey)(val);
}
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str());
denomFactor.print("denomFactor: ");
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor);
beliefAtKey->print("New belief at key normalized: ");
beliefs->push_back(beliefAtKey); beliefs->push_back(beliefAtKey);
allMessages[key] = messages;
}
// Update the corrected belief for the neighbor's stargraph // Update corrected beliefs
VariableIndex beliefFactors(*beliefs);
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
DecisionTreeFactor correctedBelief = beliefAtKey DecisionTreeFactor
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>( correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
messages.at(neighbor))); / (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( messages.at(neighbor)));
key); correctedBelief.print("correctedBelief: ");
size_t beliefIndex =
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
starGraphs_.at(neighbor).star->replace(beliefIndex, starGraphs_.at(neighbor).star->replace(beliefIndex,
boost::make_shared<DecisionTreeFactor>(correctedBelief)); boost::make_shared<DecisionTreeFactor>(correctedBelief));
} }
} }
print("After update: ");
return beliefs; return beliefs;
} }
@ -106,7 +167,7 @@ private:
VariableIndex varIndex(graph); ///< access to all factors of each node VariableIndex varIndex(graph); ///< access to all factors of each node
BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) { BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) {
// initialize to multiply with other unary factors later // initialize to multiply with other unary factors later
DecisionTreeFactor prodOfUnaries(allDiscreteKeys.at(key), "1 1"); DecisionTreeFactor::shared_ptr prodOfUnaries;
// collect all factors involving this key in the original graph // collect all factors involving this key in the original graph
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
@ -115,9 +176,14 @@ private:
// accumulate unary factors // accumulate unary factors
if (graph.at(factorIdx)->size() == 1) { if (graph.at(factorIdx)->size() == 1) {
prodOfUnaries = prodOfUnaries if (!prodOfUnaries)
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>( prodOfUnaries = boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph.at(factorIdx))); graph.at(factorIdx));
else
prodOfUnaries = make_shared<DecisionTreeFactor>(
*prodOfUnaries
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
graph.at(factorIdx))));
} }
} }
@ -129,13 +195,13 @@ private:
BOOST_FOREACH(Key neighbor, neighbors) { BOOST_FOREACH(Key neighbor, neighbors) {
// TODO: default table for keys with more than 2 values? // TODO: default table for keys with more than 2 values?
star->push_back( star->push_back(
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0")); DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0"));
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); correctedBeliefIndices.insert(
make_pair(neighbor, star->size() - 1));
} }
starGraphs.insert( starGraphs.insert(
make_pair(key, make_pair(key,
StarGraph(star, make_shared<DecisionTreeFactor>(prodOfUnaries), StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
correctedBeliefIndices)));
} }
return starGraphs; return starGraphs;
} }
@ -155,21 +221,24 @@ TEST_UNSAFE(LoopyBelief, construction) {
DecisionTreeFactor pC(C, "0.5 0.5"); DecisionTreeFactor pC(C, "0.5 0.5");
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1"); DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8"); DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
DiscreteConditional pWSR((W | S, R) = "1.0/0.0 0.1/0.9 0.1/0.9 0.01/0.99"); DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99");
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.push_back(pC); graph.push_back(pC);
graph.push_back(pSC); graph.push_back(pSC);
graph.push_back(pRC); graph.push_back(pRC);
graph.push_back(pWSR); graph.push_back(pSR);
graph.print("graph: "); graph.print("graph: ");
LoopyBelief solver(graph, allKeys); LoopyBelief solver(graph, allKeys);
solver.print("Loopy belief: ");
// Main loop // Main loop
for (size_t iter = 0; iter < 10; ++iter) { for (size_t iter = 0; iter < 20; ++iter) {
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(); cout << "==================================" << endl;
cout << "iteration: " << iter << endl; cout << "iteration: " << iter << endl;
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
beliefs->print(); beliefs->print();
} }