diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 3f5701281..6cc62d2b0 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \ -DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \ -DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install diff --git a/.github/scripts/unix.sh b/.github/scripts/unix.sh index 9689d346c..d890577b6 100644 --- a/.github/scripts/unix.sh +++ b/.github/scripts/unix.sh @@ -64,7 +64,7 @@ function configure() -DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \ -DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \ -DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \ -DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \ -DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \ diff --git a/.github/workflows/build-special.yml b/.github/workflows/build-special.yml index 647b9c0f1..d357b9a34 100644 --- a/.github/workflows/build-special.yml +++ b/.github/workflows/build-special.yml @@ -110,7 +110,7 @@ jobs: - name: Set Allow Deprecated Flag if: matrix.flag == 'deprecated' run: | - echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV + echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV echo "Allow deprecated since version 4.1" - name: Set Use Quaternions Flag diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fd5d521c..21d8d1b60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,18 @@ endif() # Set the version number for the library set (GTSAM_VERSION_MAJOR 4) -set (GTSAM_VERSION_MINOR 1) +set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) +set (GTSAM_PRERELEASE_VERSION "a1") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") -set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) +if (${GTSAM_VERSION_PATCH} EQUAL 0) + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") +else() + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") +endif() +message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") + set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) diff --git a/README.md b/README.md index 046132301..52ac0a5d8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ **Important Note** -As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features. +As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder. -However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`. +In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`. ## What is GTSAM? @@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods. -GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. +GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. ## Wrappers diff --git a/Using-GTSAM-EXPORT.md b/Using-GTSAM-EXPORT.md index cae1d499c..faeebc97f 100644 --- a/Using-GTSAM-EXPORT.md +++ b/Using-GTSAM-EXPORT.md @@ -29,7 +29,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2 ***Compiler Rule #2*** Anything declared in a header file is not included in a DLL. -When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. +When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea. diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 64c239f39..7c8f8533f 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON) option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF) option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF) -option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) +option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF) diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index ad6ac5c5c..43ee5b57b 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") -print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1") +print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") diff --git a/examples/FisheyeExample.cpp b/examples/FisheyeExample.cpp index 223149299..fc0aed0d7 100644 --- a/examples/FisheyeExample.cpp +++ b/examples/FisheyeExample.cpp @@ -122,8 +122,7 @@ int main(int argc, char *argv[]) { std::cout << "initial error=" << graph.error(initialEstimate) << std::endl; std::cout << "final error=" << graph.error(result) << std::endl; - std::ofstream os("examples/vio_batch.dot"); - graph.saveGraph(os, result); + graph.saveGraph("examples/vio_batch.dot", result); return 0; } diff --git a/examples/Pose2SLAMExample_graphviz.cpp b/examples/Pose2SLAMExample_graphviz.cpp index 27d556725..a8768e2b8 100644 --- a/examples/Pose2SLAMExample_graphviz.cpp +++ b/examples/Pose2SLAMExample_graphviz.cpp @@ -60,11 +60,10 @@ int main(int argc, char** argv) { // save factor graph as graphviz dot file // Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf" - ofstream os("Pose2SLAMExample.dot"); - graph.saveGraph(os, result); + graph.saveGraph("Pose2SLAMExample.dot", result); // Also print out to console - graph.saveGraph(cout, result); + graph.dot(cout, result); return 0; } diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index f4f3f1fd0..3829a5c91 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -50,8 +50,7 @@ int main(int argc, char** argv) { // Print the UGM distribution cout << "\nUGM distribution:" << endl; - vector allPosbValues = cartesianProduct( - Cathy & Heather & Mark & Allison); + auto allPosbValues = cartesianProduct(Cathy & Heather & Mark & Allison); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values values = allPosbValues[i]; double prodPot = graph(values); diff --git a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h index 667ef09dc..9db32744e 100644 --- a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h +++ b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h @@ -440,7 +440,7 @@ template class TriangularViewImpl<_Mat EIGEN_DEVICE_FUNC void lazyAssign(const TriangularBase& other); - /** \deprecated */ + /** @deprecated */ template EIGEN_DEVICE_FUNC void lazyAssign(const MatrixBase& other); @@ -523,7 +523,7 @@ template class TriangularViewImpl<_Mat call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op()); } - /** \deprecated + /** @deprecated * Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */ template EIGEN_DEVICE_FUNC diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 535d60eb1..a293c6ec2 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -15,7 +15,7 @@ set (gtsam_subdirs sam sfm slam - navigation + navigation ) set(gtsam_srcs) diff --git a/gtsam/base/CMakeLists.txt b/gtsam/base/CMakeLists.txt index 99984e7b3..66d3ec721 100644 --- a/gtsam/base/CMakeLists.txt +++ b/gtsam/base/CMakeLists.txt @@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base) file(GLOB base_headers_tree "treeTraversal/*.h") install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal) -file(GLOB deprecated_headers "deprecated/*.h") -install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated) - # Build tests add_subdirectory(tests) diff --git a/gtsam/base/LieMatrix.h b/gtsam/base/LieMatrix.h deleted file mode 100644 index 210bdcc73..000000000 --- a/gtsam/base/LieMatrix.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief External deprecation warning, see deprecated/LieMatrix.h for details - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.") -#else -#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead." -#endif - -#include "gtsam/base/deprecated/LieMatrix.h" diff --git a/gtsam/base/LieScalar.h b/gtsam/base/LieScalar.h deleted file mode 100644 index e159ffa87..000000000 --- a/gtsam/base/LieScalar.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieScalar.h - * @brief External deprecation warning, see deprecated/LieScalar.h for details - * @author Kai Ni - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieScalar.h is deprecated. Please use double/float instead.") -#else - #warning "LieScalar.h is deprecated. Please use double/float instead." -#endif - -#include diff --git a/gtsam/base/LieVector.h b/gtsam/base/LieVector.h deleted file mode 100644 index a7491d804..000000000 --- a/gtsam/base/LieVector.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details. - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.") -#else -#warning "LieVector.h is deprecated. Please use Eigen::Vector instead." -#endif - -#include diff --git a/gtsam/base/TestableAssertions.h b/gtsam/base/TestableAssertions.h index c86fbb6d2..e5bd34d19 100644 --- a/gtsam/base/TestableAssertions.h +++ b/gtsam/base/TestableAssertions.h @@ -80,9 +80,10 @@ bool assert_equal(const V& expected, const boost::optional& actual, do return assert_equal(expected, *actual, tol); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * Version of assert_equals to work with vectors - * \deprecated: use container equals instead + * @deprecated: use container equals instead */ template bool GTSAM_DEPRECATED assert_equal(const std::vector& expected, const std::vector& actual, double tol = 1e-9) { @@ -108,6 +109,7 @@ bool GTSAM_DEPRECATED assert_equal(const std::vector& expected, const std::ve } return true; } +#endif /** * Function for comparing maps of testable->testable diff --git a/gtsam/base/Vector.h b/gtsam/base/Vector.h index a057da46b..36dc2288d 100644 --- a/gtsam/base/Vector.h +++ b/gtsam/base/Vector.h @@ -203,15 +203,16 @@ inline double inner_prod(const V1 &a, const V2& b) { return a.dot(b); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * BLAS Level 1 scal: x <- alpha*x - * \deprecated: use operators instead + * @deprecated: use operators instead */ inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; } /** * BLAS Level 1 axpy: y <- alpha*x + y - * \deprecated: use operators instead + * @deprecated: use operators instead */ template inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) { @@ -222,6 +223,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) { assert (y.size()==x.size()); y += alpha * x; } +#endif /** * house(x,j) computes HouseHolder vector v and scaling factor beta diff --git a/gtsam/base/deprecated/LieMatrix.h b/gtsam/base/deprecated/LieMatrix.h deleted file mode 100644 index a3d0a4328..000000000 --- a/gtsam/base/deprecated/LieMatrix.h +++ /dev/null @@ -1,152 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief A wrapper around Matrix providing Lie compatibility - * @author Richard Roberts and Alex Cunningham - */ - -#pragma once - -#include - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieMatrix : public Matrix { - - /// @name Constructors - /// @{ - enum { dimension = Eigen::Dynamic }; - - /** default constructor - only for serialize */ - LieMatrix() {} - - /** initialize from a normal matrix */ - LieMatrix(const Matrix& v) : Matrix(v) {} - - template - LieMatrix(const M& v) : Matrix(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieMatrix(const Eigen::Matrix& v) : Matrix(v) {} -#endif - - /** constructor with size and initial data, row order ! */ - LieMatrix(size_t m, size_t n, const double* const data) : - Matrix(Eigen::Map(data, m, n)) {} - - /// @} - /// @name Testable interface - /// @{ - - /** print @param s optional string naming the object */ - void print(const std::string& name = "") const { - gtsam::print(matrix(), name); - } - /** equality up to tolerance */ - inline bool equals(const LieMatrix& expected, double tol=1e-5) const { - return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol); - } - - /// @} - /// @name Standard Interface - /// @{ - - /** get the underlying matrix */ - inline Matrix matrix() const { - return static_cast(*this); - } - - /// @} - - /// @name Group - /// @{ - LieMatrix compose(const LieMatrix& q) { return (*this)+q;} - LieMatrix between(const LieMatrix& q) { return q-(*this);} - LieMatrix inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieMatrix& q) { return between(q).vector();} - LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieMatrix& p) {return p.vector();} - static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** Returns dimensionality of the tangent space */ - inline size_t dim() const { return size(); } - - /** Convert to vector, is done row-wise - TODO why? */ - inline Vector vector() const { - Vector result(size()); - typedef Eigen::Matrix RowMajor; - Eigen::Map(&result(0), rows(), cols()) = *this; - return result; - } - - /** identity - NOTE: no known size at compile time - so zero length */ - inline static LieMatrix identity() { - throw std::runtime_error("LieMatrix::identity(): Don't use this function"); - return LieMatrix(); - } - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Matrix", - boost::serialization::base_object(*this)); - - } - -}; - - -template<> -struct traits : public internal::VectorSpace { - - // Override Retract, as the default version does not know how to initialize - static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v, - ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) { - if (H1) *H1 = Eye(origin); - if (H2) *H2 = Eye(origin); - typedef const Eigen::Matrix RowMajor; - return origin + Eigen::Map(&v(0), origin.rows(), origin.cols()); - } - -}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieScalar.h b/gtsam/base/deprecated/LieScalar.h deleted file mode 100644 index 6c9a5f766..000000000 --- a/gtsam/base/deprecated/LieScalar.h +++ /dev/null @@ -1,88 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieScalar.h - * @brief A wrapper around scalar providing Lie compatibility - * @author Kai Ni - */ - -#pragma once - -#include -#include -#include - -namespace gtsam { - - /** - * @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ - struct LieScalar { - - enum { dimension = 1 }; - - /** default constructor */ - LieScalar() : d_(0.0) {} - - /** wrap a double */ - /*explicit*/ LieScalar(double d) : d_(d) {} - - /** access the underlying value */ - double value() const { return d_; } - - /** Automatic conversion to underlying value */ - operator double() const { return d_; } - - /** convert vector */ - Vector1 vector() const { Vector1 v; v< - struct traits : public internal::ScalarTraits {}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieVector.h b/gtsam/base/deprecated/LieVector.h deleted file mode 100644 index 745189c3d..000000000 --- a/gtsam/base/deprecated/LieVector.h +++ /dev/null @@ -1,121 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief A wrapper around vector providing Lie compatibility - * @author Alex Cunningham - */ - -#pragma once - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieVector : public Vector { - - enum { dimension = Eigen::Dynamic }; - - /** default constructor - should be unnecessary */ - LieVector() {} - - /** initialize from a normal vector */ - LieVector(const Vector& v) : Vector(v) {} - - template - LieVector(const V& v) : Vector(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieVector(const Eigen::Matrix& v) : Vector(v) {} -#endif - - /** wrap a double */ - LieVector(double d) : Vector((Vector(1) << d).finished()) {} - - /** constructor with size and initial data, row order ! */ - LieVector(size_t m, const double* const data) : Vector(m) { - for (size_t i = 0; i < m; i++) (*this)(i) = data[i]; - } - - /// @name Testable - /// @{ - void print(const std::string& name="") const { - gtsam::print(vector(), name); - } - bool equals(const LieVector& expected, double tol=1e-5) const { - return gtsam::equal(vector(), expected.vector(), tol); - } - /// @} - - /// @name Group - /// @{ - LieVector compose(const LieVector& q) { return (*this)+q;} - LieVector between(const LieVector& q) { return q-(*this);} - LieVector inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieVector& q) { return between(q).vector();} - LieVector retract(const Vector& v) {return compose(LieVector(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieVector& p) {return p.vector();} - static LieVector Expmap(const Vector& v) { return LieVector(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** get the underlying vector */ - Vector vector() const { - return static_cast(*this); - } - - /** Returns dimensionality of the tangent space */ - size_t dim() const { return this->size(); } - - /** identity - NOTE: no known size at compile time - so zero length */ - static LieVector identity() { - throw std::runtime_error("LieVector::identity(): Don't use this function"); - return LieVector(); - } - - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Vector", - boost::serialization::base_object(*this)); - } -}; - - -template<> -struct traits : public internal::VectorSpace {}; - -} // \namespace gtsam diff --git a/gtsam/base/tests/testLieMatrix.cpp b/gtsam/base/tests/testLieMatrix.cpp deleted file mode 100644 index 8c68bf8a0..000000000 --- a/gtsam/base/tests/testLieMatrix.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieMatrix.cpp - * @author Richard Roberts - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieMatrix) -GTSAM_CONCEPT_LIE_INST(LieMatrix) - -/* ************************************************************************* */ -TEST( LieMatrix, construction ) { - Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished(); - LieMatrix lie1(m), lie2(m); - - EXPECT(traits::GetDimension(m) == 4); - EXPECT(assert_equal(m, lie1.matrix())); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( LieMatrix, other_constructors ) { - Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished(); - LieMatrix exp(init); - double data[] = {10,30,20,40}; - LieMatrix b(2,2,data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -TEST(LieMatrix, retract) { - LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished()); - Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished(); - - LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished()); - LieMatrix actual = traits::Retract(init,update); - - EXPECT(assert_equal(expected, actual)); - - Vector expectedUpdate = update; - Vector actualUpdate = traits::Local(init,actual); - - EXPECT(assert_equal(expectedUpdate, actualUpdate)); - - Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished(); - Vector actualLogmap = traits::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished())); - EXPECT(assert_equal(expectedLogmap, actualLogmap)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ - - diff --git a/gtsam/base/tests/testLieScalar.cpp b/gtsam/base/tests/testLieScalar.cpp deleted file mode 100644 index 74f5e0d41..000000000 --- a/gtsam/base/tests/testLieScalar.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieScalar.cpp - * @author Kai Ni - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieScalar) -GTSAM_CONCEPT_LIE_INST(LieScalar) - -const double tol=1e-9; - -//****************************************************************************** -TEST(LieScalar , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieScalar , Invariants) { - LieScalar lie1(2), lie2(3); - CHECK(check_group_invariants(lie1, lie2)); - CHECK(check_manifold_invariants(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, construction ) { - double d = 2.; - LieScalar lie1(d), lie2(d); - - EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol); - EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol); - EXPECT(traits::dimension == 1); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, localCoordinates ) { - LieScalar lie1(1.), lie2(3.); - - Vector1 actual = traits::Local(lie1, lie2); - EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/base/tests/testLieVector.cpp b/gtsam/base/tests/testLieVector.cpp deleted file mode 100644 index 76c4fc490..000000000 --- a/gtsam/base/tests/testLieVector.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieVector.cpp - * @author Alex Cunningham - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieVector) -GTSAM_CONCEPT_LIE_INST(LieVector) - -//****************************************************************************** -TEST(LieVector , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieVector , Invariants) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - check_manifold_invariants(lie1, lie2); -} - -//****************************************************************************** -TEST( testLieVector, construction ) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - - EXPECT(lie1.dim() == 3); - EXPECT(assert_equal(v, lie1.vector())); - EXPECT(assert_equal(lie1, lie2)); -} - -//****************************************************************************** -TEST( testLieVector, other_constructors ) { - Vector init = Vector2(10.0, 20.0); - LieVector exp(init); - double data[] = { 10, 20 }; - LieVector b(2, data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ - diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index a7c218705..7802f27e1 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -173,7 +173,7 @@ TEST(Matrix, stack ) { Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished(); Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished(); - Matrix AB = stack(2, &A, &B); + Matrix AB = gtsam::stack(2, &A, &B); Matrix C(5, 2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -187,7 +187,7 @@ TEST(Matrix, stack ) std::vector matrices; matrices.push_back(A); matrices.push_back(B); - Matrix AB2 = stack(matrices); + Matrix AB2 = gtsam::stack(matrices); EQUALITY(C,AB2); } diff --git a/gtsam/base/tests/testTestableAssertions.cpp b/gtsam/base/tests/testTestableAssertions.cpp deleted file mode 100644 index 305aa7ca9..000000000 --- a/gtsam/base/tests/testTestableAssertions.cpp +++ /dev/null @@ -1,35 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testTestableAssertions - * @author Alex Cunningham - */ - -#include -#include -#include - -using namespace gtsam; - -/* ************************************************************************* */ -TEST( testTestableAssertions, optional ) { - typedef boost::optional OptionalScalar; - LieScalar x(1.0); - OptionalScalar ox(x), dummy = boost::none; - EXPECT(assert_equal(ox, ox)); - EXPECT(assert_equal(x, ox)); - EXPECT(assert_equal(dummy, dummy)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/basis/ParameterMatrix.h b/gtsam/basis/ParameterMatrix.h index df2d9f62e..eddcbfeae 100644 --- a/gtsam/basis/ParameterMatrix.h +++ b/gtsam/basis/ParameterMatrix.h @@ -153,7 +153,7 @@ class ParameterMatrix { return matrix_ * other; } - /// @name Vector Space requirements, following LieMatrix + /// @name Vector Space requirements /// @{ /** diff --git a/gtsam/config.h.in b/gtsam/config.h.in index e7623c52b..d47329a62 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -70,7 +70,7 @@ #cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION // Make sure dependent projects that want it can see deprecated functions -#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42 // Support Metis-based nested dissection #cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..d2e05927a 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,11 +28,22 @@ namespace gtsam { * TODO: consider eliminating this class altogether? */ template - class AlgebraicDecisionTree: public DecisionTree { + class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } - public: + public: - typedef DecisionTree Super; + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { @@ -60,33 +71,33 @@ namespace gtsam { }; AlgebraicDecisionTree() : - Super(1.0) { + Base(1.0) { } - AlgebraicDecisionTree(const Super& add) : - Super(add) { + AlgebraicDecisionTree(const Base& add) : + Base(add) { } /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { + Base(label, y1, y2) { } /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : + Base(labelC, y1, y2) { } /** Create from keys and vector table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + (const std::vector& labelCs, const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); @@ -94,23 +105,32 @@ namespace gtsam { std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + // Functor for label conversion so we can use `convertFrom`. + std::function L_of_M = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = this->template convertFrom(other.root_, L_of_M, op); } /** sum */ @@ -134,12 +154,31 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } + /// print method customized to value type `double`. + void print(const std::string& s, + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { + auto valueFormatter = [](const double& v) { + return (boost::format("%4.2g") % v).str(); + }; + Base::print(s, labelFormatter, valueFormatter); + } + + /// Equality method customized to value type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto compare = [tol](double a, double b) { + return std::abs(a - b) < tol; + }; + return Base::equals(other, compare); + } }; // AlgebraicDecisionTree +template struct traits> : public Testable> {}; } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebf..ab14b2a72 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,21 +20,22 @@ #pragma once #include -#include +#include #include +#include #include #include -#include -using boost::assign::operator+=; +#include #include -#include - -#include +#include #include #include +#include #include +using boost::assign::operator+=; + namespace gtsam { /*********************************************************************************/ @@ -76,23 +77,26 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Leaf* other = dynamic_cast (&q); + bool equals(const Node& q, const CompareFunc& compare) const override { + const Leaf* other = dynamic_cast(&q); if (!other) return false; - return std::abs(double(this->constant_ - other->constant_)) < tol; + return compare(this->constant_, other->constant_); } /** print */ - void print(const std::string& s) const override { - bool showZero = true; - if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } - /** to graphviz file */ - void dot(std::ostream& os, bool showZero) const override { - if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" - << boost::format("%4.2g") % constant_ - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + /** Write graphviz format to stream `os`. */ + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { + std::string value = valueFormatter(constant_); + if (showZero || value.compare("0")) + os << "\"" << this->id() << "\" [label=\"" << value + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, } /** evaluate */ @@ -151,7 +155,7 @@ namespace gtsam { /** incremental allSame */ size_t allSame_; - typedef boost::shared_ptr ChoicePtr; + using ChoicePtr = boost::shared_ptr; public: @@ -236,32 +240,38 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s) const override { + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - // std::cout << this << ","; - std::cout << label_ << ") " << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str()); + branches_[i]->print((boost::format("%s %d") % s % i).str(), + labelFormatter, valueFormatter); } /** output to graphviz (as a a graph) */ - void dot(std::ostream& os, bool showZero) const override { + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { - const Leaf* leaf = dynamic_cast (branch.get()); - if (leaf && !leaf->constant()) continue; + const Leaf* leaf = dynamic_cast(branch.get()); + if (leaf && valueFormatter(leaf->constant()).compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; - branch->dot(os, showZero); + branch->dot(os, labelFormatter, valueFormatter, showZero); } } @@ -275,15 +285,16 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + /** equality */ + bool equals(const Node& q, const CompareFunc& compare) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), compare))) + return false; return true; } @@ -315,7 +326,7 @@ namespace gtsam { /** apply unary operator */ NodePtr apply(const Unary& op) const override { - boost::shared_ptr r(new Choice(label_, *this, op)); + auto r = boost::make_shared(label_, *this, op); return Unique(r); } @@ -330,24 +341,24 @@ namespace gtsam { // If second argument of binary op is Leaf node, recurse on branches NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(NodePtr branch: branches_) - h->push_back(fL.apply_f_op_g(*branch, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); return Unique(h); } // If second argument of binary op is Choice, call constructor NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - boost::shared_ptr h(new Choice(fC, *this, op)); + auto h = boost::make_shared(fC, *this, op); return Unique(h); } // If second argument of binary op is Leaf template NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(const NodePtr& branch: branches_) - h->push_back(branch->apply_f_op_g(gL, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(branch->apply_f_op_g(gL, op)); return Unique(h); } @@ -357,9 +368,9 @@ namespace gtsam { return branches_[index]; // choose branch // second case, not label of interest, just recurse - boost::shared_ptr r(new Choice(label_, branches_.size())); - for(const NodePtr& branch: branches_) - r->push_back(branch->choose(label, index)); + auto r = boost::make_shared(label_, branches_.size()); + for (auto&& branch : branches_) + r->push_back(branch->choose(label, index)); return Unique(r); } @@ -384,10 +395,9 @@ namespace gtsam { } /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const L& label, const Y& y1, const Y& y2) { - boost::shared_ptr a(new Choice(label, 2)); + template + DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { + auto a = boost::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); @@ -395,12 +405,12 @@ namespace gtsam { } /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const LabelC& labelC, const Y& y1, const Y& y2) { + template + DecisionTree::DecisionTree(const LabelC& labelC, const Y& y1, + const Y& y2) { if (labelC.second != 2) throw std::invalid_argument( "DecisionTree: binary constructor called with non-binary label"); - boost::shared_ptr a(new Choice(labelC.first, 2)); + auto a = boost::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); @@ -447,11 +457,22 @@ namespace gtsam { } /*********************************************************************************/ - template - template + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + Func Y_of_X) { + // Define functor for identity mapping of node label. + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(other.root_, L_of_L, Y_of_X); + } + + /*********************************************************************************/ + template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); + const std::map& map, Func Y_of_X) { + auto L_of_M = [&map](const M& label) -> L { return map.at(label); }; + root_ = convertFrom(other.root_, L_of_M, Y_of_X); } /*********************************************************************************/ @@ -480,13 +501,14 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { - boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + auto choiceOnLabel = boost::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); return Choice::Unique(choiceOnLabel); } else { // Set up a new choice on the highest label - boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, nrChoices)); + auto choiceOnHighestLabel = + boost::make_shared(*highestLabel, nrChoices); // now, for all possible values of highestLabel for (size_t index = 0; index < nrChoices; index++) { // make a new set of functions for composing by iterating over the given @@ -545,7 +567,7 @@ namespace gtsam { std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; throw std::invalid_argument("DecisionTree::create invalid argument"); } - boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + auto choice = boost::make_shared(begin->first, endY - beginY); for (ValueIt y = beginY; y != endY; y++) choice->push_back(NodePtr(new Leaf(*y))); return Choice::Unique(choice); @@ -558,56 +580,136 @@ namespace gtsam { size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = create(labelC, end, beginY, beginY + split); - functions += DecisionTree(f); + functions.emplace_back(f); } return compose(functions.begin(), functions.end(), begin->first); } /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { - - typedef DecisionTree MX; - typedef typename MX::Leaf MXLeaf; - typedef typename MX::Choice MXChoice; - typedef typename MX::NodePtr MXNodePtr; - typedef DecisionTree LY; + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const { + using LY = DecisionTree; // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf - const MXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + using MXLeaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(f)) + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + using MXChoice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( "DecisionTree::Convert: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); - functions += converted; + for(auto && branch: choice->branches()) { + functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } return LY::compose(functions.begin(), functions.end(), newLabel); } /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + // Functor performing depth-first visit without Assignment argument. + template + struct Visit { + using F = std::function; + Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); } - template - void DecisionTree::print(const std::string& s) const { - root_->print(s); + /*********************************************************************************/ + // Functor performing depth-first visit with Assignment argument. + template + struct VisitWith { + using Choices = Assignment; + using F = std::function; + VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(choices, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (size_t i = 0; i < choice->nrChoices(); i++) { + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /*********************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /*********************************************************************************/ + // labels is just done with a visit + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& choices, const Y&) { + for (auto&& kv : choices) unique.insert(kv.first); + }; + visitWith(f); + return unique; + } + +/*********************************************************************************/ + template + bool DecisionTree::equals(const DecisionTree& other, + const CompareFunc& compare) const { + return root_->equals(*other.root_, compare); + } + + template + void DecisionTree::print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const { + root_->print(s, labelFormatter, valueFormatter); } template @@ -622,6 +724,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const Unary& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } return DecisionTree(root_->apply(op)); } @@ -629,6 +736,11 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { + // It is unclear what should happen if either tree is empty: + if (empty() || g.empty()) { + throw std::runtime_error( + "DecisionTree::apply(binary op) undefined for empty trees."); + } // apply the operaton on the root of both diagrams NodePtr h = root_->apply_f_op_g(*g.root_, op); // create a new class with the resulting root "h" @@ -657,21 +769,36 @@ namespace gtsam { } /*********************************************************************************/ - template - void DecisionTree::dot(std::ostream& os, bool showZero) const { + template + void DecisionTree::dot(std::ostream& os, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { os << "digraph G {\n"; - root_->dot(os, showZero); + root_->dot(os, labelFormatter, valueFormatter, showZero); os << " [ordering=out]}" << std::endl; } - template - void DecisionTree::dot(const std::string& name, bool showZero) const { + template + void DecisionTree::dot(const std::string& name, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::ofstream os((name + ".dot").c_str()); - dot(os, showZero); + dot(os, labelFormatter, valueFormatter, showZero); int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + } + + template + std::string DecisionTree::dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { + std::stringstream ss; + dot(ss, labelFormatter, valueFormatter, showZero); + return ss.str(); + } /*********************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0..9692094e1 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,13 +19,16 @@ #pragma once +#include #include #include #include #include #include +#include #include +#include namespace gtsam { @@ -35,16 +38,26 @@ namespace gtsam { * Y = function range (any algebra), e.g., bool, int, double */ template - class DecisionTree { + class GTSAM_EXPORT DecisionTree { + + protected: + /// Default method for comparison of two objects of type Y. + static bool DefaultCompare(const Y& a, const Y& b) { + return a == b; + } public: + using LabelFormatter = std::function; + using ValueFormatter = std::function; + using CompareFunc = std::function; + /** Handy typedefs for unary and binary function types */ - typedef std::function Unary; - typedef std::function Binary; + using Unary = std::function; + using Binary = std::function; /** A label annotated with cardinality */ - typedef std::pair LabelC; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ class Leaf; @@ -53,7 +66,7 @@ namespace gtsam { /** ------------------------ Node base class --------------------------- */ class Node { public: - typedef boost::shared_ptr Ptr; + using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -77,11 +90,16 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "") const = 0; - virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual void print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const = 0; + virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, const CompareFunc& compare = + &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -95,9 +113,9 @@ namespace gtsam { public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; - /* a DecisionTree just contains the root */ + /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; protected: @@ -106,19 +124,29 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param L_of_M Functor to convert from label type M to type L. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; - /** Default constructor */ - DecisionTree(); - - public: + public: /// @name Standard Constructors /// @{ + /** Default constructor (for serialization) */ + DecisionTree(); + /** Create a constant */ DecisionTree(const Y& y); @@ -142,20 +170,47 @@ namespace gtsam { DecisionTree(const L& label, // const DecisionTree& f0, const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + /** + * @brief Convert from a different value type. + * + * @tparam X The previous value type. + * @param other The DecisionTree to convert from. + * @param Y_of_X Functor to convert from value type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, Func Y_of_X); + + /** + * @brief Convert from a different value type X to value type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous value type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, const std::map& map, + Func Y_of_X); /// @} /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree") const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, + const CompareFunc& compare = &DefaultCompare) const; /// @} /// @name Standard Interface @@ -165,12 +220,61 @@ namespace gtsam { virtual ~DecisionTree() { } + /// Check if tree is empty. + bool empty() const { return !root_; } + /** equality */ bool operator==(const DecisionTree& q) const; /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking an assignment and a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all unique labels as a set. */ + std::set labels() const; + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; @@ -193,10 +297,17 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, bool showZero = true) const; + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, bool showZero = true) const; + void dot(const std::string& name, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero = true) const; /// @name Advanced Interface /// @{ @@ -214,13 +325,15 @@ namespace gtsam { /** free versions of apply */ - template + /// Apply unary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } - template + /// Apply binary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, const typename DecisionTree::Binary& op) { diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index b7b9d7034..2607a80ef 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -34,12 +34,13 @@ namespace gtsam { /* ******************************************************************************** */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials) : - DiscreteFactor(keys.indices()), Potentials(keys, potentials) { + DiscreteFactor(keys.indices()), ADT(potentials), + cardinalities_(keys.cardinalities()) { } /* *************************************************************************/ DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), Potentials(c) { + DiscreteFactor(c.keys()), AlgebraicDecisionTree(c), cardinalities_(c.cardinalities_) { } /* ************************************************************************* */ @@ -48,16 +49,24 @@ namespace gtsam { return false; } else { - const DecisionTreeFactor& f(static_cast(other)); - return Potentials::equals(f, tol); + const auto& f(static_cast(other)); + return ADT::equals(f, tol); } } + /* ************************************************************************* */ + double DecisionTreeFactor::safe_div(const double &a, const double &b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + /* ************************************************************************* */ void DecisionTreeFactor::print(const string& s, const KeyFormatter& formatter) const { cout << s; - Potentials::print("Potentials:",formatter); + ADT::print("Potentials:",formatter); } /* ************************************************************************* */ @@ -134,5 +143,90 @@ namespace gtsam { return boost::make_shared(dkeys, result); } -/* ************************************************************************* */ + /* ************************************************************************* */ + std::vector> DecisionTreeFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs; + for (auto& key : keys()) { + pairs.emplace_back(key, cardinalities_.at(key)); + } + // Reverse to make cartesianProduct output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************* */ + static std::string valueFormatter(const double& v) { + return (boost::format("%4.2g") % v).str(); + } + + /** output to graphviz format, stream version */ + void DecisionTreeFactor::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(os, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format, open a file */ + void DecisionTreeFactor::dot(const std::string& name, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(name, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format string */ + std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, + bool showZero) const { + return ADT::dot(keyFormatter, valueFormatter, showZero); + } + + /* ************************************************************************* */ + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header and construct argument for `cartesianProduct`. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << "|"; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << Translate(names, key, index) << "|"; + } + ss << kv.second << "|\n"; + } + return ss.str(); + } + + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector &table) : + DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) { + } + + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : + DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) { + } + + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index aa718e35d..f7c50d5b5 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -19,7 +19,8 @@ #pragma once #include -#include +#include +#include #include #include @@ -35,7 +36,7 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials { + class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree { public: @@ -43,6 +44,10 @@ namespace gtsam { typedef DecisionTreeFactor This; typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; + typedef AlgebraicDecisionTree ADT; + + protected: + std::map cardinalities_; public: @@ -55,14 +60,23 @@ namespace gtsam { /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from Indices and (string or doubles) */ - template - DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : - DiscreteFactor(keys.indices()), Potentials(keys, table) { - } + /** Constructor from doubles */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::vector& table); + + /** Constructor from string */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); + + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} /** Construct from a DiscreteConditional type */ - DecisionTreeFactor(const DiscreteConditional& c); + explicit DecisionTreeFactor(const DiscreteConditional& c); /// @} /// @name Testable @@ -81,7 +95,7 @@ namespace gtsam { /// Value is just look up in AlgebraicDecisonTree double operator()(const DiscreteValues& values) const override { - return Potentials::operator()(values); + return ADT::operator()(values); } /// multiply two factors @@ -89,6 +103,10 @@ namespace gtsam { return apply(f, ADT::Ring::mul); } + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j);} + /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); @@ -162,7 +180,39 @@ namespace gtsam { // Potentials::reduceWithInverse(inverseReduction); // } + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + /// @} + /// @name Wrapper support + /// @{ + + /** output to graphviz format, stream version */ + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + }; // DecisionTreeFactor diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 219f2d93e..510fb5638 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -38,7 +38,7 @@ namespace gtsam { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply double result = 1.0; - for(DiscreteConditional::shared_ptr conditional: *this) + for(const DiscreteConditional::shared_ptr& conditional: *this) result *= (*conditional)(values); return result; } @@ -61,5 +61,16 @@ namespace gtsam { return result; } + /* ************************************************************************* */ + std::string DiscreteBayesNet::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for(const DiscreteConditional::shared_ptr& conditional: *this) + ss << conditional->markdown(keyFormatter, names) << endl; + return ss.str(); + } /* ************************************************************************* */ } // namespace diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 2d92b72e8..5332b51dd 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -13,6 +13,7 @@ * @file DiscreteBayesNet.h * @date Feb 15, 2011 * @author Duy-Nguyen Ta + * @author Frank dellaert */ #pragma once @@ -22,6 +23,7 @@ #include #include #include +#include #include namespace gtsam { @@ -74,6 +76,11 @@ namespace gtsam { // Add inherited versions of add. using Base::add; + /** Add a DiscretePrior using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } + /** Add a DiscreteCondtional */ template void add(Args&&... args) { @@ -97,6 +104,14 @@ namespace gtsam { DiscreteValues sample() const; ///@} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} private: /** Serialization function */ diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 48413405a..07d6e0f0e 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -55,8 +55,22 @@ namespace gtsam { return result; } -} // \namespace gtsam - - - + /* **************************************************************************/ + std::string DiscreteBayesTree::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 42ec7d417..6189f25d5 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -72,6 +72,8 @@ class GTSAM_EXPORT DiscreteBayesTree typedef DiscreteBayesTree This; typedef boost::shared_ptr shared_ptr; + /// @name Standard interface + /// @{ /** Default constructor, creates an empty Bayes tree */ DiscreteBayesTree() {} @@ -82,10 +84,19 @@ class GTSAM_EXPORT DiscreteBayesTree double evaluate(const DiscreteValues& values) const; //** (Preferred) sugar for the above for given DiscreteValues */ - double operator()(const DiscreteValues & values) const { + double operator()(const DiscreteValues& values) const { return evaluate(values); } + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 371b15ac0..951c0b6ca 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s, } } cout << ")"; - Potentials::print(""); + ADT::print(""); cout << endl; } @@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose( - const DiscreteValues& parentsValues) const { +static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, + const DiscreteValues& parentsValues) { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - ADT pFS(*this); + DiscreteConditional::ADT adt(conditional); size_t value; - for (Key j : parents()) { + for (Key j : conditional.parents()) { try { value = parentsValues.at(j); - pFS = pFS.choose(j, value); // ADT keeps getting smaller. - } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { parentsValues.print("parentsValues: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; + return adt; } /* ******************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( +DecisionTreeFactor::shared_ptr DiscreteConditional::choose( const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues); + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + ADT adt(*this); + size_t value; + for (Key j : parents()) { + try { + value = parentsValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + parentsValues.print("parentsValues: "); + throw runtime_error("DiscreteConditional::choose: parent value missing"); + }; + } // Convert ADT to factor. - if (nrFrontals() != 1) { - throw std::runtime_error("Expected only one frontal variable in choose."); + DiscreteKeys discreteKeys; + for (Key j : frontals()) { + discreteKeys.emplace_back(j, this->cardinality(j)); } - DiscreteKeys keys; - const Key frontalKey = keys_[0]; - size_t frontalCardinality = this->cardinality(frontalKey); - keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); - return boost::make_shared(keys, pFS); + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + }; + } + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t parent_value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], parent_value); + return likelihood(values); } /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(*values); // P(F|S=parentsValues) + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -147,10 +192,10 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { keys & dk; } // Get all Possible Configurations - vector allPosbValues = cartesianProduct(keys); + const auto allPosbValues = cartesianProduct(keys); // Find the MPE - for(DiscreteValues& frontalVals: allPosbValues) { + for(const auto& frontalVals: allPosbValues) { double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) // Update MPE solution if better if (pValueS > maxP) { @@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way @@ -203,10 +248,14 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); @@ -223,5 +272,105 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { } /* ******************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} -}// namespace +/* ******************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + +/* ************************************************************************* */ +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + std::stringstream ss; + + // Print out signature. + ss << " *P("; + bool first = true; + for (Key key : frontals()) { + if (!first) ss << ","; + ss << keyFormatter(key); + first = false; + } + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << ")*:\n" << std::endl; + ss << DecisionTreeFactor::markdown(keyFormatter, names); + return ss.str(); + } + + // We have parents, continue signature and do custom print. + ss << "|"; + first = true; + for (Key parent : parents()) { + if (!first) ss << ","; + ss << keyFormatter(parent); + first = false; + } + ss << ")*:\n" << std::endl; + + // Print out header and construct argument for `cartesianProduct`. + std::vector> pairs; + ss << "|"; + const_iterator it; + for(Key parent: parents()) { + ss << "*" << keyFormatter(parent) << "*|"; + pairs.emplace_back(parent, cardinalities_.at(parent)); + } + + size_t n = 1; + for(Key key: frontals()) { + size_t k = cardinalities_.at(key); + pairs.emplace_back(key, k); + n *= k; + } + std::vector> slatnorf(pairs.rbegin(), + pairs.rend() - nrParents()); + const auto frontal_assignments = cartesianProduct(slatnorf); + for (const auto& a : frontal_assignments) { + for (it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index); + } + ss << "|"; + } + ss << "\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; + ss << "\n"; + + // Print out all rows. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + size_t count = 0; + for (const auto& a : assignments) { + if (count == 0) { + ss << "|"; + for (it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index) << "|"; + } + } + ss << operator()(a) << "|"; + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + return ss.str(); +} +/* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 06928e2e7..4c2e964fd 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -62,8 +62,6 @@ public: * conditional probability table (CPT) in 00 01 10 11 order. For * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * - * The first string is parsed to add a key and parents. - * * Example: DiscreteConditional P(D, {B,E}, table); */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, @@ -75,8 +73,7 @@ public: * probability table (CPT) in 00 01 10 11 order. For three-valued, it would * be 00 01 02 10 11 12 20 21 22, etc.... * - * The first string is parsed to add a key and parents. The second string - * parses into a table. + * The string is parsed into a Signature::Table. * * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); */ @@ -84,6 +81,10 @@ public: const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} + /// No-parent specialization; can also use DiscretePrior. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); @@ -127,7 +128,7 @@ public: /// Evaluate, just look up in AlgebraicDecisonTree double operator()(const DiscreteValues& values) const override { - return Potentials::operator()(values); + return ADT::operator()(values); } /** Convert to a factor */ @@ -135,13 +136,17 @@ public: return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); } - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const DiscreteValues& parentsValues) const; - /** Restrict to given parent values, returns DecisionTreeFactor */ - DecisionTreeFactor::shared_ptr chooseAsFactor( + DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; + /** * solve a conditional * @param parentsValues Known values of the parents @@ -156,6 +161,13 @@ public: */ size_t sample(const DiscreteValues& parentsValues) const; + + /// Single parent version. + size_t sample(size_t parent_value) const; + + /// Zero parent version. + size_t sample() const; + /// @} /// @name Advanced Interface /// @{ @@ -167,7 +179,14 @@ public: void sampleInPlace(DiscreteValues* parentsValues) const; /// @} + /// @name Wrapper support + /// @{ + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} }; // DiscreteConditional diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d2..1a12ef405 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -19,9 +19,20 @@ #include +#include + using namespace std; namespace gtsam { -/* ************************************************************************* */ -} // namespace gtsam +string DiscreteFactor::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e2be94b94..e30c0a6fe 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -73,9 +73,6 @@ public: Base::print(s, formatter); } - /** Test whether the factor is empty */ - virtual bool empty() const { return size() == 0; } - /// @} /// @name Standard Interface /// @{ @@ -88,6 +85,27 @@ public: virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = std::map>; + + /// Translate an integer index value for given key to a string. + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + virtual std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; + /// @} }; // DiscreteFactor diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 77127ac30..be046d290 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,17 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -64,7 +66,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -129,6 +131,19 @@ namespace gtsam { return std::make_pair(cond, sum); } -/* ************************************************************************* */ -} // namespace + /* ************************************************************************* */ + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "factor " << i << ":\n"; + ss << factors_[i]->markdown(keyFormatter, names) << endl; + } + return ss.str(); + } + /* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index ff0aaef19..9aa04d649 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -24,7 +24,10 @@ #include #include #include + #include +#include +#include namespace gtsam { @@ -101,29 +104,12 @@ public: /// @} - // Add single key decision-tree factor. - template - void add(const DiscreteKey& j, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j); - emplace_shared(keys, table); + /** Add a decision-tree factor */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); } - - // Add binary key decision-tree factor. - template - void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j1); - keys.push_back(j2); - emplace_shared(keys, table); - } - - // Add shared discreteFactor immediately from arguments. - template - void add(const DiscreteKeys& keys, SOURCE table) { - emplace_shared(keys, table); - } - + /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; @@ -154,6 +140,20 @@ public: // /** Apply a reduction, which is a remapping of variable indices. */ // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} }; // \ DiscreteFactorGraph /// traits diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 3462166f4..ae4dac38f 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -34,32 +34,30 @@ namespace gtsam { using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator - struct DiscreteKeys: public std::vector { + struct GTSAM_EXPORT DiscreteKeys: public std::vector { // Forward all constructors. using std::vector::vector; /// Constructor for serialization - GTSAM_EXPORT DiscreteKeys() : std::vector::vector() {} + DiscreteKeys() : std::vector::vector() {} /// Construct from a key - GTSAM_EXPORT DiscreteKeys(const DiscreteKey& key) { - push_back(key); - } + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys - GTSAM_EXPORT DiscreteKeys(const std::vector& keys) : + DiscreteKeys(const std::vector& keys) : std::vector(keys) { } /// Construct from cardinalities with default names - GTSAM_EXPORT DiscreteKeys(const std::vector& cs); + DiscreteKeys(const std::vector& cs); /// Return a vector of indices - GTSAM_EXPORT KeyVector indices() const; + KeyVector indices() const; /// Return a map from index to cardinality - GTSAM_EXPORT std::map cardinalities() const; + std::map cardinalities() const; /// Add a key (non-const!) DiscreteKeys& operator&(const DiscreteKey& key) { @@ -69,5 +67,5 @@ namespace gtsam { }; // DiscreteKeys /// Create a list from two keys - GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); } diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index b118909bc..27352a211 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -29,7 +29,7 @@ namespace gtsam { /** * A class for computing marginals of variables in a DiscreteFactorGraph */ - class DiscreteMarginals { +class GTSAM_EXPORT DiscreteMarginals { protected: diff --git a/gtsam/discrete/DiscretePrior.cpp b/gtsam/discrete/DiscretePrior.cpp new file mode 100644 index 000000000..3941e0199 --- /dev/null +++ b/gtsam/discrete/DiscretePrior.cpp @@ -0,0 +1,50 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.cpp + * @date December 2021 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +void DiscretePrior::print(const std::string& s, + const KeyFormatter& formatter) const { + Base::print(s, formatter); +} + +double DiscretePrior::operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); +} + +std::vector DiscretePrior::pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscretePrior::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h new file mode 100644 index 000000000..9ac8acb17 --- /dev/null +++ b/gtsam/discrete/DiscretePrior.h @@ -0,0 +1,111 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscretePrior() {} + + /// Constructor from factor. + DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscretePrior P(D % "3/2"); + */ + DiscretePrior(const Signature& s) : Base(s) {} + + /** + * Construct from key and a Signature::Table specifying the + * conditional probability table (CPT). + * + * Example: DiscretePrior P(D, table); + */ + DiscretePrior(const DiscreteKey& key, const Signature::Table& table) + : Base(Signature(key, {}, table)) {} + + /** + * Construct from key and a string specifying the conditional + * probability table (CPT). + * + * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + */ + DiscretePrior(const DiscreteKey& key, const std::string& spec) + : DiscretePrior(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard interface + /// @{ + + /// Evaluate given a single value. + double operator()(size_t value) const; + + /// We also want to keep the Base version, taking DiscreteValues: + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); + + /// Return entire probability mass function. + std::vector pmf() const; + + /** + * solve a conditional + * @return MPE value of the child (1 frontal variable). + */ + size_t solve() const { return Base::solve({}); } + + /** + * sample + * @return sample from conditional + */ + size_t sample() const { return Base::sample(); } + + /// @} +}; +// DiscretePrior + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index a1ee22e01..2d9c8d3cf 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -32,7 +32,25 @@ namespace gtsam { * stores cardinality of a Discrete variable. It should be handled naturally in * the new class DiscreteValue, as the variable's type (domain) */ -using DiscreteValues = Assignment; +class DiscreteValues : public Assignment { + public: + using Assignment::Assignment; // all constructors + + // Define the implicit default constructor. + DiscreteValues() = default; + + // Construct from assignment. + DiscreteValues(const Assignment& a) : Assignment(a) {} + + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + std::cout << s << ": "; + for (const typename Assignment::value_type& keyValue : *this) + std::cout << "(" << keyFormatter(keyValue.first) << ", " + << keyValue.second << ")"; + std::cout << std::endl; + } +}; // traits template<> struct traits : public Testable {}; diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp deleted file mode 100644 index 331a76c13..000000000 --- a/gtsam/discrete/Potentials.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.cpp - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#include -#include - -#include - -#include - -using namespace std; - -namespace gtsam { - -// explicit instantiation -template class DecisionTree; -template class AlgebraicDecisionTree; - -/* ************************************************************************* */ -double Potentials::safe_div(const double& a, const double& b) { - // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); -} - -/* ******************************************************************************** - */ -Potentials::Potentials() : ADT(1.0) {} - -/* ******************************************************************************** - */ -Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) - : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} - -/* ************************************************************************* */ -bool Potentials::equals(const Potentials& other, double tol) const { - return ADT::equals(other, tol); -} - -/* ************************************************************************* */ -void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: {"; - for (const std::pair& key : cardinalities_) - cout << formatter(key.first) << ":" << key.second << ", "; - cout << "}" << endl; - ADT::print(" "); -} -// -// /* ************************************************************************* */ -// template -// void Potentials::remapIndices(const P& remapping) { -// // Permute the _cardinalities (TODO: Inefficient Consider Improving) -// DiscreteKeys keys; -// map ordering; -// -// // Get the original keys from cardinalities_ -// for(const DiscreteKey& key: cardinalities_) -// keys & key; -// -// // Perform Permutation -// for(DiscreteKey& key: keys) { -// ordering[key.first] = remapping[key.first]; -// key.first = ordering[key.first]; -// } -// -// // Change *this -// AlgebraicDecisionTree permuted((*this), ordering); -// *this = permuted; -// cardinalities_ = keys.cardinalities(); -// } -// -// /* ************************************************************************* */ -// void Potentials::permuteWithInverse(const Permutation& inversePermutation) { -// remapIndices(inversePermutation); -// } -// -// /* ************************************************************************* */ -// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) { -// remapIndices(inverseReduction); -// } - - /* ************************************************************************* */ - -} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h deleted file mode 100644 index 1078b4c61..000000000 --- a/gtsam/discrete/Potentials.h +++ /dev/null @@ -1,97 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.h - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace gtsam { - - /** - * A base class for both DiscreteFactor and DiscreteConditional - */ - class Potentials: public AlgebraicDecisionTree { - - public: - - typedef AlgebraicDecisionTree ADT; - - protected: - - /// Cardinality for each key, used in combine - std::map cardinalities_; - - /** Constructor from ColumnIndex, and ADT */ - Potentials(const ADT& potentials) : - ADT(potentials) { - } - - // Safe division for probabilities - GTSAM_EXPORT static double safe_div(const double& a, const double& b); - -// // Apply either a permutation or a reduction -// template -// void remapIndices(const P& remapping); - - public: - - /** Default constructor for I/O */ - GTSAM_EXPORT Potentials(); - - /** Constructor from Indices and ADT */ - GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree); - - /** Constructor from Indices and (string or doubles) */ - template - Potentials(const DiscreteKeys& keys, SOURCE table) : - ADT(keys, table), cardinalities_(keys.cardinalities()) { - } - - // Testable - GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const; - GTSAM_EXPORT void print(const std::string& s = "Potentials: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const; - - size_t cardinality(Key j) const { return cardinalities_.at(j);} - -// /** -// * @brief Permutes the keys in Potentials -// * -// * This permutes the Indices and performs necessary re-ordering of ADD. -// * This is virtual so that derived types e.g. DecisionTreeFactor can -// * re-implement it. -// */ -// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation); -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction); - - }; // Potentials - -// traits -template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; - - -} // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index daea84e70..d17401e44 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -30,30 +30,46 @@ class DiscreteFactor { }; #include -virtual class DecisionTreeFactor: gtsam::DiscreteFactor { +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); + + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + bool showZero = true) const; + std::vector> enumerate() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); DiscreteConditional(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); - size_t size() const; // TODO(dellaert): why do I have to repeat??? - double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -62,18 +78,45 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; gtsam::DecisionTreeFactor* toFactor() const; - gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* choose( + const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; - void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; - void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; + size_t sample(size_t value) const; + size_t sample() const; + void solveInPlace(gtsam::DiscreteValues @parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +virtual class DiscretePrior : gtsam::DiscreteConditional { + DiscretePrior(); + DiscretePrior(const gtsam::DecisionTreeFactor& f); + DiscretePrior(const gtsam::DiscreteKey& key, string spec); + void print(string s = "Discrete Prior\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; + size_t solve() const; }; #include -class DiscreteBayesNet { +class DiscreteBayesNet { DiscreteBayesNet(); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); void add(const gtsam::DiscreteKey& key, - const gtsam::DiscreteKeys& parents, string spec); + const std::vector& parents, string spec); bool empty() const; size_t size() const; gtsam::KeySet keys() const; @@ -82,34 +125,73 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; - void saveGraph(string s, - const gtsam::KeyFormatter& keyFormatter = + string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - void add(const gtsam::DiscreteConditional& s); + void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include +class DiscreteBayesTreeClique { + DiscreteBayesTreeClique(); + DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); + const gtsam::DiscreteConditional* conditional() const; + bool isRoot() const; + void printSignature( + const string& s = "Clique: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + class DiscreteBayesTree { DiscreteBayesTree(); void print(string s = "DiscreteBayesTree\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const DiscreteBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +class DotWriter { + DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, + bool plotFactorPoints = true, bool connectKeysToFactor = true, + bool binaryEdges = true); }; #include class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); - + void add(const gtsam::DiscreteKey& j, string table); - void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKeys& keys, string table); - + void add(const std::vector& keys, string table); + bool empty() const; size_t size() const; gtsam::KeySet keys() const; @@ -117,7 +199,15 @@ class DiscreteFactorGraph { void print(string s = "") const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; - + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const; + gtsam::DecisionTreeFactor product() const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; @@ -126,6 +216,11 @@ class DiscreteFactorGraph { gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 7a33810c7..910515b5c 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -136,8 +136,8 @@ ADT create(const Signature& signature) { ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); - string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - dot(p, dotfile); + string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, DOTfile); return p; } @@ -414,13 +414,13 @@ TEST(ADT, equality_noparser) // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ************************************************************************* */ @@ -431,13 +431,13 @@ TEST(ADT, equality_parser) // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ******************************************************************************** */ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..2e6ec59f7 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -40,25 +40,69 @@ void dot(const T&f, const string& filename) { #define DOT(x)(dot(x,#x)) -struct Crazy { int a; double b; }; -typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) +struct Crazy { + int a; + double b; +}; + +struct CrazyDecisionTree : public DecisionTree { + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const Crazy& v) { + return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to Crazy node type + bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + auto compare = [tol](const Crazy& v, const Crazy& w) { + return v.a == w.a && std::abs(v.b - w.b) < tol; + }; + return DecisionTree::equals(other, compare); + } +}; // traits namespace gtsam { template<> struct traits : public Testable {}; } +GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) + /* ******************************************************************************** */ // Test string labels and int range /* ******************************************************************************** */ -typedef DecisionTree DT; +struct DT : public DecisionTree { + using Base = DecisionTree; + using DecisionTree::DecisionTree; + DT() = default; + + DT(const Base& dt) : Base(dt) {} + + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const int& v) { + return (boost::format("%d") % v).str(); + }; + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); + } +}; // traits namespace gtsam { template<> struct traits
: public Testable
{}; } +GTSAM_CONCEPT_TESTABLE_INST(DT) + struct Ring { static inline int zero() { return 0; @@ -66,6 +110,9 @@ struct Ring { static inline int one() { return 1; } + static inline int id(const int& a) { + return a; + } static inline int add(const int& a, const int& b) { return a + b; } @@ -76,8 +123,7 @@ struct Ring { /* ******************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -88,6 +134,9 @@ TEST(DT, example) x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; + // empty + DT empty; + // A DT a(A, 0, 5); LONGS_EQUAL(0,a(x00)) @@ -106,6 +155,11 @@ TEST(DT, example) LONGS_EQUAL(5,notb(x10)) DOT(notb); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); LONGS_EQUAL(0,anotb(x00)) @@ -175,17 +229,34 @@ TEST(DT, example) } /* ******************************************************************************** */ -// test Conversion +// test Conversion of values +bool bool_of_int(const int& y) { return y != 0; }; +typedef DecisionTree StringBoolTree; + +TEST(DecisionTree, ConvertValuesOnly) { + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + StringBoolTree f2(f1, bool_of_int); + + // Check a value + Assignment x00; + x00["A"] = 0, x00["B"] = 0; + EXPECT(!f2(x00)); +} + +/* ******************************************************************************** */ +// test Conversion of both values and labels. enum Label { U, V, X, Y, Z }; -typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; -} +typedef DecisionTree LabelBoolTree; -TEST(DT, conversion) -{ +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -196,12 +267,9 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); - // f1.print("f1"); - // f2.print("f2"); + LabelBoolTree f2(f1, ordering, &bool_of_int); - // create a value + // Check some values Assignment