diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index e05bfd669..d54bf5518 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace gtsam { @@ -58,6 +59,16 @@ namespace gtsam { DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); + /** + * Combine several conditional into a single one. + * The conditionals must be given in increasing order, meaning that the parents + * of any conditional may not include a conditional coming before it. + * @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr. + * @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr. + * */ + template + static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional); + /// @} /// @name Testable /// @{ @@ -123,5 +134,19 @@ namespace gtsam { }; // DiscreteConditional + /* ************************************************************************* */ + template + DiscreteConditional::shared_ptr DiscreteConditional::Combine( + ITERATOR firstConditional, ITERATOR lastConditional) { + // TODO: check for being a clique + DecisionTreeFactor product; + for(ITERATOR it = firstConditional; it != lastConditional; ++it) { + DiscreteConditional::shared_ptr c = *it; + DecisionTreeFactor::shared_ptr factor = c->toFactor(); + product = (*factor) * product; + } + return boost::make_shared(1,product); + } + }// gtsam diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 8af23e4f8..04bc47f60 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -29,7 +29,7 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditionalTest, constructors) +TEST( DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! @@ -48,7 +48,7 @@ TEST( DiscreteConditionalTest, constructors) } /* ************************************************************************* */ -TEST( DiscreteConditionalTest, constructors_alt_interface) +TEST( DiscreteConditional, constructors_alt_interface) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! @@ -71,7 +71,7 @@ TEST( DiscreteConditionalTest, constructors_alt_interface) } /* ************************************************************************* */ -TEST( DiscreteConditionalTest, constructors2) +TEST( DiscreteConditional, constructors2) { // Declare keys and ordering DiscreteKey C(0,2), B(1,2); @@ -83,15 +83,28 @@ TEST( DiscreteConditionalTest, constructors2) } /* ************************************************************************* */ -TEST( DiscreteConditionalTest, constructors3) +TEST( DiscreteConditional, constructors3) { - // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); - DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); - Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional actual(signature); - DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); - EXPECT(assert_equal(expected, *actualFactor)); + // Declare keys and ordering + DiscreteKey C(0,2), B(1,2), A(2,2); + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); + DiscreteConditional actual(signature); + DecisionTreeFactor::shared_ptr actualFactor = actual.toFactor(); + EXPECT(assert_equal(expected, *actualFactor)); +} + +/* ************************************************************************* */ +TEST( DiscreteConditional, Combine) +{ + DiscreteKey A(0,2), B(1,2); + vector c; + c.push_back(boost::make_shared(A | B = "1/2 2/1")); + c.push_back(boost::make_shared(B % "1/2")); + DiscreteConditional::shared_ptr actual = DiscreteConditional::Combine(c.begin(), c.end()); + DecisionTreeFactor::shared_ptr actualFactor = actual->toFactor(); + DecisionTreeFactor expected(A & B, "0.333333 0.666667 0.666667 0.333333"); + EXPECT(assert_equal(expected, *actualFactor,1e-5)); } /* ************************************************************************* */