diff --git a/.cproject b/.cproject index 5916e18da..c81094465 100644 --- a/.cproject +++ b/.cproject @@ -362,6 +362,14 @@ true true + + make + -j2 + testGaussianFactor.run + true + true + true + make -j2 @@ -388,7 +396,6 @@ make - tests/testBayesTree.run true false @@ -396,7 +403,6 @@ make - testBinaryBayesNet.run true false @@ -444,7 +450,6 @@ make - testSymbolicBayesNet.run true false @@ -452,7 +457,6 @@ make - tests/testSymbolicFactor.run true false @@ -460,7 +464,6 @@ make - testSymbolicFactorGraph.run true false @@ -476,20 +479,11 @@ make - tests/testBayesTree true false true - - make - -j2 - testGaussianFactor.run - true - true - true - make -j2 @@ -516,6 +510,7 @@ make + testGraph.run true false @@ -587,6 +582,7 @@ make + testInference.run true false @@ -594,6 +590,7 @@ make + testGaussianFactor.run true false @@ -601,6 +598,7 @@ make + testJunctionTree.run true false @@ -608,6 +606,7 @@ make + testSymbolicBayesNet.run true false @@ -615,6 +614,7 @@ make + testSymbolicFactorGraph.run true false @@ -684,22 +684,6 @@ false true - - make - -j2 - tests/testPose2.run - true - true - true - - - make - -j2 - tests/testPose3.run - true - true - true - make -j2 @@ -716,6 +700,22 @@ true true + + make + -j2 + tests/testPose2.run + true + true + true + + + make + -j2 + tests/testPose3.run + true + true + true + make -j2 @@ -740,26 +740,18 @@ true true - + make - -j2 - all + -j5 + nonlinear.testValues.run true true true - + make - -j2 - check - true - true - true - - - make - -j2 - clean + -j5 + nonlinear.testOrdering.run true true true @@ -796,18 +788,26 @@ true true - + make - -j5 - nonlinear.testValues.run + -j2 + all true true true - + make - -j5 - nonlinear.testOrdering.run + -j2 + check + true + true + true + + + make + -j2 + clean true true true @@ -844,30 +844,14 @@ true true - + make - -j2 - install - true - true - true - - - make - -j2 + -j5 check true true true - - make - -j2 - clean - true - true - true - make -j2 @@ -1044,14 +1028,6 @@ true true - - make - -j2 - SimpleRotation.run - true - true - true - make -j2 @@ -1142,7 +1118,6 @@ make - testErrors.run true false @@ -1598,6 +1573,7 @@ make + testSimulated2DOriented.run true false @@ -1637,6 +1613,7 @@ make + testSimulated2D.run true false @@ -1644,6 +1621,7 @@ make + testSimulated3D.run true false @@ -1657,6 +1635,14 @@ true true + + make + -j5 + testVector.run + true + true + true + make -j2 @@ -1817,6 +1803,14 @@ true true + + make + -j5 + UGM_small.run + true + true + true + make -j2 @@ -1827,6 +1821,7 @@ make + tests/testGaussianISAM2 true false @@ -1848,6 +1843,102 @@ true true + + make + -j2 + testRot3.run + true + true + true + + + make + -j2 + testRot2.run + true + true + true + + + make + -j2 + testPose3.run + true + true + true + + + make + -j2 + timeRot3.run + true + true + true + + + make + -j2 + testPose2.run + true + true + true + + + make + -j2 + testCal3_S2.run + true + true + true + + + make + -j2 + testSimpleCamera.run + true + true + true + + + make + -j2 + testHomography2.run + true + true + true + + + make + -j2 + testCalibratedCamera.run + true + true + true + + + make + -j2 + check + true + true + true + + + make + -j2 + clean + true + true + true + + + make + -j2 + testPoint2.run + true + true + true + make -j2 @@ -2031,98 +2122,134 @@ true true - + make - -j2 - testRot3.run + -j5 + wrap_gtsam true true true - + make - -j2 - testRot2.run + VERBOSE=1 + wrap_gtsam + true + false + true + + + cpack + -G DEB + true + false + true + + + cpack + -G RPM + true + false + true + + + cpack + -G TGZ + true + false + true + + + cpack + --config CPackSourceConfig.cmake + true + false + true + + + make + -j5 + check.discrete true true true - + make - -j2 - testPose3.run + -j5 + wrap_gtsam_unstable true true true - + make - -j2 - timeRot3.run + -j5 + check.wrap true true true - + make - -j2 - testPose2.run + -j5 + check.dynamics_unstable true true true - + make - -j2 - testCal3_S2.run + -j5 + check.slam_unstable true true true - + make - -j2 - testSimpleCamera.run + -j5 + check.base_unstable true true true - + make - -j2 - testHomography2.run + -j5 + testSpirit.run true true true - + make - -j2 - testCalibratedCamera.run + -j5 + testWrap.run true true true - + make - -j2 - check + -j5 + check.wrap true true true - + make - -j2 - clean + -j5 + wrap_gtsam true true true - + make - -j2 - testPoint2.run + -j5 + wrap true true true @@ -2166,46 +2293,6 @@ false true - - make - -j5 - wrap.testSpirit.run - true - true - true - - - make - -j5 - wrap.testWrap.run - true - true - true - - - make - -j5 - check.wrap - true - true - true - - - make - -j5 - wrap_gtsam - true - true - true - - - make - -j5 - wrap - true - true - true - diff --git a/CMakeLists.txt b/CMakeLists.txt index ebfd16e5f..45f907e14 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,10 +7,10 @@ set (GTSAM_VERSION_MINOR 9) set (GTSAM_VERSION_PATCH 0) # Set the default install path to home -set (CMAKE_INSTALL_PREFIX ${HOME} CACHE DOCSTRING "Install prefix for library") +#set (CMAKE_INSTALL_PREFIX ${HOME} CACHE PATH "Install prefix for library") # Use macros for creating tests/timing scripts -set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) +set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}" "${PROJECT_SOURCE_DIR}/cmake") include(GtsamTesting) include(GtsamPrinting) @@ -19,12 +19,8 @@ if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) message(FATAL_ERROR "In-source builds not allowed. Please make a new directory (called a build directory) and run CMake from there. You may need to remove CMakeCache.txt. ") endif() -# Default to Debug mode -if(NOT FIRST_PASS_DONE AND NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE "Debug" CACHE STRING - "Choose the type of build, options are: None Debug Release Timing Profiling RelWithDebInfo." - FORCE) -endif() +# Load build type flags and default to Debug mode +include(GtsamBuildTypes) # Check build types if(${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION} VERSION_GREATER 2.8 OR ${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION} VERSION_EQUAL 2.8) @@ -41,33 +37,29 @@ if( NOT cmake_build_type_tolower STREQUAL "" message(FATAL_ERROR "Unknown build type \"${CMAKE_BUILD_TYPE}\". Allowed values are None, Debug, Release, Timing, Profiling, RelWithDebInfo (case-insensitive).") endif() -# Add debugging flags but only on the first pass -if(NOT FIRST_PASS_DONE) - set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -fno-inline -Wall" CACHE STRING "Flags used by the compiler during debug builds." FORCE) - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-inline -Wall" CACHE STRING "Flags used by the compiler during debug builds." FORCE) - set(CMAKE_C_FLAGS_RELWITHDEBINFO "-g -fno-inline -Wall -DNDEBUG" CACHE STRING "Flags used by the compiler during relwithdebinfo builds." FORCE) - set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-g -fno-inline -Wall -DNDEBUG" CACHE STRING "Flags used by the compiler during relwithdebinfo builds." FORCE) - set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -Wall" CACHE STRING "Flags used by the compiler during release builds." FORCE) - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Wall" CACHE STRING "Flags used by the compiler during release builds." FORCE) - set(CMAKE_C_FLAGS_TIMING "${CMAKE_C_FLAGS_RELEASE} -DENABLE_TIMING" CACHE STRING "Flags used by the compiler during timing builds." FORCE) - set(CMAKE_CXX_FLAGS_TIMING "${CMAKE_CXX_FLAGS_RELEASE} -DENABLE_TIMING" CACHE STRING "Flags used by the compiler during timing builds." FORCE) - mark_as_advanced(CMAKE_C_FLAGS_TIMING CMAKE_CXX_FLAGS_TIMING) - set(CMAKE_C_FLAGS_PROFILING "-g -O2 -Wall -DNDEBUG" CACHE STRING "Flags used by the compiler during profiling builds." FORCE) - set(CMAKE_CXX_FLAGS_PROFILING "-g -O2 -Wall -DNDEBUG" CACHE STRING "Flags used by the compiler during profiling builds." FORCE) - mark_as_advanced(CMAKE_C_FLAGS_PROFILING CMAKE_CXX_FLAGS_PROFILING) -endif() - # Configurable Options -option(GTSAM_BUILD_TESTS "Enable/Disable building of tests" ON) -option(GTSAM_BUILD_TIMING "Enable/Disable building of timing scripts" ON) -option(GTSAM_BUILD_EXAMPLES "Enable/Disable building of examples" ON) -option(GTSAM_BUILD_WRAP "Enable/Disable building of matlab wrap utility (necessary for matlab interface)" ON) -option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices" OFF) +option(GTSAM_BUILD_TESTS "Enable/Disable building of tests" ON) +option(GTSAM_BUILD_TIMING "Enable/Disable building of timing scripts" ON) +option(GTSAM_BUILD_EXAMPLES "Enable/Disable building of examples" ON) +option(GTSAM_BUILD_UNSTABLE "Enable/Disable libgtsam_unstable" OFF) +option(GTSAM_BUILD_WRAP "Enable/Disable building of matlab wrap utility (necessary for matlab interface)" ON) +option(GTSAM_BUILD_SHARED_LIBRARY "Enable/Disable building of a shared version of gtsam" ON) +option(GTSAM_BUILD_STATIC_LIBRARY "Enable/Disable building of a static version of gtsam" ON) +option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices" OFF) option(GTSAM_BUILD_CONVENIENCE_LIBRARIES "Enable/Disable use of convenience libraries for faster development rebuilds, but slower install" ON) -option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" ON) -option(GTSAM_INSTALL_MATLAB_EXAMPLES "Enable/Disable installation of matlab examples" ON) -option(GTSAM_INSTALL_MATLAB_TESTS "Enable/Disable installation of matlab tests" ON) -option(GTSAM_INSTALL_WRAP "Enable/Disable installation of wrap utility" ON) +option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" ON) +option(GTSAM_INSTALL_MATLAB_EXAMPLES "Enable/Disable installation of matlab examples" ON) +option(GTSAM_INSTALL_MATLAB_TESTS "Enable/Disable installation of matlab tests" ON) +option(GTSAM_INSTALL_WRAP "Enable/Disable installation of wrap utility" ON) + +# Flags for choosing default packaging tools +set(CPACK_SOURCE_GENERATOR "TGZ" CACHE STRING "CPack Default Source Generator") +set(CPACK_GENERATOR "TGZ" CACHE STRING "CPack Default Binary Generator") + +# Sanity check building of libraries +if (NOT GTSAM_BUILD_SHARED_LIBRARY AND NOT GTSAM_BUILD_STATIC_LIBRARY) + message(FATAL_ERROR "Both shared and static version of GTSAM library disabled - need to choose at least one!") +endif() # Add the Quaternion Build Flag if requested if (GTSAM_USE_QUATERNIONS) @@ -80,6 +72,10 @@ endif(GTSAM_USE_QUATERNIONS) # FIXME: can't add install dependencies, so libraries never get built #set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) +# Alternative version to keep tests from building during make install +# Use the EXCLUDE_FROM_ALL property on test executables +option(GTSAM_ENABLE_INSTALL_TEST_FIX "Enable/Disable fix to remove dependency of tests on 'all' target" ON) + # Pull in infrastructure if (GTSAM_BUILD_TESTS) enable_testing() @@ -87,10 +83,6 @@ if (GTSAM_BUILD_TESTS) include(CTest) endif() -# Enable make check (http://www.cmake.org/Wiki/CMakeEmulateMakeCheck) -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND}) -add_custom_target(timing) - # Find boost find_package(Boost 1.40 COMPONENTS serialization REQUIRED) @@ -122,24 +114,52 @@ if (GTSAM_BUILD_EXAMPLES) add_subdirectory(examples) endif(GTSAM_BUILD_EXAMPLES) -# Mark that first pass is done -set(FIRST_PASS_DONE true CACHE BOOL "Internally used to mark whether cmake has been run multiple times") -mark_as_advanced(FIRST_PASS_DONE) +# Build gtsam_unstable +if (GTSAM_BUILD_UNSTABLE) + add_subdirectory(gtsam_unstable) +endif(GTSAM_BUILD_UNSTABLE) + +# Set up CPack +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "GTSAM") +set(CPACK_PACKAGE_VENDOR "Frank Dellaert, Georgia Institute of Technology") +set(CPACK_PACKAGE_CONTACT "Frank Dellaert, dellaert@cc.gatech.edu") +set(CPACK_PACKAGE_DESCRIPTION_FILE "${CMAKE_CURRENT_SOURCE_DIR}/README") +set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") +set(CPACK_PACKAGE_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) +set(CPACK_PACKAGE_VERSION_MINOR ${GTSAM_VERSION_MINOR}) +set(CPACK_PACKAGE_VERSION_PATCH ${GTSAM_VERSION_PATCH}) +set(CPACK_PACKAGE_INSTALL_DIRECTORY "CMake ${CMake_VERSION_MAJOR}.${CMake_VERSION_MINOR}") +set(CPACK_INSTALLED_DIRECTORIES "doc;.") # Include doc directory +set(CPACK_SOURCE_IGNORE_FILES "/build;/\\\\.;/makedoc.sh$") +set(CPACK_SOURCE_PACKAGE_FILE_NAME "gtsam-${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") +#set(CPACK_SOURCE_PACKAGE_FILE_NAME "gtsam-aspn${GTSAM_VERSION_PATCH}") # Used for creating ASPN tarballs + +# Record the root dir for gtsam - needed during external builds, e.g., ROS +set(GTSAM_SOURCE_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +message(STATUS "GTSAM_SOURCE_ROOT_DIR: [${GTSAM_SOURCE_ROOT_DIR}]") # print configuration variables message(STATUS "===============================================================") message(STATUS "================ Configuration Options ======================") message(STATUS "Build flags ") -print_config_flag(${GTSAM_BUILD_TIMING} "Build Timing scripts ") -print_config_flag(${GTSAM_BUILD_EXAMPLES} "Build Examples ") -print_config_flag(${GTSAM_BUILD_TESTS} "Build Tests ") -print_config_flag(${GTSAM_BUILD_WRAP} "Build Wrap ") -print_config_flag(${GTSAM_BUILD_CONVENIENCE_LIBRARIES} "Build Convenience Libraries") +print_config_flag(${GTSAM_BUILD_TIMING} "Build Timing scripts ") +print_config_flag(${GTSAM_BUILD_EXAMPLES} "Build Examples ") +print_config_flag(${GTSAM_BUILD_TESTS} "Build Tests ") +print_config_flag(${GTSAM_BUILD_WRAP} "Build Wrap ") +print_config_flag(${GTSAM_BUILD_SHARED_LIBRARY} "Build shared GTSAM Library ") +print_config_flag(${GTSAM_BUILD_STATIC_LIBRARY} "Build static GTSAM Library ") +print_config_flag(${GTSAM_BUILD_CONVENIENCE_LIBRARIES} "Build Convenience Libraries ") +print_config_flag(${GTSAM_BUILD_UNSTABLE} "Build libgtsam_unstable ") +print_config_flag(${GTSAM_ENABLE_INSTALL_TEST_FIX} "Tests excluded from all target ") string(TOUPPER "${CMAKE_BUILD_TYPE}" cmake_build_type_toupper) message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") message(STATUS " C compilation flags : ${CMAKE_C_FLAGS} ${CMAKE_C_FLAGS_${cmake_build_type_toupper}}") message(STATUS " C++ compilation flags : ${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${cmake_build_type_toupper}}") +message(STATUS "Packaging flags ") +message(STATUS " CPack Source Generator : ${CPACK_SOURCE_GENERATOR}") +message(STATUS " CPack Generator : ${CPACK_GENERATOR}") + message(STATUS "GTSAM flags ") print_config_flag(${GTSAM_USE_QUATERNIONS} "Quaternions as default Rot3") @@ -150,19 +170,5 @@ print_config_flag(${GTSAM_INSTALL_MATLAB_TESTS} "Install matlab tests print_config_flag(${GTSAM_INSTALL_WRAP} "Install wrap utility ") message(STATUS "===============================================================") -# Set up CPack -set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "GTSAM") -set(CPACK_PACKAGE_VENDOR "Frank Dellaert, Georgia Institute of Technology") -set(CPACK_PACKAGE_DESCRIPTION_FILE "${CMAKE_CURRENT_SOURCE_DIR}/README") -set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") -set(CPACK_PACKAGE_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) -set(CPACK_PACKAGE_VERSION_MINOR ${GTSAM_VERSION_MINOR}) -set(CPACK_PACKAGE_VERSION_PATCH ${GTSAM_VERSION_PATCH}) -set(CPACK_PACKAGE_INSTALL_DIRECTORY "CMake ${CMake_VERSION_MAJOR}.${CMake_VERSION_MINOR}") -set(CPACK_INSTALLED_DIRECTORIES "doc" ".") # Include doc directory -set(CPACK_SOURCE_IGNORE_FILES "/build;/\\\\.;/makedoc.sh$") -set(CPACK_SOURCE_PACKAGE_FILE_NAME "gtsam-${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -#set(CPACK_SOURCE_PACKAGE_FILE_NAME "gtsam-aspn${GTSAM_VERSION_PATCH}") # Used for creating ASPN tarballs -set(CPACK_SOURCE_GENERATOR "TGZ") -set(CPACK_GENERATOR "TGZ") +# Include CPack *after* all flags include(CPack) diff --git a/CppUnitLite/SimpleString.cpp b/CppUnitLite/SimpleString.cpp index 90356183a..19dc7f258 100644 --- a/CppUnitLite/SimpleString.cpp +++ b/CppUnitLite/SimpleString.cpp @@ -102,7 +102,7 @@ SimpleString StringFrom (long value) SimpleString StringFrom (double value) { char buffer [DEFAULT_SIZE]; - sprintf (buffer, "%lf", value); + sprintf (buffer, "%lg", value); return SimpleString(buffer); } diff --git a/Doxyfile b/Doxyfile index 2ecf956f5..2bd88de6c 100644 --- a/Doxyfile +++ b/Doxyfile @@ -32,7 +32,7 @@ PROJECT_NAME = gtsam # This could be handy for archiving the generated documentation or # if some version control system is used. -PROJECT_NUMBER = 0.9.3 +PROJECT_NUMBER = 1.9.0 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer @@ -641,10 +641,15 @@ WARN_LOGFILE = INPUT = gtsam/base \ gtsam/geometry \ gtsam/inference \ + gtsam/discrete \ gtsam/linear \ gtsam/nonlinear \ - gtsam/slam \ - gtsam + gtsam \ + gtsam_unstable/slam \ + gtsam_unstable/base \ + gtsam_unstable/geometry \ + gtsam_unstable/dynamics \ + gtsam_unstable # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding, which is diff --git a/README b/README index 095d7b79c..595bbccf4 100644 --- a/README +++ b/README @@ -11,7 +11,10 @@ What is GTSAM ? GTSAM is not (yet) open source: See COPYING & LICENSE Please see USAGE for an example on how to use GTSAM. -The code is organized according to the following directory structure: +The core GTSAM code within the folder gtsam, with source, headers, and +unit tests. After building, this will generate a single library "libgtsam" + +The libgtsam code is organized according to the following directory structure: 3rdparty local copies of third party libraries - Eigen3 and CCOLAMD base provides some base Math and data structures, as well as test-related utilities @@ -20,7 +23,13 @@ The code is organized according to the following directory structure: linear inference specialized to Gaussian linear case, GaussianFactorGraph etc... nonlinear non-linear factor graphs and non-linear optimization slam SLAM and visual SLAM application code - + +Additionally, in the SVN development version of GTSAM, there is an area for +unstable code directly under development in the folder gtsam_unstable, which contains +a directory structuring mirroring the libgtsam structure as necessary. This section produces +a single library "libgtsam_unstable". Building of gtsam_unstable is disabled by default, see +CMake configuration options for information on enabling building. + This library contains unchanged copies of two third party libraries, with documentation of licensing in LICENSE and as follows: - CCOLAMD 2.73: Tim Davis' constrained column approximate minimum degree ordering library @@ -28,11 +37,10 @@ of licensing in LICENSE and as follows: - Licenced under LGPL v2.1, provided in gtsam/3rdparty/CCOLAMD/Doc/lesser.txt - Eigen 3.0.5: General C++ matrix and linear algebra library - Licenced under LGPL v3, provided in gtsam/3rdparty/Eigen/COPYING.LGPL - -All of the above contain code and tests, and produce a single library libgtsam. + After this is built, you can also run the more involved tests, which test the entire library: - tests more involved tests that depend on slam + tests More involved unit tests that depend on slam examples Demo applications as a tutorial for using gtsam cmake CMake scripts used within the library, as well as for finding GTSAM by dependent projects @@ -158,6 +166,16 @@ $] cmake -DGTSAM_BUILD_CONVENIENCE_LIBRARIES:OPTION=ON .. link all of the tests at once. This option is best for users of GTSAM, as it avoids rebuilding the entirety of gtsam an extra time. +GTSAM_BUILD_UNSTABLE: Enable build and install for libgtsam_unstable library. +Set with the command line as follows: +$] cmake -DGTSAM_BUILD_UNSTABLE:OPTION=ON .. + ON When enabled, libgtsam_unstable will be built and installed with the + same options as libgtsam. In addition, if tests are enabled, the + unit tests will be built as well. The Matlab toolbox will also + be generated if the matlab toolbox is enabled, installing into a + folder called "gtsam_unstable". + OFF (Default) If disabled, no gtsam_unstable code will be included in build or install. + CMAKE_BUILD_TYPE: We support several build configurations for GTSAM (case insensitive) Debug (default) All error checking options on, no optimization. Use for development. Release Optimizations turned on, no debug symbols. @@ -176,12 +194,13 @@ Running "make install" will install the library to the prefix location. Check -As with autotools, "make check" will build and run all of the tests. You can also -run "make timing" to build all of the timing scripts. To run check on a particular -subsection, there is a convention of "make check.[subfolder]", so to run just the -geometry tests, run "make check.geometry". Individual tests can be run by -appending ".run" to the name of the test, for example, to run testMatrix, -run "make testMatrix.run". +As with autotools, "make check" will build and run all of the tests. Note that the +tests will only be built when using the "check" targets, to prevent "make install" from +building the tests unnecessarily. You can also run "make timing" to build all of +the timing scripts. To run check on a particular subsection, there is a convention +of "make check.[subfolder]", so to run just the geometry tests, +run "make check.geometry". Individual tests can be run by appending ".run" to the +name of the test, for example, to run testMatrix, run "make testMatrix.run". The make target "wrap" will build the wrap binary, and the "wrap_gtsam" target will generate code for the toolbox. By default, the toolbox will be created and installed diff --git a/examples/Pose2SLAMExample_advanced.cpp b/examples/Pose2SLAMExample_advanced.cpp index c495dba97..9cc8c2040 100644 --- a/examples/Pose2SLAMExample_advanced.cpp +++ b/examples/Pose2SLAMExample_advanced.cpp @@ -65,16 +65,16 @@ int main(int argc, char** argv) { Ordering::shared_ptr ordering = graph->orderingCOLAMD(*initial); /* 4.2.2 set up solver and optimize */ - LevenbergMarquardtParams params; - params.relativeErrorTol = 1e-15; - params.absoluteErrorTol = 1e-15; - pose2SLAM::Values result = *LevenbergMarquardtOptimizer(graph, initial, params, ordering).optimized(); + NonlinearOptimizationParameters::shared_ptr params = NonlinearOptimizationParameters::newDecreaseThresholds(1e-15, 1e-15); + Optimizer optimizer(graph, initial, ordering, params); + Optimizer optimizer_result = optimizer.levenbergMarquardt(); + + pose2SLAM::Values result = *optimizer_result.values(); result.print("final result"); /* Get covariances */ - GaussianMultifrontalSolver solver(*graph->linearize(result, *ordering)); - Matrix covariance1 = solver.marginalCovariance(ordering->at(PoseKey(1))); - Matrix covariance2 = solver.marginalCovariance(ordering->at(PoseKey(1))); + Matrix covariance1 = optimizer_result.marginalCovariance(PoseKey(1)); + Matrix covariance2 = optimizer_result.marginalCovariance(PoseKey(2)); print(covariance1, "Covariance1"); print(covariance2, "Covariance2"); diff --git a/examples/README b/examples/README index 87368b2ef..d7706a0b8 100644 --- a/examples/README +++ b/examples/README @@ -26,4 +26,10 @@ Visual SLAM The directory vSLAMexample includes 2 simple examples using GTSAM: - vSFMexample using visualSLAM in for structure-from-motion (SFM), and - vISAMexample using visualSLAM and ISAM for incremental SLAM updates -See the separate README file there. \ No newline at end of file +See the separate README file there. + +Undirected Graphical Models (UGM) +================================= +The best representation for a Markov Random Field is a factor graph :-) +This is illustrated with some discrete examples from the UGM MATLAB toolbox, which +can be found at http://www.di.ens.fr/~mschmidt/Software/UGM \ No newline at end of file diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp new file mode 100644 index 000000000..a4655d2ad --- /dev/null +++ b/examples/UGM_small.cpp @@ -0,0 +1,74 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file small.cpp + * @brief UGM (undirected graphical model) examples: small + * @author Frank Dellaert + * + * See http://www.di.ens.fr/~mschmidt/Software/UGM/small.html + */ + +#include +#include + +using namespace std; +using namespace gtsam; + +int main(int argc, char** argv) { + + // We will assume 2-state variables, where, to conform to the "small" example + // we have 0 == "right answer" and 1 == "wrong answer" + size_t nrStates = 2; + + // define variables + DiscreteKey Cathy(1, nrStates), Heather(2, nrStates), Mark(3, nrStates), + Allison(4, nrStates); + + // create graph + DiscreteFactorGraph graph; + + // add node potentials + graph.add(Cathy, "1 3"); + graph.add(Heather, "9 1"); + graph.add(Mark, "1 3"); + graph.add(Allison, "9 1"); + + // add edge potentials + graph.add(Cathy & Heather, "2 1 1 2"); + graph.add(Heather & Mark, "2 1 1 2"); + graph.add(Mark & Allison, "2 1 1 2"); + + // Print the UGM distribution + cout << "\nUGM distribution:" << endl; + for (size_t a = 0; a < nrStates; a++) + for (size_t m = 0; m < nrStates; m++) + for (size_t h = 0; h < nrStates; h++) + for (size_t c = 0; c < nrStates; c++) { + DiscreteFactor::Values values; + values[1] = c; + values[2] = h; + values[3] = m; + values[4] = a; + double prodPot = graph(values); + cout << c << " " << h << " " << m << " " << a << " :\t" + << prodPot << "\t" << prodPot/3790 << endl; + } + + // "Decoding", i.e., configuration with largest value + // We use sequential variable elimination + DiscreteSequentialSolver solver(graph); + DiscreteFactor::sharedValues optimalDecoding = solver.optimize(); + optimalDecoding->print("\noptimalDecoding"); + + return 0; +} + diff --git a/gtsam.h b/gtsam.h index 26f3457c7..0e4d70e16 100644 --- a/gtsam.h +++ b/gtsam.h @@ -32,6 +32,9 @@ * Namespace usage * - Namespaces can be specified for classes in arguments and return values * - In each case, the namespace must be fully specified, e.g., "namespace1::namespace2::ClassName" + * Using namespace + * - To use a namespace (e.g., generate a "using namespace x" line in cpp files), add "using namespace x;" + * - This declaration applies to all classes *after* the declaration, regardless of brackets * Methods must start with a lowercase letter * Static methods must start with a letter (upper or lowercase) and use the "static" keyword * Includes in C++ wrappers diff --git a/gtsam/3rdparty/CMakeLists.txt b/gtsam/3rdparty/CMakeLists.txt index 673411a58..8ee371fe8 100644 --- a/gtsam/3rdparty/CMakeLists.txt +++ b/gtsam/3rdparty/CMakeLists.txt @@ -1,10 +1,10 @@ # install CCOLAMD headers -install(FILES CCOLAMD/Include/ccolamd.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/gtsam/3rdparty/CCOLAMD) -install(FILES UFconfig/UFconfig.h DESTINATION ${CMAKE_INSTALL_PREFIX}/include/gtsam/3rdparty/UFconfig) +install(FILES CCOLAMD/Include/ccolamd.h DESTINATION include/gtsam/3rdparty/CCOLAMD) +install(FILES UFconfig/UFconfig.h DESTINATION include/gtsam/3rdparty/UFconfig) # install Eigen - only the headers install(DIRECTORY Eigen/Eigen - DESTINATION ${CMAKE_INSTALL_PREFIX}/include/gtsam/3rdparty/Eigen + DESTINATION include/gtsam/3rdparty/Eigen FILES_MATCHING PATTERN "*.h") file(GLOB eigen_dir_headers_all "Eigen/Eigen/*") @@ -12,6 +12,6 @@ file(GLOB eigen_dir_headers_all "Eigen/Eigen/*") foreach(eigen_dir ${eigen_dir_headers_all}) get_filename_component(filename ${eigen_dir} NAME) if (NOT ((${filename} MATCHES "CMakeLists.txt") OR (${filename} MATCHES "src") OR (${filename} MATCHES ".svn"))) - install(FILES Eigen/Eigen/${filename} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/gtsam/3rdparty/Eigen/Eigen) + install(FILES Eigen/Eigen/${filename} DESTINATION include/gtsam/3rdparty/Eigen/Eigen) endif() endforeach(eigen_dir) diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 0ca316b36..dcbd10075 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -4,6 +4,7 @@ set (gtsam_subdirs base geometry inference + discrete linear nonlinear slam @@ -25,10 +26,24 @@ if (GTSAM_BUILD_CONVENIENCE_LIBRARIES) add_library(ccolamd STATIC ${3rdparty_srcs}) endif() +# Sources to remove from builds +set (excluded_sources + "${CMAKE_CURRENT_SOURCE_DIR}/discrete/TypedDiscreteFactor.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/discrete/TypedDiscreteFactorGraph.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/discrete/parseUAI.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/discrete/PotentialTable.cpp") + +if(GTSAM_USE_QUATERNIONS) + set(excluded_sources ${excluded_sources} "${CMAKE_CURRENT_SOURCE_DIR}/geometry/Rot3M.cpp") +else() + set(excluded_sources ${excluded_sources} "${CMAKE_CURRENT_SOURCE_DIR}/geometry/Rot3Q.cpp") +endif() + # assemble core libaries foreach(subdir ${gtsam_subdirs}) # Build convenience libraries file(GLOB subdir_srcs "${subdir}/*.cpp") + list(REMOVE_ITEM subdir_srcs ${excluded_sources}) set(${subdir}_srcs ${subdir_srcs}) if (GTSAM_BUILD_CONVENIENCE_LIBRARIES) message(STATUS "Building Convenience Library: ${subdir}") @@ -46,13 +61,12 @@ set(gtsam_srcs ${base_srcs} ${geometry_srcs} ${inference_srcs} + ${discrete_srcs} ${linear_srcs} ${nonlinear_srcs} ${slam_srcs} ) -option (GTSAM_BUILD_SHARED_LIBRARY "Enable/Disable building of a shared version of gtsam" ON) - # Versions set(gtsam_version ${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}) set(gtsam_soversion ${GTSAM_VERSION_MAJOR}) @@ -60,14 +74,16 @@ message(STATUS "GTSAM Version: ${gtsam_version}") message(STATUS "Install prefix: ${CMAKE_INSTALL_PREFIX}") # build shared and static versions of the library -message(STATUS "Building GTSAM - static") -add_library(gtsam-static STATIC ${gtsam_srcs}) -set_target_properties(gtsam-static PROPERTIES - OUTPUT_NAME gtsam - CLEAN_DIRECT_OUTPUT 1 - VERSION ${gtsam_version} - SOVERSION ${gtsam_soversion}) -install(TARGETS gtsam-static ARCHIVE DESTINATION lib) +if (GTSAM_BUILD_STATIC_LIBRARY) + message(STATUS "Building GTSAM - static") + add_library(gtsam-static STATIC ${gtsam_srcs}) + set_target_properties(gtsam-static PROPERTIES + OUTPUT_NAME gtsam + CLEAN_DIRECT_OUTPUT 1 + VERSION ${gtsam_version} + SOVERSION ${gtsam_soversion}) + install(TARGETS gtsam-static ARCHIVE DESTINATION lib) +endif (GTSAM_BUILD_STATIC_LIBRARY) if (GTSAM_BUILD_SHARED_LIBRARY) message(STATUS "Building GTSAM - shared") diff --git a/gtsam/base/DSFVector.cpp b/gtsam/base/DSFVector.cpp deleted file mode 100644 index ceb33fc23..000000000 --- a/gtsam/base/DSFVector.cpp +++ /dev/null @@ -1,97 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file DSFVector.cpp - * @date Jun 25, 2010 - * @author Kai Ni - * @brief a faster implementation for DSF, which uses vector rather than btree. - */ - -#include -#include -#include - -using namespace std; - -namespace gtsam { - - /* ************************************************************************* */ - DSFVector::DSFVector (const size_t numNodes) { - v_ = boost::make_shared(numNodes); - int index = 0; - keys_.reserve(numNodes); - for(V::iterator it = v_->begin(); it!=v_->end(); it++, index++) { - *it = index; - keys_.push_back(index); - } - } - - /* ************************************************************************* */ - DSFVector::DSFVector(const boost::shared_ptr& v_in, const std::vector& keys) : keys_(keys) { - v_ = v_in; - BOOST_FOREACH(const size_t key, keys) - (*v_)[key] = key; - } - - /* ************************************************************************* */ - bool DSFVector::isSingleton(const Label& label) const { - bool result = false; - V::const_iterator it = keys_.begin(); - for (; it != keys_.end(); ++it) { - if(findSet(*it) == label) { - if (!result) // find the first occurrence - result = true; - else - return false; - } - } - return result; - } - - /* ************************************************************************* */ - std::set DSFVector::set(const Label& label) const { - std::set set; - V::const_iterator it = keys_.begin(); - for (; it != keys_.end(); it++) { - if (findSet(*it) == label) - set.insert(*it); - } - return set; - } - - /* ************************************************************************* */ - std::map > DSFVector::sets() const { - std::map > sets; - V::const_iterator it = keys_.begin(); - for (; it != keys_.end(); it++) { - sets[findSet(*it)].insert(*it); - } - return sets; - } - - /* ************************************************************************* */ - std::map > DSFVector::arrays() const { - std::map > arrays; - V::const_iterator it = keys_.begin(); - for (; it != keys_.end(); it++) { - arrays[findSet(*it)].push_back(*it); - } - return arrays; - } - - /* ************************************************************************* */ - void DSFVector::makeUnionInPlace(const size_t& i1, const size_t& i2) { - (*v_)[findSet(i2)] = findSet(i1); - } - -} // namespace - diff --git a/gtsam/base/FastSet.h b/gtsam/base/FastSet.h index 05fb879a1..6abd7efb8 100644 --- a/gtsam/base/FastSet.h +++ b/gtsam/base/FastSet.h @@ -105,7 +105,7 @@ struct FastSetTestableHelper { typename Set::const_iterator it2 = set2.begin(); while (it1 != set1.end()) { if (it2 == set2.end() || - fabs((double)(*it1) - (double)(*it2) > tol)) + fabs((double)(*it1) - (double)(*it2)) > tol) return false; ++it1; ++it2; diff --git a/gtsam/base/LieVector.h b/gtsam/base/LieVector.h index 9a8e5480f..82455e173 100644 --- a/gtsam/base/LieVector.h +++ b/gtsam/base/LieVector.h @@ -34,6 +34,10 @@ struct LieVector : public Vector, public DerivedValue { /** initialize from a normal vector */ LieVector(const Vector& v) : Vector(v) {} + /** initialize from a fixed size normal vector */ + template + LieVector(const Eigen::Matrix& v) : Vector(v) {} + /** wrap a double */ LieVector(double d) : Vector(Vector_(1, d)) {} diff --git a/gtsam/base/Matrix.cpp b/gtsam/base/Matrix.cpp index b3c51325d..849dc3f5a 100644 --- a/gtsam/base/Matrix.cpp +++ b/gtsam/base/Matrix.cpp @@ -541,12 +541,12 @@ Matrix vector_scale(const Matrix& A, const Vector& v, bool inf_mask) { } /* ************************************************************************* */ -Matrix skewSymmetric(double wx, double wy, double wz) +Matrix3 skewSymmetric(double wx, double wy, double wz) { - return Matrix_(3,3, + return (Matrix3() << 0.0, -wz, +wy, +wz, 0.0, -wx, - -wy, +wx, 0.0); + -wy, +wx, 0.0).finished(); } /* ************************************************************************* */ diff --git a/gtsam/base/Matrix.h b/gtsam/base/Matrix.h index fee61f62b..693f1e3da 100644 --- a/gtsam/base/Matrix.h +++ b/gtsam/base/Matrix.h @@ -37,6 +37,10 @@ namespace gtsam { typedef Eigen::MatrixXd Matrix; typedef Eigen::Matrix MatrixRowMajor; +typedef Eigen::Matrix3d Matrix3; +typedef Eigen::Matrix4d Matrix4; +typedef Eigen::Matrix Matrix6; + // Matrix expressions for accessing parts of matrices typedef Eigen::Block SubMatrix; typedef Eigen::Block ConstSubMatrix; @@ -393,8 +397,8 @@ Matrix vector_scale(const Matrix& A, const Vector& v, bool inf_mask = false); // * @param wz * @return a 3*3 skew symmetric matrix */ -Matrix skewSymmetric(double wx, double wy, double wz); -inline Matrix skewSymmetric(const Vector& w) { return skewSymmetric(w(0),w(1),w(2));} +Matrix3 skewSymmetric(double wx, double wy, double wz); +inline Matrix3 skewSymmetric(const Vector& w) { return skewSymmetric(w(0),w(1),w(2));} /** Use SVD to calculate inverse square root of a matrix */ Matrix inverse_square_root(const Matrix& A); diff --git a/gtsam/base/Vector.cpp b/gtsam/base/Vector.cpp index 5367cec09..27b56830c 100644 --- a/gtsam/base/Vector.cpp +++ b/gtsam/base/Vector.cpp @@ -38,6 +38,8 @@ using namespace std; +boost::minstd_rand generator(42u); + namespace gtsam { /* ************************************************************************* */ @@ -197,6 +199,15 @@ bool assert_equal(const Vector& expected, const Vector& actual, double tol) { return false; } +/* ************************************************************************* */ +bool assert_inequal(const Vector& expected, const Vector& actual, double tol) { + if (!equal_with_abs_tol(expected,actual,tol)) return true; + cout << "Erroneously equal:" << endl; + print(expected, "expected"); + print(actual, "actual"); + return false; +} + /* ************************************************************************* */ bool assert_equal(const SubVector& expected, const SubVector& actual, double tol) { if (equal_with_abs_tol(expected,actual,tol)) return true; diff --git a/gtsam/base/Vector.h b/gtsam/base/Vector.h index abfa44364..84fd506cb 100644 --- a/gtsam/base/Vector.h +++ b/gtsam/base/Vector.h @@ -25,14 +25,27 @@ #include #include -// Vector is just a typedef of the Eigen dynamic vector type -// TODO: make a version that works for matlab wrapping +/** + * Static random number generator - needs to maintain a state + * over time, hence the static generator. Be careful in + * cases where multiple processes (as is frequently the case with + * multi-robot scenarios) are using the sample() facilities + * in NoiseModel, as they will each have the same seed. + */ +// FIXME: make this go away - use the Sampler class instead +extern boost::minstd_rand generator; namespace gtsam { +// Vector is just a typedef of the Eigen dynamic vector type + // Typedef arbitary length vector typedef Eigen::VectorXd Vector; +// Commonly used fixed size vectors +typedef Eigen::Vector3d Vector3; +typedef Eigen::Matrix Vector6; + typedef Eigen::VectorBlock SubVector; typedef Eigen::VectorBlock ConstSubVector; @@ -156,6 +169,15 @@ inline bool equal(const Vector& vec1, const Vector& vec2) { */ bool assert_equal(const Vector& vec1, const Vector& vec2, double tol=1e-9); +/** + * Not the same, prints if error + * @param vec1 Vector + * @param vec2 Vector + * @param tol 1e-9 + * @return bool + */ +bool assert_inequal(const Vector& vec1, const Vector& vec2, double tol=1e-9); + /** * Same, prints if error * @param vec1 Vector @@ -340,12 +362,17 @@ Vector concatVectors(size_t nrVectors, ...); */ Vector rand_vector_norm(size_t dim, double mean = 0, double sigma = 1); +/** + * Sets the generator to use a different seed value. + * Default argument resets the RNG + * @param seed is the new seed + */ +inline void seedRNG(unsigned int seed = 42u) { + generator.seed(seed); +} + } // namespace gtsam -// FIXME: make this go away - use the Sampler class instead -static boost::minstd_rand generator(42u); - - #include #include diff --git a/gtsam/base/tests/testCholesky.cpp b/gtsam/base/tests/testCholesky.cpp index 68b2bf660..5232fc6ba 100644 --- a/gtsam/base/tests/testCholesky.cpp +++ b/gtsam/base/tests/testCholesky.cpp @@ -160,6 +160,55 @@ TEST(cholesky, ldlPartial2) { EXPECT(assert_equal(IexpectedR, p.transpose()*I)); } +/* ************************************************************************* */ +TEST(cholesky, BadScalingCholesky) { + Matrix A = Matrix_(2,2, + 1e-40, 0.0, + 0.0, 1.0); + + Matrix R(A.transpose() * A); + choleskyPartial(R, 2); + + double expectedSqrtCondition = 1e-40; + double actualSqrtCondition = R(0,0) / R(1,1); + + DOUBLES_EQUAL(expectedSqrtCondition, actualSqrtCondition, 1e-41); +} + +/* ************************************************************************* */ +TEST(cholesky, BadScalingLDL) { + Matrix A = Matrix_(2,2, + 1.0, 0.0, + 0.0, 1e-40); + + Matrix R(A.transpose() * A); + Eigen::LDLT::TranspositionType permutation = ldlPartial(R, 2); + + EXPECT(permutation.indices()(0) == 0); + EXPECT(permutation.indices()(1) == 1); + + double expectedCondition = 1e40; + double actualCondition = R(0,0) / R(1,1); + + DOUBLES_EQUAL(expectedCondition, actualCondition, 1e-41); +} + +/* ************************************************************************* */ +TEST(cholesky, BadScalingSVD) { + Matrix A = Matrix_(2,2, + 1.0, 0.0, + 0.0, 1e-40); + + Matrix U, V; + Vector S; + gtsam::svd(A, U, S, V); + + double expectedCondition = 1e40; + double actualCondition = S(0) / S(1); + + DOUBLES_EQUAL(expectedCondition, actualCondition, 1e-41); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/base/tests/testFixedVector.cpp b/gtsam/base/tests/testFixedVector.cpp deleted file mode 100644 index 85a963c89..000000000 --- a/gtsam/base/tests/testFixedVector.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testFixedVector.cpp - * @author Alex Cunningham - */ - -#include - -#include - -using namespace gtsam; - -typedef FixedVector<5> Vector5; -typedef FixedVector<3> Vector3; - -static const double tol = 1e-9; - -/* ************************************************************************* */ -TEST( testFixedVector, conversions ) { - double data1[] = {1.0, 2.0, 3.0}; - Vector v1 = Vector_(3, data1); - Vector3 fv1(v1), fv2(data1); - - Vector actFv2(fv2); - CHECK(assert_equal(v1, actFv2)); -} - -/* ************************************************************************* */ -TEST( testFixedVector, variable_constructor ) { - Vector3 act(3, 1.0, 2.0, 3.0); - DOUBLES_EQUAL(1.0, act(0), tol); - DOUBLES_EQUAL(2.0, act(1), tol); - DOUBLES_EQUAL(3.0, act(2), tol); -} - -/* ************************************************************************* */ -TEST( testFixedVector, equals ) { - Vector3 vec1(3, 1.0, 2.0, 3.0), vec2(3, 1.0, 2.0, 3.0), vec3(3, 2.0, 3.0, 4.0); - Vector5 vec4(5, 1.0, 2.0, 3.0, 4.0, 5.0); - - CHECK(assert_equal(vec1, vec1, tol)); - CHECK(assert_equal(vec1, vec2, tol)); - CHECK(assert_equal(vec2, vec1, tol)); - CHECK(!vec1.equals(vec3, tol)); - CHECK(!vec3.equals(vec1, tol)); - CHECK(!vec1.equals(vec4, tol)); - CHECK(!vec4.equals(vec1, tol)); -} - -/* ************************************************************************* */ -TEST( testFixedVector, static_constructors ) { - Vector3 actZero = Vector3::zero(); - Vector3 expZero(3, 0.0, 0.0, 0.0); - CHECK(assert_equal(expZero, actZero, tol)); - - Vector3 actOnes = Vector3::ones(); - Vector3 expOnes(3, 1.0, 1.0, 1.0); - CHECK(assert_equal(expOnes, actOnes, tol)); - - Vector3 actRepeat = Vector3::repeat(2.3); - Vector3 expRepeat(3, 2.3, 2.3, 2.3); - CHECK(assert_equal(expRepeat, actRepeat, tol)); - - Vector3 actBasis = Vector3::basis(1); - Vector3 expBasis(3, 0.0, 1.0, 0.0); - CHECK(assert_equal(expBasis, actBasis, tol)); - - Vector3 actDelta = Vector3::delta(1, 2.3); - Vector3 expDelta(3, 0.0, 2.3, 0.0); - CHECK(assert_equal(expDelta, actDelta, tol)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ - - diff --git a/gtsam/base/tests/testVector.cpp b/gtsam/base/tests/testVector.cpp index 7506043c3..8b8757027 100644 --- a/gtsam/base/tests/testVector.cpp +++ b/gtsam/base/tests/testVector.cpp @@ -297,6 +297,27 @@ TEST( TestVector, linear_dependent3 ) EXPECT(!linear_dependent(v1, v2)); } +/* ************************************************************************* */ +TEST( TestVector, random ) +{ + // Assumes seed not previously reset during this test + seedRNG(); + Vector v1_42 = rand_vector_norm(5); + + // verify that resetting the RNG produces the same value + seedRNG(); + Vector v2_42 = rand_vector_norm(5); + + EXPECT(assert_equal(v1_42, v2_42, 1e-6)); + + // verify that different seed produces a different value + seedRNG(41u); + + Vector v3_41 = rand_vector_norm(5); + + EXPECT(assert_inequal(v1_42, v3_41, 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h new file mode 100644 index 000000000..029223ef3 --- /dev/null +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -0,0 +1,134 @@ +/* + * @file AlgebraicDecisionTree.h + * @brief Algebraic Decision Trees + * @author Frank Dellaert + * @date Mar 14, 2011 + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Algebraic Decision Trees fix the range to double + * Just has some nice constructors and some syntactic sugar + * TODO: consider eliminating this class altogether? + */ + template + class AlgebraicDecisionTree: public DecisionTree { + + public: + + typedef DecisionTree Super; + + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { + return 0.0; + } + static inline double one() { + return 1.0; + } + static inline double add(const double& a, const double& b) { + return a + b; + } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { + return a * b; + } + static inline double div(const double& a, const double& b) { + return a / b; + } + static inline double id(const double& x) { + return x; + } + }; + + AlgebraicDecisionTree() : + Super(1.0) { + } + + AlgebraicDecisionTree(const Super& add) : + Super(add) { + } + + /** Create a new leaf function splitting on a variable */ + AlgebraicDecisionTree(const L& label, double y1, double y2) : + Super(label, y1, y2) { + } + + /** Create a new leaf function splitting on a variable */ + AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : + Super(labelC, y1, y2) { + } + + /** Create from keys and vector table */ + AlgebraicDecisionTree // + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + ys.end()); + } + + /** Create from keys and string table */ + AlgebraicDecisionTree // + (const std::vector& labelCs, const std::string& table) { + // Convert string to doubles + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), + std::istream_iterator(), std::back_inserter(ys)); + + // now call recursive Create + this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + ys.end()); + } + + /** Create a new function splitting on a variable */ + template + AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : + Super(NULL) { + this->root_ = compose(begin, end, label); + } + + /** Convert */ + template + AlgebraicDecisionTree(const AlgebraicDecisionTree& other, + const std::map& map) { + this->root_ = this->template convert(other.root_, map, + Ring::id); + } + + /** sum */ + AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::add); + } + + /** product */ + AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::mul); + } + + /** division */ + AlgebraicDecisionTree operator/(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::div); + } + + /** sum out variable */ + AlgebraicDecisionTree sum(const L& label, size_t cardinality) const { + return this->combine(label, cardinality, &Ring::add); + } + + /** sum out variable */ + AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + return this->combine(labelC, &Ring::add); + } + + }; +// AlgebraicDecisionTree + +} +// namespace gtsam diff --git a/gtsam/discrete/AllDiff.cpp b/gtsam/discrete/AllDiff.cpp new file mode 100644 index 000000000..064e0d1c8 --- /dev/null +++ b/gtsam/discrete/AllDiff.cpp @@ -0,0 +1,110 @@ +/* + * AllDiff.cpp + * @brief General "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + /* ************************************************************************* */ + AllDiff::AllDiff(const DiscreteKeys& dkeys) : + DiscreteFactor(dkeys.indices()) { + BOOST_FOREACH(const DiscreteKey& dkey, dkeys) + cardinalities_.insert(dkey); + } + + /* ************************************************************************* */ + void AllDiff::print(const std::string& s) const { + std::cout << s << ": AllDiff on "; + BOOST_FOREACH (Index dkey, keys_) + std::cout << dkey << " "; + std::cout << std::endl; + } + + /* ************************************************************************* */ + double AllDiff::operator()(const Values& values) const { + std::set < size_t > taken; // record values taken by keys + BOOST_FOREACH(Index dkey, keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0;// check if value alreday taken + taken.insert(value);// if not, record it as taken and keep checking + } + return 1.0; + } + + /* ************************************************************************* */ + AllDiff::operator DecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); + converted = converted * binary12; + } + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool AllDiff::ensureArcConsistency(size_t j, std::vector& domains) const { + // Though strictly not part of allDiff, we check for + // a value in domains[j] that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + Domain& Dj = domains[j]; + if (Dj.checkAllDiff(keys_, domains)) return true; + + // Check all other domains for singletons and erase corresponding values + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + BOOST_FOREACH(Index k, keys_) + if (k != j) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; + } + } + } + return changed; + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr AllDiff::partiallyApply(const Values& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + BOOST_FOREACH(Index k, keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); + } + return boost::make_shared(newKeys); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr AllDiff::partiallyApply( + const std::vector& domains) const { + DiscreteFactor::Values known; + BOOST_FOREACH(Index k, keys_) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) + known[k] = Dk.firstValue(); + } + return partiallyApply(known); + } + + /* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/AllDiff.h b/gtsam/discrete/AllDiff.h new file mode 100644 index 000000000..846dc335b --- /dev/null +++ b/gtsam/discrete/AllDiff.h @@ -0,0 +1,64 @@ +/* + * AllDiff.h + * @brief General "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * General AllDiff constraint + * Returns 1 if values for all keys are different, 0 otherwise + * DiscreteFactors are all awkward in that they have to store two types of keys: + * for each variable we have a Index and an Index. In this factor, we + * keep the Indices locally, and the Indices are stored in IndexFactor. + */ + class AllDiff: public DiscreteFactor { + + std::map cardinalities_; + + DiscreteKey discreteKey(size_t i) const { + Index j = keys_[i]; + return DiscreteKey(j,cardinalities_.at(j)); + } + + public: + + /// Constructor + AllDiff(const DiscreteKeys& dkeys); + + // print + virtual void print(const std::string& s = "") const; + + /// Calculate value = expensive ! + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree, can be *very* expensive ! + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * Arc-consistency involves creating binaryAllDiff constraints + * In which case the combinatorial hyper-arc explosion disappears. + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply(const std::vector&) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h new file mode 100644 index 000000000..0150f6ff9 --- /dev/null +++ b/gtsam/discrete/Assignment.h @@ -0,0 +1,36 @@ +/* + * @file Assignment.h + * @brief An assignment from labels to a discrete value index (size_t) + * @author Frank Dellaert + * @date Feb 5, 2012 + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + + /** + * An assignment from labels to value index (size_t). + * Assigns to each label a value. Implemented as a simple map. + * A discrete factor takes an Assignment and returns a value. + */ + template + class Assignment: public std::map { + public: + void print(const std::string& s = "Assignment: ") const { + std::cout << s << ": "; + BOOST_FOREACH(const typename Assignment::value_type& keyValue, *this) + std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + std::cout << std::endl; + } + + bool equals(const Assignment& other, double tol = 1e-9) const { + return (*this == other); + } + }; + +} // namespace gtsam diff --git a/gtsam/discrete/BinaryAllDiff.h b/gtsam/discrete/BinaryAllDiff.h new file mode 100644 index 000000000..31fe070c2 --- /dev/null +++ b/gtsam/discrete/BinaryAllDiff.h @@ -0,0 +1,87 @@ +/* + * BinaryAllDiff.h + * @brief Binary "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise + * DiscreteFactors are all awkward in that they have to store two types of keys: + * for each variable we have a Index and an Index. In this factor, we + * keep the Indices locally, and the Indices are stored in IndexFactor. + */ + class BinaryAllDiff: public DiscreteFactor { + + size_t cardinality0_, cardinality1_; /// cardinality + + public: + + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : + DiscreteFactor(key1.first, key2.first), + cardinality0_(key1.second), cardinality1_(key2.second) { + } + + // print + virtual void print(const std::string& s = "") const { + std::cout << s << ": BinaryAllDiff on " << keys_[0] << " and " << keys_[1] + << std::endl; + } + + /// Calculate value + virtual double operator()(const Values& values) const { + return (double) (values.at(keys_[0]) != values.at(keys_[1])); + } + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0],cardinality0_)); + keys.push_back(DiscreteKey(keys_[1],cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) + table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, std::vector& domains) const { +// throw std::runtime_error( +// "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector&) const { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + }; + +} // namespace gtsam diff --git a/gtsam/discrete/CMakeLists.txt b/gtsam/discrete/CMakeLists.txt new file mode 100644 index 000000000..d45340990 --- /dev/null +++ b/gtsam/discrete/CMakeLists.txt @@ -0,0 +1,38 @@ +# Install headers +set(subdir discrete) +file(GLOB discrete_headers "*.h") +# FIXME: exclude headers +install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete) + +# Set up library dependencies +set (discrete_local_libs + discrete + inference + base + ccolamd +) + +# Exclude tests that don't work +set (discrete_excluded_tests +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testTypedDiscreteFactor.cpp" +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testTypedDiscreteFactorGraph.cpp" +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testPotentialTable.cpp") + +# Add all tests +if (GTSAM_BUILD_TESTS) + gtsam_add_subdir_tests(discrete "${discrete_local_libs}" "gtsam-static" "${discrete_excluded_tests}") +endif() + +# add examples +foreach(example schedulingExample schedulingQuals12) + add_executable(${example} "examples/${example}.cpp") + add_dependencies(${example} gtsam-static) + target_link_libraries(${example} gtsam-static) + add_custom_target(${example}.run ${EXECUTABLE_OUTPUT_PATH}${example} ${ARGN}) +endforeach(example) + +# Build timing scripts +#if (GTSAM_BUILD_TIMING) +# gtsam_add_timing(discrete "${discrete_local_libs}") +#endif(GTSAM_BUILD_TIMING) + diff --git a/gtsam/discrete/CSP.cpp b/gtsam/discrete/CSP.cpp new file mode 100644 index 000000000..c0d57f320 --- /dev/null +++ b/gtsam/discrete/CSP.cpp @@ -0,0 +1,94 @@ +/* + * CSP.cpp + * @brief Constraint Satisfaction Problem class + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + + /// Find the best total assignment - can be expensive + CSP::sharedValues CSP::optimalAssignment() const { + DiscreteSequentialSolver solver(*this); + DiscreteBayesNet::shared_ptr chordal = solver.eliminate(); + sharedValues mpe = optimize(*chordal); + return mpe; + } + + void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const { + // Create VariableIndex + VariableIndex index(*this); + // index.print(); + + size_t n = index.size(); + + // Initialize domains + std::vector < Domain > domains; + for (size_t j = 0; j < n; j++) + domains.push_back(Domain(DiscreteKey(j,cardinality))); + + // Create array of flags indicating a domain changed or not + std::vector changed(n); + + // iterate nrIterations over entire grid + for (size_t it = 0; it < nrIterations; it++) { + bool anyChange = false; + // iterate over all cells + for (size_t v = 0; v < n; v++) { + // keep track of which domains changed + changed[v] = false; + // loop over all factors/constraints for variable v + const VariableIndex::Factors& factors = index[v]; + BOOST_FOREACH(size_t f,factors) { + // if not already a singleton + if (!domains[v].isSingleton()) { + // get the constraint and call its ensureArcConsistency method + DiscreteFactor::shared_ptr factor = (*this)[f]; + changed[v] = factor->ensureArcConsistency(v,domains) || changed[v]; + } + } // f + if (changed[v]) anyChange = true; + } // v + if (!anyChange) break; + // TODO: Sudoku specific hack + if (print) { + if (cardinality == 9 && n == 81) { + for (size_t i = 0, v = 0; i < sqrt(n); i++) { + for (size_t j = 0; j < sqrt(n); j++, v++) { + if (changed[v]) cout << "*"; + domains[v].print(); + cout << "\t"; + } // i + cout << endl; + } // j + } else { + for (size_t v = 0; v < n; v++) { + if (changed[v]) cout << "*"; + domains[v].print(); + cout << "\t"; + } // v + } + cout << endl; + } // print + } // it + +#ifndef INPROGRESS + // Now create new problem with all singleton variables removed + // We do this by adding simplifying all factors using parial application + // TODO: create a new ordering as we go, to ensure a connected graph + // KeyOrdering ordering; + // vector dkeys; + BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors_) { + DiscreteFactor::shared_ptr reduced = factor->partiallyApply(domains); + if (print) reduced->print(); + } +#endif + } +} // gtsam + diff --git a/gtsam/discrete/CSP.h b/gtsam/discrete/CSP.h new file mode 100644 index 000000000..d423426fd --- /dev/null +++ b/gtsam/discrete/CSP.h @@ -0,0 +1,71 @@ +/* + * CSP.h + * @brief Constraint Satisfaction Problem class + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + + /** + * Constraint Satisfaction Problem class + * A specialization of a DiscreteFactorGraph. + * It knows about CSP-specific constraints and algorithms + */ + class CSP: public DiscreteFactorGraph { + + public: + /// Constructor + CSP() { + } + + /// Add a unary constraint, allowing only a single value + void addSingleValue(const DiscreteKey& dkey, size_t value) { + boost::shared_ptr factor(new SingleValue(dkey, value)); + push_back(factor); + } + + /// Add a binary AllDiff constraint + void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { + boost::shared_ptr factor( + new BinaryAllDiff(key1, key2)); + push_back(factor); + } + + /// Add a general AllDiff constraint + void addAllDiff(const DiscreteKeys& dkeys) { + boost::shared_ptr factor(new AllDiff(dkeys)); + push_back(factor); + } + + /// Find the best total assignment - can be expensive + sharedValues optimalAssignment() const; + + /* + * Perform loopy belief propagation + * True belief propagation would check for each value in domain + * whether any satisfying separator assignment can be found. + * This corresponds to hyper-arc consistency in CSP speak. + * This can be done by creating a mini-factor graph and search. + * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. + * It will be very expensive to exclude values that way. + */ + // void applyBeliefPropagation(size_t nrIterations = 10) const; + /* + * Apply arc-consistency ~ Approximate loopy belief propagation + * We need to give the domains to a constraint, and it returns + * a domain whose values don't conflict in the arc-consistency way. + * TODO: should get cardinality from Indices + */ + void runArcConsistency(size_t cardinality, size_t nrIterations = 10, + bool print = false) const; + }; + +} // gtsam + diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h new file mode 100644 index 000000000..867a9c824 --- /dev/null +++ b/gtsam/discrete/DecisionTree-inl.h @@ -0,0 +1,667 @@ +/* + * @file DecisionTree.h + * @brief Decision Tree for use in DiscreteFactors + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + + using namespace boost::assign; + + /*********************************************************************************/ + // Node + /*********************************************************************************/ +#ifdef DT_DEBUG_MEMORY + template + int DecisionTree::Node::nrNodes = 0; +#endif + + /*********************************************************************************/ + // Leaf + /*********************************************************************************/ + template + class DecisionTree::Leaf: public DecisionTree::Node { + + /** constant stored in this leaf */ + Y constant_; + + public: + + /** Constructor from constant */ + Leaf(const Y& constant) : + constant_(constant) {} + + /** return the constant */ + const Y& constant() const { + return constant_; + } + + /// Leaf-Leaf equality + bool sameLeaf(const Leaf& q) const { + return constant_ == q.constant_; + } + + /// polymorphic equality: is q is a leaf, could be + bool sameLeaf(const Node& q) const { + return (q.isLeaf() && q.sameLeaf(*this)); + } + + /** equality up to tolerance */ + bool equals(const Node& q, double tol) const { + const Leaf* other = dynamic_cast (&q); + if (!other) return false; + return fabs(this->constant_ - other->constant_) < tol; + } + + /** print */ + void print(const std::string& s) const { + bool showZero = true; + if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + } + + /** to graphviz file */ + void dot(std::ostream& os, bool showZero) const { + if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" + << boost::format("%4.2g") % constant_ + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + } + + /** evaluate */ + const Y& operator()(const Assignment& x) const { + return constant_; + } + + /** apply unary operator */ + NodePtr apply(const Unary& op) const { + NodePtr f(new Leaf(op(constant_))); + return f; + } + + // Apply binary operator "h = f op g" on Leaf node + // Note op is not assumed commutative so we need to keep track of order + // Simply calls apply on argument to call correct virtual method: + // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below) + // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice) + NodePtr apply_f_op_g(const Node& g, const Binary& op) const { + return g.apply_g_op_fL(*this, op); + } + + // Applying binary operator to two leaves results in a leaf + NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const { + NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + return h; + } + + // If second argument is a Choice node, call it's apply with leaf as second + NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const { + return fC.apply_fC_op_gL(*this, op); // operand order back to normal + } + + /** choose a branch, create new memory ! */ + NodePtr choose(const L& label, size_t index) const { + return NodePtr(new Leaf(constant())); + } + + bool isLeaf() const { return true; } + + }; // Leaf + + /*********************************************************************************/ + // Choice + /*********************************************************************************/ + template + class DecisionTree::Choice: public DecisionTree::Node { + + /** the label of the variable on which we split */ + L label_; + + /** The children of this Choice node. */ + std::vector branches_; + + private: + /** incremental allSame */ + size_t allSame_; + + typedef boost::shared_ptr ChoicePtr; + + public: + + virtual ~Choice() { +#ifdef DT_DEBUG_MEMORY + std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; +#endif + } + + /** If all branches of a choice node f are the same, just return a branch */ + static NodePtr Unique(const ChoicePtr& f) { +#ifndef DT_NO_PRUNING + if (f->allSame_) { + assert(f->branches().size() > 0); + NodePtr f0 = f->branches_[0]; + assert(f0->isLeaf()); + NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + return newLeaf; + } else +#endif + return f; + } + + bool isLeaf() const { return false; } + + /** Constructor, given choice label and mandatory expected branch count */ + Choice(const L& label, size_t count) : + label_(label), allSame_(true) { + branches_.reserve(count); + } + + /** + * Construct from applying binary op to two Choice nodes + */ + Choice(const Choice& f, const Choice& g, const Binary& op) : + allSame_(true) { + + // Choose what to do based on label + if (f.label() > g.label()) { + // f higher than g + label_ = f.label(); + size_t count = f.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(f.branches_[i]->apply_f_op_g(g, op)); + } else if (g.label() > f.label()) { + // f lower than g + label_ = g.label(); + size_t count = g.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(g.branches_[i]->apply_g_op_fC(f, op)); + } else { + // f same level as g + label_ = f.label(); + size_t count = f.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op)); + } + } + + const L& label() const { + return label_; + } + + size_t nrChoices() const { + return branches_.size(); + } + + const std::vector& branches() const { + return branches_; + } + + /** add a branch: TODO merge into constructor */ + void push_back(const NodePtr& node) { + // allSame_ is restricted to leaf nodes in a decision tree + if (allSame_ && !branches_.empty()) { + allSame_ = node->sameLeaf(*branches_.back()); + } + branches_.push_back(node); + } + + /** print (as a tree) */ + void print(const std::string& s) const { + std::cout << s << " Choice("; + // std::cout << this << ","; + std::cout << label_ << ") " << std::endl; + for (size_t i = 0; i < branches_.size(); i++) + branches_[i]->print((boost::format("%s %d") % s % i).str()); + } + + /** output to graphviz (as a a graph) */ + void dot(std::ostream& os, bool showZero) const { + os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ + << "\"]\n"; + for (size_t i = 0; i < branches_.size(); i++) { + NodePtr branch = branches_[i]; + + // Check if zero + if (!showZero) { + const Leaf* leaf = dynamic_cast (branch.get()); + if (leaf && !leaf->constant()) continue; + } + + os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + os << std::endl; + branch->dot(os, showZero); + } + } + + /// Choice-Leaf equality: always false + bool sameLeaf(const Leaf& q) const { + return false; + } + + /// polymorphic equality: if q is a leaf, could be... + bool sameLeaf(const Node& q) const { + return (q.isLeaf() && q.sameLeaf(*this)); + } + + /** equality up to tolerance */ + bool equals(const Node& q, double tol) const { + const Choice* other = dynamic_cast (&q); + if (!other) return false; + if (this->label_ != other->label_) return false; + if (branches_.size() != other->branches_.size()) return false; + // we don't care about shared pointers being equal here + for (size_t i = 0; i < branches_.size(); i++) + if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + return true; + } + + /** evaluate */ + const Y& operator()(const Assignment& x) const { +#ifndef NDEBUG + typename Assignment::const_iterator it = x.find(label_); + if (it == x.end()) { + std::cout << "Trying to find value for " << label_ << std::endl; + throw std::invalid_argument( + "DecisionTree::operator(): value undefined for a label"); + } +#endif + size_t index = x.at(label_); + NodePtr child = branches_[index]; + return (*child)(x); + } + + /** + * Construct from applying unary op to a Choice node + */ + Choice(const L& label, const Choice& f, const Unary& op) : + label_(label), allSame_(true) { + + branches_.reserve(f.branches_.size()); // reserve space + BOOST_FOREACH (const NodePtr& branch, f.branches_) + push_back(branch->apply(op)); + } + + /** apply unary operator */ + NodePtr apply(const Unary& op) const { + boost::shared_ptr r(new Choice(label_, *this, op)); + return Unique(r); + } + + // Apply binary operator "h = f op g" on Choice node + // Note op is not assumed commutative so we need to keep track of order + // Simply calls apply on argument to call correct virtual method: + // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf) + // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below) + NodePtr apply_f_op_g(const Node& g, const Binary& op) const { + return g.apply_g_op_fC(*this, op); + } + + // If second argument of binary op is Leaf node, recurse on branches + NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const { + boost::shared_ptr h(new Choice(label(), nrChoices())); + BOOST_FOREACH(NodePtr branch, branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); + return Unique(h); + } + + // If second argument of binary op is Choice, call constructor + NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const { + boost::shared_ptr h(new Choice(fC, *this, op)); + return Unique(h); + } + + // If second argument of binary op is Leaf + template + NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { + boost::shared_ptr h(new Choice(label(), nrChoices())); + BOOST_FOREACH(const NodePtr& branch, branches_) + h->push_back(branch->apply_f_op_g(gL, op)); + return Unique(h); + } + + /** choose a branch, recursively */ + NodePtr choose(const L& label, size_t index) const { + if (label_ == label) + return branches_[index]; // choose branch + + // second case, not label of interest, just recurse + boost::shared_ptr r(new Choice(label_, branches_.size())); + BOOST_FOREACH(const NodePtr& branch, branches_) + r->push_back(branch->choose(label, index)); + return Unique(r); + } + + }; // Choice + + /*********************************************************************************/ + // DecisionTree + /*********************************************************************************/ + template + DecisionTree::DecisionTree() { + } + + template + DecisionTree::DecisionTree(const NodePtr& root) : + root_(root) { + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const Y& y) { + root_ = NodePtr(new Leaf(y)); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(// + const L& label, const Y& y1, const Y& y2) { + boost::shared_ptr a(new Choice(label, 2)); + NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); + a->push_back(l1); + a->push_back(l2); + root_ = Choice::Unique(a); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(// + const LabelC& labelC, const Y& y1, const Y& y2) { + if (labelC.second != 2) throw std::invalid_argument( + "DecisionTree: binary constructor called with non-binary label"); + boost::shared_ptr a(new Choice(labelC.first, 2)); + NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); + a->push_back(l1); + a->push_back(l2); + root_ = Choice::Unique(a); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const std::vector& labelCs, + const std::vector& ys) { + // call recursive Create + root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const std::vector& labelCs, + const std::string& table) { + + // Convert std::string to doubles + std::vector ys; + std::istringstream iss(table); + copy(std::istream_iterator(iss), std::istream_iterator(), + back_inserter(ys)); + + // now call recursive Create + root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); + } + + /*********************************************************************************/ + template + template DecisionTree::DecisionTree( + Iterator begin, Iterator end, const L& label) { + root_ = compose(begin, end, label); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const L& label, + const DecisionTree& f0, const DecisionTree& f1) { + std::vector functions; + functions += f0, f1; + root_ = compose(functions.begin(), functions.end(), label); + } + + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, boost::function op) { + root_ = convert(other.root_, map, op); + } + + /*********************************************************************************/ + // Called by two constructors above. + // Takes a label and a corresponding range of decision trees, and creates a new + // decision tree. However, the order of the labels needs to be respected, so we + // cannot just create a root Choice node on the label: if the label is not the + // highest label, we need to do a complicated and expensive recursive call. + template template + typename DecisionTree::NodePtr DecisionTree::compose( + Iterator begin, Iterator end, const L& label) const { + + // find highest label among branches + boost::optional highestLabel; + boost::optional nrChoices; + for (Iterator it = begin; it != end; it++) { + if (it->root_->isLeaf()) continue; + boost::shared_ptr c = boost::dynamic_pointer_cast (it->root_); + if (!highestLabel || c->label() > *highestLabel) { + highestLabel.reset(c->label()); + nrChoices.reset(c->nrChoices()); + } + } + + // if label is already in correct order, just put together a choice on label + if (!highestLabel || label > *highestLabel) { + boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + for (Iterator it = begin; it != end; it++) + choiceOnLabel->push_back(it->root_); + return Choice::Unique(choiceOnLabel); + } + + // Set up a new choice on the highest label + boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, *nrChoices)); + // now, for all possible values of highestLabel + for (size_t index = 0; index < *nrChoices; index++) { + // make a new set of functions for composing by iterating over the given + // functions, and selecting the appropriate branch. + std::vector functions; + for (Iterator it = begin; it != end; it++) { + // by restricting the input functions to value i for labelBelow + DecisionTree chosen = it->choose(*highestLabel, index); + functions.push_back(chosen); + } + // We then recurse, for all values of the highest label + NodePtr fi = compose(functions.begin(), functions.end(), label); + choiceOnHighestLabel->push_back(fi); + } + return Choice::Unique(choiceOnHighestLabel); + } + + /*********************************************************************************/ + // "create" is a bit of a complicated thing, but very useful. + // It takes a range of labels and a corresponding range of values, + // and creates a decision tree, as follows: + // - if there is only one label, creates a choice node with values in leaves + // - otherwise, it evenly splits up the range of values and creates a tree for + // each sub-range, and assigns that tree to first label's choices + // Example: + // create([B A],[1 2 3 4]) would call + // create([A],[1 2]) + // create([A],[3 4]) + // and produce + // B=0 + // A=0: 1 + // A=1: 2 + // B=1 + // A=0: 3 + // A=1: 4 + // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // exactly the same tree as above: the highest label is always the root. + // However, it will be *way* faster if labels are given highest to lowest. + template + template + typename DecisionTree::NodePtr DecisionTree::create( + It begin, It end, ValueIt beginY, ValueIt endY) const { + + // get crucial counts + size_t nrChoices = begin->second; + size_t size = endY - beginY; + + // Find the next key to work on + It labelC = begin + 1; + if (labelC == end) { + // Base case: only one key left + // Create a simple choice node with values as leaves. + if (size != nrChoices) { + std::cout << "Trying to create DD on " << begin->first << std::endl; + std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; + throw std::invalid_argument("DecisionTree::create invalid argument"); + } + boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + for (ValueIt y = beginY; y != endY; y++) + choice->push_back(NodePtr(new Leaf(*y))); + return Choice::Unique(choice); + } + + // Recursive case: perform "Shannon expansion" + // Creates one tree (i.e.,function) for each choice of current key + // by calling create recursively, and then puts them all together. + std::vector functions; + size_t split = size / nrChoices; + for (size_t i = 0; i < nrChoices; i++, beginY += split) { + NodePtr f = create(labelC, end, beginY, beginY + split); + functions += DecisionTree(f); + } + return compose(functions.begin(), functions.end(), begin->first); + } + + /*********************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convert( + const typename DecisionTree::NodePtr& f, const std::map& map, + boost::function op) { + + typedef DecisionTree MX; + typedef typename MX::Leaf MXLeaf; + typedef typename MX::Choice MXChoice; + typedef typename MX::NodePtr MXNodePtr; + typedef DecisionTree LY; + + // ugliness below because apparently we can't have templated virtual functions + // If leaf, apply unary conversion "op" and create a unique leaf + const MXLeaf* leaf = dynamic_cast (f.get()); + if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + + // Check if Choice + boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + if (!choice) throw std::invalid_argument( + "DecisionTree::Convert: Invalid NodePtr"); + + // get new label + M oldLabel = choice->label(); + L newLabel = map.at(oldLabel); + + // put together via Shannon expansion otherwise not sorted. + std::vector functions; + BOOST_FOREACH(const MXNodePtr& branch, choice->branches()) { + LY converted(convert(branch, map, op)); + functions += converted; + } + return LY::compose(functions.begin(), functions.end(), newLabel); + } + + /*********************************************************************************/ + template + bool DecisionTree::equals(const DecisionTree& other, double tol) const { + return root_->equals(*other.root_, tol); + } + + template + void DecisionTree::print(const std::string& s) const { + root_->print(s); + } + + template + bool DecisionTree::operator==(const DecisionTree& other) const { + return root_->equals(*other.root_); + } + + template + const Y& DecisionTree::operator()(const Assignment& x) const { + return root_->operator ()(x); + } + + template + DecisionTree DecisionTree::apply(const Unary& op) const { + return DecisionTree(root_->apply(op)); + } + + /*********************************************************************************/ + template + DecisionTree DecisionTree::apply(const DecisionTree& g, + const Binary& op) const { + // apply the operaton on the root of both diagrams + NodePtr h = root_->apply_f_op_g(*g.root_, op); + // create a new class with the resulting root "h" + DecisionTree result(h); + return result; + } + + /*********************************************************************************/ + // The way this works: + // We have an ADT, picture it as a tree. + // At a certain depth, we have a branch on "label". + // The function "choose(label,index)" will return a tree of one less depth, + // where there is no more branch on "label": only the subtree under that + // branch point corresponding to the value "index" is left instead. + // The function below get all these smaller trees and "ops" them together. + template + DecisionTree DecisionTree::combine(const L& label, + size_t cardinality, const Binary& op) const { + DecisionTree result = choose(label, 0); + for (size_t index = 1; index < cardinality; index++) { + DecisionTree chosen = choose(label, index); + result = result.apply(chosen, op); + } + return result; + } + + /*********************************************************************************/ + template + void DecisionTree::dot(std::ostream& os, bool showZero) const { + os << "digraph G {\n"; + root_->dot(os, showZero); + os << " [ordering=out]}" << std::endl; + } + + template + void DecisionTree::dot(const std::string& name, bool showZero) const { + std::ofstream os((name + ".dot").c_str()); + dot(os, showZero); + system( + ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); + } + +/*********************************************************************************/ + +} // namespace gtsam + + diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h new file mode 100644 index 000000000..003656945 --- /dev/null +++ b/gtsam/discrete/DecisionTree.h @@ -0,0 +1,218 @@ +/* + * @file DecisionTree.h + * @brief Decision Tree for use in DiscreteFactors + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + + /** + * Algebraic Decision Trees + * L = label for variables + * Y = function range (any algebra), e.g., bool, int, double + */ + template + class DecisionTree { + + public: + + /** Handy typedefs for unary and binary function types */ + typedef boost::function Unary; + typedef boost::function Binary; + + /** A label annotated with cardinality */ + typedef std::pair LabelC; + + /** DD's consist of Leaf and Choice nodes, both subclasses of Node */ + class Leaf; + class Choice; + + /** ------------------------ Node base class --------------------------- */ + class Node { + public: + typedef boost::shared_ptr Ptr; + +#ifdef DT_DEBUG_MEMORY + static int nrNodes; +#endif + + // Constructor + Node() { +#ifdef DT_DEBUG_MEMORY + std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); + +#endif + } + + // Destructor + virtual ~Node() { +#ifdef DT_DEBUG_MEMORY + std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); + +#endif + } + + // Unique ID for dot files + const void* id() const { return this; } + + // everything else is virtual, no documentation here as internal + virtual void print(const std::string& s = "") const = 0; + virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual bool sameLeaf(const Leaf& q) const = 0; + virtual bool sameLeaf(const Node& q) const = 0; + virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual const Y& operator()(const Assignment& x) const = 0; + virtual Ptr apply(const Unary& op) const = 0; + virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; + virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; + virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; + virtual Ptr choose(const L& label, size_t index) const = 0; + virtual bool isLeaf() const = 0; + }; + /** ------------------------ Node base class --------------------------- */ + + public: + + /** A function is a shared pointer to the root of an ADD */ + typedef typename Node::Ptr NodePtr; + + /* an AlgebraicDecisionTree just contains the root */ + NodePtr root_; + + protected: + + /** Internal recursive function to create from keys, cardinalities, and Y values */ + template + NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + + /** Convert to a different type */ + template NodePtr + convert(const typename DecisionTree::NodePtr& f, const std::map& map, boost::function op); + + /** Default constructor */ + DecisionTree(); + + public: + + /// @name Standard Constructors + /// @{ + + /** Create a constant */ + DecisionTree(const Y& y); + + /** Create a new leaf function splitting on a variable */ + DecisionTree(const L& label, const Y& y1, const Y& y2); + + /** Allow Label+Cardinality for convenience */ + DecisionTree(const LabelC& label, const Y& y1, const Y& y2); + + /** Create from keys and string table */ + DecisionTree(const std::vector& labelCs, const std::vector& ys); + + /** Create from keys and string table */ + DecisionTree(const std::vector& labelCs, const std::string& table); + + /** Create DecisionTree from others */ + template + DecisionTree(Iterator begin, Iterator end, const L& label); + + /** Create DecisionTree from others others (binary version) */ + DecisionTree(const L& label, // + const DecisionTree& f0, const DecisionTree& f1); + + /** Convert from a different type */ + template + DecisionTree(const DecisionTree& other, + const std::map& map, boost::function op); + + /// @} + /// @name Testable + /// @{ + + /** GTSAM-style print */ + void print(const std::string& s = "DecisionTree") const; + + // Testable + bool equals(const DecisionTree& other, double tol = 1e-9) const; + + /// @} + /// @name Standard Interface + /// @{ + + /** Make virtual */ + virtual ~DecisionTree() { + } + + /** equality */ + bool operator==(const DecisionTree& q) const; + + /** evaluate */ + const Y& operator()(const Assignment& x) const; + + /** apply Unary operation "op" to f */ + DecisionTree apply(const Unary& op) const; + + /** apply binary operation "op" to f and g */ + DecisionTree apply(const DecisionTree& g, const Binary& op) const; + + /** create a new function where value(label)==index */ + DecisionTree choose(const L& label, size_t index) const { + NodePtr newRoot = root_->choose(label, index); + return DecisionTree(newRoot); + } + + /** combine subtrees on key with binary operation "op" */ + DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; + + /** combine with LabelC for convenience */ + DecisionTree combine(const LabelC& labelC, const Binary& op) const { + return combine(labelC.first, labelC.second, op); + } + + /** output to graphviz format, stream version */ + void dot(std::ostream& os, bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, bool showZero = true) const; + + /// @name Advanced Interface + /// @{ + + // internal use only + DecisionTree(const NodePtr& root); + + // internal use only + template NodePtr + compose(Iterator begin, Iterator end, const L& label) const; + + /// @} + + }; // DecisionTree + + /** free versions of apply */ + + template + DecisionTree apply(const DecisionTree& f, + const typename DecisionTree::Unary& op) { + return f.apply(op); + } + + template + DecisionTree apply(const DecisionTree& f, + const DecisionTree& g, + const typename DecisionTree::Binary& op) { + return f.apply(g, op); + } + +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp new file mode 100644 index 000000000..d66d16d99 --- /dev/null +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -0,0 +1,91 @@ +/* + * DecisionTreeFactor.cpp + * @brief: discrete factor + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include + +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DecisionTreeFactor::DecisionTreeFactor() { + } + + /* ******************************************************************************** */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const ADT& potentials) : + DiscreteFactor(keys.indices()), Potentials(keys, potentials) { + } + + /* *************************************************************************/ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : + DiscreteFactor(c.keys()), Potentials(c) { + } + + /* ************************************************************************* */ + bool DecisionTreeFactor::equals(const This& other, double tol) const { + return IndexFactor::equals(other, tol) && Potentials::equals(other, tol); + } + + /* ************************************************************************* */ + void DecisionTreeFactor::print(const string& s) const { + cout << s << ":\n"; + IndexFactor::print("IndexFactor:"); + Potentials::print("Potentials:"); + } + + /* ************************************************************************* */ + DecisionTreeFactor DecisionTreeFactor::apply // + (const DecisionTreeFactor& f, ADT::Binary op) const { + map cs; // new cardinalities + // make unique key-cardinality map + BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j); + BOOST_FOREACH(Index j, f.keys()) cs[j] = f.cardinality(j); + // Convert map into keys + DiscreteKeys keys; + BOOST_FOREACH(const DiscreteKey& key, cs) + keys.push_back(key); + // apply operand + ADT result = ADT::apply(f, op); + // Make a new factor + return DecisionTreeFactor(keys, result); + } + + /* ************************************************************************* */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine // + (size_t nrFrontals, ADT::Binary op) const { + + if (nrFrontals == 0 || nrFrontals > size()) throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") + % nrFrontals % size()).str()); + + // sum over nrFrontals keys + size_t i; + ADT result(*this); + for (i = 0; i < nrFrontals; i++) { + Index j = keys()[i]; + result = result.combine(j, cardinality(j), op); + } + + // create new factor, note we start keys after nrFrontals + DiscreteKeys dkeys; + for (; i < keys().size(); i++) { + Index j = keys()[i]; + dkeys.push_back(DiscreteKey(j,cardinality(j))); + } + shared_ptr f(new DecisionTreeFactor(dkeys, result)); + return f; + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h new file mode 100644 index 000000000..e98c7020b --- /dev/null +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -0,0 +1,145 @@ +/* + * DecisionTreeFactor.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace gtsam { + + class DiscreteConditional; + + /** + * A discrete probabilistic factor + */ + class DecisionTreeFactor: public DiscreteFactor, public Potentials { + + public: + + // typedefs needed to play nice with gtsam + typedef DecisionTreeFactor This; + typedef DiscreteConditional ConditionalType; + typedef boost::shared_ptr shared_ptr; + + public: + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + DecisionTreeFactor(); + + /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); + + /** Constructor from Indices and (string or doubles) */ + template + DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : + DiscreteFactor(keys.indices()), Potentials(keys, table) { + } + + /** Construct from a DiscreteConditional type */ + DecisionTreeFactor(const DiscreteConditional& c); + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DecisionTreeFactor& other, double tol = 1e-9) const; + + // print + void print(const std::string& s = "DecisionTreeFactor: ") const; + + /// @} + /// @name Standard Interface + /// @{ + + /// Value is just look up in AlgebraicDecisonTree + virtual double operator()(const Values& values) const { + return Potentials::operator()(values); + } + + /// multiply two factors + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { + return apply(f, ADT::Ring::mul); + } + + /// divide by factor f (safely) + DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const { + return *this; + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, ADT::Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator values + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, ADT::Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + */ + DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @return shared pointer to newly created DecisionTreeFactor + */ + shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, std::vector& domains) const { +// throw std::runtime_error( +// "DecisionTreeFactor::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const { + throw std::runtime_error("DecisionTreeFactor::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector&) const { + throw std::runtime_error("DecisionTreeFactor::partiallyApply not implemented"); + } + /// @} + }; +// DecisionTreeFactor + +}// namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp new file mode 100644 index 000000000..c7f09d3c2 --- /dev/null +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -0,0 +1,48 @@ +/* + * DiscreteBayesNet.cpp + * + * @date Feb 15, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + // Explicitly instantiate so we don't have to include everywhere + template class BayesNet ; + + /* ************************************************************************* */ + void add_front(DiscreteBayesNet& bayesNet, const Signature& s) { + bayesNet.push_front(boost::make_shared(s)); + } + + /* ************************************************************************* */ + void add(DiscreteBayesNet& bayesNet, const Signature& s) { + bayesNet.push_back(boost::make_shared(s)); + } + + /* ************************************************************************* */ + DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn) { + // solve each node in turn in topological sort order (parents first) + DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, bn) + conditional->solveInPlace(*result); + return result; + } + + /* ************************************************************************* */ + DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn) { + // sample each node in turn in topological sort order (parents first) + DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, bn) + conditional->sampleInPlace(*result); + return result; + } + +/* ************************************************************************* */ +} // namespace diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h new file mode 100644 index 000000000..418a7aa2d --- /dev/null +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -0,0 +1,33 @@ +/* + * DiscreteBayesNet.h + * + * @date Feb 15, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + + typedef BayesNet DiscreteBayesNet; + + /** Add a DiscreteCondtional */ + void add(DiscreteBayesNet&, const Signature& s); + + /** Add a DiscreteCondtional in front, when listing parents first*/ + void add_front(DiscreteBayesNet&, const Signature& s); + + /** Optimize function for back-substitution. */ + DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn); + + /** Do ancestral sampling */ + DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn); + +} // namespace + diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp new file mode 100644 index 000000000..4791af6e8 --- /dev/null +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -0,0 +1,152 @@ +/* + * DiscreteConditional.cpp + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const size_t nrFrontals, + const DecisionTreeFactor& f) : + IndexConditional(f.keys(), nrFrontals), Potentials( + f / (*f.sum(nrFrontals))) { + } + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal) : + IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials( + ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) { + assert(nrFrontals() == 1); + if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout + << (firstFrontalKey()) << endl; + } + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const Signature& signature) : + IndexConditional(signature.indices(), 1), Potentials( + signature.discreteKeysParentsFirst(), signature.cpt()) { + } + + /* ******************************************************************************** */ + Potentials::ADT DiscreteConditional::choose( + const Values& parentsValues) const { + ADT pFS(*this); + BOOST_FOREACH(Index key, parents()) + try { + Index j = (key); + size_t value = parentsValues.at(j); + pFS = pFS.choose(j, value); + } catch (exception& e) { + throw runtime_error( + "DiscreteConditional::choose: parent value missing"); + }; + return pFS; + } + + /* ******************************************************************************** */ + void DiscreteConditional::solveInPlace(Values& values) const { + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t mpe = solve(values); // Solve for variable + values[j] = mpe; // store result in partial solution + } + + /* ******************************************************************************** */ + void DiscreteConditional::sampleInPlace(Values& values) const { + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t sampled = sample(values); // Sample variable + values[j] = sampled; // store result in partial solution + } + + /* ******************************************************************************** */ + size_t DiscreteConditional::solve(const Values& parentsValues) const { + + // TODO: is this really the fastest way? I think it is. + ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO, only works for one key now, seems horribly slow this way + size_t mpe = 0; + Values frontals; + double maxP = 0; + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; + } + + /* ******************************************************************************** */ + size_t DiscreteConditional::sample(const Values& parentsValues) const { + + using boost::uniform_real; + static boost::mt19937 gen(2); // random number generator + + bool debug = ISDEBUG("DiscreteConditional::sample"); + + // Get the correct conditional density + ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + if (debug) GTSAM_PRINT(pFS); + + // get cumulative distribution function (cdf) + // TODO, only works for one key now, seems horribly slow this way + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t nj = cardinality(j); + vector cdf(nj); + Values frontals; + double sum = 0; + for (size_t value = 0; value < nj; value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + sum += pValueS; // accumulate + if (debug) cout << sum << " "; + if (pValueS == 1) { + if (debug) cout << "--> " << value << endl; + return value; // shortcut exit + } + cdf[value] = sum; + } + + // inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html + uniform_real<> dist(0, cdf.back()); + boost::variate_generator > die(gen, dist); + size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin(); + if (debug) cout << "-> " << sampled << endl; + + return sampled; + + return 0; + } + +/* ******************************************************************************** */ + +} // namespace diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h new file mode 100644 index 000000000..a11a6368f --- /dev/null +++ b/gtsam/discrete/DiscreteConditional.h @@ -0,0 +1,110 @@ +/* + * DiscreteConditional.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * Discrete Conditional Density + * Derives from DecisionTreeFactor + */ + class DiscreteConditional: public IndexConditional, public Potentials { + + public: + // typedefs needed to play nice with gtsam + typedef DiscreteFactor FactorType; + typedef boost::shared_ptr shared_ptr; + typedef IndexConditional Base; + + /** A map from keys to values */ + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + /// @name Standard Constructors + /// @{ + + /** default constructor needed for serialization */ + DiscreteConditional() { + } + + /** constructor from factor */ + DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + + /** Construct from signature */ + DiscreteConditional(const Signature& signature); + + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal); + + /// @} + /// @name Testable + /// @{ + + /** GTSAM-style print */ + void print(const std::string& s = "Discrete Conditional: ") const { + std::cout << s << std::endl; + IndexConditional::print(s); + Potentials::print(s); + } + + /** GTSAM-style equals */ + bool equals(const DiscreteConditional& other, double tol = 1e-9) const { + return IndexConditional::equals(other, tol) + && Potentials::equals(other, tol); + } + + /// @} + /// @name Standard Interface + /// @{ + + /** Convert to a factor */ + DecisionTreeFactor::shared_ptr toFactor() const { + return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); + } + + /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ + ADT choose(const Assignment& parentsValues) const; + + /** + * solve a conditional + * @param parentsAssignment Known values of the parents + * @return MPE value of the child (1 frontal variable). + */ + size_t solve(const Values& parentsValues) const; + + /** + * sample + * @param parentsAssignment Known values of the parents + * @return sample from conditional + */ + size_t sample(const Values& parentsValues) const; + + /// @} + /// @name Advanced Interface + /// @{ + + /// solve a conditional, in place + void solveInPlace(Values& parentsValues) const; + + /// sample in place, stores result in partial solution + void sampleInPlace(Values& parentsValues) const; + + /// @} + + }; +// DiscreteConditional + +}// gtsam + diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp new file mode 100644 index 000000000..6112cfea9 --- /dev/null +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -0,0 +1,21 @@ +/* + * DiscreteFactor.cpp + * @brief: discrete factor + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DiscreteFactor::DiscreteFactor() { + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h new file mode 100644 index 000000000..a66d5b522 --- /dev/null +++ b/gtsam/discrete/DiscreteFactor.h @@ -0,0 +1,108 @@ +/* + * DiscreteFactor.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include + +namespace gtsam { + + class DecisionTreeFactor; + class DiscreteConditional; + class Domain; + + /** + * Base class for discrete probabilistic factors + * The most general one is the derived DecisionTreeFactor + */ + class DiscreteFactor: public IndexFactor { + + public: + + // typedefs needed to play nice with gtsam + typedef DiscreteFactor This; + typedef DiscreteConditional ConditionalType; + typedef boost::shared_ptr shared_ptr; + + /** A map from keys to values */ + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + protected: + + /// Construct n-way factor + DiscreteFactor(const std::vector& js) : + IndexFactor(js) { + } + + /// Construct unary factor + DiscreteFactor(Index j) : + IndexFactor(j) { + } + + /// Construct binary factor + DiscreteFactor(Index j1, Index j2) : + IndexFactor(j1, j2) { + } + + /// construct from container + template + DiscreteFactor(KeyIterator beginKey, KeyIterator endKey) : + IndexFactor(beginKey, endKey) { + } + + public: + + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + DiscreteFactor(); + + /// Virtual destructor + virtual ~DiscreteFactor() {} + + /// @} + /// @name Testable + /// @{ + + // print + virtual void print(const std::string& s = "DiscreteFactor") const { + IndexFactor::print(s); + } + + /// @} + /// @name Standard Interface + /// @{ + + /// Find value for given assignment of values to variables + virtual double operator()(const Values&) const = 0; + + /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor + virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + + virtual operator DecisionTreeFactor() const = 0; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + virtual bool ensureArcConsistency(size_t j, std::vector& domains) const = 0; + + /// Partially apply known values + virtual shared_ptr partiallyApply(const Values&) const = 0; + + + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const std::vector&) const = 0; + /// @} + }; +// DiscreteFactor + +}// namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp new file mode 100644 index 000000000..479bcc45a --- /dev/null +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -0,0 +1,82 @@ +/* + * DiscreteFactorGraph.cpp + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +//#define ENABLE_TIMING +#include +#include +#include +#include + +namespace gtsam { + +// Explicitly instantiate so we don't have to include everywhere +template class FactorGraph ; +template class EliminationTree ; + +/* ************************************************************************* */ +DiscreteFactorGraph::DiscreteFactorGraph() { +} + +/* ************************************************************************* */ +DiscreteFactorGraph::DiscreteFactorGraph( + const BayesNet& bayesNet) : + FactorGraph(bayesNet) { +} + +/* ************************************************************************* */ +FastSet DiscreteFactorGraph::keys() const { + FastSet keys; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) keys.insert(factor->begin(), factor->end()); + return keys; +} + +/* ************************************************************************* */ +DecisionTreeFactor DiscreteFactorGraph::product() const { + DecisionTreeFactor result; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) result = (*factor) * result; + return result; +} + +/* ************************************************************************* */ +double DiscreteFactorGraph::operator()( + const DiscreteFactor::Values &values) const { + double product = 1.0; + BOOST_FOREACH( const sharedFactor& factor, factors_ ) + product *= (*factor)(values); + return product; +} + +/* ************************************************************************* */ +pair // +EliminateDiscrete(const FactorGraph& factors, size_t num) { + + // PRODUCT: multiply all factors + tic(1, "product"); + DecisionTreeFactor product; + BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors) + product = (*factor) * product; + toc(1, "product"); + + // sum out frontals, this is the factor on the separator + tic(2, "sum"); + DecisionTreeFactor::shared_ptr sum = product.sum(num); + toc(2, "sum"); + + // now divide product/sum to get conditional + tic(3, "divide"); + DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); + toc(3, "divide"); + tictoc_finishedIteration(); + + return make_pair(cond, sum); +} + +/* ************************************************************************* */ +} // namespace + diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h new file mode 100644 index 000000000..83a065361 --- /dev/null +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -0,0 +1,87 @@ +/* + * DiscreteFactorGraph.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + +class DiscreteFactorGraph: public FactorGraph { +public: + + /** A map from keys to values */ + typedef std::vector Indices; + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + /** Construct empty factor graph */ + DiscreteFactorGraph(); + + /** Constructor from a factor graph of GaussianFactor or a derived type */ + template + DiscreteFactorGraph(const FactorGraph& fg) { + push_back(fg); + } + + /** construct from a BayesNet */ + DiscreteFactorGraph(const BayesNet& bayesNet); + + template + void add(const DiscreteKey& j, SOURCE table) { + DiscreteKeys keys; + keys.push_back(j); + push_back(boost::make_shared(keys, table)); + } + + template + void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { + DiscreteKeys keys; + keys.push_back(j1); + keys.push_back(j2); + push_back(boost::make_shared(keys, table)); + } + + /** add shared discreteFactor immediately from arguments */ + template + void add(const DiscreteKeys& keys, SOURCE table) { + push_back(boost::make_shared(keys, table)); + } + + /** Return the set of variables involved in the factors (set union) */ + FastSet keys() const; + + /** return product of all factors as a single factor */ + DecisionTreeFactor product() const; + + /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ + double operator()(const DiscreteFactor::Values & values) const; + + /// print + void print(const std::string& s = "DiscreteFactorGraph") const { + std::cout << s << std::endl; + std::cout << "size: " << size() << std::endl; + for (size_t i = 0; i < factors_.size(); i++) { + std::stringstream ss; + ss << "factor " << i << ": "; + if (factors_[i] != NULL) factors_[i]->print(ss.str()); + } + } + +}; +// DiscreteFactorGraph + +/** Main elimination function for DiscreteFactorGraph */ +std::pair, DecisionTreeFactor::shared_ptr> +EliminateDiscrete(const FactorGraph& factors, + size_t nrFrontals = 1); + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp new file mode 100644 index 000000000..01b025d80 --- /dev/null +++ b/gtsam/discrete/DiscreteKey.cpp @@ -0,0 +1,44 @@ +/* + * DiscreteKey.h + * @brief specialized key for discrete variables + * @author Frank Dellaert + * @date Feb 28, 2011 + */ + +#include +#include // for key names +#include // FOREACH +#include "DiscreteKey.h" + +namespace gtsam { + + using namespace std; + + DiscreteKeys::DiscreteKeys(const vector& cs) { + for (size_t i = 0; i < cs.size(); i++) { + string name = boost::str(boost::format("v%1%") % i); + push_back(DiscreteKey(i, cs[i])); + } + } + + vector DiscreteKeys::indices() const { + vector < Index > js; + BOOST_FOREACH(const DiscreteKey& key, *this) + js.push_back(key.first); + return js; + } + + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(),end()); +// BOOST_FOREACH(const DiscreteKey& key, *this) +// cs.insert(key); + return cs; + } + + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2) { + DiscreteKeys keys(key1); + return keys & key2; + } + +} diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h new file mode 100644 index 000000000..87e6e96a6 --- /dev/null +++ b/gtsam/discrete/DiscreteKey.h @@ -0,0 +1,59 @@ +/* + * DiscreteKey.h + * @brief specialized key for discrete variables + * @author Frank Dellaert + * @date Feb 28, 2011 + */ + +#pragma once + +#include + +#include +#include +#include + +namespace gtsam { + + /** + * Key type for discrete conditionals + * Includes name and cardinality + */ + typedef std::pair DiscreteKey; + + /// DiscreteKeys is a set of keys that can be assembled using the & operator + struct DiscreteKeys: public std::vector { + + /// Default constructor + DiscreteKeys() { + } + + /// Construct from a key + DiscreteKeys(const DiscreteKey& key) { + push_back(key); + } + + /// Construct from a vector of keys + DiscreteKeys(const std::vector& keys) : + std::vector(keys) { + } + + /// Construct from cardinalities with default names + DiscreteKeys(const std::vector& cs); + + /// Return a vector of indices + std::vector indices() const; + + /// Return a map from index to cardinality + std::map cardinalities() const; + + /// Add a key (non-const!) + DiscreteKeys& operator&(const DiscreteKey& key) { + push_back(key); + return *this; + } + }; // DiscreteKeys + + /// Create a list from two keys + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); +} diff --git a/gtsam/discrete/DiscreteSequentialSolver.cpp b/gtsam/discrete/DiscreteSequentialSolver.cpp new file mode 100644 index 000000000..b4d16a24f --- /dev/null +++ b/gtsam/discrete/DiscreteSequentialSolver.cpp @@ -0,0 +1,47 @@ +/* + * DiscreteSequentialSolver.cpp + * + * @date Feb 16, 2011 + * @author Duy-Nguyen Ta + */ + +//#define ENABLE_TIMING +#include +#include +#include +#include + +namespace gtsam { + + template class GenericSequentialSolver ; + + /* ************************************************************************* */ + DiscreteFactor::sharedValues DiscreteSequentialSolver::optimize() const { + + static const bool debug = false; + + if (debug) this->factors_->print("DiscreteSequentialSolver, eliminating "); + if (debug) this->eliminationTree_->print( + "DiscreteSequentialSolver, elimination tree "); + + // Eliminate using the elimination tree + tic(1, "eliminate"); + DiscreteBayesNet::shared_ptr bayesNet = eliminate(); + toc(1, "eliminate"); + + if (debug) bayesNet->print("DiscreteSequentialSolver, Bayes net "); + + // Allocate the solution vector if it is not already allocated + + // Back-substitute + tic(2, "optimize"); + DiscreteFactor::sharedValues solution = gtsam::optimize(*bayesNet); + toc(2, "optimize"); + + if (debug) solution->print("DiscreteSequentialSolver, solution "); + + return solution; + } +/* ************************************************************************* */ + +} diff --git a/gtsam/discrete/DiscreteSequentialSolver.h b/gtsam/discrete/DiscreteSequentialSolver.h new file mode 100644 index 000000000..60512b873 --- /dev/null +++ b/gtsam/discrete/DiscreteSequentialSolver.h @@ -0,0 +1,97 @@ +/* + * DiscreteSequentialSolver.h + * + * @date Feb 16, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + // The base class provides all of the needed functionality + + class DiscreteSequentialSolver: public GenericSequentialSolver { + + protected: + typedef GenericSequentialSolver Base; + typedef boost::shared_ptr shared_ptr; + + public: + + /** + * The problem we are trying to solve (SUM or MPE). + */ + typedef enum { + BEL, // Belief updating (or conditional updating) + MPE, // Most-Probable-Explanation + MAP + // Maximum A Posteriori hypothesis + } ProblemType; + + /** + * Construct the solver for a factor graph. This builds the elimination + * tree, which already does some of the work of elimination. + */ + DiscreteSequentialSolver(const FactorGraph& factorGraph) : + Base(factorGraph) { + } + + /** + * Construct the solver with a shared pointer to a factor graph and to a + * VariableIndex. The solver will store these pointers, so this constructor + * is the fastest. + */ + DiscreteSequentialSolver( + const FactorGraph::shared_ptr& factorGraph, + const VariableIndex::shared_ptr& variableIndex) : + Base(factorGraph, variableIndex) { + } + + const EliminationTree& eliminationTree() const { + return *eliminationTree_; + } + + /** + * Eliminate the factor graph sequentially. Uses a column elimination tree + * to recursively eliminate. + */ + BayesNet::shared_ptr eliminate() const { + return Base::eliminate(&EliminateDiscrete); + } + +#ifdef BROKEN + /** + * Compute the marginal joint over a set of variables, by integrating out + * all of the other variables. This function returns the result as a factor + * graph. + */ + DiscreteFactorGraph::shared_ptr jointFactorGraph( + const std::vector& js) const { + DiscreteFactorGraph::shared_ptr results(new DiscreteFactorGraph( + *Base::jointFactorGraph(js, &EliminateDiscrete))); + return results; + } + + /** + * Compute the marginal density over a variable, by integrating out + * all of the other variables. This function returns the result as a factor. + */ + DiscreteFactor::shared_ptr marginalFactor(Index j) const { + return Base::marginalFactor(j, &EliminateDiscrete); + } +#endif + + /** + * Compute the MPE solution of the DiscreteFactorGraph. This + * eliminates to create a BayesNet and then back-substitutes this BayesNet to + * obtain the solution. + */ + DiscreteFactor::sharedValues optimize() const; + + }; + +} // gtsam diff --git a/gtsam/discrete/Domain.cpp b/gtsam/discrete/Domain.cpp new file mode 100644 index 000000000..130bd71ff --- /dev/null +++ b/gtsam/discrete/Domain.cpp @@ -0,0 +1,95 @@ +/* + * Domain.cpp + * @brief Domain restriction constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + using namespace std; + + /* ************************************************************************* */ + void Domain::print(const string& s) const { +// cout << s << ": Domain on " << keys_[0] << " (j=" << keys_[0] +// << ") with values"; +// BOOST_FOREACH (size_t v,values_) cout << " " << v; +// cout << endl; + BOOST_FOREACH (size_t v,values_) cout << v; + } + + /* ************************************************************************* */ + double Domain::operator()(const Values& values) const { + return contains(values.at(keys_[0])); + } + + /* ************************************************************************* */ + Domain::operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) + table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool Domain::ensureArcConsistency(size_t j, vector& domains) const { + if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains[j]; + BOOST_FOREACH(size_t value, values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; + } + + /* ************************************************************************* */ + bool Domain::checkAllDiff(const vector keys, vector& domains) { + Index j = keys_[0]; + // for all values in this domain + BOOST_FOREACH(size_t value, values_) { + // for all connected domains + BOOST_FOREACH(Index k, keys) + // if any domain contains the value we cannot make this domain singleton + if (k!=j && domains[k].contains(value)) + goto found; + values_.clear(); + values_.insert(value); + return true; // we changed it + found:; + } + return false; // we did not change it + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr Domain::partiallyApply( + const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && !contains(it->second)) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (*this); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr Domain::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (Dk); + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/Domain.h b/gtsam/discrete/Domain.h new file mode 100644 index 000000000..934f0c306 --- /dev/null +++ b/gtsam/discrete/Domain.h @@ -0,0 +1,107 @@ +/* + * Domain.h + * @brief Domain restriction constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * Domain restriction constraint + */ + class Domain: public DiscreteFactor { + + size_t cardinality_; /// Cardinality + std::set values_; /// allowed values + + public: + + typedef boost::shared_ptr shared_ptr; + + // Constructor on Discrete Key initializes an "all-allowed" domain + Domain(const DiscreteKey& dkey) : + DiscreteFactor(dkey.first), cardinality_(dkey.second) { + for (size_t v = 0; v < cardinality_; v++) + values_.insert(v); + } + + // Constructor on Discrete Key with single allowed value + // Consider SingleValue constraint + Domain(const DiscreteKey& dkey, size_t v) : + DiscreteFactor(dkey.first), cardinality_(dkey.second) { + values_.insert(v); + } + + /// Constructor + Domain(const Domain& other) : + DiscreteFactor(other.keys_[0]), values_(other.values_) { + } + + /// insert a value, non const :-( + void insert(size_t value) { + values_.insert(value); + } + + /// erase a value, non const :-( + void erase(size_t value) { + values_.erase(value); + } + + size_t nrValues() const { + return values_.size(); + } + + bool isSingleton() const { + return nrValues() == 1; + } + + size_t firstValue() const { + return *values_.begin(); + } + + // print + virtual void print(const std::string& s = "") const; + + bool contains(size_t value) const { + return values_.count(value)>0; + } + + /// Calculate value + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /** + * Check for a value in domain that does not occur in any other connected domain. + * If found, we make this a singleton... Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + */ + bool checkAllDiff(const std::vector keys, std::vector& domains); + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply( + const Values& values) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector& domains) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/PotentialTable.cpp b/gtsam/discrete/PotentialTable.cpp new file mode 100644 index 000000000..07d2b2880 --- /dev/null +++ b/gtsam/discrete/PotentialTable.cpp @@ -0,0 +1,162 @@ +/* + * Potentials.cpp + * + * @date Feb 21, 2011 + * @author Duy-Nguyen Ta + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************* */ + void PotentialTable::Iterator::operator++() { + // note size_t is unsigned and i>=0 is always true, so strange-looking loop: + for (size_t i = size(); i--; ) { + if (++at(i) < cardinalities_[i]) + return; + else + at(i) = 0; + } + } + + /* ************************************************************************* */ + size_t PotentialTable::computeTableSize( + const std::vector& cardinalities) { + size_t tableSize = 1; + BOOST_FOREACH(const size_t c, cardinalities) + tableSize *= c; + return tableSize; + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cs) : + cardinalities_(cs), table_(computeTableSize(cs)) { + generateKeyFactors(); + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cardinalities, + const Table& table) : cardinalities_(cardinalities),table_(table) { + generateKeyFactors(); + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cardinalities, + const std::string& tableString) : cardinalities_(cardinalities) { + parse(tableString); + generateKeyFactors(); + } + + /* ************************************************************************* */ + bool PotentialTable::equals(const PotentialTable& other, double tol) const { + //TODO: compare potentials in a more general sense with arbitrary order of keys??? + if ((cardinalities_ == other.cardinalities_) && (table_.size() + == other.table_.size()) && (keyFactors_ == other.keyFactors_)) { + for (size_t i = 0; i < table_.size(); i++) { + if (fabs(table_[i] - other.table_[i]) > tol) { + return false; + } + return true; + } + } + return false; + } + + /* ************************************************************************* */ + void PotentialTable::print(const std::string& s) const { + cout << s << endl; + for (size_t i = 0; i < cardinalities_.size(); i++) + cout << boost::format("[%d,%d]") % cardinalities_[i] % keyFactors_[i] << " "; + cout << endl; + Iterator assignment(cardinalities_); + for (size_t idx = 0; idx < table_.size(); ++idx, ++assignment) { + for (size_t k = 0; k < assignment.size(); k++) + cout << assignment[k] << "\t\t"; + cout << table_[idx] << endl; + } + } + + /* ************************************************************************* */ + const double& PotentialTable::operator()(const Assignment& var) const { + return table_[tableIndexFromAssignment(var)]; + } + + /* ************************************************************************* */ + const double& PotentialTable::operator[](const size_t index) const { + return table_.at(index); + } + + + /* ************************************************************************* */ + void PotentialTable::setPotential(const PotentialTable::Assignment& asg, const double potential) { + size_t idx = tableIndexFromAssignment(asg); + assert(idx (iss), istream_iterator (), + back_inserter(table_)); + +#ifndef NDEBUG + size_t expectedSize = computeTableSize(cardinalities_); + if (table_.size() != expectedSize) throw invalid_argument( + boost::str( + boost::format( + "String specification \"%s\" for table only contains %d doubles instead of %d") + % tableString % table_.size() % expectedSize)); +#endif + } + +} // namespace diff --git a/gtsam/discrete/PotentialTable.h b/gtsam/discrete/PotentialTable.h new file mode 100644 index 000000000..b7741ba1e --- /dev/null +++ b/gtsam/discrete/PotentialTable.h @@ -0,0 +1,95 @@ +/* + * Potentials.h + * + * @date Feb 21, 2011 + * @author Duy-Nguyen Ta + */ + +#ifndef POTENTIALS_H_ +#define POTENTIALS_H_ + +#include +#include +#include +#include +#include +#include + +namespace gtsam +{ +/** + * PotentialTable holds the real-valued potentials for Factors or Conditionals + */ +class PotentialTable { +public: + typedef std::vector Table; // container type for potentials f(x1,x2,..) + typedef std::vector Cardinalities; // just a typedef + typedef std::vector Assignment; // just a typedef + + /** + * An assignment that can be incemented + */ + struct Iterator: std::vector { + Cardinalities cardinalities_; + Iterator(const Cardinalities& cs):cardinalities_(cs) { + for(size_t i=0;i cardinalities_; // cardinalities of variables + Table table_; // Potential values of all instantiations of the variables, following the variables' order in vector Keys. + std::vector keyFactors_; // factors to multiply a key's assignment with, to access the potential table + + void generateKeyFactors(); + void parse(const std::string& tableString); + +public: + + /** compute table size from variable cardinalities */ + static size_t computeTableSize(const std::vector& cardinalities); + + /** construct an empty potential */ + PotentialTable() {} + + /** Dangerous empty n-ary potential. */ + PotentialTable(const std::vector& cardinalities); + + /** n-ary potential. */ + PotentialTable(const std::vector& cardinalities, + const Table& table); + + /** n-ary potential. */ + PotentialTable(const std::vector& cardinalities, + const std::string& tableString); + + /** return iterator to first element */ + Iterator begin() const { return Iterator(cardinalities_);} + + /** equality */ + bool equals(const PotentialTable& other, double tol = 1e-9) const; + + /** print */ + void print(const std::string& s = "Potential Table: ") const; + + /** return cardinality of a variable */ + size_t cardinality(size_t var) const { return cardinalities_[var]; } + size_t tableSize() const { return table_.size(); } + + /** accessors to potential values in the table given the assignment */ + const double& operator()(const Assignment& var) const; + const double& operator[](const size_t index) const; + + void setPotential(const Assignment& asg, const double potential); + void setPotential(const size_t tableIndex, const double potential); + + /** convert between assignment and where it is in the table */ + size_t tableIndexFromAssignment(const Assignment& var) const; + Assignment assignmentFromTableIndex(const size_t i) const; +}; + + +} // namespace + +#endif /* POTENTIALS_H_ */ diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp new file mode 100644 index 000000000..2684e6cce --- /dev/null +++ b/gtsam/discrete/Potentials.cpp @@ -0,0 +1,53 @@ +/* + * Potentials.cpp + * @date March 24, 2011 + * @author Frank Dellaert + */ + +#include +#include +#include + +using namespace std; + +namespace gtsam { + + // explicit instantiation + template class DecisionTree ; + template class AlgebraicDecisionTree ; + + /* ************************************************************************* */ + double Potentials::safe_div(const double& a, const double& b) { + // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); + // The use for safe_div is when we divide the product factor by the sum factor. + // If the product or sum is zero, we accord zero probability to the event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ******************************************************************************** */ + Potentials::Potentials() : + ADT(1.0) { + } + + /* ******************************************************************************** */ + Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) : + ADT(decisionTree), cardinalities_(keys.cardinalities()) { + } + + /* ************************************************************************* */ + bool Potentials::equals(const Potentials& other, double tol) const { + return ADT::equals(other, tol); + } + + /* ************************************************************************* */ + void Potentials::print(const string&s) const { + cout << s << "\n Cardinalities: "; + BOOST_FOREACH(const DiscreteKey& key, cardinalities_) + cout << key.first << "=" << key.second << " "; + cout << endl; + ADT::print(" "); + } + + /* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h new file mode 100644 index 000000000..9468745e1 --- /dev/null +++ b/gtsam/discrete/Potentials.h @@ -0,0 +1,62 @@ +/* + * Potentials.h + * @date March 24, 2011 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace gtsam { + + /** + * A base class for both DiscreteFactor and DiscreteConditional + */ + class Potentials: public AlgebraicDecisionTree { + + public: + + typedef AlgebraicDecisionTree ADT; + + protected: + + /// Cardinality for each key, used in combine + std::map cardinalities_; + + /** Constructor from ColumnIndex, and ADT */ + Potentials(const ADT& potentials) : + ADT(potentials) { + } + + // Safe division for probabilities + static double safe_div(const double& a, const double& b); + + public: + + /** Default constructor for I/O */ + Potentials(); + + /** Constructor from Indices and ADT */ + Potentials(const DiscreteKeys& keys, const ADT& decisionTree); + + /** Constructor from Indices and (string or doubles) */ + template + Potentials(const DiscreteKeys& keys, SOURCE table) : + ADT(keys, table), cardinalities_(keys.cardinalities()) { + } + + // Testable + bool equals(const Potentials& other, double tol = 1e-9) const; + void print(const std::string& s = "Potentials: ") const; + + size_t cardinality(Index j) const { return cardinalities_.at(j);} + + }; // Potentials + +} // namespace gtsam diff --git a/gtsam/discrete/RefCounted.cpp b/gtsam/discrete/RefCounted.cpp new file mode 100644 index 000000000..521e58c6a --- /dev/null +++ b/gtsam/discrete/RefCounted.cpp @@ -0,0 +1,9 @@ +/* + * @file RefCounted.cpp + * @brief Simple reference-counted base class + * @author Frank Dellaert + * @date Mar 29, 2011 + */ + +#include + diff --git a/gtsam/discrete/RefCounted.h b/gtsam/discrete/RefCounted.h new file mode 100644 index 000000000..03d086ab6 --- /dev/null +++ b/gtsam/discrete/RefCounted.h @@ -0,0 +1,86 @@ +/* + * @file RefCounted.h + * @brief Simple reference-counted base class + * @author Frank Dellaert + * @date Mar 29, 2011 + */ + +#include + +// Forward Declarations +namespace gtsam { + struct RefCounted; +} + +namespace boost { + void intrusive_ptr_add_ref(const gtsam::RefCounted * p); + void intrusive_ptr_release(const gtsam::RefCounted * p); +} + +namespace gtsam { + + /** + * Simple reference counted class inspired by + * http://www.codeproject.com/KB/stl/boostsmartptr.aspx + */ + struct RefCounted { + private: + mutable long references_; + friend void ::boost::intrusive_ptr_add_ref(const RefCounted * p); + friend void ::boost::intrusive_ptr_release(const RefCounted * p); + public: + RefCounted() : + references_(0) { + } + virtual ~RefCounted() { + } + }; + +} // namespace gtsam + +// Intrusive Pointer free functions +#ifndef DEBUG_REFCOUNT + +namespace boost { + + // increment reference count of object *p + inline void intrusive_ptr_add_ref(const gtsam::RefCounted * p) { + ++(p->references_); + } + + // decrement reference count, and delete object when reference count reaches 0 + inline void intrusive_ptr_release(const gtsam::RefCounted * p) { + if (--(p->references_) == 0) + delete p; + } + +} // namespace boost + +#else + +#include + + namespace gtsam { + static long GlobalRefCount = 0; + } + + namespace boost { + inline void intrusive_ptr_add_ref(const gtsam::RefCounted * p) { + ++(p->references_); + gtsam::GlobalRefCount++; + std::cout << "add_ref " << p << " " << p->references_ << // + " " << gtsam::GlobalRefCount << std::endl; + } + + inline void intrusive_ptr_release(const gtsam::RefCounted * p) { + gtsam::GlobalRefCount--; + std::cout << "release " << p << " " << (p->references_ - 1) << // + " " << gtsam::GlobalRefCount << std::endl; + if (--(p->references_) == 0) + delete p; + } + + } // namespace boost + +#endif + diff --git a/gtsam/discrete/Scheduler.cpp b/gtsam/discrete/Scheduler.cpp new file mode 100644 index 000000000..dd578930d --- /dev/null +++ b/gtsam/discrete/Scheduler.cpp @@ -0,0 +1,297 @@ +/* + * Scheduler.h + * @brief an example how inference can be used for scheduling qualifiers + * @date Mar 26, 2011 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace gtsam { + + using namespace std; + + Scheduler::Scheduler(size_t maxNrStudents, const string& filename): + maxNrStudents_(maxNrStudents) + { + typedef boost::tokenizer > Tokenizer; + + // open file + ifstream is(filename.c_str()); + + string line; // buffer + + // process first line with faculty + if (getline(is, line, '\r')) { + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + for (++it; it != tok.end(); ++it) + addFaculty(*it); + } + + // for all remaining lines + size_t count = 0; + while (getline(is, line, '\r')) { + if (count++ > 100) throw runtime_error("reached 100 lines, exiting"); + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + addSlot(*it++); // add slot + // add availability + for (; it != tok.end(); ++it) + available_ += (it->empty()) ? "0 " : "1 "; + available_ += '\n'; + } + } // constructor + + /** addStudent has to be called after adding slots and faculty */ + void Scheduler::addStudent(const string& studentName, + const string& area1, const string& area2, + const string& area3, const string& advisor) { + assert(nrStudents() area) const { + return area ? students_[s].keys_[*area] : students_[s].key_; + } + + const string& Scheduler::studentName(size_t i) const { + assert(i slot) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + assert(iat(j); + cout << studentName(s) << " slot: " << slotName_[slot] << endl; + Index base = 3*s; + for (size_t area = 0; area < 3; area++) { + size_t faculty = assignment->at(base+area); + cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty] + << endl; + } + cout << endl; + } + } + + /** Special print for single-student case */ + void Scheduler::printSpecial(sharedValues assignment) const { + Values::const_iterator it = assignment->begin(); + for (size_t area = 0; area < 3; area++, it++) { + size_t f = it->second; + cout << setw(12) << it->first << ": " << facultyName_[f] << endl; + } + cout << endl; + } + + /** Accumulate faculty stats */ + void Scheduler::accumulateStats(sharedValues assignment, vector< + size_t>& stats) const { + for (size_t s = 0; s < nrStudents(); s++) { + Index base = 3*s; + for (size_t area = 0; area < 3; area++) { + size_t f = assignment->at(base+area); + assert(frbegin(); + const Student & student = students_.front(); + cout << endl; + (*it)->print(student.name_); + } + + tic(3, "my_optimize"); + sharedValues mpe = optimize(*chordal); + toc(3, "my_optimize"); + return mpe; + } + + /** find the assignment of students to slots with most possible committees */ + Scheduler::sharedValues Scheduler::bestSchedule() const { + sharedValues best; + throw runtime_error("bestSchedule not implemented"); + return best; + } + + /** find the corresponding most desirable committee assignment */ + Scheduler::sharedValues Scheduler::bestAssignment( + sharedValues bestSchedule) const { + sharedValues best; + throw runtime_error("bestAssignment not implemented"); + return best; + } + +} // gtsam + + diff --git a/gtsam/discrete/Scheduler.h b/gtsam/discrete/Scheduler.h new file mode 100644 index 000000000..f01b1591e --- /dev/null +++ b/gtsam/discrete/Scheduler.h @@ -0,0 +1,171 @@ +/* + * Scheduler.h + * @brief an example how inference can be used for scheduling qualifiers + * @date Mar 26, 2011 + * @author Frank Dellaert + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Scheduler class + * Creates one variable for each student, and three variables for each + * of the student's areas, for a total of 4*nrStudents variables. + * The "student" variable will determine when the student takes the qual. + * The "area" variables determine which faculty are on his/her committee. + */ + class Scheduler : public CSP { + + private: + + /** Internal data structure for students */ + struct Student { + std::string name_; + DiscreteKey key_; // key for student + std::vector keys_; // key for areas + std::vector areaName_; + std::vector advisor_; + Student(size_t nrFaculty, size_t advisorIndex) : + keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { + advisor_[advisorIndex] = 0.0; + } + void print() const { + using std::cout; + cout << name_ << ": "; + for (size_t area = 0; area < 3; area++) + cout << areaName_[area] << " "; + cout << std::endl; + } + }; + + /** Maximum number of students */ + size_t maxNrStudents_; + + /** discrete keys, indexed by student and area index */ + std::vector students_; + + /** faculty identifiers */ + std::map facultyIndex_; + std::vector facultyName_, slotName_, areaName_; + + /** area constraints */ + typedef std::map > FacultyInArea; + FacultyInArea facultyInArea_; + + /** nrTimeSlots * nrFaculty availability constraints */ + std::string available_; + + /** which slots are good */ + std::vector slotsAvailable_; + + public: + + /** + * Constructor + * WE need to know the number of students in advance for ordering keys. + * then add faculty, slots, areas, availability, students, in that order + */ + Scheduler(size_t maxNrStudents):maxNrStudents_(maxNrStudents) { + } + + void addFaculty(const std::string& facultyName) { + facultyIndex_[facultyName] = nrFaculty(); + facultyName_.push_back(facultyName); + } + + size_t nrFaculty() const { + return facultyName_.size(); + } + + /** boolean std::string of nrTimeSlots * nrFaculty */ + void setAvailability(const std::string& available) { + available_ = available; + } + + void addSlot(const std::string& slotName) { + slotName_.push_back(slotName); + } + + size_t nrTimeSlots() const { + return slotName_.size(); + } + + const std::string& slotName(size_t s) const { + return slotName_[s]; + } + + /** slots available, boolean */ + void setSlotsAvailable(const std::vector& slotsAvailable) { + slotsAvailable_ = slotsAvailable; + } + + void addArea(const std::string& facultyName, const std::string& areaName) { + areaName_.push_back(areaName); + std::vector& table = facultyInArea_[areaName]; // will create if needed + if (table.empty()) table.resize(nrFaculty(), 0); + table[facultyIndex_[facultyName]] = 1; + } + + /** + * Constructor that reads in faculty, slots, availibility. + * Still need to add areas and students after this + */ + Scheduler(size_t maxNrStudents, const std::string& filename); + + /** get key for student and area, 0 is time slot itself */ + const DiscreteKey& key(size_t s, boost::optional area = boost::none) const; + + /** addStudent has to be called after adding slots and faculty */ + void addStudent(const std::string& studentName, const std::string& area1, + const std::string& area2, const std::string& area3, + const std::string& advisor); + + /// current number of students + size_t nrStudents() const { + return students_.size(); + } + + const std::string& studentName(size_t i) const; + const DiscreteKey& studentKey(size_t i) const; + const std::string& studentArea(size_t i, size_t area) const; + + /** Add student-specific constraints to the graph */ + void addStudentSpecificConstraints(size_t i, boost::optional slot = boost::none); + + /** Main routine that builds factor graph */ + void buildGraph(size_t mutexBound = 7); + + /** print */ + void print(const std::string& s = "Scheduler") const; + + /** Print readable form of assignment */ + void printAssignment(sharedValues assignment) const; + + /** Special print for single-student case */ + void printSpecial(sharedValues assignment) const; + + /** Accumulate faculty stats */ + void accumulateStats(sharedValues assignment, + std::vector& stats) const; + + /** Eliminate, return a Bayes net */ + DiscreteBayesNet::shared_ptr eliminate() const; + + /** Find the best total assignment - can be expensive */ + sharedValues optimalAssignment() const; + + /** find the assignment of students to slots with most possible committees */ + sharedValues bestSchedule() const; + + /** find the corresponding most desirable committee assignment */ + sharedValues bestAssignment(sharedValues bestSchedule) const; + + }; // Scheduler + +} // gtsam + + diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp new file mode 100644 index 000000000..4d808543a --- /dev/null +++ b/gtsam/discrete/Signature.cpp @@ -0,0 +1,217 @@ +/* + * Signature.cpp + * @brief: signatures for conditional densities + * @author: Frank dellaert + * @date Feb 27, 2011 + */ + +#include +#include + +#include "Signature.h" + +#ifdef BOOST_HAVE_PARSER +#include // for parsing +#include // for qi::_val +#endif + +namespace gtsam { + + using namespace std; + + +#ifdef BOOST_HAVE_PARSER + namespace qi = boost::spirit::qi; + + // parser for strings of form "99/1 80/20" etc... + namespace parser { + typedef string::const_iterator It; + using boost::phoenix::val; + using boost::phoenix::ref; + using boost::phoenix::push_back; + + // Special rows, true and false + Signature::Row createF() { + Signature::Row r(2); + r[0] = 1; + r[1] = 0; + return r; + } + Signature::Row createT() { + Signature::Row r(2); + r[0] = 0; + r[1] = 1; + return r; + } + Signature::Row T = createT(), F = createF(); + + // Special tables (inefficient, but do we care for user input?) + Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { + Signature::Table t(4); + t[0] = ff ? T : F; + t[1] = ft ? T : F; + t[2] = tf ? T : F; + t[3] = tt ? T : F; + return t; + } + + struct Grammar { + qi::rule table, or_, and_, rows; + qi::rule true_, false_, row; + Grammar() { + table = or_ | and_ | rows; + or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; + and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; + rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + row = qi::double_ >> +("/" >> qi::double_); + true_ = qi::lit("T")[qi::_val = T]; + false_ = qi::lit("F")[qi::_val = F]; + } + } grammar; + + // Create simpler parsing function to avoid the issue of only parsing a single row + bool parse_table(const string& spec, Signature::Table& table) { + // check for OR, AND on whole phrase + It f = spec.begin(), l = spec.end(); + if (qi::parse(f, l, + qi::lit("OR")[ref(table) = logic(false, true, true, true)]) || + qi::parse(f, l, + qi::lit("AND")[ref(table) = logic(false, false, false, true)])) + return true; + + // tokenize into separate rows + istringstream iss(spec); + string token; + while (iss >> token) { + Signature::Row values; + It tf = token.begin(), tl = token.end(); + bool r = qi::parse(tf, tl, + qi::double_[push_back(ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ref(values), qi::_1)]) | + qi::lit("T")[ref(values) = T] | + qi::lit("F")[ref(values) = F] ); + if (!r) + return false; + table.push_back(values); + } + + return true; + } + } // \namespace parser +#endif + + ostream& operator <<(ostream &os, const Signature::Row &row) { + os << row[0]; + for (size_t i = 1; i < row.size(); i++) + os << " " << row[i]; + return os; + } + + ostream& operator <<(ostream &os, const Signature::Table &table) { + for (size_t i = 0; i < table.size(); i++) + os << table[i] << endl; + return os; + } + + Signature::Signature(const DiscreteKey& key) : + key_(key) { + } + + DiscreteKeys Signature::discreteKeysParentsFirst() const { + DiscreteKeys keys; + BOOST_FOREACH(const DiscreteKey& key, parents_) + keys.push_back(key); + keys.push_back(key_); + return keys; + } + + vector Signature::indices() const { + vector js; + js.push_back(key_.first); + BOOST_FOREACH(const DiscreteKey& key, parents_) + js.push_back(key.first); + return js; + } + + vector Signature::cpt() const { + vector cpt; + if (table_) { + BOOST_FOREACH(const Row& row, *table_) + BOOST_FOREACH(const double& x, row) + cpt.push_back(x); + } + return cpt; + } + + Signature& Signature::operator,(const DiscreteKey& parent) { + parents_.push_back(parent); + return *this; + } + + static void normalize(Signature::Row& row) { + double sum = 0; + for (size_t i = 0; i < row.size(); i++) + sum += row[i]; + for (size_t i = 0; i < row.size(); i++) + row[i] /= sum; + } + + Signature& Signature::operator=(const string& spec) { + spec_.reset(spec); +#ifdef BOOST_HAVE_PARSER + Table table; + // NOTE: using simpler parse function to ensure boost back compatibility +// parser::It f = spec.begin(), l = spec.end(); + bool success = // +// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar + parser::parse_table(spec, table); + if (success) { + BOOST_FOREACH(Row& row, table) + normalize(row); + table_.reset(table); + } +#endif + return *this; + } + + Signature& Signature::operator=(const Table& t) { + Table table = t; + BOOST_FOREACH(Row& row, table) + normalize(row); + table_.reset(table); + return *this; + } + + ostream& operator <<(ostream &os, const Signature &s) { + os << s.key_.first; + if (s.parents_.empty()) { + os << " % "; + } else { + os << " | " << s.parents_[0].first; + for (size_t i = 1; i < s.parents_.size(); i++) + os << " && " << s.parents_[i].first; + os << " = "; + } + os << (s.spec_ ? *s.spec_ : "no spec") << endl; + if (s.table_) + os << (*s.table_); + else + os << "spec could not be parsed" << endl; + return os; + } + + Signature operator|(const DiscreteKey& key, const DiscreteKey& parent) { + Signature s(key); + return s, parent; + } + + Signature operator%(const DiscreteKey& key, const string& parent) { + Signature s(key); + return s = parent; + } + + Signature operator%(const DiscreteKey& key, const Signature::Table& parent) { + Signature s(key); + return s = parent; + } + +} // namespace gtsam diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h new file mode 100644 index 000000000..84d34c1fa --- /dev/null +++ b/gtsam/discrete/Signature.h @@ -0,0 +1,129 @@ +/* + * Signature.h + * @brief: signatures for conditional densities + * @author: Frank dellaert + * @date Feb 27, 2011 + */ + +#pragma once +#include +#include +#include +#include + +#include // for checking whether we are using boost 1.40 +#if BOOST_VERSION >= 104200 +#define BOOST_HAVE_PARSER +#endif + +namespace gtsam { + + /** + * Signature for a discrete conditional density, used to construct conditionals. + * + * The format is (Key % string) for nodes with no parents, + * and (Key | Key, Key = string) for nodes with parents. + * + * The string specifies a conditional probability spec in the 00 01 10 11 order. + * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... + * + * For example, given the following keys + * + * Key A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), + * B("Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea"); + * + * These are all valid signatures (Asia network example): + * + * A % "99/1" + * S % "50/50" + * T|A = "99/1 95/5" + * L|S = "99/1 90/10" + * B|S = "70/30 40/60" + * E|T,L = "F F F 1" + * X|E = "95/5 2/98" + * D|E,B = "9/1 2/8 3/7 1/9" + */ + class Signature { + + public: + + /** Data type for the CPT */ + typedef std::vector Row; + typedef std::vector Table; + + private: + + /** the variable key */ + DiscreteKey key_; + + /** the parent keys */ + DiscreteKeys parents_; + + // the given CPT specification string + boost::optional spec_; + + // the CPT as parsed, if successful + boost::optional table_; + + public: + + /** Constructor from DiscreteKey */ + Signature(const DiscreteKey& key); + + /** the variable key */ + const DiscreteKey& key() const { + return key_; + } + + /** the parent keys */ + const DiscreteKeys& parents() const { + return parents_; + } + + /** All keys, with variable key last */ + DiscreteKeys discreteKeysParentsFirst() const; + + /** All key indices, with variable key first */ + std::vector indices() const; + + // the CPT as parsed, if successful + const boost::optional
& table() const { + return table_; + } + + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; + + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); + + /** Add the CPT spec - Fails in boost 1.40 */ + Signature& operator=(const std::string& spec); + + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + friend std::ostream& operator <<(std::ostream &os, const Signature &s); + }; + + /** + * Helper function to create Signature objects + * example: Signature s = D | E; + */ + Signature operator|(const DiscreteKey& key, const DiscreteKey& parent); + + /** + * Helper function to create Signature objects + * example: Signature s(D % "99/1"); + * Uses string parser, which requires BOOST 1.42 or higher + */ + Signature operator%(const DiscreteKey& key, const std::string& parent); + + /** + * Helper function to create Signature objects, using table construction directly + * example: Signature s(D % table); + */ + Signature operator%(const DiscreteKey& key, const Signature::Table& parent); + +} diff --git a/gtsam/discrete/SingleValue.cpp b/gtsam/discrete/SingleValue.cpp new file mode 100644 index 000000000..8d5fd0d8d --- /dev/null +++ b/gtsam/discrete/SingleValue.cpp @@ -0,0 +1,78 @@ +/* + * SingleValue.cpp + * @brief domain constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + + using namespace std; + + /* ************************************************************************* */ + void SingleValue::print(const string& s) const { + cout << s << ": SingleValue on " << keys_[0] << " (j=" << keys_[0] + << ") with value " << value_ << endl; + } + + /* ************************************************************************* */ + double SingleValue::operator()(const Values& values) const { + return (double) (values.at(keys_[0]) == value_); + } + + /* ************************************************************************* */ + SingleValue::operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) + table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool SingleValue::ensureArcConsistency(size_t j, + vector& domains) const { + if (j != keys_[0]) throw invalid_argument( + "SingleValue check on wrong domain"); + Domain& D = domains[j]; + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(),value_); + return true; + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr SingleValue::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr SingleValue::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared < SingleValue > (discreteKey(), value_); + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/SingleValue.h b/gtsam/discrete/SingleValue.h new file mode 100644 index 000000000..fc3d166fd --- /dev/null +++ b/gtsam/discrete/SingleValue.h @@ -0,0 +1,72 @@ +/* + * SingleValue.h + * @brief domain constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * SingleValue constraint + */ + class SingleValue: public DiscreteFactor { + + /// Number of values + size_t cardinality_; + + /// allowed value + size_t value_; + + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0],cardinality_); + } + + public: + + typedef boost::shared_ptr shared_ptr; + + /// Constructor + SingleValue(Index key, size_t n, size_t value) : + DiscreteFactor(key), cardinality_(n), value_(value) { + } + + /// Constructor + SingleValue(const DiscreteKey& dkey, size_t value) : + DiscreteFactor(dkey.first), cardinality_(dkey.second), value_(value) { + } + + // print + virtual void print(const std::string& s = "") const; + + /// Calculate value + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply( + const Values& values) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector& domains) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/TypedDiscreteFactor.cpp b/gtsam/discrete/TypedDiscreteFactor.cpp new file mode 100644 index 000000000..86f902e89 --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactor.cpp @@ -0,0 +1,117 @@ +/* + * @file TypedDiscreteFactor.cpp + * @brief + * @author Duy-Nguyen Ta + * @date Mar 5, 2011 + */ + +#include +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + TypedDiscreteFactor::TypedDiscreteFactor(const Indices& keys, + const string& table) : + Factor (keys.begin(), keys.end()), potentials_(keys, table) { + } + + /* ******************************************************************************** */ + TypedDiscreteFactor::TypedDiscreteFactor(const Indices& keys, + const vector& table) : + Factor (keys.begin(), keys.end()), potentials_(keys, table) { + //#define DEBUG_FACTORS +#ifdef DEBUG_FACTORS + static size_t count = 0; + string dotfile = (boost::format("Factor-%03d") % ++count).str(); + potentials_.dot(dotfile); + if (count == 57) potentials_.print("57"); +#endif + } + + /* ************************************************************************* */ + double TypedDiscreteFactor::operator()(const Values& values) const { + return potentials_(values); + } + + /* ************************************************************************* */ + void TypedDiscreteFactor::print(const string&s) const { + Factor::print(s); + potentials_.print(); + } + + /* ************************************************************************* */ + bool TypedDiscreteFactor::equals(const TypedDiscreteFactor& other, double tol) const { + return potentials_.equals(other.potentials_, tol); + } + + /* ******************************************************************************** */ + DiscreteFactor::shared_ptr TypedDiscreteFactor::toDiscreteFactor( + const KeyOrdering& ordering) const { + throw std::runtime_error("broken"); + //return boost::make_shared(keys(), ordering, potentials_); + } + +#ifdef OLD +DiscreteFactor TypedDiscreteFactor::toDiscreteFactor( + const KeyOrdering& ordering, const ProblemType problemType) const { + { + static bool debug = false; + + // instantiate vector keys and column index in order + DiscreteFactor::ColumnIndex orderColumnIndex; + vector keys; + BOOST_FOREACH(const KeyOrdering::value_type& ord, ordering) + { + if (debug) cout << "Key: " << ord.first; + + // find the key with ord.first in this factor + vector::const_iterator it = std::find(keys_.begin(), + keys_.end(), ord.first); + + // if found + if (it != keys_.end()) { + if (debug) cout << "it found: " << (*it) << ", index: " + << ord.second << endl; + + keys.push_back(ord.second); // push back the ordering index + orderColumnIndex[ord.second] = columnIndex_.at(ord.first.name()); + + if (debug) cout << "map " << ord.second << " with name: " + << ord.first.name() << " to " << columnIndex_.at( + ord.first.name()) << endl; + } + } + + DiscreteFactor f(keys, potentials_, orderColumnIndex, problemType); + return f; + } + + /* ******************************************************************************** */ + std::vector TypedDiscreteFactor::init(const Indices& keys) { + vector cardinalities; + for (size_t j = 0; j < keys.size(); j++) { + Index key = keys[j]; + keys_.push_back(key); + columnIndex_[key.name()] = j; + cardinalities.push_back(key.cardinality()); + } + return cardinalities; + } + + /* ******************************************************************************** */ + double TypedDiscreteFactor::potential(const TypedValues& values) const { + vector assignment(values.size()); + BOOST_FOREACH(const TypedValues::value_type& val, values) + if (columnIndex_.find(val.first) != columnIndex_.end()) assignment[columnIndex_.at( + val.first)] = val.second; + return potentials_(assignment); + } + +#endif + +} // namespace diff --git a/gtsam/discrete/TypedDiscreteFactor.h b/gtsam/discrete/TypedDiscreteFactor.h new file mode 100644 index 000000000..3d9fd6ee6 --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactor.h @@ -0,0 +1,68 @@ +/* + * @file TypedDiscreteFactor.h + * @brief + * @author Duy-Nguyen Ta + * @date Mar 5, 2011 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * A factor on discrete variables with string keys + */ + class TypedDiscreteFactor: public Factor { + + typedef AlgebraicDecisionDiagram ADD; + + /** potentials of the factor */ + ADD potentials_; + + public: + + /** A map from keys to values */ + typedef ADD::Assignment Values; + + /** Constructor from keys and string table */ + TypedDiscreteFactor(const Indices& keys, const std::string& table); + + /** Constructor from keys and doubles */ + TypedDiscreteFactor(const Indices& keys, + const std::vector& table); + + /** Evaluate */ + double operator()(const Values& values) const; + + // Testable + bool equals(const TypedDiscreteFactor& other, double tol = 1e-9) const; + void print(const std::string& s = "DiscreteFactor: ") const; + + DiscreteFactor::shared_ptr toDiscreteFactor(const KeyOrdering& ordering) const; + +#ifdef OLD + /** map each variable name to its column index in the potential table */ + typedef std::map Index2IndexMap; + Index2IndexMap columnIndex_; + + /** Initialize keys, column index, and return cardinalities */ + std::vector init(const Indices& keys); + + public: + + /** Default constructor */ + TypedDiscreteFactor() {} + + /** Evaluate potential of a given assignment of values */ + double potential(const TypedValues& values) const; + +#endif + + }; // TypedDiscreteFactor + +} // namespace diff --git a/gtsam/discrete/TypedDiscreteFactorGraph.cpp b/gtsam/discrete/TypedDiscreteFactorGraph.cpp new file mode 100644 index 000000000..2f8c2d22b --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactorGraph.cpp @@ -0,0 +1,68 @@ +/* + * @file TypedDiscreteFactorGraph.cpp + * @brief + * @author Duy-Nguyen Ta + * @date Mar 1, 2011 + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************* */ + TypedDiscreteFactorGraph::TypedDiscreteFactorGraph() { + } + + /* ************************************************************************* */ + TypedDiscreteFactorGraph::TypedDiscreteFactorGraph(const string& filename) { + bool success = parseUAI(filename, *this); + if (!success) throw runtime_error( + "TypedDiscreteFactorGraph constructor from filename failed"); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::add// + (const Indices& keys, const string& table) { + push_back(boost::shared_ptr// + (new TypedDiscreteFactor(keys, table))); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::add// + (const Indices& keys, const vector& table) { + push_back(boost::shared_ptr// + (new TypedDiscreteFactor(keys, table))); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::print(const string s) { + cout << s << endl; + cout << "Factors: " << endl; + BOOST_FOREACH(const sharedFactor factor, factors_) + factor->print(); + } + + /* ************************************************************************* */ + double TypedDiscreteFactorGraph::operator()( + const TypedDiscreteFactor::Values& values) const { + // Loop over all factors and multiply their probabilities + double p = 1.0; + BOOST_FOREACH(const sharedFactor& factor, *this) + p *= (*factor)(values); + return p; + } + +/* ************************************************************************* */ + +} diff --git a/gtsam/discrete/TypedDiscreteFactorGraph.h b/gtsam/discrete/TypedDiscreteFactorGraph.h new file mode 100644 index 000000000..010113bda --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactorGraph.h @@ -0,0 +1,50 @@ +/* + * @file TypedDiscreteFactorGraph.h + * @brief Factor graph with typed factors (with Index keys) + * @author Duy-Nguyen Ta + * @author Frank Dellaert + * @date Mar 1, 2011 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * Typed discrete factor graph, where keys are strings + */ + class TypedDiscreteFactorGraph: public FactorGraph { + + public: + + /** + * Default constructor + */ + TypedDiscreteFactorGraph(); + + /** + * Constructor from file + * For now assumes in .uai format from UAI'08 Probablistic Inference Evaluation + * See http://graphmod.ics.uci.edu/uai08/FileFormat + */ + TypedDiscreteFactorGraph(const std::string& filename); + + // Add factors without shared pointer ugliness + void add(const Indices& keys, const std::string& table); + void add(const Indices& keys, const std::vector& table); + + /** print */ + void print(const std::string s); + + /** Evaluate potential of a given assignment of values */ + double operator()(const TypedDiscreteFactor::Values& values) const; + + }; // TypedDiscreteFactorGraph + + +} // namespace diff --git a/gtsam/discrete/examples/Doodle.csv b/gtsam/discrete/examples/Doodle.csv new file mode 100644 index 000000000..1ce4ecebb --- /dev/null +++ b/gtsam/discrete/examples/Doodle.csv @@ -0,0 +1 @@ +,Ron Arkin,Andrea Thomaz,Ayanna Howard,Wayne Book,Mike Stilman,Charlie Kemp,Jun Ueda,Patricio Vela,Magnus Egerstedt,Harvey Lipkin,Frank Dellaert,Irfan Essa,Aaron Bobick,Jim Rehg,Henrik Christensen,Tucker Balch,Karen Feigh,N/A 1,N/A 2 Mon 9:00-10.30,,1,1,1,1,,1,1,,,1,,,,1,,,1,1 Mon 10:30-12:00,,1,1,1,1,,,1,1,,,,,,,,,1,1 Mon 1:30-3:00,,,1,,,1,1,1,1,1,1,,,,1,,,1,1 Mon 3:00-4:30,,,,1,,1,1,1,,1,1,1,,1,1,,,1,1 Tue 9:00-10.30,,,1,,,,,1,,1,1,,,1,1,,,1,1 Tue 10:30-12:00,,,1,1,1,,1,1,,1,1,,,1,,1,,1,1 Tue 1:30-3:00,,1,,1,1,,1,1,1,1,1,,,,,1,,1,1 Tue 3:00-4:30,,1,1,,,,,1,1,1,1,,,,1,1,,1,1 Wed 9:00-10.30,,,1,1,,,,,1,,1,,1,,1,,1,1,1 Wed 10:30-12:00,,,,1,1,,1,1,1,,,1,1,,1,1,1,1,1 Wed 1:30-3:00,,,,,1,1,,1,,1,1,1,1,,,1,,1,1 Wed 3:00-4:30,,,,,1,1,1,1,1,,1,1,,1,,1,,1,1 Thu 9:00-10.30,,,1,,,,,1,,1,,,1,1,,,,1,1 Thu 10:30-12:00,,,1,1,1,,1,1,,1,,,,1,,1,,1,1 Thu 1:30-3:00,,,1,1,1,,1,1,1,1,,,,1,,1,1,1,1 Thu 3:00-4:30,,,1,,,,,,1,1,,,,,1,1,1,1,1 Fri 9:00-10.30,,,1,1,1,1,1,1,,,1,,,,1,,,1,1 Fri 10:30-12:00,,,1,1,1,,1,1,,,,1,,,1,,,1,1 Fri 1:30-3:00,,,1,1,1,,,1,,1,1,1,1,,,,,1,1 Fri 3:00-4:30,,,,,,,,,,1,1,,1,,,,,1,1 \ No newline at end of file diff --git a/gtsam/discrete/examples/Doodle.xls b/gtsam/discrete/examples/Doodle.xls new file mode 100644 index 000000000..c607581e9 Binary files /dev/null and b/gtsam/discrete/examples/Doodle.xls differ diff --git a/gtsam/discrete/examples/Doodle2012.csv b/gtsam/discrete/examples/Doodle2012.csv new file mode 100644 index 000000000..54520b614 --- /dev/null +++ b/gtsam/discrete/examples/Doodle2012.csv @@ -0,0 +1 @@ +,Karen Feigh,Henrik Christensen,Panos Tsiotras,Ron Arkin,Andrea Thomaz,Magnus Egerstedt,Charles Isbell,Fumin Zhang,Mike Stilman,Jun Ueda,Aaron Bobick,Ayanna Howard,Patricio Vela,Charlie Kemp,Tucker Balch Mon 9:00 AM - 10:30 AM,,,1,1,1,1,1,,,,1,,,, Mon 10:30 AM - 12:00 PM,1,,,1,1,,1,1,1,,1,,1,1,1 Mon 1:30 PM - 3:00 PM,1,1,1,,,1,1,1,1,1,1,1,1,,1 Mon 3:00 PM - 4:30 PM,,,1,1,,,1,,1,,1,1,1,,1 Mon 4:30 PM - 6:00 PM,,1,1,,,,,1,,1,1,,1,, Tue 9:00 AM - 10:30 AM,,1,1,,1,1,1,,,,1,1,,, Tue 10:30 AM - 12:00 PM,1,1,1,1,1,,1,1,,1,1,,1,,1 Tue 1:30 PM - 3:00 PM,1,1,1,,1,1,,1,1,1,1,,,1, Tue 3:00 PM - 4:30 PM,,1,,,1,1,,1,,,,,,, Tue 4:30 PM - 6:00 PM,,,,,1,,,1,1,,1,,,, Wed 9:00 AM - 10:30 AM,1,1,1,,1,,1,,,,,1,,, Wed 10:30 AM - 12:00 PM,1,,,,1,1,1,1,1,1,,,1,1, Wed 1:30 PM - 3:00 PM,1,,1,,,1,,1,1,1,,,1,, Wed 3:00 PM - 4:30 PM,,,1,,,,,,1,,,,1,,1 Wed 4:30 PM - 6:00 PM,,,1,,,,,1,,,,,1,, Thu 9:00 AM - 10:30 AM,,1,1,,,1,,,,,,,,, Thu 10:30 AM - 12:00 PM,1,1,,,,1,,1,,1,,,,,1 Thu 1:30 PM - 3:00 PM,1,,,,,1,,1,,,,,,, Thu 3:00 PM - 4:30 PM,,1,1,,,1,1,1,,,,,,, Thu 4:30 PM - 6:00 PM,,1,1,,,,,1,,,,,,, Fri 9:00 AM - 10:30 AM,1,1,,,,,1,,,,,1,,, Fri 10:30 AM - 12:00 PM,1,1,,,,,,1,1,1,,,,,1 Fri 1:30 PM - 3:00 PM,1,,,,,1,,1,1,1,,,1,1,1 Fri 3:00 PM - 4:30 PM,1,,,,,,1,1,1,1,,,,,1 Fri 4:30 PM - 6:00 PM,,1,,,,,,1,,,,,1,, \ No newline at end of file diff --git a/gtsam/discrete/examples/Doodle2012.xls b/gtsam/discrete/examples/Doodle2012.xls new file mode 100644 index 000000000..981e2dc25 Binary files /dev/null and b/gtsam/discrete/examples/Doodle2012.xls differ diff --git a/gtsam/discrete/examples/intrusive.xlsx b/gtsam/discrete/examples/intrusive.xlsx new file mode 100644 index 000000000..53fd048e2 Binary files /dev/null and b/gtsam/discrete/examples/intrusive.xlsx differ diff --git a/gtsam/discrete/examples/schedulingExample.cpp b/gtsam/discrete/examples/schedulingExample.cpp new file mode 100644 index 000000000..ff3f8a26f --- /dev/null +++ b/gtsam/discrete/examples/schedulingExample.cpp @@ -0,0 +1,344 @@ +/* + * schedulingExample.cpp + * @brief hard scheduling example + * @date March 25, 2011 + * @author Frank Dellaert + */ + +//#define ENABLE_TIMING +#define ADD_NO_CACHING +#define ADD_NO_PRUNING +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace boost::assign; +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +void addStudent(Scheduler& s, size_t i) { + switch (i) { + case 0: + s.addStudent("Michael N", "AI", "Autonomy", "Perception", "Tucker Balch"); + break; + case 1: + s.addStudent("Tucker H", "Controls", "AI", "Perception", "Jim Rehg"); + break; + case 2: + s.addStudent("Jake H", "Controls", "AI", "Perception", "Henrik Christensen"); + break; + case 3: + s.addStudent("Tobias K", "Controls", "AI", "Autonomy", "Mike Stilman"); + break; + case 4: + s.addStudent("Shu J", "Controls", "AI", "HRI", "N/A 1"); + break; + case 5: + s.addStudent("Akansel C", "AI", "Autonomy", "Mechanics", + "Henrik Christensen"); + break; + case 6: + s.addStudent("Tiffany C", "Controls", "N/A 1", "N/A 2", "Charlie Kemp"); + break; + } +} +/* ************************************************************************* */ +Scheduler largeExample(size_t nrStudents = 7) { + string path("/Users/dellaert/borg/gtsam/gtsam/discrete/examples/"); + Scheduler s(nrStudents, path + "Doodle.csv"); + + s.addArea("Harvey Lipkin", "Mechanics"); + s.addArea("Wayne Book", "Mechanics"); + s.addArea("Jun Ueda", "Mechanics"); + + // s.addArea("Wayne Book", "Controls"); + s.addArea("Patricio Vela", "Controls"); + s.addArea("Magnus Egerstedt", "Controls"); + s.addArea("Jun Ueda", "Controls"); + + // s.addArea("Frank Dellaert", "Perception"); + s.addArea("Jim Rehg", "Perception"); + s.addArea("Irfan Essa", "Perception"); + s.addArea("Aaron Bobick", "Perception"); + s.addArea("Henrik Christensen", "Perception"); + + s.addArea("Mike Stilman", "AI"); + s.addArea("Henrik Christensen", "AI"); + s.addArea("Frank Dellaert", "AI"); + s.addArea("Ayanna Howard", "AI"); + // s.addArea("Tucker Balch", "AI"); + + s.addArea("Ayanna Howard", "Autonomy"); + // s.addArea("Andrea Thomaz", "Autonomy"); + s.addArea("Charlie Kemp", "Autonomy"); + s.addArea("Tucker Balch", "Autonomy"); + s.addArea("Ron Arkin", "Autonomy"); + + s.addArea("Andrea Thomaz", "HRI"); + s.addArea("Karen Feigh", "HRI"); + s.addArea("Charlie Kemp", "HRI"); + + // Allow students not to take three areas + s.addArea("N/A 1", "N/A 1"); + s.addArea("N/A 2", "N/A 2"); + + // add students + for (size_t i = 0; i < nrStudents; i++) + addStudent(s, i); + + return s; +} + +/* ************************************************************************* */ +void runLargeExample() { + + Scheduler scheduler = largeExample(); + scheduler.print(); + + // BUILD THE GRAPH ! + size_t addMutex = 2; + scheduler.buildGraph(addMutex); + + // Do brute force product and output that to file + if (scheduler.nrStudents() == 1) { // otherwise too slow + DecisionTreeFactor product = scheduler.product(); + product.dot("scheduling-large", false); + } + + // Do exact inference + // SETDEBUG("timing-verbose", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", true); + tic(2, "large"); + DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + scheduler.printAssignment(MPE); +} + +/* ************************************************************************* */ +// Solve a series of relaxed problems for maximum flexibility solution +void solveStaged(size_t addMutex = 2) { + + // super-hack! just count... + bool debug = false; + SETDEBUG("DiscreteConditional::COUNT", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + // make a vector with slot availability, initially all 1 + // Reads file to get count :-) + vector slotsAvailable(largeExample(0).nrTimeSlots(), 1.0); + + // now, find optimal value for each student, using relaxed mutex constraints + for (size_t s = 0; s < 7; s++) { + // add all students first time, then drop last one second time, etc... + Scheduler scheduler = largeExample(7 - s); + //scheduler.print(str(boost::format("Scheduler %d") % (7-s))); + + // only allow slots not yet taken + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(addMutex); + + // Do EXACT INFERENCE + tic_("eliminate"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc_("eliminate"); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(6 - s); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + + // remove this slot from consideration + slotsAvailable[bestSlot] = 0.0; + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(6-s) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + } + tictoc_print_(); + + // Solution with addMutex = 2: (20 secs) + // TC = Wed 2 (9), count = 96375041778 + // AC = Tue 2 (5), count = 4076088090 + // SJ = Mon 1 (0), count = 29596704 + // TK = Mon 3 (2), count = 755370 + // JH = Wed 4 (11), count = 12000 + // TH = Fri 2 (17), count = 220 + // MN = Fri 1 (16), count = 5 + // + // Mutex does make a difference !! + +} + +/* ************************************************************************* */ +// Sample from solution found above and evaluate cost function +bool NonZero(size_t i) { + return i > 0; +} + +DiscreteBayesNet::shared_ptr createSampler(size_t i, + size_t slot, vector& schedulers) { + Scheduler scheduler = largeExample(0); // todo: wrong nr students + addStudent(scheduler, i); + SETDEBUG("Scheduler::buildGraph", false); + scheduler.addStudentSpecificConstraints(0, slot); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + // chordal->print(scheduler[i].studentKey(0).name()); // large ! + schedulers.push_back(scheduler); + return chordal; +} + +void sampleSolutions() { + + vector schedulers; + vector samplers(7); + + // Given the time-slots, we can create 7 independent samplers + vector slots; + slots += 16, 17, 11, 2, 0, 5, 9; // given slots + for (size_t i = 0; i < 7; i++) + samplers[i] = createSampler(i, slots[i], schedulers); + + // now, sample schedules + for (size_t n = 0; n < 500; n++) { + vector stats(19, 0); + vector samples; + for (size_t i = 0; i < 7; i++) { + samples.push_back(sample(*samplers[i])); + schedulers[i].accumulateStats(samples[i], stats); + } + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); + if (nz >= 15 && max <= 2) { + cout << boost::format( + "Sampled schedule %d, min = %d, nz = %d, max = %d\n") % (n + 1) % min + % nz % max; + for (size_t i = 0; i < 7; i++) { + cout << schedulers[i].studentName(0) << " : " << schedulers[i].slotName( + slots[i]) << endl; + schedulers[i].printSpecial(samples[i]); + } + } + } + // Output was + // Sampled schedule 359, min = 0, nz = 15, max = 2 + // Michael N : Fri 9:00-10.30 + // Michael N AI: Frank Dellaert + // Michael N Autonomy: Charlie Kemp + // Michael N Perception: Henrik Christensen + // + // Tucker H : Fri 10:30-12:00 + // Tucker H AI: Ayanna Howard + // Tucker H Controls: Patricio Vela + // Tucker H Perception: Irfan Essa + // + // Jake H : Wed 3:00-4:30 + // Jake H AI: Mike Stilman + // Jake H Controls: Magnus Egerstedt + // Jake H Perception: Jim Rehg + // + // Tobias K : Mon 1:30-3:00 + // Tobias K AI: Ayanna Howard + // Tobias K Autonomy: Charlie Kemp + // Tobias K Controls: Magnus Egerstedt + // + // Shu J : Mon 9:00-10.30 + // Shu J AI: Mike Stilman + // Shu J Controls: Jun Ueda + // Shu J HRI: Andrea Thomaz + // + // Akansel C : Tue 10:30-12:00 + // Akansel C AI: Frank Dellaert + // Akansel C Autonomy: Tucker Balch + // Akansel C Mechanics: Harvey Lipkin + // + // Tiffany C : Wed 10:30-12:00 + // Tiffany C Controls: Patricio Vela + // Tiffany C N/A 1: N/A 1 + // Tiffany C N/A 2: N/A 2 + +} + +/* ************************************************************************* */ +void accomodateStudent() { + + // super-hack! just count... + bool debug = false; + // SETDEBUG("DiscreteConditional::COUNT",true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + Scheduler scheduler = largeExample(0); + // scheduler.addStudent("Victor E", "Autonomy", "HRI", "AI", + // "Henrik Christensen"); + scheduler.addStudent("Carlos N", "Perception", "AI", "Autonomy", + "Henrik Christensen"); + scheduler.print("scheduler"); + + // rule out all occupied slots + vector slots; + slots += 16, 17, 11, 2, 0, 5, 9, 14; + vector slotsAvailable(scheduler.nrTimeSlots(), 1.0); + BOOST_FOREACH(size_t s, slots) + slotsAvailable[s] = 0; + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(1); + + // Do EXACT INFERENCE + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + // GTSAM_PRINT(*chordal); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(0); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + + // sample schedules + for (size_t n = 0; n < 10; n++) { + Scheduler::sharedValues sample0 = sample(*chordal); + scheduler.printAssignment(sample0); + } +} + +/* ************************************************************************* */ +int main() { + runLargeExample(); + solveStaged(3); +// sampleSolutions(); + // accomodateStudent(); + return 0; +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/examples/schedulingQuals12.cpp b/gtsam/discrete/examples/schedulingQuals12.cpp new file mode 100644 index 000000000..7571fbc58 --- /dev/null +++ b/gtsam/discrete/examples/schedulingQuals12.cpp @@ -0,0 +1,264 @@ +/* + * schedulingExample.cpp + * @brief hard scheduling example + * @date March 25, 2011 + * @author Frank Dellaert + */ + +#define ENABLE_TIMING +#define ADD_NO_CACHING +#define ADD_NO_PRUNING +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace boost::assign; +using namespace std; +using namespace gtsam; + +size_t NRSTUDENTS = 9; + +bool NonZero(size_t i) { + return i > 0; +} + +/* ************************************************************************* */ +void addStudent(Scheduler& s, size_t i) { + switch (i) { + case 0: + s.addStudent("Pan, Yunpeng", "Controls", "Perception", "Mechanics", "Eric Johnson"); + break; + case 1: + s.addStudent("Sawhney, Rahul", "Controls", "AI", "Perception", "Henrik Christensen"); + break; + case 2: + s.addStudent("Akgun, Baris", "Controls", "AI", "HRI", "Andrea Thomaz"); + break; + case 3: + s.addStudent("Jiang, Shu", "Controls", "AI", "Perception", "Ron Arkin"); + break; + case 4: + s.addStudent("Grice, Phillip", "Controls", "Perception", "HRI", "Charlie Kemp"); + break; + case 5: + s.addStudent("Huaman, Ana", "Controls", "AI", "Perception", "Mike Stilman"); + break; + case 6: + s.addStudent("Levihn, Martin", "AI", "Autonomy", "Perception", "Mike Stilman"); + break; + case 7: + s.addStudent("Nieto, Carlos", "AI", "Autonomy", "Perception", "Henrik Christensen"); + break; + case 8: + s.addStudent("Robinette, Paul", "Controls", "AI", "HRI", "Ayanna Howard"); + break; + } +} + +/* ************************************************************************* */ +Scheduler largeExample(size_t nrStudents = NRSTUDENTS) { + string path("/Users/dellaert/borg/gtsam/gtsam/discrete/examples/"); + Scheduler s(nrStudents, path + "Doodle2012.csv"); + + s.addArea("Harvey Lipkin", "Mechanics"); + s.addArea("Jun Ueda", "Mechanics"); + + s.addArea("Patricio Vela", "Controls"); + s.addArea("Magnus Egerstedt", "Controls"); + s.addArea("Jun Ueda", "Controls"); + s.addArea("Panos Tsiotras", "Controls"); + s.addArea("Fumin Zhang", "Controls"); + + s.addArea("Henrik Christensen", "Perception"); + s.addArea("Aaron Bobick", "Perception"); + + s.addArea("Mike Stilman", "AI"); +// s.addArea("Henrik Christensen", "AI"); + s.addArea("Ayanna Howard", "AI"); + s.addArea("Charles Isbell", "AI"); + s.addArea("Tucker Balch", "AI"); + + s.addArea("Ayanna Howard", "Autonomy"); + s.addArea("Charlie Kemp", "Autonomy"); + s.addArea("Tucker Balch", "Autonomy"); + s.addArea("Ron Arkin", "Autonomy"); + + s.addArea("Andrea Thomaz", "HRI"); + s.addArea("Karen Feigh", "HRI"); + s.addArea("Charlie Kemp", "HRI"); + + // add students + for (size_t i = 0; i < nrStudents; i++) + addStudent(s, i); + + return s; +} + +/* ************************************************************************* */ +void runLargeExample() { + + Scheduler scheduler = largeExample(); + scheduler.print(); + + // BUILD THE GRAPH ! + size_t addMutex = 3; + // SETDEBUG("Scheduler::buildGraph", true); + scheduler.buildGraph(addMutex); + + // Do brute force product and output that to file + if (scheduler.nrStudents() == 1) { // otherwise too slow + DecisionTreeFactor product = scheduler.product(); + product.dot("scheduling-large", false); + } + + // Do exact inference + // SETDEBUG("timing-verbose", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", true); +#define SAMPLE +#ifdef SAMPLE + tic(2, "large"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + for (size_t i=0;i<100;i++) { + DiscreteFactor::sharedValues assignment = sample(*chordal); + vector stats(scheduler.nrFaculty()); + scheduler.accumulateStats(assignment, stats); + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); +// cout << min << ", " << max << ", " << nz << endl; + if (nz >= 13 && min >=1 && max <= 4) { + cout << "======================================================\n"; + scheduler.printAssignment(assignment); + } + } +#else + tic(2, "large"); + DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + scheduler.printAssignment(MPE); +#endif +} + +/* ************************************************************************* */ +// Solve a series of relaxed problems for maximum flexibility solution +void solveStaged(size_t addMutex = 2) { + + // super-hack! just count... + bool debug = false; + SETDEBUG("DiscreteConditional::COUNT", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + // make a vector with slot availability, initially all 1 + // Reads file to get count :-) + vector slotsAvailable(largeExample(0).nrTimeSlots(), 1.0); + + // now, find optimal value for each student, using relaxed mutex constraints + for (size_t s = 0; s < NRSTUDENTS; s++) { + // add all students first time, then drop last one second time, etc... + Scheduler scheduler = largeExample(NRSTUDENTS - s); + //scheduler.print(str(boost::format("Scheduler %d") % (NRSTUDENTS-s))); + + // only allow slots not yet taken + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(addMutex); + + // Do EXACT INFERENCE + tic_("eliminate"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc_("eliminate"); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + + // remove this slot from consideration + slotsAvailable[bestSlot] = 0.0; + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(NRSTUDENTS-1-s) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + } + tictoc_print_(); +} + +/* ************************************************************************* */ +// Sample from solution found above and evaluate cost function +DiscreteBayesNet::shared_ptr createSampler(size_t i, + size_t slot, vector& schedulers) { + Scheduler scheduler = largeExample(0); // todo: wrong nr students + addStudent(scheduler, i); + SETDEBUG("Scheduler::buildGraph", false); + scheduler.addStudentSpecificConstraints(0, slot); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + // chordal->print(scheduler[i].studentKey(0).name()); // large ! + schedulers.push_back(scheduler); + return chordal; +} + +void sampleSolutions() { + + vector schedulers; + vector samplers(NRSTUDENTS); + + // Given the time-slots, we can create NRSTUDENTS independent samplers + vector slots; + slots += 3, 20, 2, 6, 5, 11, 1, 4; // given slots + for (size_t i = 0; i < NRSTUDENTS; i++) + samplers[i] = createSampler(i, slots[i], schedulers); + + // now, sample schedules + for (size_t n = 0; n < 500; n++) { + vector stats(19, 0); + vector samples; + for (size_t i = 0; i < NRSTUDENTS; i++) { + samples.push_back(sample(*samplers[i])); + schedulers[i].accumulateStats(samples[i], stats); + } + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); + if (nz >= 15 && max <= 2) { + cout << boost::format( + "Sampled schedule %d, min = %d, nz = %d, max = %d\n") % (n + 1) % min + % nz % max; + for (size_t i = 0; i < NRSTUDENTS; i++) { + cout << schedulers[i].studentName(0) << " : " << schedulers[i].slotName( + slots[i]) << endl; + schedulers[i].printSpecial(samples[i]); + } + } + } +} + +/* ************************************************************************* */ +int main() { + runLargeExample(); +// solveStaged(3); +// sampleSolutions(); + return 0; +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/examples/small.csv b/gtsam/discrete/examples/small.csv new file mode 100644 index 000000000..144ead08c --- /dev/null +++ b/gtsam/discrete/examples/small.csv @@ -0,0 +1 @@ +,Frank,Harvey,Magnus,Andrea Mon,1,1,1, Wed,1,1,1,1 Fri,,1,1,1 \ No newline at end of file diff --git a/gtsam/discrete/parseUAI.cpp b/gtsam/discrete/parseUAI.cpp new file mode 100644 index 000000000..91296ab9b --- /dev/null +++ b/gtsam/discrete/parseUAI.cpp @@ -0,0 +1,157 @@ +/* + * parseUAI.cpp + * @brief: parse UAI 2008 format + * @date March 5, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +//#define PARSE +#ifdef PARSE +#include +#include // for parsing +#include // for ref +#include +#include +#include +#include + +#include + +using namespace std; +namespace qi = boost::spirit::qi; + +namespace gtsam { + + /* ************************************************************************* */ + // Keys are the vars of variables connected to a factor + // subclass of Indices with special constructor + struct Keys: public Indices { + Keys() { + } + // Pick correct vars based on indices + Keys(const Indices& vars, const vector& indices) { + BOOST_FOREACH(int i, indices) + push_back(vars[i]); + } + }; + + /* ************************************************************************* */ + // The UAI grammar is defined in a class + // Spirit local variables are used, see + // http://boost-spirit.com/home/2010/01/21/what-are-rule-bound-semantic-actions + /* ************************************************************************* */ + struct Grammar { + + // declare all parsers as instance variables + typedef vector Table; + typedef boost::spirit::istream_iterator It; + qi::rule uai, preamble, type, vars, factors, tables; + qi::rule > keys; + qi::rule > table; + + // Variables filled by preamble parser + size_t nrVars_, nrFactors_; + Indices vars_; + vector factors_; + + // Variables filled by tables parser + vector
tables_; + + // The constructor defines the parser rules (declared below) + // To debug, just say debug(rule) after defining the rule + Grammar() { + using boost::phoenix::val; + using boost::phoenix::ref; + using boost::phoenix::construct; + using namespace boost::spirit::qi; + + //--------------- high level parsers with side-effects :-( ----------------- + + // A uai file consists of preamble followed by tables + uai = preamble >> tables; + + // The preamble defines the variables and factors + // The parser fills in the first set of variables above, + // including the vector of factor "Neighborhoods" + preamble = type >> vars >> int_[ref(nrFactors_) = _1] >> factors; + + // type string, does not seem to matter + type = lit("BAYES") | lit("MARKOV"); + + // vars parses "3 2 2 3" and synthesizes a Keys class, in this case + // containing Indices {v0,2}, {v1,2}, and {v2,3} + vars = int_[ref(nrVars_) = _1] >> (repeat(ref(nrVars_))[int_]) // + [ref(vars_) = construct (_1)]; + + // Parse a list of Neighborhoods and fill factors_ + factors = (repeat(ref(nrFactors_))[keys])// + [ref(factors_) = _1]; + + // The tables parser fills in the tables_ + tables = (repeat(ref(nrFactors_))[table])// + [ref(tables_) = _1]; + + //----------- basic parsers with synthesized attributes :-) ----------------- + + // keys parses strings like "2 1 2", indicating + // a binary factor (2) on variables v1 and v2. + // It returns a Keys class as attribute + keys = int_[_a = _1] >> repeat(_a)[int_] // + [_val = construct (ref(vars_), _1)]; + + // The tables are a list of doubles preceded by a count, e.g. "4 1.0 2.0 3.0 4.0" + // The table parser returns a PotentialTable::Table attribute + table = int_[_a = _1] >> repeat(_a)[double_] // + [_val = construct
(_1)]; + } + + // Add the factors to the graph + void addFactorsToGraph(TypedDiscreteFactorGraph& graph) { + assert(factors_.size()==nrFactors_); + assert(tables_.size()==nrFactors_); + for (size_t i = 0; i < nrFactors_; i++) + graph.add(factors_[i], tables_[i]); + } + + }; + + /* ************************************************************************* */ + bool parseUAI(const std::string& filename, TypedDiscreteFactorGraph& graph) { + + // open file, disable skipping of whitespace + std::ifstream in(filename.c_str()); + if (!in) { + cerr << "Could not open " << filename << endl; + return false; + } + + in.unsetf(std::ios::skipws); + + // wrap istream into iterator + boost::spirit::istream_iterator first(in); + boost::spirit::istream_iterator last; + + // Parse and add factors into the graph + Grammar grammar; + bool success = qi::phrase_parse(first, last, grammar.uai, qi::space); + if (success) grammar.addFactorsToGraph(graph); + + return success; + } +/* ************************************************************************* */ + +}// gtsam +#else + +#include + +namespace gtsam { + +/** Dummy version of function - otherwise, missing symbol */ +bool parseUAI(const std::string& filename, TypedDiscreteFactorGraph& graph) { + return false; +} + +} // \namespace gtsam +#endif diff --git a/gtsam/discrete/parseUAI.h b/gtsam/discrete/parseUAI.h new file mode 100644 index 000000000..da9cb0f47 --- /dev/null +++ b/gtsam/discrete/parseUAI.h @@ -0,0 +1,22 @@ +/* + * parseUAI.h + * @brief: parse UAI 2008 format + * @date March 5, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { + + /** + * Constructor from file + * For now assumes in .uai format from UAI'08 Probablistic Inference Evaluation + * See http://graphmod.ics.uci.edu/uai08/FileFormat + */ + bool parseUAI(const std::string& filename, + gtsam::TypedDiscreteFactorGraph& graph); + +} // gtsam diff --git a/gtsam/discrete/tests/data/FG/alarm.fg b/gtsam/discrete/tests/data/FG/alarm.fg new file mode 100644 index 000000000..40fbb6f9d --- /dev/null +++ b/gtsam/discrete/tests/data/FG/alarm.fg @@ -0,0 +1,935 @@ +# ALARM network +# from http://compbio.cs.huji.ac.il/Repository/Datasets/alarm/alarm.dsc +37 + +2 +0 5 +2 2 +4 +0 0.9 +1 0.1 +2 0.01 +3 0.99 + +2 +1 4 +3 3 +9 +0 0.95 +1 0.04 +2 0.01 +3 0.04 +4 0.95 +5 0.01 +6 0.01 +7 0.29 +8 0.7 + +2 +2 4 +3 3 +9 +0 0.95 +1 0.04 +2 0.01 +3 0.04 +4 0.95 +5 0.01 +6 0.01 +7 0.04 +8 0.95 + +1 +3 +2 +2 +0 0.2 +1 0.8 + +3 +3 4 5 +2 3 2 +12 +0 0.95 +1 0.01 +2 0.04 +3 0.09 +4 0.01 +5 0.9 +6 0.98 +7 0.05 +8 0.01 +9 0.9 +10 0.01 +11 0.05 + +1 +5 +2 +2 +0 0.05 +1 0.95 + +3 +3 5 6 +2 2 3 +12 +0 0.98 +1 0.5 +2 0.95 +3 0.05 +4 0.01 +5 0.49 +6 0.04 +7 0.9 +8 0.01 +9 0.01 +10 0.01 +11 0.05 + +1 +7 +2 +2 +0 0.05 +1 0.95 + +3 +7 8 34 +2 3 3 +18 +0 0.98 +1 0.98 +2 0.01 +3 0.01 +4 0.01 +5 0.01 +6 0.4 +7 0.01 +8 0.59 +9 0.98 +10 0.01 +11 0.01 +12 0.3 +13 0.01 +14 0.4 +15 0.01 +16 0.3 +17 0.98 + +3 +9 10 34 +3 2 3 +18 +0 0.333 +1 0.333 +2 0.333 +3 0.98 +4 0.01 +5 0.01 +6 0.333 +7 0.333 +8 0.333 +9 0.01 +10 0.98 +11 0.01 +12 0.333 +13 0.333 +14 0.333 +15 0.01 +16 0.01 +17 0.98 + +1 +10 +2 +2 +0 0.1 +1 0.9 + +3 +10 11 34 +2 3 3 +18 +0 0.333 +1 0.98 +2 0.333 +3 0.01 +4 0.333 +5 0.01 +6 0.333 +7 0.01 +8 0.333 +9 0.98 +10 0.333 +11 0.01 +12 0.333 +13 0.01 +14 0.333 +15 0.01 +16 0.333 +17 0.98 + +1 +12 +2 +2 +0 0.1 +1 0.9 + +1 +13 +2 +2 +0 0.01 +1 0.99 + +2 +13 14 +2 3 +6 +0 0.98 +1 0.3 +2 0.01 +3 0.4 +4 0.01 +5 0.3 + +3 +15 30 32 +4 4 3 +48 +0 0.97 +1 0.01 +2 0.01 +3 0.01 +4 0.01 +5 0.97 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.97 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.97 +16 0.01 +17 0.97 +18 0.01 +19 0.01 +20 0.97 +21 0.01 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.97 +27 0.01 +28 0.01 +29 0.01 +30 0.01 +31 0.97 +32 0.01 +33 0.97 +34 0.01 +35 0.01 +36 0.01 +37 0.01 +38 0.97 +39 0.01 +40 0.97 +41 0.01 +42 0.01 +43 0.01 +44 0.01 +45 0.01 +46 0.01 +47 0.97 + +1 +16 +2 +2 +0 0.04 +1 0.96 + +3 +17 24 30 +4 3 4 +48 +0 0.97 +1 0.01 +2 0.01 +3 0.01 +4 0.97 +5 0.01 +6 0.01 +7 0.01 +8 0.97 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.97 +14 0.01 +15 0.01 +16 0.6 +17 0.38 +18 0.01 +19 0.01 +20 0.01 +21 0.97 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.97 +27 0.01 +28 0.5 +29 0.48 +30 0.01 +31 0.01 +32 0.01 +33 0.01 +34 0.97 +35 0.01 +36 0.01 +37 0.01 +38 0.01 +39 0.97 +40 0.5 +41 0.48 +42 0.01 +43 0.01 +44 0.01 +45 0.01 +46 0.01 +47 0.97 + +1 +18 +2 +2 +0 0.05 +1 0.95 + +3 +18 19 31 +2 3 4 +19 +0 1 +1 1 +6 0.99 +7 0.95 +8 0.01 +9 0.04 +11 0.01 +12 0.95 +13 0.01 +14 0.04 +15 0.95 +16 0.01 +17 0.04 +18 0.95 +19 0.01 +20 0.04 +21 0.01 +22 0.01 +23 0.98 + +3 +19 20 23 +3 3 2 +18 +0 0.98 +1 0.01 +2 0.98 +3 0.01 +4 0.01 +5 0.01 +6 0.01 +7 0.98 +8 0.01 +9 0.01 +10 0.98 +11 0.69 +12 0.98 +13 0.01 +14 0.3 +15 0.01 +16 0.01 +17 0.01 + +2 +21 22 +3 2 +6 +0 0.01 +1 0.19 +2 0.8 +3 0.05 +4 0.9 +5 0.05 + +1 +22 +2 +2 +0 0.01 +1 0.99 + +3 +22 23 24 +2 2 3 +12 +0 0.1 +1 0.95 +2 0.9 +3 0.05 +4 0.1 +5 0.95 +6 0.9 +7 0.05 +8 0.01 +9 0.05 +10 0.99 +11 0.95 + +1 +24 +3 +3 +0 0.92 +1 0.03 +2 0.05 + +4 +16 24 25 29 +2 3 4 4 +96 +0 0.97 +1 0.97 +2 0.97 +3 0.97 +4 0.97 +5 0.97 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.01 +16 0.01 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.01 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.1 +27 0.4 +28 0.01 +29 0.01 +30 0.3 +31 0.97 +32 0.84 +33 0.58 +34 0.29 +35 0.9 +36 0.49 +37 0.01 +38 0.05 +39 0.01 +40 0.3 +41 0.08 +42 0.2 +43 0.01 +44 0.01 +45 0.01 +46 0.4 +47 0.01 +48 0.01 +49 0.01 +50 0.05 +51 0.2 +52 0.01 +53 0.01 +54 0.01 +55 0.01 +56 0.25 +57 0.75 +58 0.01 +59 0.01 +60 0.08 +61 0.97 +62 0.25 +63 0.04 +64 0.08 +65 0.38 +66 0.9 +67 0.01 +68 0.45 +69 0.01 +70 0.9 +71 0.6 +72 0.01 +73 0.01 +74 0.01 +75 0.2 +76 0.01 +77 0.01 +78 0.01 +79 0.01 +80 0.15 +81 0.7 +82 0.01 +83 0.01 +84 0.01 +85 0.01 +86 0.25 +87 0.09 +88 0.01 +89 0.01 +90 0.97 +91 0.97 +92 0.59 +93 0.01 +94 0.97 +95 0.97 + +1 +26 +2 +2 +0 0.1 +1 0.9 + +1 +27 +3 +3 +0 0.05 +1 0.9 +2 0.05 + +2 +27 28 +3 4 +12 +0 0.05 +1 0.05 +2 0.05 +3 0.93 +4 0.01 +5 0.01 +6 0.01 +7 0.93 +8 0.01 +9 0.01 +10 0.01 +11 0.93 + +3 +26 28 29 +2 4 4 +32 +0 0.97 +1 0.97 +2 0.97 +3 0.01 +4 0.97 +5 0.01 +6 0.97 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.97 +12 0.01 +13 0.01 +14 0.01 +15 0.01 +16 0.01 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.97 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.01 +31 0.97 + +4 +16 24 29 30 +2 3 4 4 +96 +0 0.97 +1 0.97 +2 0.97 +3 0.97 +4 0.97 +5 0.97 +6 0.95 +7 0.01 +8 0.97 +9 0.97 +10 0.95 +11 0.01 +12 0.4 +13 0.01 +14 0.97 +15 0.97 +16 0.5 +17 0.01 +18 0.3 +19 0.01 +20 0.97 +21 0.97 +22 0.3 +23 0.01 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.03 +31 0.97 +32 0.01 +33 0.01 +34 0.03 +35 0.97 +36 0.58 +37 0.01 +38 0.01 +39 0.01 +40 0.48 +41 0.01 +42 0.68 +43 0.01 +44 0.01 +45 0.01 +46 0.68 +47 0.01 +48 0.01 +49 0.01 +50 0.01 +51 0.01 +52 0.01 +53 0.01 +54 0.01 +55 0.01 +56 0.01 +57 0.01 +58 0.01 +59 0.01 +60 0.01 +61 0.97 +62 0.01 +63 0.01 +64 0.01 +65 0.97 +66 0.01 +67 0.01 +68 0.01 +69 0.01 +70 0.01 +71 0.01 +72 0.01 +73 0.01 +74 0.01 +75 0.01 +76 0.01 +77 0.01 +78 0.01 +79 0.01 +80 0.01 +81 0.01 +82 0.01 +83 0.01 +84 0.01 +85 0.01 +86 0.01 +87 0.01 +88 0.01 +89 0.01 +90 0.01 +91 0.97 +92 0.01 +93 0.01 +94 0.01 +95 0.97 + +3 +24 30 31 +3 4 4 +48 +0 0.97 +1 0.97 +2 0.97 +3 0.01 +4 0.01 +5 0.03 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.97 +16 0.97 +17 0.95 +18 0.01 +19 0.01 +20 0.94 +21 0.01 +22 0.01 +23 0.88 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.97 +31 0.97 +32 0.04 +33 0.01 +34 0.01 +35 0.1 +36 0.01 +37 0.01 +38 0.01 +39 0.01 +40 0.01 +41 0.01 +42 0.01 +43 0.01 +44 0.01 +45 0.97 +46 0.97 +47 0.01 + +2 +31 32 +4 3 +12 +0 0.01 +1 0.01 +2 0.04 +3 0.9 +4 0.01 +5 0.01 +6 0.92 +7 0.09 +8 0.98 +9 0.98 +10 0.04 +11 0.01 + +5 +12 14 20 32 33 +2 3 3 3 2 +108 +0 0.01 +1 0.05 +2 0.01 +3 0.7 +4 0.01 +5 0.95 +6 0.01 +7 0.05 +8 0.01 +9 0.7 +10 0.05 +11 0.95 +12 0.01 +13 0.05 +14 0.05 +15 0.7 +16 0.05 +17 0.95 +18 0.01 +19 0.05 +20 0.01 +21 0.7 +22 0.01 +23 0.99 +24 0.01 +25 0.05 +26 0.01 +27 0.7 +28 0.05 +29 0.99 +30 0.01 +31 0.05 +32 0.05 +33 0.7 +34 0.05 +35 0.99 +36 0.01 +37 0.01 +38 0.01 +39 0.1 +40 0.01 +41 0.3 +42 0.01 +43 0.01 +44 0.01 +45 0.1 +46 0.01 +47 0.3 +48 0.01 +49 0.01 +50 0.01 +51 0.1 +52 0.01 +53 0.3 +54 0.99 +55 0.95 +56 0.99 +57 0.3 +58 0.99 +59 0.05 +60 0.99 +61 0.95 +62 0.99 +63 0.3 +64 0.95 +65 0.05 +66 0.99 +67 0.95 +68 0.95 +69 0.3 +70 0.95 +71 0.05 +72 0.99 +73 0.95 +74 0.99 +75 0.3 +76 0.99 +77 0.01 +78 0.99 +79 0.95 +80 0.99 +81 0.3 +82 0.95 +83 0.01 +84 0.99 +85 0.95 +86 0.95 +87 0.3 +88 0.95 +89 0.01 +90 0.99 +91 0.99 +92 0.99 +93 0.9 +94 0.99 +95 0.7 +96 0.99 +97 0.99 +98 0.99 +99 0.9 +100 0.99 +101 0.7 +102 0.99 +103 0.99 +104 0.99 +105 0.9 +106 0.99 +107 0.7 + +2 +33 34 +2 3 +6 +0 0.05 +1 0.01 +2 0.9 +3 0.09 +4 0.05 +5 0.9 + +3 +6 34 35 +3 3 3 +27 +0 0.98 +1 0.95 +2 0.3 +3 0.95 +4 0.04 +5 0.01 +6 0.8 +7 0.01 +8 0.01 +9 0.01 +10 0.04 +11 0.69 +12 0.04 +13 0.95 +14 0.3 +15 0.19 +16 0.04 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.01 +22 0.01 +23 0.69 +24 0.01 +25 0.95 +26 0.98 + +3 +14 35 36 +3 3 3 +27 +0 0.98 +1 0.98 +2 0.3 +3 0.98 +4 0.1 +5 0.05 +6 0.9 +7 0.05 +8 0.01 +9 0.01 +10 0.01 +11 0.6 +12 0.01 +13 0.85 +14 0.4 +15 0.09 +16 0.2 +17 0.09 +18 0.01 +19 0.01 +20 0.1 +21 0.01 +22 0.05 +23 0.55 +24 0.01 +25 0.75 +26 0.9 diff --git a/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai new file mode 100644 index 000000000..aacf458ed --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai @@ -0,0 +1,18 @@ +MARKOV +3 +2 2 3 +3 +1 0 +2 0 1 +2 1 2 + +2 + 0.436 0.564 + +4 + 0.128 0.872 + 0.920 0.080 + +6 + 0.210 0.333 0.457 + 0.811 0.000 0.189 \ No newline at end of file diff --git a/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid new file mode 100644 index 000000000..59f3e67a5 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid @@ -0,0 +1,3 @@ +2 + 1 0 + 2 1 \ No newline at end of file diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai b/gtsam/discrete/tests/data/UAI/uai08_test1.uai new file mode 100644 index 000000000..d205773fc --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai @@ -0,0 +1,996 @@ +BAYES +54 +2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 +54 +4 1 17 27 0 +4 49 32 38 1 +2 22 2 +2 19 3 +1 4 +2 49 5 +3 17 19 6 +4 36 26 28 7 +1 8 +5 22 51 4 28 9 +4 8 23 26 10 +2 5 11 +1 12 +3 27 53 13 +3 38 12 14 +3 32 29 15 +3 23 45 16 +4 49 38 23 17 +6 32 36 33 43 39 18 +4 23 26 44 19 +5 31 22 38 23 20 +6 22 29 34 37 40 21 +1 22 +5 49 31 2 29 23 +2 4 24 +3 49 5 25 +4 22 5 29 26 +2 2 27 +3 25 47 28 +1 29 +2 2 30 +1 31 +1 32 +3 42 41 33 +3 36 25 34 +5 32 38 5 41 35 +1 36 +10 36 27 19 53 18 46 50 35 11 37 +2 22 38 +3 38 26 39 +1 40 +1 41 +3 2 25 42 +4 1 23 4 43 +1 44 +4 34 7 0 45 +4 2 30 33 46 +5 49 23 26 27 47 +3 31 44 48 +1 49 +3 26 41 50 +1 51 +4 36 8 23 52 +3 32 26 53 + +16 + 0.285714 0.714286 + 0.461538 0.538462 + 0.307692 0.692308 + 0.300000 0.700000 + 0.333333 0.666667 + 0.714286 0.285714 + 0.588235 0.411765 + 0.588235 0.411765 + +16 + 0.625000 0.375000 + 0.750000 0.250000 + 0.625000 0.375000 + 0.166667 0.833333 + 0.555556 0.444444 + 0.545455 0.454545 + 0.500000 0.500000 + 0.428571 0.571429 + +4 + 0.666667 0.333333 + 0.571429 0.428571 + +4 + 0.461538 0.538462 + 0.272727 0.727273 + +2 + 0.230769 0.769231 + +4 + 0.625000 0.375000 + 0.583333 0.416667 + +8 + 0.800000 0.200000 + 0.500000 0.500000 + 0.333333 0.666667 + 0.250000 0.750000 + +16 + 0.411765 0.588235 + 0.500000 0.500000 + 0.769231 0.230769 + 0.692308 0.307692 + 0.625000 0.375000 + 0.600000 0.400000 + 0.833333 0.166667 + 0.571429 0.428571 + +2 + 0.700000 0.300000 + +32 + 0.833333 0.166667 + 0.250000 0.750000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.526316 0.473684 + 0.500000 0.500000 + 0.428571 0.571429 + 0.500000 0.500000 + 0.444444 0.555556 + 0.666667 0.333333 + 0.636364 0.363636 + 0.384615 0.615385 + 0.222222 0.777778 + 0.411765 0.588235 + 0.526316 0.473684 + 0.583333 0.416667 + +16 + 0.357143 0.642857 + 0.363636 0.636364 + 0.166667 0.833333 + 0.777778 0.222222 + 0.473684 0.526316 + 0.538462 0.461538 + 0.500000 0.500000 + 0.470588 0.529412 + +4 + 0.470588 0.529412 + 0.400000 0.600000 + +2 + 0.500000 0.500000 + +8 + 0.473684 0.526316 + 0.555556 0.444444 + 0.411765 0.588235 + 0.714286 0.285714 + +8 + 0.800000 0.200000 + 0.444444 0.555556 + 0.600000 0.400000 + 0.666667 0.333333 + +8 + 0.500000 0.500000 + 0.333333 0.666667 + 0.625000 0.375000 + 0.692308 0.307692 + +8 + 0.285714 0.714286 + 0.714286 0.285714 + 0.500000 0.500000 + 0.500000 0.500000 + +16 + 0.833333 0.166667 + 0.454545 0.545455 + 0.625000 0.375000 + 0.250000 0.750000 + 0.727273 0.272727 + 0.588235 0.411765 + 0.400000 0.600000 + 0.500000 0.500000 + +64 + 0.166667 0.833333 + 0.666667 0.333333 + 0.692308 0.307692 + 0.538462 0.461538 + 0.500000 0.500000 + 0.250000 0.750000 + 0.437500 0.562500 + 0.473684 0.526316 + 0.769231 0.230769 + 0.400000 0.600000 + 0.555556 0.444444 + 0.272727 0.727273 + 0.473684 0.526316 + 0.818182 0.181818 + 0.750000 0.250000 + 0.416667 0.583333 + 0.588235 0.411765 + 0.769231 0.230769 + 0.500000 0.500000 + 0.473684 0.526316 + 0.833333 0.166667 + 0.444444 0.555556 + 0.600000 0.400000 + 0.529412 0.470588 + 0.727273 0.272727 + 0.615385 0.384615 + 0.444444 0.555556 + 0.400000 0.600000 + 0.642857 0.357143 + 0.200000 0.800000 + 0.333333 0.666667 + 0.437500 0.562500 + +16 + 0.666667 0.333333 + 0.250000 0.750000 + 0.625000 0.375000 + 0.357143 0.642857 + 0.500000 0.500000 + 0.300000 0.700000 + 0.526316 0.473684 + 0.600000 0.400000 + +32 + 0.444444 0.555556 + 0.583333 0.416667 + 0.500000 0.500000 + 0.571429 0.428571 + 0.400000 0.600000 + 0.500000 0.500000 + 0.333333 0.666667 + 0.666667 0.333333 + 0.473684 0.526316 + 0.500000 0.500000 + 0.545455 0.454545 + 0.454545 0.545455 + 0.500000 0.500000 + 0.466667 0.533333 + 0.777778 0.222222 + 0.222222 0.777778 + +64 + 0.333333 0.666667 + 0.818182 0.181818 + 0.526316 0.473684 + 0.375000 0.625000 + 0.625000 0.375000 + 0.444444 0.555556 + 0.473684 0.526316 + 0.533333 0.466667 + 0.500000 0.500000 + 0.500000 0.500000 + 0.363636 0.636364 + 0.300000 0.700000 + 0.250000 0.750000 + 0.562500 0.437500 + 0.571429 0.428571 + 0.642857 0.357143 + 0.666667 0.333333 + 0.363636 0.636364 + 0.384615 0.615385 + 0.600000 0.400000 + 0.818182 0.181818 + 0.428571 0.571429 + 0.625000 0.375000 + 0.562500 0.437500 + 0.583333 0.416667 + 0.529412 0.470588 + 0.529412 0.470588 + 0.545455 0.454545 + 0.333333 0.666667 + 0.230769 0.769231 + 0.500000 0.500000 + 0.230769 0.769231 + +2 + 0.588235 0.411765 + +32 + 0.333333 0.666667 + 0.333333 0.666667 + 0.428571 0.571429 + 0.600000 0.400000 + 0.750000 0.250000 + 0.666667 0.333333 + 0.411765 0.588235 + 0.583333 0.416667 + 0.800000 0.200000 + 0.545455 0.454545 + 0.333333 0.666667 + 0.375000 0.625000 + 0.571429 0.428571 + 0.285714 0.714286 + 0.555556 0.444444 + 0.461538 0.538462 + +4 + 0.500000 0.500000 + 0.625000 0.375000 + +8 + 0.727273 0.272727 + 0.461538 0.538462 + 0.777778 0.222222 + 0.400000 0.600000 + +16 + 0.555556 0.444444 + 0.600000 0.400000 + 0.571429 0.428571 + 0.833333 0.166667 + 0.777778 0.222222 + 0.357143 0.642857 + 0.285714 0.714286 + 0.642857 0.357143 + +4 + 0.461538 0.538462 + 0.250000 0.750000 + +8 + 0.692308 0.307692 + 0.529412 0.470588 + 0.437500 0.562500 + 0.666667 0.333333 + +2 + 0.727273 0.272727 + +4 + 0.500000 0.500000 + 0.571429 0.428571 + +2 + 0.375000 0.625000 + +2 + 0.428571 0.571429 + +8 + 0.666667 0.333333 + 0.444444 0.555556 + 0.500000 0.500000 + 0.416667 0.583333 + +8 + 0.357143 0.642857 + 0.461538 0.538462 + 0.272727 0.727273 + 0.411765 0.588235 + +32 + 0.470588 0.529412 + 0.466667 0.533333 + 0.700000 0.300000 + 0.555556 0.444444 + 0.444444 0.555556 + 0.666667 0.333333 + 0.466667 0.533333 + 0.466667 0.533333 + 0.200000 0.800000 + 0.588235 0.411765 + 0.166667 0.833333 + 0.333333 0.666667 + 0.526316 0.473684 + 0.562500 0.437500 + 0.333333 0.666667 + 0.700000 0.300000 + +2 + 0.166667 0.833333 + +1024 + 0.250000 0.750000 + 0.307692 0.692308 + 0.500000 0.500000 + 0.666667 0.333333 + 0.818182 0.181818 + 0.500000 0.500000 + 0.625000 0.375000 + 0.615385 0.384615 + 0.500000 0.500000 + 0.285714 0.714286 + 0.230769 0.769231 + 0.692308 0.307692 + 0.333333 0.666667 + 0.625000 0.375000 + 0.437500 0.562500 + 0.625000 0.375000 + 0.272727 0.727273 + 0.636364 0.363636 + 0.181818 0.818182 + 0.500000 0.500000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.818182 0.181818 + 0.437500 0.562500 + 0.500000 0.500000 + 0.750000 0.250000 + 0.375000 0.625000 + 0.625000 0.375000 + 0.700000 0.300000 + 0.466667 0.533333 + 0.411765 0.588235 + 0.666667 0.333333 + 0.750000 0.250000 + 0.285714 0.714286 + 0.250000 0.750000 + 0.571429 0.428571 + 0.555556 0.444444 + 0.428571 0.571429 + 0.500000 0.500000 + 0.666667 0.333333 + 0.571429 0.428571 + 0.222222 0.777778 + 0.615385 0.384615 + 0.461538 0.538462 + 0.250000 0.750000 + 0.666667 0.333333 + 0.200000 0.800000 + 0.384615 0.615385 + 0.300000 0.700000 + 0.466667 0.533333 + 0.625000 0.375000 + 0.562500 0.437500 + 0.583333 0.416667 + 0.500000 0.500000 + 0.727273 0.272727 + 0.571429 0.428571 + 0.250000 0.750000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.545455 0.454545 + 0.333333 0.666667 + 0.666667 0.333333 + 0.461538 0.538462 + 0.181818 0.818182 + 0.714286 0.285714 + 0.666667 0.333333 + 0.470588 0.529412 + 0.500000 0.500000 + 0.470588 0.529412 + 0.500000 0.500000 + 0.416667 0.583333 + 0.625000 0.375000 + 0.625000 0.375000 + 0.692308 0.307692 + 0.500000 0.500000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.600000 0.400000 + 0.461538 0.538462 + 0.500000 0.500000 + 0.500000 0.500000 + 0.181818 0.818182 + 0.750000 0.250000 + 0.357143 0.642857 + 0.400000 0.600000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.461538 0.538462 + 0.250000 0.750000 + 0.333333 0.666667 + 0.272727 0.727273 + 0.428571 0.571429 + 0.166667 0.833333 + 0.600000 0.400000 + 0.750000 0.250000 + 0.583333 0.416667 + 0.769231 0.230769 + 0.769231 0.230769 + 0.545455 0.454545 + 0.470588 0.529412 + 0.454545 0.545455 + 0.555556 0.444444 + 0.714286 0.285714 + 0.384615 0.615385 + 0.428571 0.571429 + 0.636364 0.363636 + 0.583333 0.416667 + 0.384615 0.615385 + 0.357143 0.642857 + 0.571429 0.428571 + 0.642857 0.357143 + 0.636364 0.363636 + 0.714286 0.285714 + 0.230769 0.769231 + 0.333333 0.666667 + 0.428571 0.571429 + 0.533333 0.466667 + 0.625000 0.375000 + 0.444444 0.555556 + 0.357143 0.642857 + 0.555556 0.444444 + 0.500000 0.500000 + 0.333333 0.666667 + 0.384615 0.615385 + 0.600000 0.400000 + 0.333333 0.666667 + 0.700000 0.300000 + 0.500000 0.500000 + 0.545455 0.454545 + 0.800000 0.200000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.666667 0.333333 + 0.692308 0.307692 + 0.400000 0.600000 + 0.692308 0.307692 + 0.666667 0.333333 + 0.555556 0.444444 + 0.666667 0.333333 + 0.222222 0.777778 + 0.562500 0.437500 + 0.500000 0.500000 + 0.666667 0.333333 + 0.230769 0.769231 + 0.555556 0.444444 + 0.307692 0.692308 + 0.800000 0.200000 + 0.400000 0.600000 + 0.666667 0.333333 + 0.285714 0.714286 + 0.500000 0.500000 + 0.444444 0.555556 + 0.555556 0.444444 + 0.272727 0.727273 + 0.600000 0.400000 + 0.428571 0.571429 + 0.400000 0.600000 + 0.526316 0.473684 + 0.333333 0.666667 + 0.750000 0.250000 + 0.636364 0.363636 + 0.333333 0.666667 + 0.750000 0.250000 + 0.500000 0.500000 + 0.818182 0.181818 + 0.375000 0.625000 + 0.333333 0.666667 + 0.625000 0.375000 + 0.583333 0.416667 + 0.230769 0.769231 + 0.769231 0.230769 + 0.800000 0.200000 + 0.636364 0.363636 + 0.384615 0.615385 + 0.562500 0.437500 + 0.727273 0.272727 + 0.250000 0.750000 + 0.600000 0.400000 + 0.538462 0.461538 + 0.750000 0.250000 + 0.428571 0.571429 + 0.300000 0.700000 + 0.555556 0.444444 + 0.692308 0.307692 + 0.230769 0.769231 + 0.333333 0.666667 + 0.454545 0.545455 + 0.666667 0.333333 + 0.583333 0.416667 + 0.454545 0.545455 + 0.562500 0.437500 + 0.666667 0.333333 + 0.500000 0.500000 + 0.250000 0.750000 + 0.625000 0.375000 + 0.588235 0.411765 + 0.818182 0.181818 + 0.500000 0.500000 + 0.250000 0.750000 + 0.636364 0.363636 + 0.181818 0.818182 + 0.333333 0.666667 + 0.411765 0.588235 + 0.500000 0.500000 + 0.428571 0.571429 + 0.230769 0.769231 + 0.333333 0.666667 + 0.562500 0.437500 + 0.666667 0.333333 + 0.600000 0.400000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.333333 0.666667 + 0.714286 0.285714 + 0.333333 0.666667 + 0.714286 0.285714 + 0.454545 0.545455 + 0.181818 0.818182 + 0.400000 0.600000 + 0.750000 0.250000 + 0.636364 0.363636 + 0.300000 0.700000 + 0.222222 0.777778 + 0.200000 0.800000 + 0.777778 0.222222 + 0.500000 0.500000 + 0.384615 0.615385 + 0.411765 0.588235 + 0.818182 0.181818 + 0.357143 0.642857 + 0.588235 0.411765 + 0.285714 0.714286 + 0.562500 0.437500 + 0.529412 0.470588 + 0.466667 0.533333 + 0.454545 0.545455 + 0.800000 0.200000 + 0.571429 0.428571 + 0.250000 0.750000 + 0.500000 0.500000 + 0.400000 0.600000 + 0.444444 0.555556 + 0.600000 0.400000 + 0.500000 0.500000 + 0.200000 0.800000 + 0.642857 0.357143 + 0.666667 0.333333 + 0.600000 0.400000 + 0.250000 0.750000 + 0.500000 0.500000 + 0.600000 0.400000 + 0.300000 0.700000 + 0.363636 0.636364 + 0.727273 0.272727 + 0.250000 0.750000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.615385 0.384615 + 0.642857 0.357143 + 0.473684 0.526316 + 0.437500 0.562500 + 0.545455 0.454545 + 0.411765 0.588235 + 0.466667 0.533333 + 0.666667 0.333333 + 0.333333 0.666667 + 0.562500 0.437500 + 0.700000 0.300000 + 0.500000 0.500000 + 0.473684 0.526316 + 0.357143 0.642857 + 0.571429 0.428571 + 0.416667 0.583333 + 0.555556 0.444444 + 0.833333 0.166667 + 0.727273 0.272727 + 0.181818 0.818182 + 0.750000 0.250000 + 0.200000 0.800000 + 0.470588 0.529412 + 0.583333 0.416667 + 0.625000 0.375000 + 0.800000 0.200000 + 0.400000 0.600000 + 0.437500 0.562500 + 0.400000 0.600000 + 0.444444 0.555556 + 0.454545 0.545455 + 0.181818 0.818182 + 0.615385 0.384615 + 0.533333 0.466667 + 0.428571 0.571429 + 0.625000 0.375000 + 0.777778 0.222222 + 0.333333 0.666667 + 0.588235 0.411765 + 0.285714 0.714286 + 0.500000 0.500000 + 0.636364 0.363636 + 0.428571 0.571429 + 0.727273 0.272727 + 0.500000 0.500000 + 0.285714 0.714286 + 0.818182 0.181818 + 0.250000 0.750000 + 0.555556 0.444444 + 0.181818 0.818182 + 0.727273 0.272727 + 0.529412 0.470588 + 0.625000 0.375000 + 0.555556 0.444444 + 0.777778 0.222222 + 0.714286 0.285714 + 0.727273 0.272727 + 0.300000 0.700000 + 0.411765 0.588235 + 0.222222 0.777778 + 0.800000 0.200000 + 0.642857 0.357143 + 0.769231 0.230769 + 0.562500 0.437500 + 0.600000 0.400000 + 0.400000 0.600000 + 0.600000 0.400000 + 0.461538 0.538462 + 0.500000 0.500000 + 0.461538 0.538462 + 0.750000 0.250000 + 0.307692 0.692308 + 0.444444 0.555556 + 0.400000 0.600000 + 0.666667 0.333333 + 0.727273 0.272727 + 0.250000 0.750000 + 0.666667 0.333333 + 0.500000 0.500000 + 0.473684 0.526316 + 0.727273 0.272727 + 0.444444 0.555556 + 0.428571 0.571429 + 0.285714 0.714286 + 0.500000 0.500000 + 0.470588 0.529412 + 0.500000 0.500000 + 0.363636 0.636364 + 0.428571 0.571429 + 0.615385 0.384615 + 0.500000 0.500000 + 0.555556 0.444444 + 0.500000 0.500000 + 0.250000 0.750000 + 0.642857 0.357143 + 0.400000 0.600000 + 0.411765 0.588235 + 0.250000 0.750000 + 0.700000 0.300000 + 0.500000 0.500000 + 0.416667 0.583333 + 0.692308 0.307692 + 0.500000 0.500000 + 0.357143 0.642857 + 0.750000 0.250000 + 0.181818 0.818182 + 0.166667 0.833333 + 0.250000 0.750000 + 0.714286 0.285714 + 0.769231 0.230769 + 0.666667 0.333333 + 0.714286 0.285714 + 0.333333 0.666667 + 0.285714 0.714286 + 0.750000 0.250000 + 0.166667 0.833333 + 0.500000 0.500000 + 0.466667 0.533333 + 0.714286 0.285714 + 0.545455 0.454545 + 0.166667 0.833333 + 0.428571 0.571429 + 0.750000 0.250000 + 0.307692 0.692308 + 0.428571 0.571429 + 0.818182 0.181818 + 0.375000 0.625000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.700000 0.300000 + 0.300000 0.700000 + 0.625000 0.375000 + 0.642857 0.357143 + 0.428571 0.571429 + 0.500000 0.500000 + 0.777778 0.222222 + 0.444444 0.555556 + 0.333333 0.666667 + 0.428571 0.571429 + 0.307692 0.692308 + 0.333333 0.666667 + 0.166667 0.833333 + 0.571429 0.428571 + 0.333333 0.666667 + 0.500000 0.500000 + 0.538462 0.461538 + 0.250000 0.750000 + 0.416667 0.583333 + 0.500000 0.500000 + 0.500000 0.500000 + 0.625000 0.375000 + 0.473684 0.526316 + 0.375000 0.625000 + 0.470588 0.529412 + 0.454545 0.545455 + 0.500000 0.500000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.363636 0.636364 + 0.600000 0.400000 + 0.166667 0.833333 + 0.769231 0.230769 + 0.588235 0.411765 + 0.642857 0.357143 + 0.636364 0.363636 + 0.833333 0.166667 + 0.166667 0.833333 + 0.470588 0.529412 + 0.700000 0.300000 + 0.700000 0.300000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.384615 0.615385 + 0.500000 0.500000 + 0.777778 0.222222 + 0.454545 0.545455 + 0.500000 0.500000 + 0.181818 0.818182 + 0.526316 0.473684 + 0.700000 0.300000 + 0.777778 0.222222 + 0.529412 0.470588 + 0.714286 0.285714 + 0.428571 0.571429 + 0.500000 0.500000 + 0.588235 0.411765 + 0.571429 0.428571 + 0.750000 0.250000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.363636 0.636364 + 0.571429 0.428571 + 0.454545 0.545455 + 0.444444 0.555556 + 0.250000 0.750000 + 0.363636 0.636364 + 0.272727 0.727273 + 0.333333 0.666667 + 0.615385 0.384615 + 0.615385 0.384615 + 0.333333 0.666667 + 0.583333 0.416667 + 0.166667 0.833333 + 0.428571 0.571429 + 0.400000 0.600000 + 0.454545 0.545455 + 0.500000 0.500000 + 0.714286 0.285714 + 0.500000 0.500000 + 0.800000 0.200000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.300000 0.700000 + 0.454545 0.545455 + 0.416667 0.583333 + 0.615385 0.384615 + 0.600000 0.400000 + 0.357143 0.642857 + 0.454545 0.545455 + 0.230769 0.769231 + 0.428571 0.571429 + 0.500000 0.500000 + 0.562500 0.437500 + 0.555556 0.444444 + 0.571429 0.428571 + 0.750000 0.250000 + 0.166667 0.833333 + 0.285714 0.714286 + 0.400000 0.600000 + 0.461538 0.538462 + 0.333333 0.666667 + 0.555556 0.444444 + 0.416667 0.583333 + 0.466667 0.533333 + 0.333333 0.666667 + 0.444444 0.555556 + 0.375000 0.625000 + 0.642857 0.357143 + 0.727273 0.272727 + 0.470588 0.529412 + 0.363636 0.636364 + 0.714286 0.285714 + 0.666667 0.333333 + 0.411765 0.588235 + 0.250000 0.750000 + 0.437500 0.562500 + 0.500000 0.500000 + 0.400000 0.600000 + 0.400000 0.600000 + 0.428571 0.571429 + 0.222222 0.777778 + +4 + 0.454545 0.545455 + 0.363636 0.636364 + +8 + 0.285714 0.714286 + 0.625000 0.375000 + 0.400000 0.600000 + 0.727273 0.272727 + +2 + 0.526316 0.473684 + +2 + 0.454545 0.545455 + +8 + 0.529412 0.470588 + 0.500000 0.500000 + 0.538462 0.461538 + 0.500000 0.500000 + +16 + 0.526316 0.473684 + 0.571429 0.428571 + 0.562500 0.437500 + 0.230769 0.769231 + 0.333333 0.666667 + 0.750000 0.250000 + 0.333333 0.666667 + 0.600000 0.400000 + +2 + 0.714286 0.285714 + +16 + 0.230769 0.769231 + 0.454545 0.545455 + 0.571429 0.428571 + 0.777778 0.222222 + 0.466667 0.533333 + 0.250000 0.750000 + 0.384615 0.615385 + 0.571429 0.428571 + +16 + 0.666667 0.333333 + 0.555556 0.444444 + 0.363636 0.636364 + 0.833333 0.166667 + 0.400000 0.600000 + 0.818182 0.181818 + 0.692308 0.307692 + 0.692308 0.307692 + +32 + 0.533333 0.466667 + 0.400000 0.600000 + 0.666667 0.333333 + 0.333333 0.666667 + 0.588235 0.411765 + 0.363636 0.636364 + 0.470588 0.529412 + 0.500000 0.500000 + 0.636364 0.363636 + 0.400000 0.600000 + 0.636364 0.363636 + 0.428571 0.571429 + 0.500000 0.500000 + 0.714286 0.285714 + 0.272727 0.727273 + 0.357143 0.642857 + +8 + 0.250000 0.750000 + 0.285714 0.714286 + 0.583333 0.416667 + 0.571429 0.428571 + +2 + 0.375000 0.625000 + +8 + 0.666667 0.333333 + 0.300000 0.700000 + 0.529412 0.470588 + 0.473684 0.526316 + +2 + 0.500000 0.500000 + +16 + 0.666667 0.333333 + 0.200000 0.800000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.470588 0.529412 + 0.533333 0.466667 + +8 + 0.307692 0.692308 + 0.470588 0.529412 + 0.333333 0.666667 + 0.333333 0.666667 + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid new file mode 100644 index 000000000..5ca206d95 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid @@ -0,0 +1,11 @@ +10 + 0 1 + 2 0 + 9 0 + 16 1 + 20 1 + 21 1 + 22 0 + 26 1 + 39 0 + 41 1 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output new file mode 100644 index 000000000..c376783fe --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output @@ -0,0 +1,3 @@ +z -2.7351873 +m 54 2 0.0 1.0 2 0.5995855 0.40041456 2 1.0 0.0 2 0.3761365 0.62386346 2 0.25656807 0.7434319 2 0.6449692 0.35503078 2 0.4957979 0.5042021 2 0.69854456 0.30145544 2 0.7 0.3 2 1.0 0.0 2 0.5303537 0.46964625 2 0.44570237 0.5542976 2 0.5 0.5 2 0.55686617 0.4431338 2 0.6284742 0.3715258 2 0.5607879 0.43921205 2 0.0 1.0 2 0.54289234 0.4571077 2 0.5770133 0.42298666 2 0.547688 0.452312 2 0.0 1.0 2 0.0 1.0 2 1.0 0.0 2 0.5760513 0.4239487 2 0.592929 0.40707102 2 0.63438964 0.36561036 2 0.0 1.0 2 0.52899235 0.47100765 2 0.5998554 0.40014458 2 0.7750039 0.22499608 2 0.50000435 0.49999565 2 0.36475798 0.63524204 2 0.44666538 0.55333465 2 0.43111995 0.56888 2 0.37207335 0.62792665 2 0.5581817 0.4418183 2 0.16809757 0.83190244 2 0.4813641 0.5186359 2 0.43732184 0.56267816 2 1.0 0.0 2 0.54721755 0.45278242 2 0.0 1.0 2 0.51865995 0.48134002 2 0.51229435 0.48770565 2 0.7142385 0.2857615 2 0.53666514 0.46333483 2 0.6171147 0.38288528 2 0.46532288 0.5346771 2 0.46330887 0.5366911 2 0.36718464 0.63281536 2 0.4735739 0.5264261 2 0.5244508 0.4755492 2 0.604569 0.395431 2 0.3945428 0.6054572 +s -11.533098 54 1 0 0 1 1 0 1 0 0 0 0 1 1 0 0 0 1 1 1 1 1 1 0 1 0 0 1 1 0 0 1 1 1 1 1 0 1 1 1 0 0 1 0 1 0 1 0 1 0 1 1 0 0 1 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai b/gtsam/discrete/tests/data/UAI/uai08_test2.uai new file mode 100644 index 000000000..a75b376ed --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai @@ -0,0 +1,269 @@ +BAYES +21 +4 4 4 4 4 4 4 4 4 2 2 2 2 2 2 2 2 2 2 2 2 +21 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +3 1 0 9 +3 1 3 10 +3 5 1 11 +3 2 6 12 +3 6 4 13 +3 3 6 14 +3 5 7 15 +3 7 2 16 +3 7 3 17 +3 0 8 18 +3 3 8 19 +3 8 4 20 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.1 0.2 0.3 0.4 + +4 + 0.25 0.25 0.25 0.25 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid new file mode 100644 index 000000000..1214f3c1b --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid @@ -0,0 +1,13 @@ +12 + 17 0 + 10 0 + 19 0 + 18 0 + 11 0 + 13 0 + 15 0 + 20 0 + 9 0 + 12 0 + 16 0 + 14 0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output new file mode 100644 index 000000000..a124d2b4c --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output @@ -0,0 +1,3 @@ +z -5.264346 +m 21 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.105094366 0.20169812 0.29830188 0.39490563 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.1 0.2 0.3 0.4 4 0.10956474 0.20318824 0.29681176 0.39043528 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 +s -5.7635098 21 3 3 3 3 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai b/gtsam/discrete/tests/data/UAI/uai08_test3.uai new file mode 100644 index 000000000..2abb99bc2 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai @@ -0,0 +1,94 @@ +MARKOV +9 +4 4 4 4 4 4 4 4 4 +13 +1 7 +2 1 0 +2 1 3 +2 5 1 +2 2 6 +2 6 4 +2 3 6 +2 5 7 +2 7 2 +2 7 3 +2 0 8 +2 3 8 +2 8 4 + +4 + 0.1 0.2 0.3 0.4 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid new file mode 100644 index 000000000..18748286e --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid @@ -0,0 +1 @@ +0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output new file mode 100644 index 000000000..1ddb8297a --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output @@ -0,0 +1,3 @@ +z -0.44786617 +m 9 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.105094366 0.20169812 0.29830188 0.39490563 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.1 0.2 0.3 0.4 4 0.10956474 0.20318824 0.29681176 0.39043528 +s -0.9470299 9 3 3 3 3 3 3 3 3 3 diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp new file mode 100644 index 000000000..0ca98f620 --- /dev/null +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -0,0 +1,524 @@ +/* + * @file testDecisionTree.cpp + * @brief Develop DecisionTree + * @author Frank Dellaert + * @date Mar 6, 2011 + */ + +#include +#include // make sure we have traits +// headers first to make sure no missing headers +//#define DT_NO_PRUNING +#include +#include // for convert only +#define DISABLE_TIMING +#include // for checking whether we are using boost 1.40 +#if BOOST_VERSION >= 104200 +#define BOOST_HAVE_PARSER +#endif + +#include +#include +#include +#include +#include +using namespace boost::assign; + +#include +#include + +using namespace std; +using namespace gtsam; + +/* ******************************************************************************** */ +typedef AlgebraicDecisionTree ADT; + +template class DecisionTree; +template class AlgebraicDecisionTree; + +#define DISABLE_DOT + +template +void dot(const T&f, const string& filename) { +#ifndef DISABLE_DOT + f.dot(filename); +#endif +} + +/** I can't get this to work ! + class Mul: boost::function { + inline double operator()(const double& a, const double& b) { + return a * b; + } + }; + + // If second argument of binary op is Leaf + template + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( + Cache& cache, const Leaf& gL, Mul op) const { + Ptr h(new Choice(label(), cardinality())); + BOOST_FOREACH(const NodePtr& branch, branches_) + h->push_back(branch->apply_f_op_g(cache, gL, op)); + return Unique(cache, h); + } + */ + +/* ******************************************************************************** */ +// instrumented operators +/* ******************************************************************************** */ +size_t muls = 0, adds = 0; +boost::timer timer; +void resetCounts() { + muls = 0; + adds = 0; + timer.restart(); +} +void printCounts(const string& s) { +#ifndef DISABLE_TIMING + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds + % (1000 * timer.elapsed()) << endl; +#endif + resetCounts(); +} +double mul(const double& a, const double& b) { + muls++; + return a * b; +} +double add_(const double& a, const double& b) { + adds++; + return a + b; +} + +/* ******************************************************************************** */ +// test ADT +TEST(ADT, example3) +{ + // Create labels + DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + + // Literals + ADT a(A, 0.5, 0.5); + ADT notb(B, 1, 0); + ADT c(C, 0.1, 0.9); + ADT d(D, 0.1, 0.9); + ADT note(E, 0.9, 0.1); + + ADT cnotb = c * notb; + dot(cnotb, "ADT-cnotb"); + +// a.print("a: "); +// cnotb.print("cnotb: "); + ADT acnotb = a * cnotb; +// acnotb.print("acnotb: "); +// acnotb.printCache("acnotb Cache:"); + + dot(acnotb, "ADT-acnotb"); + + + ADT big = apply(apply(d, note, &mul), acnotb, &add_); + dot(big, "ADT-big"); +} + +/* ******************************************************************************** */ +// Asia Bayes Network +/* ******************************************************************************** */ + +/** Convert Signature into CPT */ +ADT create(const Signature& signature) { + ADT p(signature.discreteKeysParentsFirst(), signature.cpt()); + static size_t count = 0; + const DiscreteKey& key = signature.key(); + string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, dotfile); + return p; +} + +/* ************************************************************************* */ +// test Asia Joint +TEST(ADT, joint) +{ + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pA = create(A % "99/1"); + ADT pS = create(S % "50/50"); + ADT pT = create(T | A = "99/1 95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pB = create(B | S = "70/30 40/60"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + printCounts("Asia CPTs"); + + // Create joint + resetCounts(); + ADT joint = pA; + dot(joint, "Asia-A"); + joint = apply(joint, pS, &mul); + dot(joint, "Asia-AS"); + joint = apply(joint, pT, &mul); + dot(joint, "Asia-AST"); + joint = apply(joint, pL, &mul); + dot(joint, "Asia-ASTL"); + joint = apply(joint, pB, &mul); + dot(joint, "Asia-ASTLB"); + joint = apply(joint, pE, &mul); + dot(joint, "Asia-ASTLBE"); + joint = apply(joint, pX, &mul); + dot(joint, "Asia-ASTLBEX"); + joint = apply(joint, pD, &mul); + dot(joint, "Asia-ASTLBEXD"); + EXPECT_LONGS_EQUAL(346, muls); + printCounts("Asia joint"); + + ADT pASTL = pA; + pASTL = apply(pASTL, pS, &mul); + pASTL = apply(pASTL, pT, &mul); + pASTL = apply(pASTL, pL, &mul); + + // test combine + ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_); + EXPECT(assert_equal(pA, fAa)); + ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_); + EXPECT(assert_equal(pA, fAb)); +#endif +} + +/* ************************************************************************* */ +// test Inference with joint +TEST(ADT, inference) +{ + DiscreteKey A(0,2), D(1,2),// + B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pA = create(A % "99/1"); + ADT pS = create(S % "50/50"); + ADT pT = create(T | A = "99/1 95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pB = create(B | S = "70/30 40/60"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + // printCounts("Inference CPTs"); + + // Create joint + resetCounts(); + ADT joint = pA; + dot(joint, "Joint-Product-A"); + joint = apply(joint, pS, &mul); + dot(joint, "Joint-Product-AS"); + joint = apply(joint, pT, &mul); + dot(joint, "Joint-Product-AST"); + joint = apply(joint, pL, &mul); + dot(joint, "Joint-Product-ASTL"); + joint = apply(joint, pB, &mul); + dot(joint, "Joint-Product-ASTLB"); + joint = apply(joint, pE, &mul); + dot(joint, "Joint-Product-ASTLBE"); + joint = apply(joint, pX, &mul); + dot(joint, "Joint-Product-ASTLBEX"); + joint = apply(joint, pD, &mul); + dot(joint, "Joint-Product-ASTLBEXD"); + EXPECT_LONGS_EQUAL(370, muls); // different ordering + printCounts("Asia product"); + + ADT marginal = joint; + marginal = marginal.combine(X, &add_); + dot(marginal, "Joint-Sum-ADBLEST"); + marginal = marginal.combine(T, &add_); + dot(marginal, "Joint-Sum-ADBLES"); + marginal = marginal.combine(S, &add_); + dot(marginal, "Joint-Sum-ADBLE"); + marginal = marginal.combine(E, &add_); + dot(marginal, "Joint-Sum-ADBL"); + EXPECT_LONGS_EQUAL(161, adds); + printCounts("Asia sum"); +#endif +} + +/* ************************************************************************* */ +TEST(ADT, factor_graph) +{ + DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pS = create(S % "50/50"); + ADT pT = create(T % "95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create(B | E = "1/8 7/9"); + ADT pB = create(B | S = "70/30 40/60"); + // printCounts("Create CPTs"); + + // Create joint + resetCounts(); + ADT fg = pS; + fg = apply(fg, pT, &mul); + fg = apply(fg, pL, &mul); + fg = apply(fg, pB, &mul); + fg = apply(fg, pE, &mul); + fg = apply(fg, pX, &mul); + fg = apply(fg, pD, &mul); + dot(fg, "FactorGraph"); + EXPECT_LONGS_EQUAL(158, muls); + printCounts("Asia FG"); + + fg = fg.combine(X, &add_); + dot(fg, "Marginalized-6X"); + fg = fg.combine(T, &add_); + dot(fg, "Marginalized-5T"); + fg = fg.combine(S, &add_); + dot(fg, "Marginalized-4S"); + fg = fg.combine(E, &add_); + dot(fg, "Marginalized-3E"); + fg = fg.combine(L, &add_); + dot(fg, "Marginalized-2L"); + EXPECT(adds = 54); + printCounts("marginalize"); + + // BLESTX + + // Eliminate X + ADT fE = pX; + dot(fE, "Eliminate-01-fEX"); + fE = fE.combine(X, &add_); + dot(fE, "Eliminate-02-fE"); + printCounts("Eliminate X"); + + // Eliminate T + ADT fLE = pT; + fLE = apply(fLE, pE, &mul); + dot(fLE, "Eliminate-03-fLET"); + fLE = fLE.combine(T, &add_); + dot(fLE, "Eliminate-04-fLE"); + printCounts("Eliminate T"); + + // Eliminate S + ADT fBL = pS; + fBL = apply(fBL, pL, &mul); + fBL = apply(fBL, pB, &mul); + dot(fBL, "Eliminate-05-fBLS"); + fBL = fBL.combine(S, &add_); + dot(fBL, "Eliminate-06-fBL"); + printCounts("Eliminate S"); + + // Eliminate E + ADT fBL2 = fE; + fBL2 = apply(fBL2, fLE, &mul); + fBL2 = apply(fBL2, pD, &mul); + dot(fBL2, "Eliminate-07-fBLE"); + fBL2 = fBL2.combine(E, &add_); + dot(fBL2, "Eliminate-08-fBL2"); + printCounts("Eliminate E"); + + // Eliminate L + ADT fB = fBL; + fB = apply(fB, fBL2, &mul); + dot(fB, "Eliminate-09-fBL"); + fB = fB.combine(L, &add_); + dot(fB, "Eliminate-10-fB"); + printCounts("Eliminate L"); +#endif +} + +/* ************************************************************************* */ +// test equality +TEST(ADT, equality_noparser) +{ + DiscreteKey A(0,2), B(1,2); + Signature::Table tableA, tableB; + Signature::Row rA, rB; + rA += 80, 20; rB += 60, 40; + tableA += rA; tableB += rB; + + // Check straight equality + ADT pA1 = create(A % tableA); + ADT pA2 = create(A % tableA); + EXPECT(pA1 == pA2); // should be equal + + // Check equality after apply + ADT pB = create(B % tableB); + ADT pAB1 = apply(pA1, pB, &mul); + ADT pAB2 = apply(pB, pA1, &mul); + EXPECT(pAB2 == pAB1); +} + +/* ************************************************************************* */ +#ifdef BOOST_HAVE_PARSER +// test equality +TEST(ADT, equality_parser) +{ + DiscreteKey A(0,2), B(1,2); + // Check straight equality + ADT pA1 = create(A % "80/20"); + ADT pA2 = create(A % "80/20"); + EXPECT(pA1 == pA2); // should be equal + + // Check equality after apply + ADT pB = create(B % "60/40"); + ADT pAB1 = apply(pA1, pB, &mul); + ADT pAB2 = apply(pB, pA1, &mul); + EXPECT(pAB2 == pAB1); +} +#endif + +/* ******************************************************************************** */ +// Factor graph construction +// test constructor from strings +TEST(ADT, constructor) +{ + DiscreteKey v0(0,2), v1(1,3); + Assignment x00, x01, x02, x10, x11, x12; + x00[0] = 0, x00[1] = 0; + x01[0] = 0, x01[1] = 1; + x02[0] = 0, x02[1] = 2; + x10[0] = 1, x10[1] = 0; + x11[0] = 1, x11[1] = 1; + x12[0] = 1, x12[1] = 2; + + ADT f1(v0 & v1, "0 1 2 3 4 5"); + EXPECT_DOUBLES_EQUAL(0, f1(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(1, f1(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(2, f1(x02), 1e-9); + EXPECT_DOUBLES_EQUAL(3, f1(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(4, f1(x11), 1e-9); + EXPECT_DOUBLES_EQUAL(5, f1(x12), 1e-9); + + ADT f2(v1 & v0, "0 1 2 3 4 5"); + EXPECT_DOUBLES_EQUAL(0, f2(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(2, f2(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(4, f2(x02), 1e-9); + EXPECT_DOUBLES_EQUAL(1, f2(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); + EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); + + DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + vector table(5 * 4 * 3 * 2); + double x = 0; + BOOST_FOREACH(double& t, table) + t = x++; + ADT f3(z0 & z1 & z2 & z3, table); + Assignment assignment; + assignment[0] = 0; + assignment[1] = 0; + assignment[2] = 0; + assignment[3] = 1; + EXPECT_DOUBLES_EQUAL(1, f3(assignment), 1e-9); +} + +/* ************************************************************************* */ +// test conversion to integer indices +// Only works if DiscreteKeys are binary, as size_t has binary cardinality! +TEST(ADT, conversion) +{ + DiscreteKey X(0,2), Y(1,2); + ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); + dot(fDiscreteKey, "conversion-f1"); + + std::map ordering; + ordering[0] = 5; + ordering[1] = 2; + + AlgebraicDecisionTree fIndexKey(fDiscreteKey, ordering); + // f1.print("f1"); + // f2.print("f2"); + dot(fIndexKey, "conversion-f2"); + + Assignment x00, x01, x02, x10, x11, x12; + x00[5] = 0, x00[2] = 0; + x01[5] = 0, x01[2] = 1; + x10[5] = 1, x10[2] = 0; + x11[5] = 1, x11[2] = 1; + EXPECT_DOUBLES_EQUAL(0.2, fIndexKey(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(0.5, fIndexKey(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(0.3, fIndexKey(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); +} + +/* ******************************************************************************** */ +// test operations in elimination +TEST(ADT, elimination) +{ + DiscreteKey A(0,2), B(1,3), C(2,2); + ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); + dot(f1, "elimination-f1"); + + { + // sum out lower key + ADT actualSum = f1.sum(C); + ADT expectedSum(A & B, "3 7 11 9 6 10"); + CHECK(assert_equal(expectedSum,actualSum)); + + // normalize + ADT actual = f1 / actualSum; + vector cpt; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + ADT expected(A & B & C, cpt); + CHECK(assert_equal(expected,actual)); + } + + { + // sum out lower 2 keys + ADT actualSum = f1.sum(C).sum(B); + ADT expectedSum(A, 21, 25); + CHECK(assert_equal(expectedSum,actualSum)); + + // normalize + ADT actual = f1 / actualSum; + vector cpt; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + ADT expected(A & B & C, cpt); + CHECK(assert_equal(expected,actual)); + } +} + +/* ******************************************************************************** */ +// Test non-commutative op +TEST(ADT, div) +{ + DiscreteKey A(0,2), B(1,2); + + // Literals + ADT a(A, 8, 16); + ADT b(B, 2, 4); + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + EXPECT(assert_equal(expected_a_div_b, a / b)); + EXPECT(assert_equal(expected_b_div_a, b / a)); +} + +/* ******************************************************************************** */ +// test zero shortcut +TEST(ADT, zero) +{ + DiscreteKey A(0,2), B(1,2); + + // Literals + ADT a(A, 0, 1); + ADT notb(B, 1, 0); + ADT anotb = a * notb; + // GTSAM_PRINT(anotb); + Assignment x00, x01, x10, x11; + x00[0] = 0, x00[1] = 0; + x01[0] = 0, x01[1] = 1; + x10[0] = 1, x10[1] = 0; + x11[0] = 1, x11[1] = 1; + EXPECT_DOUBLES_EQUAL(0, anotb(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(0, anotb(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(1, anotb(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testCSP.cpp b/gtsam/discrete/tests/testCSP.cpp new file mode 100644 index 000000000..cce32f09f --- /dev/null +++ b/gtsam/discrete/tests/testCSP.cpp @@ -0,0 +1,224 @@ +/* + * testCSP.cpp + * @brief develop code for CSP solver + * @date Feb 5, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +TEST_UNSAFE( BinaryAllDif, allInOne) +{ + // Create keys and ordering + size_t nrColors = 2; +// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors); + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Check construction and conversion + BinaryAllDiff c1(ID, UT); + DecisionTreeFactor f1(ID & UT, "0 1 1 0"); + EXPECT(assert_equal(f1,(DecisionTreeFactor)c1)); + + // Check construction and conversion + BinaryAllDiff c2(UT, AZ); + DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); + EXPECT(assert_equal(f2,(DecisionTreeFactor)c2)); + + DecisionTreeFactor f3 = f1*f2; + EXPECT(assert_equal(f3,c1*f2)); + EXPECT(assert_equal(f3,c2*f1)); +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, allInOne) +{ + // Create keys and ordering + size_t nrColors = 2; + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Create the CSP + CSP csp; + csp.addAllDiff(ID,UT); + csp.addAllDiff(UT,AZ); + + // Check an invalid combination, with ID==UT==AZ all same color + DiscreteFactor::Values invalid; + invalid[ID.first] = 0; + invalid[UT.first] = 0; + invalid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); + + // Check a valid combination + DiscreteFactor::Values valid; + valid[ID.first] = 0; + valid[UT.first] = 1; + valid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); + + // Just for fun, create the product and check it + DecisionTreeFactor product = csp.product(); + // product.dot("product"); + DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); + EXPECT(assert_equal(expectedProduct,product)); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + CSP::Values expected; + insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, WesternUS) +{ + // Create keys + size_t nrColors = 4; + DiscreteKey + // Create ordering according to example in ND-CSP.lyx + WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), + ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); + + // Create the CSP + CSP csp; + csp.addAllDiff(WA,ID); + csp.addAllDiff(WA,OR); + csp.addAllDiff(OR,ID); + csp.addAllDiff(OR,CA); + csp.addAllDiff(OR,NV); + csp.addAllDiff(CA,NV); + csp.addAllDiff(CA,AZ); + csp.addAllDiff(ID,MT); + csp.addAllDiff(ID,WY); + csp.addAllDiff(ID,UT); + csp.addAllDiff(ID,NV); + csp.addAllDiff(NV,UT); + csp.addAllDiff(NV,AZ); + csp.addAllDiff(UT,WY); + csp.addAllDiff(UT,CO); + csp.addAllDiff(UT,NM); + csp.addAllDiff(UT,AZ); + csp.addAllDiff(AZ,CO); + csp.addAllDiff(AZ,NM); + csp.addAllDiff(MT,WY); + csp.addAllDiff(WY,CO); + csp.addAllDiff(CO,NM); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + // GTSAM_PRINT(*mpe); + CSP::Values expected; + insert(expected) + (WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) + (MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) + (ID.first,2)(UT.first,1)(AZ.first,0); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + + // Write out the dual graph for hmetis +#ifdef DUAL + VariableIndex index(csp); + index.print("index"); + ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt"); + index.outputMetisFormat(os); +#endif +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, AllDiff) +{ + // Create keys and ordering + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Create the CSP + CSP csp; + vector dkeys; + dkeys += ID,UT,AZ; + csp.addAllDiff(dkeys); + csp.addSingleValue(AZ,2); + //GTSAM_PRINT(csp); + + // Check construction and conversion + SingleValue s(AZ,2); + DecisionTreeFactor f1(AZ,"0 0 1"); + EXPECT(assert_equal(f1,(DecisionTreeFactor)s)); + + // Check construction and conversion + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = (DecisionTreeFactor)alldiff; +// GTSAM_PRINT(actual); +// actual.dot("actual"); + DecisionTreeFactor f2(ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2,actual)); + + // Check an invalid combination, with ID==UT==AZ all same color + DiscreteFactor::Values invalid; + invalid[ID.first] = 0; + invalid[UT.first] = 1; + invalid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); + + // Check a valid combination + DiscreteFactor::Values valid; + valid[ID.first] = 0; + valid[UT.first] = 1; + valid[AZ.first] = 2; + EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + CSP::Values expected; + insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + + // Arc-consistency + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + SingleValue singleValue(AZ,2); + EXPECT(singleValue.ensureArcConsistency(1,domains)); + EXPECT(alldiff.ensureArcConsistency(0,domains)); + EXPECT(!alldiff.ensureArcConsistency(1,domains)); + EXPECT(alldiff.ensureArcConsistency(2,domains)); + LONGS_EQUAL(2,domains[0].nrValues()); + LONGS_EQUAL(1,domains[1].nrValues()); + LONGS_EQUAL(2,domains[2].nrValues()); + + // Parial application, version 1 + DiscreteFactor::Values known; + known[AZ.first] = 2; + DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); + DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); + EXPECT(assert_equal(f3,reduced1->operator DecisionTreeFactor())); + DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); + DecisionTreeFactor f4(AZ, "0 0 1"); + EXPECT(assert_equal(f4,reduced2->operator DecisionTreeFactor())); + + // Parial application, version 2 + DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); + EXPECT(assert_equal(f3,reduced3->operator DecisionTreeFactor())); + DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); + EXPECT(assert_equal(f4,reduced4->operator DecisionTreeFactor())); + + // full arc-consistency test + csp.runArcConsistency(nrColors); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp new file mode 100644 index 000000000..fa7336aa6 --- /dev/null +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -0,0 +1,226 @@ +/* + * @file testDecisionTree.cpp + * @brief Develop DecisionTree + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#include +#include +using namespace boost::assign; + +#include +#include +#include + +//#define DT_DEBUG_MEMORY +//#define DT_NO_PRUNING +#define DISABLE_DOT +#include +using namespace std; +using namespace gtsam; + +template +void dot(const T&f, const string& filename) { +#ifndef DISABLE_DOT + f.dot(filename); +#endif +} + +#define DOT(x)(dot(x,#x)) + +/* ******************************************************************************** */ +// Test string labels and int range +/* ******************************************************************************** */ + +typedef DecisionTree DT; + +struct Ring { + static inline int zero() { + return 0; + } + static inline int one() { + return 1; + } + static inline int add(const int& a, const int& b) { + return a + b; + } + static inline int mul(const int& a, const int& b) { + return a * b; + } +}; + +/* ******************************************************************************** */ +// test DT +TEST(DT, example) +{ + // Create labels + string A("A"), B("B"), C("C"); + + // create a value + Assignment x00, x01, x10, x11; + x00[A] = 0, x00[B] = 0; + x01[A] = 0, x01[B] = 1; + x10[A] = 1, x10[B] = 0; + x11[A] = 1, x11[B] = 1; + + // A + DT a(A, 0, 5); + LONGS_EQUAL(0,a(x00)) + LONGS_EQUAL(5,a(x10)) + DOT(a); + + // pruned + DT p(A, 2, 2); + LONGS_EQUAL(2,p(x00)) + LONGS_EQUAL(2,p(x10)) + DOT(p); + + // \neg B + DT notb(B, 5, 0); + LONGS_EQUAL(5,notb(x00)) + LONGS_EQUAL(5,notb(x10)) + DOT(notb); + + // apply, two nodes, in natural order + DT anotb = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,anotb(x00)) + LONGS_EQUAL(0,anotb(x01)) + LONGS_EQUAL(25,anotb(x10)) + LONGS_EQUAL(0,anotb(x11)) + DOT(anotb); + + // check pruning + DT pnotb = apply(p, notb, &Ring::mul); + LONGS_EQUAL(10,pnotb(x00)) + LONGS_EQUAL( 0,pnotb(x01)) + LONGS_EQUAL(10,pnotb(x10)) + LONGS_EQUAL( 0,pnotb(x11)) + DOT(pnotb); + + // check pruning + DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); + LONGS_EQUAL(0,zeros(x00)) + LONGS_EQUAL(0,zeros(x01)) + LONGS_EQUAL(0,zeros(x10)) + LONGS_EQUAL(0,zeros(x11)) + DOT(zeros); + + // apply, two nodes, in switched order + DT notba = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,notba(x00)) + LONGS_EQUAL(0,notba(x01)) + LONGS_EQUAL(25,notba(x10)) + LONGS_EQUAL(0,notba(x11)) + DOT(notba); + + // Test choose 0 + DT actual0 = notba.choose(A, 0); + EXPECT(assert_equal(DT(0.0), actual0)); + DOT(actual0); + + // Test choose 1 + DT actual1 = notba.choose(A, 1); + EXPECT(assert_equal(DT(B, 25, 0), actual1)); + DOT(actual1); + + // apply, two nodes at same level + DT a_and_a = apply(a, a, &Ring::mul); + LONGS_EQUAL(0,a_and_a(x00)) + LONGS_EQUAL(0,a_and_a(x01)) + LONGS_EQUAL(25,a_and_a(x10)) + LONGS_EQUAL(25,a_and_a(x11)) + DOT(a_and_a); + + // create a function on C + DT c(C, 0, 5); + + // and a model assigning stuff to C + Assignment x101; + x101[A] = 1, x101[B] = 0, x101[C] = 1; + + // mul notba with C + DT notbac = apply(notba, c, &Ring::mul); + LONGS_EQUAL(125,notbac(x101)) + DOT(notbac); + + // mul now in different order + DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); + LONGS_EQUAL(125,acnotb(x101)) + DOT(acnotb); +} + +/* ******************************************************************************** */ +// test Conversion +enum Label { + U, V, X, Y, Z +}; +typedef DecisionTree BDT; +bool convert(const int& y) { + return y != 0; +} + +TEST(DT, conversion) +{ + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + map ordering; + ordering[A] = X; + ordering[B] = Y; + boost::function op = convert; + BDT f2(f1, ordering, op); + // f1.print("f1"); + // f2.print("f2"); + + // create a value + Assignment