implemented Combine to we can create a BayesTree from a DiscreteBayesNet

release/4.3a0
Frank Dellaert 2012-09-15 11:11:57 +00:00
parent b95210a5f0
commit b819b7c446
2 changed files with 49 additions and 11 deletions

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/IndexConditional.h>
#include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp>
namespace gtsam {
@ -58,6 +59,16 @@ namespace gtsam {
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
/**
* Combine several conditional into a single one.
* The conditionals must be given in increasing order, meaning that the parents
* of any conditional may not include a conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
/// @}
/// @name Testable
/// @{
@ -123,5 +134,19 @@ namespace gtsam {
};
// DiscreteConditional
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
DecisionTreeFactor product;
for(ITERATOR it = firstConditional; it != lastConditional; ++it) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
return boost::make_shared<DiscreteConditional>(1,product);
}
}// gtsam

View File

@ -29,7 +29,7 @@ using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST( DiscreteConditionalTest, constructors)
TEST( DiscreteConditional, constructors)
{
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
@ -48,7 +48,7 @@ TEST( DiscreteConditionalTest, constructors)
}
/* ************************************************************************* */
TEST( DiscreteConditionalTest, constructors_alt_interface)
TEST( DiscreteConditional, constructors_alt_interface)
{
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
@ -71,7 +71,7 @@ TEST( DiscreteConditionalTest, constructors_alt_interface)
}
/* ************************************************************************* */
TEST( DiscreteConditionalTest, constructors2)
TEST( DiscreteConditional, constructors2)
{
// Declare keys and ordering
DiscreteKey C(0,2), B(1,2);
@ -83,15 +83,28 @@ TEST( DiscreteConditionalTest, constructors2)
}
/* ************************************************************************* */
TEST( DiscreteConditionalTest, constructors3)
TEST( DiscreteConditional, constructors3)
{
// Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2);
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
DiscreteConditional actual(signature);
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
EXPECT(assert_equal(expected, *actualFactor));
// Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2);
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
DiscreteConditional actual(signature);
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
EXPECT(assert_equal(expected, *actualFactor));
}
/* ************************************************************************* */
TEST( DiscreteConditional, Combine)
{
DiscreteKey A(0,2), B(1,2);
vector<DiscreteConditional::shared_ptr> c;
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine(c.begin(), c.end());
DecisionTreeFactor::shared_ptr actualFactor = actual->toFactor();
DecisionTreeFactor expected(A & B, "0.333333 0.666667 0.666667 0.333333");
EXPECT(assert_equal(expected, *actualFactor,1e-5));
}
/* ************************************************************************* */