Better documentation and new test

release/4.3a0
Frank Dellaert 2023-06-04 16:04:18 +01:00
parent febeacd686
commit 6cb0fa7cd7
2 changed files with 55 additions and 11 deletions

View File

@ -63,11 +63,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */ /** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest. and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table); const std::vector<double>& table);
/** Constructor from string */ /**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest. and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization /// Single-key specialization

View File

@ -27,6 +27,18 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */
TEST(DecisionTreeFactor, constructors_match) {
// Declare two keys
DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DecisionTreeFactor, constructors) TEST( DecisionTreeFactor, constructors)
{ {
@ -41,16 +53,13 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3,f3.size());
DiscreteValues values; DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
values[0] = 1; // x EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
values[1] = 2; // y EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
values[2] = 1; // z EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */