first converged loopy belief test
parent
9ad033fc45
commit
939f694b33
|
@ -32,12 +32,25 @@ class LoopyBelief {
|
||||||
typedef std::map<Key, size_t> CorrectedBeliefIndices;
|
typedef std::map<Key, size_t> CorrectedBeliefIndices;
|
||||||
struct StarGraph {
|
struct StarGraph {
|
||||||
DiscreteFactorGraph::shared_ptr star;
|
DiscreteFactorGraph::shared_ptr star;
|
||||||
DecisionTreeFactor::shared_ptr unary;
|
|
||||||
CorrectedBeliefIndices correctedBeliefIndices;
|
CorrectedBeliefIndices correctedBeliefIndices;
|
||||||
|
DecisionTreeFactor::shared_ptr unary;
|
||||||
|
VariableIndex varIndex_;
|
||||||
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
StarGraph(const DiscreteFactorGraph::shared_ptr& _star,
|
||||||
const DecisionTreeFactor::shared_ptr& _unary,
|
const CorrectedBeliefIndices& _beliefIndices,
|
||||||
const CorrectedBeliefIndices& _beliefIndices) :
|
const DecisionTreeFactor::shared_ptr& _unary) :
|
||||||
star(_star), unary(_unary), correctedBeliefIndices(_beliefIndices) {
|
star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_(
|
||||||
|
*_star) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(const std::string& s = "") const {
|
||||||
|
cout << s << ":" << endl;
|
||||||
|
star->print("Star graph: ");
|
||||||
|
BOOST_FOREACH(Key key, correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
|
cout << "Belief factor index for " << key << ": "
|
||||||
|
<< correctedBeliefIndices.at(key) << endl;
|
||||||
|
}
|
||||||
|
if (unary)
|
||||||
|
unary->print("Unary: ");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,48 +64,96 @@ public:
|
||||||
*/
|
*/
|
||||||
LoopyBelief(const DiscreteFactorGraph& graph,
|
LoopyBelief(const DiscreteFactorGraph& graph,
|
||||||
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) :
|
||||||
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/// print
|
||||||
|
void print(const std::string& s = "") const {
|
||||||
|
cout << s << ":" << endl;
|
||||||
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
|
starGraphs_.at(key).print((boost::format("Node %d:") % key).str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// One step of belief propagation
|
/// One step of belief propagation
|
||||||
DiscreteFactorGraph::shared_ptr iterate() {
|
DiscreteFactorGraph::shared_ptr iterate(
|
||||||
|
const std::map<Key, DiscreteKey>& allDiscreteKeys) {
|
||||||
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination
|
||||||
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph());
|
||||||
|
std::map<Key, std::map<Key, DiscreteFactor::shared_ptr> > allMessages;
|
||||||
// Eliminate each star graph
|
// Eliminate each star graph
|
||||||
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
|
cout << "***** Node " << key << "*****" << endl;
|
||||||
// initialize belief to the unary factor from the original graph
|
// initialize belief to the unary factor from the original graph
|
||||||
DecisionTreeFactor beliefAtKey = *starGraphs_.at(key).unary;
|
DecisionTreeFactor::shared_ptr beliefAtKey;
|
||||||
|
|
||||||
// keep intermediate messages to divide later
|
// keep intermediate messages to divide later
|
||||||
std::map<Key, DiscreteFactor::shared_ptr> messages;
|
std::map<Key, DiscreteFactor::shared_ptr> messages;
|
||||||
|
|
||||||
// eliminate each neighbor in this star graph one by one
|
// eliminate each neighbor in this star graph one by one
|
||||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
DiscreteFactor::shared_ptr factor;
|
DiscreteFactorGraph subGraph;
|
||||||
boost::tie(dummyCond, factor) = EliminateDiscrete(
|
BOOST_FOREACH(size_t factor, starGraphs_.at(key).varIndex_[neighbor]) {
|
||||||
*starGraphs_.at(key).star, Ordering(list_of(neighbor)));
|
subGraph.push_back(starGraphs_.at(key).star->at(factor));
|
||||||
|
}
|
||||||
|
subGraph.print("------- Subgraph:");
|
||||||
|
DiscreteFactor::shared_ptr message;
|
||||||
|
boost::tie(dummyCond, message) = EliminateDiscrete(subGraph,
|
||||||
|
Ordering(list_of(neighbor)));
|
||||||
// store the new factor into messages
|
// store the new factor into messages
|
||||||
messages.insert(make_pair(neighbor, factor));
|
messages.insert(make_pair(neighbor, message));
|
||||||
|
message->print("------- Message: ");
|
||||||
|
|
||||||
// Belief is the product of all messages and the unary factor
|
// Belief is the product of all messages and the unary factor
|
||||||
// Incorporate new the factor to belief
|
// Incorporate new the factor to belief
|
||||||
beliefAtKey = beliefAtKey
|
if (!beliefAtKey)
|
||||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(factor));
|
beliefAtKey = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
|
message);
|
||||||
|
else
|
||||||
|
beliefAtKey =
|
||||||
|
make_shared<DecisionTreeFactor>(
|
||||||
|
(*beliefAtKey)
|
||||||
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
|
message)));
|
||||||
}
|
}
|
||||||
|
if (starGraphs_.at(key).unary)
|
||||||
|
beliefAtKey = make_shared<DecisionTreeFactor>(
|
||||||
|
(*beliefAtKey) * (*starGraphs_.at(key).unary));
|
||||||
|
beliefAtKey->print("New belief at key: ");
|
||||||
|
// normalize belief
|
||||||
|
double sum = 0.0;
|
||||||
|
for (size_t v = 0; v<allDiscreteKeys.at(key).second; ++v) {
|
||||||
|
DiscreteFactor::Values val;
|
||||||
|
val[key] = v;
|
||||||
|
sum += (*beliefAtKey)(val);
|
||||||
|
}
|
||||||
|
DecisionTreeFactor denomFactor(allDiscreteKeys.at(key), (boost::format("%f %f")%sum%sum).str());
|
||||||
|
denomFactor.print("denomFactor: ");
|
||||||
|
beliefAtKey = make_shared<DecisionTreeFactor>((*beliefAtKey)/denomFactor);
|
||||||
|
beliefAtKey->print("New belief at key normalized: ");
|
||||||
beliefs->push_back(beliefAtKey);
|
beliefs->push_back(beliefAtKey);
|
||||||
|
allMessages[key] = messages;
|
||||||
|
}
|
||||||
|
|
||||||
// Update the corrected belief for the neighbor's stargraph
|
// Update corrected beliefs
|
||||||
|
VariableIndex beliefFactors(*beliefs);
|
||||||
|
BOOST_FOREACH(Key key, starGraphs_ | boost::adaptors::map_keys) {
|
||||||
|
std::map<Key, DiscreteFactor::shared_ptr> messages = allMessages[key];
|
||||||
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key neighbor, starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) {
|
||||||
DecisionTreeFactor correctedBelief = beliefAtKey
|
DecisionTreeFactor
|
||||||
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
correctedBelief = (*boost::dynamic_pointer_cast<DecisionTreeFactor>(beliefs->at(beliefFactors[key].front())))
|
||||||
messages.at(neighbor)));
|
/ (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at(
|
messages.at(neighbor)));
|
||||||
key);
|
correctedBelief.print("correctedBelief: ");
|
||||||
|
size_t beliefIndex =
|
||||||
|
starGraphs_.at(neighbor).correctedBeliefIndices.at(key);
|
||||||
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
starGraphs_.at(neighbor).star->replace(beliefIndex,
|
||||||
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
boost::make_shared<DecisionTreeFactor>(correctedBelief));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print("After update: ");
|
||||||
|
|
||||||
return beliefs;
|
return beliefs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +167,7 @@ private:
|
||||||
VariableIndex varIndex(graph); ///< access to all factors of each node
|
VariableIndex varIndex(graph); ///< access to all factors of each node
|
||||||
BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) {
|
BOOST_FOREACH(Key key, varIndex | boost::adaptors::map_keys) {
|
||||||
// initialize to multiply with other unary factors later
|
// initialize to multiply with other unary factors later
|
||||||
DecisionTreeFactor prodOfUnaries(allDiscreteKeys.at(key), "1 1");
|
DecisionTreeFactor::shared_ptr prodOfUnaries;
|
||||||
|
|
||||||
// collect all factors involving this key in the original graph
|
// collect all factors involving this key in the original graph
|
||||||
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
|
DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph());
|
||||||
|
@ -115,9 +176,14 @@ private:
|
||||||
|
|
||||||
// accumulate unary factors
|
// accumulate unary factors
|
||||||
if (graph.at(factorIdx)->size() == 1) {
|
if (graph.at(factorIdx)->size() == 1) {
|
||||||
prodOfUnaries = prodOfUnaries
|
if (!prodOfUnaries)
|
||||||
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
prodOfUnaries = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
graph.at(factorIdx)));
|
graph.at(factorIdx));
|
||||||
|
else
|
||||||
|
prodOfUnaries = make_shared<DecisionTreeFactor>(
|
||||||
|
*prodOfUnaries
|
||||||
|
* (*boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||||
|
graph.at(factorIdx))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,13 +195,13 @@ private:
|
||||||
BOOST_FOREACH(Key neighbor, neighbors) {
|
BOOST_FOREACH(Key neighbor, neighbors) {
|
||||||
// TODO: default table for keys with more than 2 values?
|
// TODO: default table for keys with more than 2 values?
|
||||||
star->push_back(
|
star->push_back(
|
||||||
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "1.0 0.0"));
|
DecisionTreeFactor(allDiscreteKeys.at(neighbor), "0.0 1.0"));
|
||||||
correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1));
|
correctedBeliefIndices.insert(
|
||||||
|
make_pair(neighbor, star->size() - 1));
|
||||||
}
|
}
|
||||||
starGraphs.insert(
|
starGraphs.insert(
|
||||||
make_pair(key,
|
make_pair(key,
|
||||||
StarGraph(star, make_shared<DecisionTreeFactor>(prodOfUnaries),
|
StarGraph(star, correctedBeliefIndices, prodOfUnaries)));
|
||||||
correctedBeliefIndices)));
|
|
||||||
}
|
}
|
||||||
return starGraphs;
|
return starGraphs;
|
||||||
}
|
}
|
||||||
|
@ -155,21 +221,24 @@ TEST_UNSAFE(LoopyBelief, construction) {
|
||||||
DecisionTreeFactor pC(C, "0.5 0.5");
|
DecisionTreeFactor pC(C, "0.5 0.5");
|
||||||
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
|
DiscreteConditional pSC(S | C = "0.5/0.5 0.9/0.1");
|
||||||
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
|
DiscreteConditional pRC(R | C = "0.8/0.2 0.2/0.8");
|
||||||
DiscreteConditional pWSR((W | S, R) = "1.0/0.0 0.1/0.9 0.1/0.9 0.01/0.99");
|
DecisionTreeFactor pSR( S & R, "0.0 0.9 0.9 0.99");
|
||||||
|
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.push_back(pC);
|
graph.push_back(pC);
|
||||||
graph.push_back(pSC);
|
graph.push_back(pSC);
|
||||||
graph.push_back(pRC);
|
graph.push_back(pRC);
|
||||||
graph.push_back(pWSR);
|
graph.push_back(pSR);
|
||||||
|
|
||||||
graph.print("graph: ");
|
graph.print("graph: ");
|
||||||
|
|
||||||
LoopyBelief solver(graph, allKeys);
|
LoopyBelief solver(graph, allKeys);
|
||||||
|
solver.print("Loopy belief: ");
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (size_t iter = 0; iter < 10; ++iter) {
|
for (size_t iter = 0; iter < 20; ++iter) {
|
||||||
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate();
|
cout << "==================================" << endl;
|
||||||
cout << "iteration: " << iter << endl;
|
cout << "iteration: " << iter << endl;
|
||||||
|
DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys);
|
||||||
beliefs->print();
|
beliefs->print();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue