implemented Combine to we can create a BayesTree from a DiscreteBayesNet
parent
b95210a5f0
commit
b819b7c446
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/IndexConditional.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
|
@ -58,6 +59,16 @@ namespace gtsam {
|
|||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
||||
/**
|
||||
* Combine several conditional into a single one.
|
||||
* The conditionals must be given in increasing order, meaning that the parents
|
||||
* of any conditional may not include a conditional coming before it.
|
||||
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
||||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
||||
* */
|
||||
template<typename ITERATOR>
|
||||
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
@ -123,5 +134,19 @@ namespace gtsam {
|
|||
};
|
||||
// DiscreteConditional
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<typename ITERATOR>
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||
// TODO: check for being a clique
|
||||
DecisionTreeFactor product;
|
||||
for(ITERATOR it = firstConditional; it != lastConditional; ++it) {
|
||||
DiscreteConditional::shared_ptr c = *it;
|
||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
||||
product = (*factor) * product;
|
||||
}
|
||||
return boost::make_shared<DiscreteConditional>(1,product);
|
||||
}
|
||||
|
||||
}// gtsam
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ using namespace std;
|
|||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditionalTest, constructors)
|
||||
TEST( DiscreteConditional, constructors)
|
||||
{
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ TEST( DiscreteConditionalTest, constructors)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditionalTest, constructors_alt_interface)
|
||||
TEST( DiscreteConditional, constructors_alt_interface)
|
||||
{
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ TEST( DiscreteConditionalTest, constructors_alt_interface)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditionalTest, constructors2)
|
||||
TEST( DiscreteConditional, constructors2)
|
||||
{
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0,2), B(1,2);
|
||||
|
|
@ -83,15 +83,28 @@ TEST( DiscreteConditionalTest, constructors2)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditionalTest, constructors3)
|
||||
TEST( DiscreteConditional, constructors3)
|
||||
{
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
||||
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||
DiscreteConditional actual(signature);
|
||||
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
|
||||
EXPECT(assert_equal(expected, *actualFactor));
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
||||
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||
DiscreteConditional actual(signature);
|
||||
DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor();
|
||||
EXPECT(assert_equal(expected, *actualFactor));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteConditional, Combine)
|
||||
{
|
||||
DiscreteKey A(0,2), B(1,2);
|
||||
vector<DiscreteConditional::shared_ptr> c;
|
||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
||||
DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine(c.begin(), c.end());
|
||||
DecisionTreeFactor::shared_ptr actualFactor = actual->toFactor();
|
||||
DecisionTreeFactor expected(A & B, "0.333333 0.666667 0.666667 0.333333");
|
||||
EXPECT(assert_equal(expected, *actualFactor,1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
Loading…
Reference in New Issue