From 6cb0fa7cd7f9d981b843945e8f71183d611412aa Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 4 Jun 2023 16:04:18 +0100 Subject: [PATCH] Better documentation and new test --- gtsam/discrete/DecisionTreeFactor.h | 41 +++++++++++++++++-- .../discrete/tests/testDecisionTreeFactor.cpp | 25 +++++++---- 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index dd292cae8..891030f49 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -63,11 +63,46 @@ namespace gtsam { /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from doubles */ + /** + * @brief Constructor from doubles + * + * @param keys The discrete keys. + * @param table The table of values. + * + * @throw std::invalid_argument if the size of `table` does not match the + * number of assignments. + * + * Example: + * @code{.cpp} + * DiscreteKey X(0,2), Y(1,3); + * const std::vector table {2, 5, 3, 6, 4, 7}; + * DecisionTreeFactor f1({X, Y}, table); + * @endcode + * + * The values in the table should be laid out so that the first key varies + * the slowest. and the last key the fastest. + */ DecisionTreeFactor(const DiscreteKeys& keys, - const std::vector& table); + const std::vector& table); - /** Constructor from string */ + /** + * @brief Constructor from string + * + * @param keys The discrete keys. + * @param table The table of values. + * + * @throw std::invalid_argument if the size of `table` does not match the + * number of assignments. + * + * Example: + * @code{.cpp} + * DiscreteKey X(0,2), Y(1,3); + * DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7"); + * @endcode + * + * The values in the table should be laid out so that the first key varies + * the slowest. and the last key the fastest. + */ DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); /// Single-key specialization diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 3dbb3e64f..8e203c56a 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -27,6 +27,18 @@ using namespace std; using namespace gtsam; +/* ************************************************************************* */ +TEST(DecisionTreeFactor, constructors_match) { + // Declare two keys + DiscreteKey X(0, 2), Y(1, 3); + + // Create with vector and with string + const std::vector table {2, 5, 3, 6, 4, 7}; + DecisionTreeFactor f1({X, Y}, table); + DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7"); + EXPECT(assert_equal(f1, f2)); +} + /* ************************************************************************* */ TEST( DecisionTreeFactor, constructors) { @@ -41,16 +53,13 @@ TEST( DecisionTreeFactor, constructors) EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(3,f3.size()); - DiscreteValues values; - values[0] = 1; // x - values[1] = 2; // y - values[2] = 1; // z - EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); - EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); - EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); + DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}}; + EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9); + EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9); + EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9); // Assert that error = -log(value) - EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); + EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9); } /* ************************************************************************* */