diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 3f5701281..0855dbc21 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \ -DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \ -DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install @@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ make -j2 install cd $GITHUB_WORKSPACE/build/python -$PYTHON setup.py install --user --prefix= +$PYTHON -m pip install --user . cd $GITHUB_WORKSPACE/python/gtsam/tests $PYTHON -m unittest discover -v diff --git a/.github/scripts/unix.sh b/.github/scripts/unix.sh index 7fb925593..b5a559df5 100644 --- a/.github/scripts/unix.sh +++ b/.github/scripts/unix.sh @@ -64,13 +64,14 @@ function configure() -DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \ -DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \ -DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \ -DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \ -DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \ -DGTSAM_USE_SYSTEM_EIGEN=${GTSAM_USE_SYSTEM_EIGEN:-OFF} \ -DGTSAM_USE_SYSTEM_METIS=${GTSAM_USE_SYSTEM_METIS:-OFF} \ -DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF \ + -DGTSAM_SINGLE_TEST_EXE=ON \ -DBOOST_ROOT=$BOOST_ROOT \ -DBoost_NO_SYSTEM_PATHS=ON \ -DBoost_ARCHITECTURE=-x64 @@ -95,7 +96,11 @@ function build () configure if [ "$(uname)" == "Linux" ]; then - make -j$(nproc) + if (($(nproc) > 2)); then + make -j$(nproc) + else + make -j2 + fi elif [ "$(uname)" == "Darwin" ]; then make -j$(sysctl -n hw.physicalcpu) fi @@ -113,9 +118,13 @@ function test () # Actual testing if [ "$(uname)" == "Linux" ]; then - make -j$(nproc) + if (($(nproc) > 2)); then + make -j$(nproc) check + else + make -j2 check + fi elif [ "$(uname)" == "Darwin" ]; then - make -j$(sysctl -n hw.physicalcpu) + make -j$(sysctl -n hw.physicalcpu) check fi finish diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index f52e5eec3..7b13b6646 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -15,7 +15,7 @@ jobs: BOOST_VERSION: 1.67.0 strategy: - fail-fast: false + fail-fast: true matrix: # Github Actions requires a single row to be added to the build matrix. # See https://help.github.com/en/articles/workflow-syntax-for-github-actions. diff --git a/.github/workflows/build-special.yml b/.github/workflows/build-special.yml index 647b9c0f1..d357b9a34 100644 --- a/.github/workflows/build-special.yml +++ b/.github/workflows/build-special.yml @@ -110,7 +110,7 @@ jobs: - name: Set Allow Deprecated Flag if: matrix.flag == 'deprecated' run: | - echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV + echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV echo "Allow deprecated since version 4.1" - name: Set Use Quaternions Flag diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 5dfdcd013..ef2500b46 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -26,7 +26,11 @@ jobs: windows-2019-cl, ] - build_type: [Debug, Release] + build_type: [ + Debug, + #TODO(Varun) The release build takes over 2.5 hours, need to figure out why. + # Release + ] build_unstable: [ON] include: #TODO This build fails, need to understand why. @@ -44,7 +48,9 @@ jobs: - name: Install Dependencies shell: powershell run: | - Invoke-Expression (New-Object System.Net.WebClient).DownloadString('https://get.scoop.sh') + iwr -useb get.scoop.sh -outfile 'install_scoop.ps1' + .\install_scoop.ps1 -RunAsAdmin + scoop install cmake --global # So we don't get issues with CMP0074 policy scoop install ninja --global @@ -90,13 +96,33 @@ jobs: - name: Checkout uses: actions/checkout@v2 - - name: Build + - name: Configuration run: | cmake -E remove_directory build cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib" - cmake --build build --config ${{ matrix.build_type }} --target gtsam - cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable - cmake --build build --config ${{ matrix.build_type }} --target wrap - cmake --build build --config ${{ matrix.build_type }} --target check.base - cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable - cmake --build build --config ${{ matrix.build_type }} --target check.linear + + - name: Build + run: | + # Since Visual Studio is a multi-generator, we need to use --config + # https://stackoverflow.com/a/24470998/1236990 + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable + cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap + + # Run GTSAM tests + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.basis + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.discrete + #cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.geometry + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.inference + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.navigation + #cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.nonlinear + #cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sam + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.sfm + #cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.slam + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.symbolic + + # Run GTSAM_UNSTABLE tests + #cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable + diff --git a/.gitignore b/.gitignore index cde059767..0e34eed34 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .idea *.pyc *.DS_Store +*.swp /examples/Data/dubrovnik-3-7-pre-rewritten.txt /examples/Data/pose2example-rewritten.txt /examples/Data/pose3example-rewritten.txt @@ -16,3 +17,4 @@ # for QtCreator: CMakeLists.txt.user* xcode/ +/Dockerfile diff --git a/CMakeLists.txt b/CMakeLists.txt index d2559705d..cfb251663 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,3 @@ -project(GTSAM CXX C) cmake_minimum_required(VERSION 3.0) # new feature to Cmake Version > 2.8.12 @@ -9,12 +8,23 @@ endif() # Set the version number for the library set (GTSAM_VERSION_MAJOR 4) -set (GTSAM_VERSION_MINOR 1) +set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) +set (GTSAM_PRERELEASE_VERSION "a6") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") -set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) +if (${GTSAM_VERSION_PATCH} EQUAL 0) + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") +else() + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") +endif() + +project(GTSAM + LANGUAGES CXX C + VERSION "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") + +message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") + set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) @@ -87,6 +97,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX) CACHE STRING "The Python version to use for wrapping") # Set the include directory for matlab.h set(GTWRAP_INCLUDE_NAME "wrap") + + # Copy matlab.h to the correct folder. + configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h + ${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY) + # Add the include directories so that matlab.h can be found + include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}") + add_subdirectory(wrap) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake") endif() @@ -105,6 +122,11 @@ endif() GtsamMakeConfigFile(GTSAM "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_extra.cmake.in") export(TARGETS ${GTSAM_EXPORTED_TARGETS} FILE GTSAM-exports.cmake) +if (GTSAM_BUILD_UNSTABLE) + GtsamMakeConfigFile(GTSAM_UNSTABLE "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_extra.cmake.in") + export(TARGETS ${GTSAM_UNSTABLE_EXPORTED_TARGETS} FILE GTSAM_UNSTABLE-exports.cmake) +endif() + # Check for doxygen availability - optional dependency find_package(Doxygen) diff --git a/DEVELOP.md b/DEVELOP.md index 8604afe0f..7cd303373 100644 --- a/DEVELOP.md +++ b/DEVELOP.md @@ -15,7 +15,7 @@ For example: ```cpp class GTSAM_EXPORT MyClass { ... }; -GTSAM_EXPORT myFunction(); +GTSAM_EXPORT return_type myFunction(); ``` More details [here](Using-GTSAM-EXPORT.md). diff --git a/INSTALL.md b/INSTALL.md index 965246304..1edccd3cd 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -13,7 +13,7 @@ $ make install ## Important Installation Notes 1. GTSAM requires the following libraries to be installed on your system: - - BOOST version 1.65 or greater (install through Linux repositories or MacPorts). Please see [Boost Notes](#boost-notes). + - BOOST version 1.65 or greater (install through Linux repositories or MacPorts). Please see [Boost Notes](#boost-notes) for version recommendations based on your compiler. - Cmake version 3.0 or higher - Support for XCode 4.3 command line tools on Mac requires CMake 2.8.8 or higher @@ -72,7 +72,7 @@ execute commands as follows for an out-of-source build: Versions of Boost prior to 1.65 have a known bug that prevents proper "deep" serialization of objects, which means that objects encapsulated inside other objects don't get serialized. This is particularly seen when using `clang` as the C++ compiler. -For this reason we require Boost>=1.65, and recommend installing it through alternative channels when it is not available through your operating system's primary package manager. +For this reason we recommend Boost>=1.65, and recommend installing it through alternative channels when it is not available through your operating system's primary package manager. ## Known Issues diff --git a/README.md b/README.md index 046132301..52ac0a5d8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ **Important Note** -As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features. +As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder. -However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`. +In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`. ## What is GTSAM? @@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods. -GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. +GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. ## Wrappers diff --git a/Using-GTSAM-EXPORT.md b/Using-GTSAM-EXPORT.md index cae1d499c..24a29f96b 100644 --- a/Using-GTSAM-EXPORT.md +++ b/Using-GTSAM-EXPORT.md @@ -8,6 +8,7 @@ To create a DLL in windows, the `GTSAM_EXPORT` keyword has been created and need * At least one of the functions inside that class is declared in a .cpp file and not just the .h file. * You can `GTSAM_EXPORT` any class it inherits from as well. (Note that this implictly requires the class does not derive from a "header-only" class. Note that Eigen is a "header-only" library, so if your class derives from Eigen, _do not_ use `GTSAM_EXPORT` in the class definition!) 3. If you have defined a class using `GTSAM_EXPORT`, do not use `GTSAM_EXPORT` in any of its individual function declarations. (Note that you _can_ put `GTSAM_EXPORT` in the definition of individual functions within a class as long as you don't put `GTSAM_EXPORT` in the class definition.) +4. For template specializations, you need to add `GTSAM_EXPORT` to each individual specialization. ## When is GTSAM_EXPORT being used incorrectly Unfortunately, using `GTSAM_EXPORT` incorrectly often does not cause a compiler or linker error in the library that is being compiled, but only when you try to use that DLL in a different library. For example, an error in `gtsam/base` will often show up when compiling the `check_base_program` or the MATLAB wrapper, but not when compiling/linking gtsam itself. The most common errors will say something like: @@ -29,7 +30,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2 ***Compiler Rule #2*** Anything declared in a header file is not included in a DLL. -When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. +When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea. diff --git a/cmake/GtsamBuildTypes.cmake b/cmake/GtsamBuildTypes.cmake index 4b179d128..9058807ad 100644 --- a/cmake/GtsamBuildTypes.cmake +++ b/cmake/GtsamBuildTypes.cmake @@ -93,6 +93,10 @@ if(MSVC) /wd4267 # warning C4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data ) + add_compile_options(/wd4005) + add_compile_options(/wd4101) + add_compile_options(/wd4834) + endif() # Other (non-preprocessor macros) compiler flags: @@ -187,7 +191,7 @@ endif() if (NOT MSVC) option(GTSAM_BUILD_WITH_MARCH_NATIVE "Enable/Disable building with all instructions supported by native architecture (binary may not be portable!)" ON) - if(GTSAM_BUILD_WITH_MARCH_NATIVE) + if(GTSAM_BUILD_WITH_MARCH_NATIVE AND (APPLE AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")) # Add as public flag so all dependant projects also use it, as required # by Eigen to avid crashes due to SIMD vectorization: list_append_cache(GTSAM_COMPILE_OPTIONS_PUBLIC "-march=native") diff --git a/cmake/GtsamMakeConfigFile.cmake b/cmake/GtsamMakeConfigFile.cmake index 0479a2524..91cb98a8c 100644 --- a/cmake/GtsamMakeConfigFile.cmake +++ b/cmake/GtsamMakeConfigFile.cmake @@ -27,6 +27,8 @@ function(GtsamMakeConfigFile PACKAGE_NAME) # here. if(NOT DEFINED ${PACKAGE_NAME}_VERSION AND DEFINED ${PACKAGE_NAME}_VERSION_STRING) set(${PACKAGE_NAME}_VERSION ${${PACKAGE_NAME}_VERSION_STRING}) + elseif(NOT DEFINED ${PACKAGE_NAME}_VERSION_STRING) + set(${PACKAGE_NAME}_VERSION ${GTSAM_VERSION_STRING}) endif() # Version file diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index ee86066a2..7c8f8533f 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -14,20 +14,21 @@ if(GTSAM_UNSTABLE_AVAILABLE) option(GTSAM_UNSTABLE_BUILD_PYTHON "Enable/Disable Python wrapper for libgtsam_unstable" ON) option(GTSAM_UNSTABLE_INSTALL_MATLAB_TOOLBOX "Enable/Disable MATLAB wrapper for libgtsam_unstable" OFF) endif() -option(BUILD_SHARED_LIBS "Build shared gtsam library, instead of static" ON) -option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) -option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) -option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) -option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) -option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) -option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF) -option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF) -option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON) -option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF) -option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF) -option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) -option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) -option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) +option(BUILD_SHARED_LIBS "Build shared gtsam library, instead of static" ON) +option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) +option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) +option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) +option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) +option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) +option(GTSAM_WITH_EIGEN_MKL "Eigen will use Intel MKL if available" OFF) +option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF) +option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON) +option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF) +option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF) +option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) +option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) +option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) +option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF) if(NOT MSVC AND NOT XCODE_VERSION) option(GTSAM_BUILD_WITH_CCACHE "Use ccache compiler cache" ON) endif() diff --git a/cmake/HandleMetis.cmake b/cmake/HandleMetis.cmake index 9c29e5776..5cbec4ff5 100644 --- a/cmake/HandleMetis.cmake +++ b/cmake/HandleMetis.cmake @@ -21,7 +21,12 @@ if(GTSAM_USE_SYSTEM_METIS) mark_as_advanced(METIS_LIBRARY) add_library(metis-gtsam-if INTERFACE) - target_include_directories(metis-gtsam-if BEFORE INTERFACE ${METIS_INCLUDE_DIR}) + target_include_directories(metis-gtsam-if BEFORE INTERFACE ${METIS_INCLUDE_DIR} + # gtsam_unstable/partition/FindSeparator-inl.h uses internal metislib.h API + # via extern "C" + $ + $ + ) target_link_libraries(metis-gtsam-if INTERFACE ${METIS_LIBRARY}) endif() else() @@ -30,10 +35,12 @@ else() add_subdirectory(${GTSAM_SOURCE_DIR}/gtsam/3rdparty/metis) target_include_directories(metis-gtsam BEFORE PUBLIC + $ $ + # gtsam_unstable/partition/FindSeparator-inl.h uses internal metislib.h API + # via extern "C" $ $ - $ ) add_library(metis-gtsam-if INTERFACE) diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index ad6ac5c5c..43ee5b57b 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") -print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1") +print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") diff --git a/cmake/HandleTBB.cmake b/cmake/HandleTBB.cmake index 118dc4dac..52ee75494 100644 --- a/cmake/HandleTBB.cmake +++ b/cmake/HandleTBB.cmake @@ -7,9 +7,9 @@ if (GTSAM_WITH_TBB) if(TBB_FOUND) set(GTSAM_USE_TBB 1) # This will go into config.h - if ((${TBB_VERSION} VERSION_GREATER "2021.1") OR (${TBB_VERSION} VERSION_EQUAL "2021.1")) - message(FATAL_ERROR "TBB version greater than 2021.1 (oneTBB API) is not yet supported. Use an older version instead.") - endif() +# if ((${TBB_VERSION} VERSION_GREATER "2021.1") OR (${TBB_VERSION} VERSION_EQUAL "2021.1")) +# message(FATAL_ERROR "TBB version greater than 2021.1 (oneTBB API) is not yet supported. Use an older version instead.") +# endif() if ((${TBB_VERSION_MAJOR} GREATER 2020) OR (${TBB_VERSION_MAJOR} EQUAL 2020)) set(TBB_GREATER_EQUAL_2020 1) diff --git a/doc/Code/LocalizationExample2.cpp b/doc/Code/LocalizationExample2.cpp index d22180314..df9469a64 100644 --- a/doc/Code/LocalizationExample2.cpp +++ b/doc/Code/LocalizationExample2.cpp @@ -1,7 +1,7 @@ // add unary measurement factors, like GPS, on all three poses -noiseModel::Diagonal::shared_ptr unaryNoise = +auto unaryNoise = noiseModel::Diagonal::Sigmas(Vector2(0.1, 0.1)); // 10cm std on x,y -graph.add(boost::make_shared(1, 0.0, 0.0, unaryNoise)); -graph.add(boost::make_shared(2, 2.0, 0.0, unaryNoise)); -graph.add(boost::make_shared(3, 4.0, 0.0, unaryNoise)); +graph.emplace_shared(1, 0.0, 0.0, unaryNoise); +graph.emplace_shared(2, 2.0, 0.0, unaryNoise); +graph.emplace_shared(3, 4.0, 0.0, unaryNoise); diff --git a/doc/Code/LocalizationFactor.cpp b/doc/Code/LocalizationFactor.cpp index d298091dc..2c1f01c43 100644 --- a/doc/Code/LocalizationFactor.cpp +++ b/doc/Code/LocalizationFactor.cpp @@ -1,13 +1,12 @@ class UnaryFactor: public NoiseModelFactor1 { double mx_, my_; ///< X and Y measurements - + public: UnaryFactor(Key j, double x, double y, const SharedNoiseModel& model): NoiseModelFactor1(model, j), mx_(x), my_(y) {} - Vector evaluateError(const Pose2& q, - boost::optional H = boost::none) const - { + Vector evaluateError(const Pose2& q, + boost::optional H = boost::none) const override { const Rot2& R = q.rotation(); if (H) (*H) = (gtsam::Matrix(2, 3) << R.c(), -R.s(), 0.0, diff --git a/doc/Code/OdometryExample.cpp b/doc/Code/OdometryExample.cpp index 2befa9dc2..7af27c60c 100644 --- a/doc/Code/OdometryExample.cpp +++ b/doc/Code/OdometryExample.cpp @@ -3,13 +3,11 @@ NonlinearFactorGraph graph; // Add a Gaussian prior on pose x_1 Pose2 priorMean(0.0, 0.0, 0.0); -noiseModel::Diagonal::shared_ptr priorNoise = - noiseModel::Diagonal::Sigmas(Vector3(0.3, 0.3, 0.1)); -graph.addPrior(1, priorMean, priorNoise); +auto priorNoise = noiseModel::Diagonal::Sigmas(Vector3(0.3, 0.3, 0.1)); +graph.add(PriorFactor(1, priorMean, priorNoise)); // Add two odometry factors Pose2 odometry(2.0, 0.0, 0.0); -noiseModel::Diagonal::shared_ptr odometryNoise = - noiseModel::Diagonal::Sigmas(Vector3(0.2, 0.2, 0.1)); +auto odometryNoise = noiseModel::Diagonal::Sigmas(Vector3(0.2, 0.2, 0.1)); graph.add(BetweenFactor(1, 2, odometry, odometryNoise)); graph.add(BetweenFactor(2, 3, odometry, odometryNoise)); diff --git a/doc/Code/OdometryOutput1.txt b/doc/Code/OdometryOutput1.txt index cc34e8ef2..70aba38ee 100644 --- a/doc/Code/OdometryOutput1.txt +++ b/doc/Code/OdometryOutput1.txt @@ -1,11 +1,14 @@ Factor Graph: size: 3 -factor 0: PriorFactor on 1 - prior mean: (0, 0, 0) + +Factor 0: PriorFactor on 1 + prior mean: (0, 0, 0) noise model: diagonal sigmas [0.3; 0.3; 0.1]; -factor 1: BetweenFactor(1,2) - measured: (2, 0, 0) - noise model: diagonal sigmas [0.2; 0.2; 0.1]; -factor 2: BetweenFactor(2,3) - measured: (2, 0, 0) + +Factor 1: BetweenFactor(1,2) + measured: (2, 0, 0) noise model: diagonal sigmas [0.2; 0.2; 0.1]; + +Factor 2: BetweenFactor(2,3) + measured: (2, 0, 0) + noise model: diagonal sigmas [0.2; 0.2; 0.1]; \ No newline at end of file diff --git a/doc/Code/OdometryOutput2.txt b/doc/Code/OdometryOutput2.txt index acfa0b95d..6567bea6c 100644 --- a/doc/Code/OdometryOutput2.txt +++ b/doc/Code/OdometryOutput2.txt @@ -1,11 +1,23 @@ Initial Estimate: + Values with 3 values: -Value 1: (0.5, 0, 0.2) -Value 2: (2.3, 0.1, -0.2) -Value 3: (4.1, 0.1, 0.1) +Value 1: (gtsam::Pose2) +(0.5, 0, 0.2) + +Value 2: (gtsam::Pose2) +(2.3, 0.1, -0.2) + +Value 3: (gtsam::Pose2) +(4.1, 0.1, 0.1) Final Result: + Values with 3 values: -Value 1: (-1.8e-16, 8.7e-18, -9.1e-19) -Value 2: (2, 7.4e-18, -2.5e-18) -Value 3: (4, -1.8e-18, -3.1e-18) +Value 1: (gtsam::Pose2) +(7.5-16, -5.3-16, -1.8-16) + +Value 2: (gtsam::Pose2) +(2, -1.1-15, -2.5-16) + +Value 3: (gtsam::Pose2) +(4, -1.7-15, -2.5-16) diff --git a/doc/Code/OdometryOutput3.txt b/doc/Code/OdometryOutput3.txt index e346ccb4d..514d804cd 100644 --- a/doc/Code/OdometryOutput3.txt +++ b/doc/Code/OdometryOutput3.txt @@ -1,12 +1,12 @@ x1 covariance: - 0.09 1.1e-47 5.7e-33 - 1.1e-47 0.09 1.9e-17 - 5.7e-33 1.9e-17 0.01 + 0.09 1.7e-33 2.8e-33 +1.7e-33 0.09 2.6e-17 +2.8e-33 2.6e-17 0.01 x2 covariance: - 0.13 4.7e-18 2.4e-18 - 4.7e-18 0.17 0.02 - 2.4e-18 0.02 0.02 + 0.13 1.2e-18 6.1e-19 +1.2e-18 0.17 0.02 +6.1e-19 0.02 0.02 x3 covariance: - 0.17 2.7e-17 8.4e-18 - 2.7e-17 0.37 0.06 - 8.4e-18 0.06 0.03 + 0.17 8.6e-18 2.7e-18 +8.6e-18 0.37 0.06 +2.7e-18 0.06 0.03 \ No newline at end of file diff --git a/doc/Doxyfile.in b/doc/Doxyfile.in index fd7f4e5f6..12193d0be 100644 --- a/doc/Doxyfile.in +++ b/doc/Doxyfile.in @@ -1188,7 +1188,7 @@ USE_MATHJAX = YES # MathJax, but it is strongly recommended to install a local copy of MathJax # before deployment. -MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest +# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest # The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension # names that should be enabled during MathJax rendering. diff --git a/doc/gtsam.lyx b/doc/gtsam.lyx index a5adc2b60..705a84911 100644 --- a/doc/gtsam.lyx +++ b/doc/gtsam.lyx @@ -1,5 +1,5 @@ -#LyX 2.2 created this file. For more info see http://www.lyx.org/ -\lyxformat 508 +#LyX 2.3 created this file. For more info see http://www.lyx.org/ +\lyxformat 544 \begin_document \begin_header \save_transient_properties true @@ -62,6 +62,8 @@ \font_osf false \font_sf_scale 100 100 \font_tt_scale 100 100 +\use_microtype false +\use_dash_ligatures true \graphics default \default_output_format default \output_sync 0 @@ -91,6 +93,7 @@ \suppress_date false \justification true \use_refstyle 0 +\use_minted 0 \index Index \shortcut idx \color #008000 @@ -105,7 +108,10 @@ \tocdepth 3 \paragraph_separation indent \paragraph_indentation default -\quotes_language english +\is_math_indent 0 +\math_numbering_side default +\quotes_style english +\dynamic_quotes 0 \papercolumns 1 \papersides 1 \paperpagestyle default @@ -128,14 +134,10 @@ A Hands-on Introduction \begin_layout Author Frank Dellaert -\begin_inset Newline newline -\end_inset - -Technical Report number GT-RIM-CP&R-2014-XXX \end_layout \begin_layout Date -September 2014 +Updated Last March 2022 \end_layout \begin_layout Standard @@ -148,18 +150,14 @@ filename "common_macros.tex" \end_layout -\begin_layout Section* -Overview -\end_layout - -\begin_layout Standard +\begin_layout Abstract In this document I provide a hands-on introduction to both factor graphs and GTSAM. This is an updated version from the 2012 TR that is tailored to our GTSAM - 3.0 library and beyond. + 4.0 library and beyond. \end_layout -\begin_layout Standard +\begin_layout Abstract \series bold Factor graphs @@ -168,6 +166,7 @@ Factor graphs \begin_inset CommandInset citation LatexCommand citep key "Koller09book" +literal "true" \end_inset @@ -199,7 +198,7 @@ ts or prior knowledge. robotics and vision. \end_layout -\begin_layout Standard +\begin_layout Abstract The GTSAM toolbox (GTSAM stands for \begin_inset Quotes eld \end_inset @@ -214,11 +213,13 @@ Georgia Tech Smoothing and Mapping It provides state of the art solutions to the SLAM and SFM problems, but can also be used to model and solve both simpler and more complex estimation problems. - It also provides a MATLAB interface which allows for rapid prototype developmen -t, visualization, and user interaction. + It also provides MATLAB and Python wrappers which allow for rapid prototype + development, visualization, and user interaction. + In addition, it is easy to use in Jupyter notebooks and/or Google's coLaborator +y. \end_layout -\begin_layout Standard +\begin_layout Abstract GTSAM exploits sparsity to be computationally efficient. Typically measurements only provide information on the relationship between a handful of variables, and hence the resulting factor graph will be sparsely @@ -229,14 +230,17 @@ l complexity. GTSAM provides iterative methods that are quite efficient regardless. \end_layout -\begin_layout Standard -You can download the latest version of GTSAM at +\begin_layout Abstract +You can download the latest version of GTSAM from GitHub at +\end_layout + +\begin_layout Abstract \begin_inset Flex URL status open \begin_layout Plain Layout -http://tinyurl.com/gtsam +https://github.com/borglab/gtsam \end_layout \end_inset @@ -270,6 +274,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces \begin_inset CommandInset citation LatexCommand citet key "Kschischang01it" +literal "true" \end_inset @@ -277,6 +282,7 @@ key "Kschischang01it" \begin_inset CommandInset citation LatexCommand citet key "Loeliger04spm" +literal "true" \end_inset @@ -732,7 +738,7 @@ noindent \begin_inset Formula $f_{0}(x_{1})$ \end_inset - on lines 5-8 as an instance of + on lines 5-7 as an instance of \series bold \emph on PriorFactor @@ -764,7 +770,7 @@ Pose2, noiseModel::Diagonal \series default \emph default - by specifying three standard deviations in line 7, respectively 30 cm. + by specifying three standard deviations in line 6, respectively 30 cm. \begin_inset space ~ \end_inset @@ -786,7 +792,7 @@ Similarly, odometry measurements are specified as Pose2 \series default \emph default - on line 11, with a slightly different noise model defined on line 12-13. + on line 10, with a slightly different noise model defined on line 11. We then add the two factors \begin_inset Formula $f_{1}(x_{1},x_{2};o_{1})$ \end_inset @@ -795,7 +801,7 @@ Pose2 \begin_inset Formula $f_{2}(x_{2},x_{3};o_{2})$ \end_inset - on lines 14-15, as instances of yet another templated class, + on lines 12-13, as instances of yet another templated class, \series bold \emph on BetweenFactor @@ -866,7 +872,7 @@ smoothing and mapping . Later in this document we will talk about how we can also use GTSAM to - do filtering (which you often do + do filtering (which often you do \emph on not \emph default @@ -919,7 +925,11 @@ Values \begin_layout Standard The latter point is often a point of confusion with beginning users of GTSAM. It helps to remember that when designing GTSAM we took a functional approach - of classes corresponding to mathematical objects, which are usually immutable. + of classes corresponding to mathematical objects, which are usually +\emph on +immutable +\emph default +. You should think of a factor graph as a \emph on function @@ -1018,7 +1028,7 @@ NonlinearFactorGraph \end_layout \begin_layout Standard -The relevant output from running the example is as follows: +The relevant output from running the example is as follows: \family typewriter \size small @@ -1321,6 +1331,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights \begin_inset CommandInset citation LatexCommand citet key "Dellaert99b" +literal "true" \end_inset @@ -1353,14 +1364,18 @@ where \end_inset is the measurement, -\begin_inset Formula $q$ +\begin_inset Formula $q\in SE(2)$ \end_inset is the unknown variable, \begin_inset Formula $h(q)$ \end_inset - is a (possibly nonlinear) measurement function, and + is a +\series bold +measurement function +\series default +, and \begin_inset Formula $\Sigma$ \end_inset @@ -1536,12 +1551,13 @@ E(q)\define h(q)-m \end_inset -which is done on line 12. +which is done on line 14. Importantly, because we want to use this factor for nonlinear optimization (see e.g., \begin_inset CommandInset citation LatexCommand citealt key "Dellaert06ijrr" +literal "true" \end_inset @@ -1588,11 +1604,11 @@ q_{y} \begin_inset Formula $q=\left(q_{x},q_{y},q_{\theta}\right)$ \end_inset -, yields the following simple +, yields the following \begin_inset Formula $2\times3$ \end_inset - matrix in tangent space which is the same the as the rotation matrix: + matrix: \end_layout \begin_layout Standard @@ -1607,6 +1623,171 @@ H=\left[\begin{array}{ccc} \end_inset +\end_layout + +\begin_layout Paragraph* +Important Note +\end_layout + +\begin_layout Standard +Many of our users, when attempting to create a custom factor, are initially + surprised at the Jacobian matrix not agreeing with their intuition. + For example, above you might simply expect a +\begin_inset Formula $2\times3$ +\end_inset + + identity matrix. + This +\emph on +would +\emph default + be true for variables belonging to a vector space. + However, in GTSAM we define the Jacobian more generally to be the matrix + +\begin_inset Formula $H$ +\end_inset + + such that +\begin_inset Formula +\[ +h(q\exp\hat{\xi})\approx h(q)+H\xi +\] + +\end_inset + +where +\begin_inset Formula $\xi=(\delta x,\delta y,\delta\theta)$ +\end_inset + + is an incremental update and +\begin_inset Formula $\exp\hat{\xi}$ +\end_inset + + is the +\series bold +exponential map +\series default + for the variable we want to update. + In this case +\begin_inset Formula $q\in SE(2)$ +\end_inset + +, where +\begin_inset Formula $SE(2)$ +\end_inset + + is the group of 2D rigid transforms, implemented by +\series bold +\emph on +Pose2 +\emph default +. + +\series default +The exponential map for +\begin_inset Formula $SE(2)$ +\end_inset + + can be approximated to first order as +\begin_inset Formula +\[ +\exp\hat{\xi}\approx\left[\begin{array}{ccc} +1 & -\delta\theta & \delta x\\ +\delta\theta & 1 & \delta y\\ +0 & 0 & 1 +\end{array}\right] +\] + +\end_inset + +when using the +\begin_inset Formula $3\times3$ +\end_inset + + matrix representation for 2D poses, and hence +\begin_inset Formula +\[ +h(qe^{\hat{\xi}})\approx h\left(\left[\begin{array}{ccc} +\cos(q_{\theta}) & -\sin(q_{\theta}) & q_{x}\\ +\sin(q_{\theta}) & \cos(q_{\theta}) & q_{y}\\ +0 & 0 & 1 +\end{array}\right]\left[\begin{array}{ccc} +1 & -\delta\theta & \delta x\\ +\delta\theta & 1 & \delta y\\ +0 & 0 & 1 +\end{array}\right]\right)=\left[\begin{array}{c} +q_{x}+\cos(q_{\theta})\delta x-\sin(q_{\theta})\delta y\\ +q_{y}+\sin(q_{\theta})\delta x+\cos(q_{\theta})\delta y +\end{array}\right] +\] + +\end_inset + +which then explains the Jacobian +\begin_inset Formula $H$ +\end_inset + +. +\end_layout + +\begin_layout Standard +Lie groups are very relevant in the robotics context, and you can read more + here: +\end_layout + +\begin_layout Itemize +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://github.com/borglab/gtsam/blob/develop/doc/LieGroups.pdf +\end_layout + +\end_inset + + +\end_layout + +\begin_layout Itemize +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://github.com/borglab/gtsam/blob/develop/doc/math.pdf +\end_layout + +\end_inset + + +\end_layout + +\begin_layout Standard +In some cases you want to go even beyond Lie groups to a looser concept, + +\series bold +manifolds +\series default +, because not all unknown variables behave like a group, e.g., the space of + 3D planes, 2D lines, directions in space, etc. + For manifolds we do not always have an exponential map, but we have a retractio +n that plays the same role. + Some of this is explained here: +\end_layout + +\begin_layout Itemize +\begin_inset Flex URL +status open + +\begin_layout Plain Layout + +https://gtsam.org/notes/GTSAM-Concepts.html +\end_layout + +\end_inset + + \end_layout \begin_layout Subsection @@ -1669,13 +1850,13 @@ UnaryFactor \series default \emph default instances, and add them to graph. - GTSAM uses shared pointers to refer to factors in factor graphs, and + GTSAM uses shared pointers to refer to factors, and \series bold \emph on -boost::make_shared +emplace_shared \series default \emph default - is a convenience function to simultaneously construct a class and create + is a convenience method to simultaneously construct a class and create a \series bold \emph on @@ -1683,22 +1864,6 @@ shared_ptr \series default \emph default to it. - -\begin_inset Note Note -status collapsed - -\begin_layout Plain Layout -and on lines 4-6 we add three newly created -\series bold -\emph on -UnaryFactor -\series default -\emph default - instances to the graph. -\end_layout - -\end_inset - We obtain the factor graph from Figure \begin_inset CommandInset ref LatexCommand vref @@ -1936,8 +2101,8 @@ reference "fig:CompareMarginals" \end_inset -, where I show the marginals on position as covariance ellipses that contain - 68.26% of all probability mass. +, where I show the marginals on position as 5-sigma covariance ellipses + that contain 99.9996% of all probability mass. For the odometry marginals, it is immediately apparent from the figure that (1) the uncertainty on pose keeps growing, and (2) the uncertainty on angular odometry translates into increasing uncertainty on y. @@ -1992,6 +2157,7 @@ PoseSLAM \begin_inset CommandInset citation LatexCommand citep key "DurrantWhyte06ram" +literal "true" \end_inset @@ -2190,9 +2356,9 @@ reference "fig:example" \end_inset , along with covariance ellipses shown in green. - These covariance ellipses in 2D indicate the marginal over position, over - all possible orientations, and show the area which contain 68.26% of the - probability mass (in 1D this would correspond to one standard deviation). + These 5-sigma covariance ellipses in 2D indicate the marginal over position, + over all possible orientations, and show the area which contain 99.9996% + of the probability mass. The graph shows in a clear manner that the uncertainty on pose \begin_inset Formula $x_{5}$ \end_inset @@ -3076,6 +3242,7 @@ reference "fig:Victoria-1" \begin_inset CommandInset citation LatexCommand citep key "Kaess09ras" +literal "true" \end_inset @@ -3088,6 +3255,7 @@ key "Kaess09ras" \begin_inset CommandInset citation LatexCommand citep key "Kaess08tro" +literal "true" \end_inset @@ -3355,6 +3523,7 @@ iSAM \begin_inset CommandInset citation LatexCommand citet key "Kaess08tro,Kaess12ijrr" +literal "true" \end_inset @@ -3606,6 +3775,7 @@ subgraph preconditioning \begin_inset CommandInset citation LatexCommand citet key "Dellaert10iros,Jian11iccv" +literal "true" \end_inset @@ -3638,6 +3808,7 @@ Visual Odometry \begin_inset CommandInset citation LatexCommand citet key "Nister04cvpr2" +literal "true" \end_inset @@ -3661,6 +3832,7 @@ Visual SLAM \begin_inset CommandInset citation LatexCommand citet key "Davison03iccv" +literal "true" \end_inset @@ -3711,6 +3883,7 @@ Filtering \begin_inset CommandInset citation LatexCommand citep key "Smith87b" +literal "true" \end_inset diff --git a/doc/gtsam.pdf b/doc/gtsam.pdf index c6a39a79c..961c808d0 100644 Binary files a/doc/gtsam.pdf and b/doc/gtsam.pdf differ diff --git a/doc/math.lyx b/doc/math.lyx index 2533822a7..86ed2b220 100644 --- a/doc/math.lyx +++ b/doc/math.lyx @@ -2668,7 +2668,7 @@ reference "eq:pushforward" \begin{eqnarray*} \varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\ a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\ -e^{\yhat} & = & -ae^{\xhat}a^{-1}\\ +e^{\yhat} & = & ae^{-\xhat}a^{-1}\\ \yhat & = & -\Ad a\xhat \end{eqnarray*} @@ -3003,8 +3003,8 @@ between \begin_inset Formula \begin{align} \varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\ -g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\ -e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\ +g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\ +e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\ \yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1} \end{align} @@ -5082,6 +5082,394 @@ reference "ex:projection" \end_inset +\end_layout + +\begin_layout Subsection +Derivative of Adjoint +\begin_inset CommandInset label +LatexCommand label +name "subsec:pose3_adjoint_deriv" + +\end_inset + + +\end_layout + +\begin_layout Standard +Consider +\begin_inset Formula $f:SE(3)\times\mathbb{R}^{6}\rightarrow\mathbb{R}^{6}$ +\end_inset + + is defined as +\begin_inset Formula $f(T,\xi_{b})=Ad_{T}\hat{\xi}_{b}$ +\end_inset + +. + The derivative is notated (see Section +\begin_inset CommandInset ref +LatexCommand ref +reference "sec:Derivatives-of-Actions" +plural "false" +caps "false" +noprefix "false" + +\end_inset + +): +\end_layout + +\begin_layout Standard +\begin_inset Formula +\[ +Df_{(T,\xi_{b})}(\xi,\delta\xi_{b})=D_{1}f_{(T,\xi_{b})}(\xi)+D_{2}f_{(T,\xi_{b})}(\delta\xi_{b}) +\] + +\end_inset + +First, computing +\begin_inset Formula $D_{2}f_{(T,\xi_{b})}(\xi_{b})$ +\end_inset + + is easy, as its matrix is simply +\begin_inset Formula $Ad_{T}$ +\end_inset + +: +\end_layout + +\begin_layout Standard +\begin_inset Formula +\[ +f(T,\xi_{b}+\delta\xi_{b})=Ad_{T}(\widehat{\xi_{b}+\delta\xi_{b}})=Ad_{T}(\hat{\xi}_{b})+Ad_{T}(\delta\hat{\xi}_{b}) +\] + +\end_inset + + +\end_layout + +\begin_layout Standard +\begin_inset Formula +\[ +D_{2}f_{(T,\xi_{b})}(\xi_{b})=Ad_{T} +\] + +\end_inset + +We will derive +\begin_inset Formula $D_{1}f_{(T,\xi_{b})}(\xi)$ +\end_inset + + using two approaches. + In the first, we'll define +\begin_inset Formula $g(T,\xi)\triangleq T\exp\hat{\xi}$ +\end_inset + +. + From Section +\begin_inset CommandInset ref +LatexCommand ref +reference "sec:Derivatives-of-Actions" +plural "false" +caps "false" +noprefix "false" + +\end_inset + +, +\end_layout + +\begin_layout Standard +\begin_inset Formula +\begin{align*} +D_{2}g_{(T,\xi)}(\xi) & =T\hat{\xi}\\ +D_{2}g_{(T,\xi)}^{-1}(\xi) & =-\hat{\xi}T^{-1} +\end{align*} + +\end_inset + +Now we can use the definition of the Adjoint representation +\begin_inset Formula $Ad_{g}\hat{\xi}=g\hat{\xi}g^{-1}$ +\end_inset + + (aka conjugation by +\begin_inset Formula $g$ +\end_inset + +) then apply product rule and simplify: +\end_layout + +\begin_layout Standard +\begin_inset Formula +\begin{align*} +D_{1}f_{(T,\xi_{b})}(\xi)=D_{1}\left(Ad_{T\exp(\hat{\xi})}\hat{\xi}_{b}\right)(\xi) & =D_{1}\left(g\hat{\xi}_{b}g^{-1}\right)(\xi)\\ + & =\left(D_{2}g_{(T,\xi)}(\xi)\right)\hat{\xi}_{b}g^{-1}(T,0)+g(T,0)\hat{\xi}_{b}\left(D_{2}g_{(T,\xi)}^{-1}(\xi)\right)\\ + & =T\hat{\xi}\hat{\xi}_{b}T^{-1}-T\hat{\xi}_{b}\hat{\xi}T^{-1}\\ + & =T\left(\hat{\xi}\hat{\xi}_{b}-\hat{\xi}_{b}\hat{\xi}\right)T^{-1}\\ + & =Ad_{T}(ad_{\hat{\xi}}\hat{\xi}_{b})\\ + & =-Ad_{T}(ad_{\hat{\xi}_{b}}\hat{\xi})\\ +D_{1}F_{(T,\xi_{b})} & =-(Ad_{T})(ad_{\hat{\xi}_{b}}) +\end{align*} + +\end_inset + +Where +\begin_inset Formula $ad_{\hat{\xi}}:\mathfrak{g}\rightarrow\mathfrak{g}$ +\end_inset + + is the adjoint map of the lie algebra. +\end_layout + +\begin_layout Standard +The second, perhaps more intuitive way of deriving +\begin_inset Formula $D_{1}f_{(T,\xi_{b})}(\xi_{b})$ +\end_inset + +, would be to use the fact that the derivative at the origin +\begin_inset Formula $D_{1}Ad_{I}\hat{\xi}_{b}=ad_{\hat{\xi}_{b}}$ +\end_inset + + by definition of the adjoint +\begin_inset Formula $ad_{\xi}$ +\end_inset + +. + Then applying the property +\begin_inset Formula $Ad_{AB}=Ad_{A}Ad_{B}$ +\end_inset + +, +\end_layout + +\begin_layout Standard +\begin_inset Formula +\[ +D_{1}Ad_{T}\hat{\xi}_{b}(\xi)=D_{1}Ad_{T*I}\hat{\xi}_{b}(\xi)=Ad_{T}\left(D_{1}Ad_{I}\hat{\xi}_{b}(\xi)\right)=Ad_{T}\left(ad_{\hat{\xi}}(\hat{\xi}_{b})\right)=-Ad_{T}\left(ad_{\hat{\xi}_{b}}(\hat{\xi})\right) +\] + +\end_inset + + +\end_layout + +\begin_layout Subsection +Derivative of AdjointTranspose +\end_layout + +\begin_layout Standard +The transpose of the Adjoint, +\family roman +\series medium +\shape up +\size normal +\emph off +\bar no +\strikeout off +\xout off +\uuline off +\uwave off +\noun off +\color none + +\begin_inset Formula $Ad_{T}^{T}:\mathfrak{g^{*}\rightarrow g^{*}}$ +\end_inset + +, is useful as a way to change the reference frame of vectors in the dual + space +\family default +\series default +\shape default +\size default +\emph default +\bar default +\strikeout default +\xout default +\uuline default +\uwave default +\noun default +\color inherit +(note the +\begin_inset Formula $^{*}$ +\end_inset + + denoting that we are now in the dual space) +\family roman +\series medium +\shape up +\size normal +\emph off +\bar no +\strikeout off +\xout off +\uuline off +\uwave off +\noun off +\color none +. + To be more concrete, where +\family default +\series default +\shape default +\size default +\emph default +\bar default +\strikeout default +\xout default +\uuline default +\uwave default +\noun default +\color inherit +as +\begin_inset Formula $Ad_{T}\hat{\xi}_{b}$ +\end_inset + + converts the +\emph on +twist +\emph default + +\family roman +\series medium +\shape up +\size normal +\emph off +\bar no +\strikeout off +\xout off +\uuline off +\uwave off +\noun off +\color none + +\begin_inset Formula $\xi_{b}$ +\end_inset + + from the +\begin_inset Formula $T$ +\end_inset + + frame, +\family default +\series default +\shape default +\size default +\emph default +\bar default +\strikeout default +\xout default +\uuline default +\uwave default +\noun default +\color inherit + +\family roman +\series medium +\shape up +\size normal +\emph off +\bar no +\strikeout off +\xout off +\uuline off +\uwave off +\noun off +\color none + +\begin_inset Formula $Ad_{T}^{T}\hat{\xi}_{b}^{*}$ +\end_inset + + converts the +\family default +\series default +\shape default +\size default +\emph on +\bar default +\strikeout default +\xout default +\uuline default +\uwave default +\noun default +\color inherit +wrench +\emph default + +\family roman +\series medium +\shape up +\size normal +\emph off +\bar no +\strikeout off +\xout off +\uuline off +\uwave off +\noun off +\color none + +\begin_inset Formula $\xi_{b}^{*}$ +\end_inset + + from the +\begin_inset Formula $T$ +\end_inset + + frame +\family default +\series default +\shape default +\size default +\emph default +\bar default +\strikeout default +\xout default +\uuline default +\uwave default +\noun default +\color inherit +. + It's difficult to apply a similar derivation as in Section +\begin_inset CommandInset ref +LatexCommand ref +reference "subsec:pose3_adjoint_deriv" +plural "false" +caps "false" +noprefix "false" + +\end_inset + + for the derivative of +\begin_inset Formula $Ad_{T}^{T}\hat{\xi}_{b}^{*}$ +\end_inset + + because +\begin_inset Formula $Ad_{T}^{T}$ +\end_inset + + cannot be naturally defined as a conjugation, so we resort to crunching + through the algebra. + The details are omitted but the result is a form that vaguely resembles + (but does not exactly match) +\begin_inset Formula $ad(Ad_{T}^{T}\hat{\xi}_{b}^{*})$ +\end_inset + +: +\end_layout + +\begin_layout Standard +\begin_inset Formula +\begin{align*} +\begin{bmatrix}\omega_{T}\\ +v_{T} +\end{bmatrix}^{*} & \triangleq Ad_{T}^{T}\hat{\xi}_{b}^{*}\\ +D_{1}Ad_{T}^{T}\hat{\xi}_{b}^{*}(\xi) & =\begin{bmatrix}\hat{\omega}_{T} & \hat{v}_{T}\\ +\hat{v}_{T} & 0 +\end{bmatrix} +\end{align*} + +\end_inset + + \end_layout \begin_layout Subsection @@ -6286,7 +6674,7 @@ One representation of a line is through 2 vectors \begin_inset Formula $d$ \end_inset - points from the orgin to the closest point on the line. + points from the origin to the closest point on the line. \end_layout \begin_layout Standard diff --git a/doc/math.pdf b/doc/math.pdf index 8dc7270f1..71533e1e8 100644 Binary files a/doc/math.pdf and b/doc/math.pdf differ diff --git a/examples/Data/randomGrid3D.xml b/examples/Data/randomGrid3D.xml index 6a82ce31c..42eb473be 100644 --- a/examples/Data/randomGrid3D.xml +++ b/examples/Data/randomGrid3D.xml @@ -7,7 +7,7 @@ 32 1 - + diff --git a/examples/Data/toy3D.xml b/examples/Data/toy3D.xml index 13dbcbe6c..26bd231ca 100644 --- a/examples/Data/toy3D.xml +++ b/examples/Data/toy3D.xml @@ -7,7 +7,7 @@ 2 1 - + diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp index 5dca116c3..dfd7beb63 100644 --- a/examples/DiscreteBayesNetExample.cpp +++ b/examples/DiscreteBayesNetExample.cpp @@ -53,11 +53,10 @@ int main(int argc, char **argv) { // Create solver and eliminate Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); - DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); + auto mpe = fg.optimize(); + GTSAM_PRINT(mpe); // We can also build a Bayes tree (directed junction tree). // The elimination order above will do fine: @@ -69,15 +68,15 @@ int main(int argc, char **argv) { fg.add(Dyspnea, "0 1"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues mpe2 = chordal2->optimize(); - GTSAM_PRINT(*mpe2); + auto mpe2 = fg.optimize(); + GTSAM_PRINT(mpe2); // We can also sample from it + DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal2->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal->sample(); + GTSAM_PRINT(sample); } return 0; } diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 121df4bef..88904001a 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -33,11 +33,11 @@ using namespace gtsam; int main(int argc, char **argv) { // Define keys and a print function Key C(1), S(2), R(3), W(4); - auto print = [=](DiscreteFactor::sharedValues values) { - cout << boolalpha << "Cloudy = " << static_cast((*values)[C]) - << " Sprinkler = " << static_cast((*values)[S]) - << " Rain = " << boolalpha << static_cast((*values)[R]) - << " WetGrass = " << static_cast((*values)[W]) << endl; + auto print = [=](const DiscreteFactor::Values& values) { + cout << boolalpha << "Cloudy = " << static_cast(values.at(C)) + << " Sprinkler = " << static_cast(values.at(S)) + << " Rain = " << boolalpha << static_cast(values.at(R)) + << " WetGrass = " << static_cast(values.at(W)) << endl; }; // We assume binary state variables @@ -85,7 +85,7 @@ int main(int argc, char **argv) { } // "Most Probable Explanation", i.e., configuration with largest value - DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize(); + auto mpe = graph.optimize(); cout << "\nMost Probable Explanation (MPE):" << endl; print(mpe); @@ -96,8 +96,7 @@ int main(int argc, char **argv) { graph.add(Cloudy, "1 0"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize(); + auto mpe_with_evidence = graph.optimize(); cout << "\nMPE given C=0:" << endl; print(mpe_with_evidence); @@ -110,10 +109,11 @@ int main(int argc, char **argv) { cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] << endl; - // We can also sample from it + // We can also sample from the eliminated graph + DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal->sample(); + auto sample = chordal->sample(); print(sample); } return 0; diff --git a/examples/FisheyeExample.cpp b/examples/FisheyeExample.cpp index 223149299..fc0aed0d7 100644 --- a/examples/FisheyeExample.cpp +++ b/examples/FisheyeExample.cpp @@ -122,8 +122,7 @@ int main(int argc, char *argv[]) { std::cout << "initial error=" << graph.error(initialEstimate) << std::endl; std::cout << "final error=" << graph.error(result) << std::endl; - std::ofstream os("examples/vio_batch.dot"); - graph.saveGraph(os, result); + graph.saveGraph("examples/vio_batch.dot", result); return 0; } diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp index ee861e381..3a7673001 100644 --- a/examples/HMMExample.cpp +++ b/examples/HMMExample.cpp @@ -59,21 +59,21 @@ int main(int argc, char **argv) { // Convert to factor graph DiscreteFactorGraph factorGraph(hmm); + // Do max-prodcut + auto mpe = factorGraph.optimize(); + GTSAM_PRINT(mpe); + // Create solver and eliminate // This will create a DAG ordered with arrow of time reversed DiscreteBayesNet::shared_ptr chordal = factorGraph.eliminateSequential(ordering); chordal->print("Eliminated"); - // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); - // We can also sample from it cout << "\n10 samples:" << endl; for (size_t k = 0; k < 10; k++) { - DiscreteFactor::sharedValues sample = chordal->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal->sample(); + GTSAM_PRINT(sample); } // Or compute the marginals. This re-eliminates the FG into a Bayes tree diff --git a/examples/IMUKittiExampleGPS.cpp b/examples/IMUKittiExampleGPS.cpp index e2ca49647..cb60b2516 100644 --- a/examples/IMUKittiExampleGPS.cpp +++ b/examples/IMUKittiExampleGPS.cpp @@ -11,21 +11,23 @@ /** * @file IMUKittiExampleGPS - * @brief Example of application of ISAM2 for GPS-aided navigation on the KITTI VISION BENCHMARK SUITE - * @author Ported by Thomas Jespersen (thomasj@tkjelectronics.dk), TKJ Electronics + * @brief Example of application of ISAM2 for GPS-aided navigation on the KITTI + * VISION BENCHMARK SUITE + * @author Ported by Thomas Jespersen (thomasj@tkjelectronics.dk), TKJ + * Electronics */ // GTSAM related includes. +#include #include #include #include -#include -#include -#include #include #include #include -#include +#include +#include +#include #include #include @@ -34,35 +36,35 @@ using namespace std; using namespace gtsam; -using symbol_shorthand::X; // Pose3 (x,y,z,r,p,y) -using symbol_shorthand::V; // Vel (xdot,ydot,zdot) using symbol_shorthand::B; // Bias (ax,ay,az,gx,gy,gz) +using symbol_shorthand::V; // Vel (xdot,ydot,zdot) +using symbol_shorthand::X; // Pose3 (x,y,z,r,p,y) struct KittiCalibration { - double body_ptx; - double body_pty; - double body_ptz; - double body_prx; - double body_pry; - double body_prz; - double accelerometer_sigma; - double gyroscope_sigma; - double integration_sigma; - double accelerometer_bias_sigma; - double gyroscope_bias_sigma; - double average_delta_t; + double body_ptx; + double body_pty; + double body_ptz; + double body_prx; + double body_pry; + double body_prz; + double accelerometer_sigma; + double gyroscope_sigma; + double integration_sigma; + double accelerometer_bias_sigma; + double gyroscope_bias_sigma; + double average_delta_t; }; struct ImuMeasurement { - double time; - double dt; - Vector3 accelerometer; - Vector3 gyroscope; // omega + double time; + double dt; + Vector3 accelerometer; + Vector3 gyroscope; // omega }; struct GpsMeasurement { - double time; - Vector3 position; // x,y,z + double time; + Vector3 position; // x,y,z }; const string output_filename = "IMUKittiExampleGPSResults.csv"; @@ -70,290 +72,313 @@ const string output_filename = "IMUKittiExampleGPSResults.csv"; void loadKittiData(KittiCalibration& kitti_calibration, vector& imu_measurements, vector& gps_measurements) { - string line; + string line; - // Read IMU metadata and compute relative sensor pose transforms - // BodyPtx BodyPty BodyPtz BodyPrx BodyPry BodyPrz AccelerometerSigma GyroscopeSigma IntegrationSigma - // AccelerometerBiasSigma GyroscopeBiasSigma AverageDeltaT - string imu_metadata_file = findExampleDataFile("KittiEquivBiasedImu_metadata.txt"); - ifstream imu_metadata(imu_metadata_file.c_str()); + // Read IMU metadata and compute relative sensor pose transforms + // BodyPtx BodyPty BodyPtz BodyPrx BodyPry BodyPrz AccelerometerSigma + // GyroscopeSigma IntegrationSigma AccelerometerBiasSigma GyroscopeBiasSigma + // AverageDeltaT + string imu_metadata_file = + findExampleDataFile("KittiEquivBiasedImu_metadata.txt"); + ifstream imu_metadata(imu_metadata_file.c_str()); - printf("-- Reading sensor metadata\n"); + printf("-- Reading sensor metadata\n"); - getline(imu_metadata, line, '\n'); // ignore the first line + getline(imu_metadata, line, '\n'); // ignore the first line - // Load Kitti calibration - getline(imu_metadata, line, '\n'); - sscanf(line.c_str(), "%lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf", - &kitti_calibration.body_ptx, - &kitti_calibration.body_pty, - &kitti_calibration.body_ptz, - &kitti_calibration.body_prx, - &kitti_calibration.body_pry, - &kitti_calibration.body_prz, - &kitti_calibration.accelerometer_sigma, - &kitti_calibration.gyroscope_sigma, - &kitti_calibration.integration_sigma, - &kitti_calibration.accelerometer_bias_sigma, - &kitti_calibration.gyroscope_bias_sigma, - &kitti_calibration.average_delta_t); - printf("IMU metadata: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n", - kitti_calibration.body_ptx, - kitti_calibration.body_pty, - kitti_calibration.body_ptz, - kitti_calibration.body_prx, - kitti_calibration.body_pry, - kitti_calibration.body_prz, - kitti_calibration.accelerometer_sigma, - kitti_calibration.gyroscope_sigma, - kitti_calibration.integration_sigma, - kitti_calibration.accelerometer_bias_sigma, - kitti_calibration.gyroscope_bias_sigma, - kitti_calibration.average_delta_t); + // Load Kitti calibration + getline(imu_metadata, line, '\n'); + sscanf(line.c_str(), "%lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf", + &kitti_calibration.body_ptx, &kitti_calibration.body_pty, + &kitti_calibration.body_ptz, &kitti_calibration.body_prx, + &kitti_calibration.body_pry, &kitti_calibration.body_prz, + &kitti_calibration.accelerometer_sigma, + &kitti_calibration.gyroscope_sigma, + &kitti_calibration.integration_sigma, + &kitti_calibration.accelerometer_bias_sigma, + &kitti_calibration.gyroscope_bias_sigma, + &kitti_calibration.average_delta_t); + printf("IMU metadata: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n", + kitti_calibration.body_ptx, kitti_calibration.body_pty, + kitti_calibration.body_ptz, kitti_calibration.body_prx, + kitti_calibration.body_pry, kitti_calibration.body_prz, + kitti_calibration.accelerometer_sigma, + kitti_calibration.gyroscope_sigma, kitti_calibration.integration_sigma, + kitti_calibration.accelerometer_bias_sigma, + kitti_calibration.gyroscope_bias_sigma, + kitti_calibration.average_delta_t); - // Read IMU data - // Time dt accelX accelY accelZ omegaX omegaY omegaZ - string imu_data_file = findExampleDataFile("KittiEquivBiasedImu.txt"); - printf("-- Reading IMU measurements from file\n"); - { - ifstream imu_data(imu_data_file.c_str()); - getline(imu_data, line, '\n'); // ignore the first line + // Read IMU data + // Time dt accelX accelY accelZ omegaX omegaY omegaZ + string imu_data_file = findExampleDataFile("KittiEquivBiasedImu.txt"); + printf("-- Reading IMU measurements from file\n"); + { + ifstream imu_data(imu_data_file.c_str()); + getline(imu_data, line, '\n'); // ignore the first line - double time = 0, dt = 0, acc_x = 0, acc_y = 0, acc_z = 0, gyro_x = 0, gyro_y = 0, gyro_z = 0; - while (!imu_data.eof()) { - getline(imu_data, line, '\n'); - sscanf(line.c_str(), "%lf %lf %lf %lf %lf %lf %lf %lf", - &time, &dt, - &acc_x, &acc_y, &acc_z, - &gyro_x, &gyro_y, &gyro_z); + double time = 0, dt = 0, acc_x = 0, acc_y = 0, acc_z = 0, gyro_x = 0, + gyro_y = 0, gyro_z = 0; + while (!imu_data.eof()) { + getline(imu_data, line, '\n'); + sscanf(line.c_str(), "%lf %lf %lf %lf %lf %lf %lf %lf", &time, &dt, + &acc_x, &acc_y, &acc_z, &gyro_x, &gyro_y, &gyro_z); - ImuMeasurement measurement; - measurement.time = time; - measurement.dt = dt; - measurement.accelerometer = Vector3(acc_x, acc_y, acc_z); - measurement.gyroscope = Vector3(gyro_x, gyro_y, gyro_z); - imu_measurements.push_back(measurement); - } + ImuMeasurement measurement; + measurement.time = time; + measurement.dt = dt; + measurement.accelerometer = Vector3(acc_x, acc_y, acc_z); + measurement.gyroscope = Vector3(gyro_x, gyro_y, gyro_z); + imu_measurements.push_back(measurement); } + } - // Read GPS data - // Time,X,Y,Z - string gps_data_file = findExampleDataFile("KittiGps_converted.txt"); - printf("-- Reading GPS measurements from file\n"); - { - ifstream gps_data(gps_data_file.c_str()); - getline(gps_data, line, '\n'); // ignore the first line + // Read GPS data + // Time,X,Y,Z + string gps_data_file = findExampleDataFile("KittiGps_converted.txt"); + printf("-- Reading GPS measurements from file\n"); + { + ifstream gps_data(gps_data_file.c_str()); + getline(gps_data, line, '\n'); // ignore the first line - double time = 0, gps_x = 0, gps_y = 0, gps_z = 0; - while (!gps_data.eof()) { - getline(gps_data, line, '\n'); - sscanf(line.c_str(), "%lf,%lf,%lf,%lf", &time, &gps_x, &gps_y, &gps_z); + double time = 0, gps_x = 0, gps_y = 0, gps_z = 0; + while (!gps_data.eof()) { + getline(gps_data, line, '\n'); + sscanf(line.c_str(), "%lf,%lf,%lf,%lf", &time, &gps_x, &gps_y, &gps_z); - GpsMeasurement measurement; - measurement.time = time; - measurement.position = Vector3(gps_x, gps_y, gps_z); - gps_measurements.push_back(measurement); - } + GpsMeasurement measurement; + measurement.time = time; + measurement.position = Vector3(gps_x, gps_y, gps_z); + gps_measurements.push_back(measurement); } + } } int main(int argc, char* argv[]) { - KittiCalibration kitti_calibration; - vector imu_measurements; - vector gps_measurements; - loadKittiData(kitti_calibration, imu_measurements, gps_measurements); + KittiCalibration kitti_calibration; + vector imu_measurements; + vector gps_measurements; + loadKittiData(kitti_calibration, imu_measurements, gps_measurements); - Vector6 BodyP = (Vector6() << kitti_calibration.body_ptx, kitti_calibration.body_pty, kitti_calibration.body_ptz, - kitti_calibration.body_prx, kitti_calibration.body_pry, kitti_calibration.body_prz) - .finished(); - auto body_T_imu = Pose3::Expmap(BodyP); - if (!body_T_imu.equals(Pose3(), 1e-5)) { - printf("Currently only support IMUinBody is identity, i.e. IMU and body frame are the same"); - exit(-1); - } + Vector6 BodyP = + (Vector6() << kitti_calibration.body_ptx, kitti_calibration.body_pty, + kitti_calibration.body_ptz, kitti_calibration.body_prx, + kitti_calibration.body_pry, kitti_calibration.body_prz) + .finished(); + auto body_T_imu = Pose3::Expmap(BodyP); + if (!body_T_imu.equals(Pose3(), 1e-5)) { + printf( + "Currently only support IMUinBody is identity, i.e. IMU and body frame " + "are the same"); + exit(-1); + } - // Configure different variables - // double t_offset = gps_measurements[0].time; - size_t first_gps_pose = 1; - size_t gps_skip = 10; // Skip this many GPS measurements each time - double g = 9.8; - auto w_coriolis = Vector3::Zero(); // zero vector + // Configure different variables + // double t_offset = gps_measurements[0].time; + size_t first_gps_pose = 1; + size_t gps_skip = 10; // Skip this many GPS measurements each time + double g = 9.8; + auto w_coriolis = Vector3::Zero(); // zero vector - // Configure noise models - auto noise_model_gps = noiseModel::Diagonal::Precisions((Vector6() << Vector3::Constant(0), - Vector3::Constant(1.0/0.07)) - .finished()); + // Configure noise models + auto noise_model_gps = noiseModel::Diagonal::Precisions( + (Vector6() << Vector3::Constant(0), Vector3::Constant(1.0 / 0.07)) + .finished()); - // Set initial conditions for the estimated trajectory - // initial pose is the reference frame (navigation frame) - auto current_pose_global = Pose3(Rot3(), gps_measurements[first_gps_pose].position); - // the vehicle is stationary at the beginning at position 0,0,0 - Vector3 current_velocity_global = Vector3::Zero(); - auto current_bias = imuBias::ConstantBias(); // init with zero bias + // Set initial conditions for the estimated trajectory + // initial pose is the reference frame (navigation frame) + auto current_pose_global = + Pose3(Rot3(), gps_measurements[first_gps_pose].position); + // the vehicle is stationary at the beginning at position 0,0,0 + Vector3 current_velocity_global = Vector3::Zero(); + auto current_bias = imuBias::ConstantBias(); // init with zero bias - auto sigma_init_x = noiseModel::Diagonal::Precisions((Vector6() << Vector3::Constant(0), - Vector3::Constant(1.0)) - .finished()); - auto sigma_init_v = noiseModel::Diagonal::Sigmas(Vector3::Constant(1000.0)); - auto sigma_init_b = noiseModel::Diagonal::Sigmas((Vector6() << Vector3::Constant(0.100), - Vector3::Constant(5.00e-05)) - .finished()); + auto sigma_init_x = noiseModel::Diagonal::Precisions( + (Vector6() << Vector3::Constant(0), Vector3::Constant(1.0)).finished()); + auto sigma_init_v = noiseModel::Diagonal::Sigmas(Vector3::Constant(1000.0)); + auto sigma_init_b = noiseModel::Diagonal::Sigmas( + (Vector6() << Vector3::Constant(0.100), Vector3::Constant(5.00e-05)) + .finished()); - // Set IMU preintegration parameters - Matrix33 measured_acc_cov = I_3x3 * pow(kitti_calibration.accelerometer_sigma, 2); - Matrix33 measured_omega_cov = I_3x3 * pow(kitti_calibration.gyroscope_sigma, 2); - // error committed in integrating position from velocities - Matrix33 integration_error_cov = I_3x3 * pow(kitti_calibration.integration_sigma, 2); + // Set IMU preintegration parameters + Matrix33 measured_acc_cov = + I_3x3 * pow(kitti_calibration.accelerometer_sigma, 2); + Matrix33 measured_omega_cov = + I_3x3 * pow(kitti_calibration.gyroscope_sigma, 2); + // error committed in integrating position from velocities + Matrix33 integration_error_cov = + I_3x3 * pow(kitti_calibration.integration_sigma, 2); - auto imu_params = PreintegratedImuMeasurements::Params::MakeSharedU(g); - imu_params->accelerometerCovariance = measured_acc_cov; // acc white noise in continuous - imu_params->integrationCovariance = integration_error_cov; // integration uncertainty continuous - imu_params->gyroscopeCovariance = measured_omega_cov; // gyro white noise in continuous - imu_params->omegaCoriolis = w_coriolis; + auto imu_params = PreintegratedImuMeasurements::Params::MakeSharedU(g); + imu_params->accelerometerCovariance = + measured_acc_cov; // acc white noise in continuous + imu_params->integrationCovariance = + integration_error_cov; // integration uncertainty continuous + imu_params->gyroscopeCovariance = + measured_omega_cov; // gyro white noise in continuous + imu_params->omegaCoriolis = w_coriolis; - std::shared_ptr current_summarized_measurement = nullptr; + std::shared_ptr current_summarized_measurement = + nullptr; - // Set ISAM2 parameters and create ISAM2 solver object - ISAM2Params isam_params; - isam_params.factorization = ISAM2Params::CHOLESKY; - isam_params.relinearizeSkip = 10; + // Set ISAM2 parameters and create ISAM2 solver object + ISAM2Params isam_params; + isam_params.factorization = ISAM2Params::CHOLESKY; + isam_params.relinearizeSkip = 10; - ISAM2 isam(isam_params); + ISAM2 isam(isam_params); - // Create the factor graph and values object that will store new factors and values to add to the incremental graph - NonlinearFactorGraph new_factors; - Values new_values; // values storing the initial estimates of new nodes in the factor graph + // Create the factor graph and values object that will store new factors and + // values to add to the incremental graph + NonlinearFactorGraph new_factors; + Values new_values; // values storing the initial estimates of new nodes in + // the factor graph - /// Main loop: - /// (1) we read the measurements - /// (2) we create the corresponding factors in the graph - /// (3) we solve the graph to obtain and optimal estimate of robot trajectory - printf("-- Starting main loop: inference is performed at each time step, but we plot trajectory every 10 steps\n"); - size_t j = 0; - for (size_t i = first_gps_pose; i < gps_measurements.size() - 1; i++) { - // At each non=IMU measurement we initialize a new node in the graph - auto current_pose_key = X(i); - auto current_vel_key = V(i); - auto current_bias_key = B(i); - double t = gps_measurements[i].time; + /// Main loop: + /// (1) we read the measurements + /// (2) we create the corresponding factors in the graph + /// (3) we solve the graph to obtain and optimal estimate of robot trajectory + printf( + "-- Starting main loop: inference is performed at each time step, but we " + "plot trajectory every 10 steps\n"); + size_t j = 0; + size_t included_imu_measurement_count = 0; - if (i == first_gps_pose) { - // Create initial estimate and prior on initial pose, velocity, and biases - new_values.insert(current_pose_key, current_pose_global); - new_values.insert(current_vel_key, current_velocity_global); - new_values.insert(current_bias_key, current_bias); - new_factors.emplace_shared>(current_pose_key, current_pose_global, sigma_init_x); - new_factors.emplace_shared>(current_vel_key, current_velocity_global, sigma_init_v); - new_factors.emplace_shared>(current_bias_key, current_bias, sigma_init_b); - } else { - double t_previous = gps_measurements[i-1].time; + for (size_t i = first_gps_pose; i < gps_measurements.size() - 1; i++) { + // At each non=IMU measurement we initialize a new node in the graph + auto current_pose_key = X(i); + auto current_vel_key = V(i); + auto current_bias_key = B(i); + double t = gps_measurements[i].time; - // Summarize IMU data between the previous GPS measurement and now - current_summarized_measurement = std::make_shared(imu_params, current_bias); - static size_t included_imu_measurement_count = 0; - while (j < imu_measurements.size() && imu_measurements[j].time <= t) { - if (imu_measurements[j].time >= t_previous) { - current_summarized_measurement->integrateMeasurement(imu_measurements[j].accelerometer, - imu_measurements[j].gyroscope, - imu_measurements[j].dt); - included_imu_measurement_count++; - } - j++; - } + if (i == first_gps_pose) { + // Create initial estimate and prior on initial pose, velocity, and biases + new_values.insert(current_pose_key, current_pose_global); + new_values.insert(current_vel_key, current_velocity_global); + new_values.insert(current_bias_key, current_bias); + new_factors.emplace_shared>( + current_pose_key, current_pose_global, sigma_init_x); + new_factors.emplace_shared>( + current_vel_key, current_velocity_global, sigma_init_v); + new_factors.emplace_shared>( + current_bias_key, current_bias, sigma_init_b); + } else { + double t_previous = gps_measurements[i - 1].time; - // Create IMU factor - auto previous_pose_key = X(i-1); - auto previous_vel_key = V(i-1); - auto previous_bias_key = B(i-1); + // Summarize IMU data between the previous GPS measurement and now + current_summarized_measurement = + std::make_shared(imu_params, + current_bias); - new_factors.emplace_shared(previous_pose_key, previous_vel_key, - current_pose_key, current_vel_key, - previous_bias_key, *current_summarized_measurement); - - // Bias evolution as given in the IMU metadata - auto sigma_between_b = noiseModel::Diagonal::Sigmas((Vector6() << - Vector3::Constant(sqrt(included_imu_measurement_count) * kitti_calibration.accelerometer_bias_sigma), - Vector3::Constant(sqrt(included_imu_measurement_count) * kitti_calibration.gyroscope_bias_sigma)) - .finished()); - new_factors.emplace_shared>(previous_bias_key, - current_bias_key, - imuBias::ConstantBias(), - sigma_between_b); - - // Create GPS factor - auto gps_pose = Pose3(current_pose_global.rotation(), gps_measurements[i].position); - if ((i % gps_skip) == 0) { - new_factors.emplace_shared>(current_pose_key, gps_pose, noise_model_gps); - new_values.insert(current_pose_key, gps_pose); - - printf("################ POSE INCLUDED AT TIME %lf ################\n", t); - cout << gps_pose.translation(); - printf("\n\n"); - } else { - new_values.insert(current_pose_key, current_pose_global); - } - - // Add initial values for velocity and bias based on the previous estimates - new_values.insert(current_vel_key, current_velocity_global); - new_values.insert(current_bias_key, current_bias); - - // Update solver - // ======================================================================= - // We accumulate 2*GPSskip GPS measurements before updating the solver at - // first so that the heading becomes observable. - if (i > (first_gps_pose + 2*gps_skip)) { - printf("################ NEW FACTORS AT TIME %lf ################\n", t); - new_factors.print(); - - isam.update(new_factors, new_values); - - // Reset the newFactors and newValues list - new_factors.resize(0); - new_values.clear(); - - // Extract the result/current estimates - Values result = isam.calculateEstimate(); - - current_pose_global = result.at(current_pose_key); - current_velocity_global = result.at(current_vel_key); - current_bias = result.at(current_bias_key); - - printf("\n################ POSE AT TIME %lf ################\n", t); - current_pose_global.print(); - printf("\n\n"); - } + while (j < imu_measurements.size() && imu_measurements[j].time <= t) { + if (imu_measurements[j].time >= t_previous) { + current_summarized_measurement->integrateMeasurement( + imu_measurements[j].accelerometer, imu_measurements[j].gyroscope, + imu_measurements[j].dt); + included_imu_measurement_count++; } + j++; + } + + // Create IMU factor + auto previous_pose_key = X(i - 1); + auto previous_vel_key = V(i - 1); + auto previous_bias_key = B(i - 1); + + new_factors.emplace_shared( + previous_pose_key, previous_vel_key, current_pose_key, + current_vel_key, previous_bias_key, *current_summarized_measurement); + + // Bias evolution as given in the IMU metadata + auto sigma_between_b = noiseModel::Diagonal::Sigmas( + (Vector6() << Vector3::Constant( + sqrt(included_imu_measurement_count) * + kitti_calibration.accelerometer_bias_sigma), + Vector3::Constant(sqrt(included_imu_measurement_count) * + kitti_calibration.gyroscope_bias_sigma)) + .finished()); + new_factors.emplace_shared>( + previous_bias_key, current_bias_key, imuBias::ConstantBias(), + sigma_between_b); + + // Create GPS factor + auto gps_pose = + Pose3(current_pose_global.rotation(), gps_measurements[i].position); + if ((i % gps_skip) == 0) { + new_factors.emplace_shared>( + current_pose_key, gps_pose, noise_model_gps); + new_values.insert(current_pose_key, gps_pose); + + printf("############ POSE INCLUDED AT TIME %.6lf ############\n", + t); + cout << gps_pose.translation(); + printf("\n\n"); + } else { + new_values.insert(current_pose_key, current_pose_global); + } + + // Add initial values for velocity and bias based on the previous + // estimates + new_values.insert(current_vel_key, current_velocity_global); + new_values.insert(current_bias_key, current_bias); + + // Update solver + // ======================================================================= + // We accumulate 2*GPSskip GPS measurements before updating the solver at + // first so that the heading becomes observable. + if (i > (first_gps_pose + 2 * gps_skip)) { + printf("############ NEW FACTORS AT TIME %.6lf ############\n", + t); + new_factors.print(); + + isam.update(new_factors, new_values); + + // Reset the newFactors and newValues list + new_factors.resize(0); + new_values.clear(); + + // Extract the result/current estimates + Values result = isam.calculateEstimate(); + + current_pose_global = result.at(current_pose_key); + current_velocity_global = result.at(current_vel_key); + current_bias = result.at(current_bias_key); + + printf("\n############ POSE AT TIME %lf ############\n", t); + current_pose_global.print(); + printf("\n\n"); + } } + } - // Save results to file - printf("\nWriting results to file...\n"); - FILE* fp_out = fopen(output_filename.c_str(), "w+"); - fprintf(fp_out, "#time(s),x(m),y(m),z(m),qx,qy,qz,qw,gt_x(m),gt_y(m),gt_z(m)\n"); + // Save results to file + printf("\nWriting results to file...\n"); + FILE* fp_out = fopen(output_filename.c_str(), "w+"); + fprintf(fp_out, + "#time(s),x(m),y(m),z(m),qx,qy,qz,qw,gt_x(m),gt_y(m),gt_z(m)\n"); - Values result = isam.calculateEstimate(); - for (size_t i = first_gps_pose; i < gps_measurements.size() - 1; i++) { - auto pose_key = X(i); - auto vel_key = V(i); - auto bias_key = B(i); + Values result = isam.calculateEstimate(); + for (size_t i = first_gps_pose; i < gps_measurements.size() - 1; i++) { + auto pose_key = X(i); + auto vel_key = V(i); + auto bias_key = B(i); - auto pose = result.at(pose_key); - auto velocity = result.at(vel_key); - auto bias = result.at(bias_key); + auto pose = result.at(pose_key); + auto velocity = result.at(vel_key); + auto bias = result.at(bias_key); - auto pose_quat = pose.rotation().toQuaternion(); - auto gps = gps_measurements[i].position; + auto pose_quat = pose.rotation().toQuaternion(); + auto gps = gps_measurements[i].position; - cout << "State at #" << i << endl; - cout << "Pose:" << endl << pose << endl; - cout << "Velocity:" << endl << velocity << endl; - cout << "Bias:" << endl << bias << endl; + cout << "State at #" << i << endl; + cout << "Pose:" << endl << pose << endl; + cout << "Velocity:" << endl << velocity << endl; + cout << "Bias:" << endl << bias << endl; - fprintf(fp_out, "%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", - gps_measurements[i].time, - pose.x(), pose.y(), pose.z(), - pose_quat.x(), pose_quat.y(), pose_quat.z(), pose_quat.w(), - gps(0), gps(1), gps(2)); - } + fprintf(fp_out, "%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f\n", + gps_measurements[i].time, pose.x(), pose.y(), pose.z(), + pose_quat.x(), pose_quat.y(), pose_quat.z(), pose_quat.w(), gps(0), + gps(1), gps(2)); + } - fclose(fp_out); + fclose(fp_out); } diff --git a/examples/Pose2SLAMExample_graphviz.cpp b/examples/Pose2SLAMExample_graphviz.cpp index 27d556725..a8768e2b8 100644 --- a/examples/Pose2SLAMExample_graphviz.cpp +++ b/examples/Pose2SLAMExample_graphviz.cpp @@ -60,11 +60,10 @@ int main(int argc, char** argv) { // save factor graph as graphviz dot file // Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf" - ofstream os("Pose2SLAMExample.dot"); - graph.saveGraph(os, result); + graph.saveGraph("Pose2SLAMExample.dot", result); // Also print out to console - graph.saveGraph(cout, result); + graph.dot(cout, result); return 0; } diff --git a/examples/RangeISAMExample_plaza2.cpp b/examples/RangeISAMExample_plaza2.cpp index c8e31e1c5..aa61ffc6c 100644 --- a/examples/RangeISAMExample_plaza2.cpp +++ b/examples/RangeISAMExample_plaza2.cpp @@ -10,62 +10,81 @@ * -------------------------------------------------------------------------- */ /** - * @file RangeISAMExample_plaza1.cpp + * @file RangeISAMExample_plaza2.cpp * @brief A 2D Range SLAM example * @date June 20, 2013 - * @author FRank Dellaert + * @author Frank Dellaert */ -// Both relative poses and recovered trajectory poses will be stored as Pose2 objects +// Both relative poses and recovered trajectory poses will be stored as Pose2. #include +using gtsam::Pose2; -// Each variable in the system (poses and landmarks) must be identified with a unique key. -// We can either use simple integer keys (1, 2, 3, ...) or symbols (X1, X2, L1). -// Here we will use Symbols +// gtsam::Vectors are dynamic Eigen vectors, Vector3 is statically sized. +#include +using gtsam::Vector; +using gtsam::Vector3; + +// Unknown landmarks are of type Point2 (which is just a 2D Eigen vector). +#include +using gtsam::Point2; + +// Each variable in the system (poses and landmarks) must be identified with a +// unique key. We can either use simple integer keys (1, 2, 3, ...) or symbols +// (X1, X2, L1). Here we will use Symbols. #include +using gtsam::Symbol; -// We want to use iSAM2 to solve the range-SLAM problem incrementally +// We want to use iSAM2 to solve the range-SLAM problem incrementally. #include -// iSAM2 requires as input a set set of new factors to be added stored in a factor graph, -// and initial guesses for any new variables used in the added factors +// iSAM2 requires as input a set set of new factors to be added stored in a +// factor graph, and initial guesses for any new variables in the added factors. #include #include -// We will use a non-liear solver to batch-inituialize from the first 150 frames +// We will use a non-linear solver to batch-initialize from the first 150 frames #include -// In GTSAM, measurement functions are represented as 'factors'. Several common factors -// have been provided with the library for solving robotics SLAM problems. -#include +// Measurement functions are represented as 'factors'. Several common factors +// have been provided with the library for solving robotics SLAM problems: #include +#include #include -// Standard headers, added last, so we know headers above work on their own +// Timing, with functions below, provides nice facilities to benchmark. +#include +using gtsam::tictoc_print_; + +// Standard headers, added last, so we know headers above work on their own. #include #include +#include +#include +#include +#include +#include -using namespace std; -using namespace gtsam; namespace NM = gtsam::noiseModel; -// data available at http://www.frc.ri.cmu.edu/projects/emergencyresponse/RangeData/ -// Datafile format (from http://www.frc.ri.cmu.edu/projects/emergencyresponse/RangeData/log.html) +// Data is second UWB ranging dataset, B2 or "plaza 2", from +// "Navigating with Ranging Radios: Five Data Sets with Ground Truth" +// by Joseph Djugash, Bradley Hamner, and Stephan Roth +// https://www.ri.cmu.edu/pub_files/2009/9/Final_5datasetsRangingRadios.pdf // load the odometry // DR: Odometry Input (delta distance traveled and delta heading change) -// Time (sec) Delta Dist. Trav. (m) Delta Heading (rad) -typedef pair TimedOdometry; -list readOdometry() { - list odometryList; - string data_file = findExampleDataFile("Plaza2_DR.txt"); - ifstream is(data_file.c_str()); +// Time (sec) Delta Distance Traveled (m) Delta Heading (rad) +using TimedOdometry = std::pair; +std::list readOdometry() { + std::list odometryList; + std::string data_file = gtsam::findExampleDataFile("Plaza2_DR.txt"); + std::ifstream is(data_file.c_str()); while (is) { double t, distance_traveled, delta_heading; is >> t >> distance_traveled >> delta_heading; - odometryList.push_back( - TimedOdometry(t, Pose2(distance_traveled, 0, delta_heading))); + odometryList.emplace_back(t, Pose2(distance_traveled, 0, delta_heading)); } is.clear(); /* clears the end-of-file and error flags */ return odometryList; @@ -73,90 +92,85 @@ list readOdometry() { // load the ranges from TD // Time (sec) Sender / Antenna ID Receiver Node ID Range (m) -typedef boost::tuple RangeTriple; -vector readTriples() { - vector triples; - string data_file = findExampleDataFile("Plaza2_TD.txt"); - ifstream is(data_file.c_str()); +using RangeTriple = boost::tuple; +std::vector readTriples() { + std::vector triples; + std::string data_file = gtsam::findExampleDataFile("Plaza2_TD.txt"); + std::ifstream is(data_file.c_str()); while (is) { - double t, sender, range; - size_t receiver; + double t, range, sender, receiver; is >> t >> sender >> receiver >> range; - triples.push_back(RangeTriple(t, receiver, range)); + triples.emplace_back(t, receiver, range); } is.clear(); /* clears the end-of-file and error flags */ return triples; } // main -int main (int argc, char** argv) { - +int main(int argc, char** argv) { // load Plaza2 data - list odometry = readOdometry(); -// size_t M = odometry.size(); + std::list odometry = readOdometry(); + size_t M = odometry.size(); + std::cout << "Read " << M << " odometry entries." << std::endl; - vector triples = readTriples(); + std::vector triples = readTriples(); size_t K = triples.size(); + std::cout << "Read " << K << " range triples." << std::endl; // parameters - size_t minK = 150; // minimum number of range measurements to process initially - size_t incK = 25; // minimum number of range measurements to process after - bool groundTruth = false; + size_t minK = + 150; // minimum number of range measurements to process initially + size_t incK = 25; // minimum number of range measurements to process after bool robust = true; // Set Noise parameters - Vector priorSigmas = Vector3(1,1,M_PI); + Vector priorSigmas = Vector3(1, 1, M_PI); Vector odoSigmas = Vector3(0.05, 0.01, 0.1); - double sigmaR = 100; // range standard deviation - const NM::Base::shared_ptr // all same type - priorNoise = NM::Diagonal::Sigmas(priorSigmas), //prior - odoNoise = NM::Diagonal::Sigmas(odoSigmas), // odometry - gaussian = NM::Isotropic::Sigma(1, sigmaR), // non-robust - tukey = NM::Robust::Create(NM::mEstimator::Tukey::Create(15), gaussian), //robust - rangeNoise = robust ? tukey : gaussian; + double sigmaR = 100; // range standard deviation + const NM::Base::shared_ptr // all same type + priorNoise = NM::Diagonal::Sigmas(priorSigmas), // prior + looseNoise = NM::Isotropic::Sigma(2, 1000), // loose LM prior + odoNoise = NM::Diagonal::Sigmas(odoSigmas), // odometry + gaussian = NM::Isotropic::Sigma(1, sigmaR), // non-robust + tukey = NM::Robust::Create(NM::mEstimator::Tukey::Create(15), + gaussian), // robust + rangeNoise = robust ? tukey : gaussian; // Initialize iSAM - ISAM2 isam; + gtsam::ISAM2 isam; // Add prior on first pose - Pose2 pose0 = Pose2(-34.2086489999201, 45.3007639991120, - M_PI - 2.02108900000000); - NonlinearFactorGraph newFactors; + Pose2 pose0 = Pose2(-34.2086489999201, 45.3007639991120, M_PI - 2.021089); + gtsam::NonlinearFactorGraph newFactors; newFactors.addPrior(0, pose0, priorNoise); - Values initial; + gtsam::Values initial; initial.insert(0, pose0); - // initialize points - if (groundTruth) { // from TL file - initial.insert(symbol('L', 1), Point2(-68.9265, 18.3778)); - initial.insert(symbol('L', 6), Point2(-37.5805, 69.2278)); - initial.insert(symbol('L', 0), Point2(-33.6205, 26.9678)); - initial.insert(symbol('L', 5), Point2(1.7095, -5.8122)); - } else { // drawn from sigma=1 Gaussian in matlab version - initial.insert(symbol('L', 1), Point2(3.5784, 2.76944)); - initial.insert(symbol('L', 6), Point2(-1.34989, 3.03492)); - initial.insert(symbol('L', 0), Point2(0.725404, -0.0630549)); - initial.insert(symbol('L', 5), Point2(0.714743, -0.204966)); - } + // We will initialize landmarks randomly, and keep track of which landmarks we + // already added with a set. + std::mt19937_64 rng; + std::normal_distribution normal(0.0, 100.0); + std::set initializedLandmarks; // set some loop variables - size_t i = 1; // step counter - size_t k = 0; // range measurement counter + size_t i = 1; // step counter + size_t k = 0; // range measurement counter bool initialized = false; Pose2 lastPose = pose0; size_t countK = 0; // Loop over odometry gttic_(iSAM); - for(const TimedOdometry& timedOdometry: odometry) { - //--------------------------------- odometry loop ----------------------------------------- + for (const TimedOdometry& timedOdometry : odometry) { + //--------------------------------- odometry loop -------------------------- double t; Pose2 odometry; boost::tie(t, odometry) = timedOdometry; // add odometry factor - newFactors.push_back(BetweenFactor(i-1, i, odometry, odoNoise)); + newFactors.emplace_shared>(i - 1, i, odometry, + odoNoise); // predict pose and add as initial estimate Pose2 predictedPose = lastPose.compose(odometry); @@ -166,17 +180,30 @@ int main (int argc, char** argv) { // Check if there are range factors to be added while (k < K && t >= boost::get<0>(triples[k])) { size_t j = boost::get<1>(triples[k]); + Symbol landmark_key('L', j); double range = boost::get<2>(triples[k]); - newFactors.push_back(RangeFactor(i, symbol('L', j), range,rangeNoise)); + newFactors.emplace_shared>( + i, landmark_key, range, rangeNoise); + if (initializedLandmarks.count(landmark_key) == 0) { + std::cout << "adding landmark " << j << std::endl; + double x = normal(rng), y = normal(rng); + initial.insert(landmark_key, Point2(x, y)); + initializedLandmarks.insert(landmark_key); + // We also add a very loose prior on the landmark in case there is only + // one sighting, which cannot fully determine the landmark. + newFactors.emplace_shared>( + landmark_key, Point2(0, 0), looseNoise); + } k = k + 1; countK = countK + 1; } // Check whether to update iSAM 2 if ((k > minK) && (countK > incK)) { - if (!initialized) { // Do a full optimize for first minK ranges + if (!initialized) { // Do a full optimize for first minK ranges + std::cout << "Initializing at time " << k << std::endl; gttic_(batchInitialization); - LevenbergMarquardtOptimizer batchOptimizer(newFactors, initial); + gtsam::LevenbergMarquardtOptimizer batchOptimizer(newFactors, initial); initial = batchOptimizer.optimize(); gttoc_(batchInitialization); initialized = true; @@ -185,21 +212,27 @@ int main (int argc, char** argv) { isam.update(newFactors, initial); gttoc_(update); gttic_(calculateEstimate); - Values result = isam.calculateEstimate(); + gtsam::Values result = isam.calculateEstimate(); gttoc_(calculateEstimate); lastPose = result.at(i); - newFactors = NonlinearFactorGraph(); - initial = Values(); + newFactors = gtsam::NonlinearFactorGraph(); + initial = gtsam::Values(); countK = 0; } i += 1; - //--------------------------------- odometry loop ----------------------------------------- - } // end for + //--------------------------------- odometry loop -------------------------- + } // end for gttoc_(iSAM); // Print timings tictoc_print_(); + // Print optimized landmarks: + gtsam::Values finalResult = isam.calculateEstimate(); + for (auto&& landmark_key : initializedLandmarks) { + Point2 p = finalResult.at(landmark_key); + std::cout << landmark_key << ":" << p.transpose() << "\n"; + } + exit(0); } - diff --git a/examples/SFMExampleExpressions_bal.cpp b/examples/SFMExampleExpressions_bal.cpp index 3768ee2a3..8a5a12e56 100644 --- a/examples/SFMExampleExpressions_bal.cpp +++ b/examples/SFMExampleExpressions_bal.cpp @@ -26,9 +26,12 @@ #include // Header order is close to far -#include +#include // for loading BAL datasets ! +#include #include -#include // for loading BAL datasets ! +#include + +#include #include using namespace std; @@ -46,10 +49,9 @@ int main(int argc, char* argv[]) { if (argc > 1) filename = string(argv[1]); // Load the SfM data from file - SfmData mydata; - readBAL(filename, mydata); + SfmData mydata = SfmData::FromBalFile(filename); cout << boost::format("read %1% tracks on %2% cameras\n") % - mydata.number_tracks() % mydata.number_cameras(); + mydata.numberTracks() % mydata.numberCameras(); // Create a factor graph ExpressionFactorGraph graph; diff --git a/examples/SFMExample_bal.cpp b/examples/SFMExample_bal.cpp index ffb5b195b..10563760d 100644 --- a/examples/SFMExample_bal.cpp +++ b/examples/SFMExample_bal.cpp @@ -10,17 +10,20 @@ * -------------------------------------------------------------------------- */ /** - * @file SFMExample.cpp + * @file SFMExample_bal.cpp * @brief Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file * @author Frank Dellaert */ // For an explanation of headers, see SFMExample.cpp -#include +#include // for loading BAL datasets ! +#include +#include #include #include -#include -#include // for loading BAL datasets ! +#include + +#include #include using namespace std; @@ -41,9 +44,8 @@ int main (int argc, char* argv[]) { if (argc>1) filename = string(argv[1]); // Load the SfM data from file - SfmData mydata; - readBAL(filename, mydata); - cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.number_tracks() % mydata.number_cameras(); + SfmData mydata = SfmData::FromBalFile(filename); + cout << boost::format("read %1% tracks on %2% cameras\n") % mydata.numberTracks() % mydata.numberCameras(); // Create a factor graph NonlinearFactorGraph graph; diff --git a/examples/SFMExample_bal_COLAMD_METIS.cpp b/examples/SFMExample_bal_COLAMD_METIS.cpp index b79a9fa28..92d779a56 100644 --- a/examples/SFMExample_bal_COLAMD_METIS.cpp +++ b/examples/SFMExample_bal_COLAMD_METIS.cpp @@ -17,15 +17,16 @@ */ // For an explanation of headers, see SFMExample.cpp -#include -#include +#include +#include // for loading BAL datasets ! +#include #include #include -#include -#include // for loading BAL datasets ! - +#include +#include #include +#include #include using namespace std; @@ -45,10 +46,9 @@ int main(int argc, char* argv[]) { if (argc > 1) filename = string(argv[1]); // Load the SfM data from file - SfmData mydata; - readBAL(filename, mydata); + SfmData mydata = SfmData::FromBalFile(filename); cout << boost::format("read %1% tracks on %2% cameras\n") % - mydata.number_tracks() % mydata.number_cameras(); + mydata.numberTracks() % mydata.numberCameras(); // Create a factor graph NonlinearFactorGraph graph; @@ -131,7 +131,7 @@ int main(int argc, char* argv[]) { cout << "Time comparison by solving " << filename << " results:" << endl; cout << boost::format("%1% point tracks and %2% cameras\n") % - mydata.number_tracks() % mydata.number_cameras() + mydata.numberTracks() % mydata.numberCameras() << endl; tictoc_print_(); diff --git a/examples/SFMdata.h b/examples/SFMdata.h index 04d3c9e47..3031828f1 100644 --- a/examples/SFMdata.h +++ b/examples/SFMdata.h @@ -22,6 +22,8 @@ * Passing function argument allows to specificy an initial position, a pose increment and step count. */ +#pragma once + // As this is a full 3D problem, we will use Pose3 variables to represent the camera // positions and Point3 variables (x, y, z) to represent the landmark coordinates. // Camera observations of landmarks (i.e. pixel coordinates) will be stored as Point2 (x, y). @@ -66,4 +68,4 @@ std::vector createPoses( } return poses; -} \ No newline at end of file +} diff --git a/examples/UGM_chain.cpp b/examples/UGM_chain.cpp index 3a885a844..ad21af9fa 100644 --- a/examples/UGM_chain.cpp +++ b/examples/UGM_chain.cpp @@ -68,10 +68,9 @@ int main(int argc, char** argv) { << graph.size() << " factors (Unary+Edge)."; // "Decoding", i.e., configuration with largest value - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n"); + // Uses max-product. + auto optimalDecoding = graph.optimize(); + optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); // "Inference" Computing marginals for each node // Here we'll make use of DiscreteMarginals class, which makes use of diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 27a6205a3..bc6a41317 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -50,8 +50,8 @@ int main(int argc, char** argv) { // Print the UGM distribution cout << "\nUGM distribution:" << endl; - vector allPosbValues = cartesianProduct( - Cathy & Heather & Mark & Allison); + auto allPosbValues = + DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values values = allPosbValues[i]; double prodPot = graph(values); @@ -61,10 +61,9 @@ int main(int argc, char** argv) { } // "Decoding", i.e., configuration with largest value (MPE) - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\noptimalDecoding"); + // Uses max-product + auto optimalDecoding = graph.optimize(); + GTSAM_PRINT(optimalDecoding); // "Inference" Computing marginals cout << "\nComputing Node Marginals .." << endl; diff --git a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h index 667ef09dc..9db32744e 100644 --- a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h +++ b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h @@ -440,7 +440,7 @@ template class TriangularViewImpl<_Mat EIGEN_DEVICE_FUNC void lazyAssign(const TriangularBase& other); - /** \deprecated */ + /** @deprecated */ template EIGEN_DEVICE_FUNC void lazyAssign(const MatrixBase& other); @@ -523,7 +523,7 @@ template class TriangularViewImpl<_Mat call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op()); } - /** \deprecated + /** @deprecated * Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */ template EIGEN_DEVICE_FUNC diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 535d60eb1..a293c6ec2 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -15,7 +15,7 @@ set (gtsam_subdirs sam sfm slam - navigation + navigation ) set(gtsam_srcs) diff --git a/gtsam/base/CMakeLists.txt b/gtsam/base/CMakeLists.txt index 99984e7b3..66d3ec721 100644 --- a/gtsam/base/CMakeLists.txt +++ b/gtsam/base/CMakeLists.txt @@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base) file(GLOB base_headers_tree "treeTraversal/*.h") install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal) -file(GLOB deprecated_headers "deprecated/*.h") -install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated) - # Build tests add_subdirectory(tests) diff --git a/gtsam/base/FastSet.h b/gtsam/base/FastSet.h index 8c23ae9e5..6fe2d06e3 100644 --- a/gtsam/base/FastSet.h +++ b/gtsam/base/FastSet.h @@ -18,6 +18,10 @@ #pragma once +#include +#if BOOST_VERSION >= 107400 +#include +#endif #include #include #include diff --git a/gtsam/base/Lie.h b/gtsam/base/Lie.h index ac7c2a9a5..cb8e7d017 100644 --- a/gtsam/base/Lie.h +++ b/gtsam/base/Lie.h @@ -370,4 +370,4 @@ public: * the gtsam namespace to be more easily enforced as testable */ #define GTSAM_CONCEPT_LIE_INST(T) template class gtsam::IsLieGroup; -#define GTSAM_CONCEPT_LIE_TYPE(T) typedef gtsam::IsLieGroup _gtsam_IsLieGroup_##T; +#define GTSAM_CONCEPT_LIE_TYPE(T) using _gtsam_IsLieGroup_##T = gtsam::IsLieGroup; diff --git a/gtsam/base/LieMatrix.h b/gtsam/base/LieMatrix.h deleted file mode 100644 index 210bdcc73..000000000 --- a/gtsam/base/LieMatrix.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief External deprecation warning, see deprecated/LieMatrix.h for details - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.") -#else -#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead." -#endif - -#include "gtsam/base/deprecated/LieMatrix.h" diff --git a/gtsam/base/LieVector.h b/gtsam/base/LieVector.h deleted file mode 100644 index a7491d804..000000000 --- a/gtsam/base/LieVector.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details. - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.") -#else -#warning "LieVector.h is deprecated. Please use Eigen::Vector instead." -#endif - -#include diff --git a/gtsam/base/Manifold.h b/gtsam/base/Manifold.h index dbe497005..962dc8269 100644 --- a/gtsam/base/Manifold.h +++ b/gtsam/base/Manifold.h @@ -178,4 +178,4 @@ struct FixedDimension { // * the gtsam namespace to be more easily enforced as testable // */ #define GTSAM_CONCEPT_MANIFOLD_INST(T) template class gtsam::IsManifold; -#define GTSAM_CONCEPT_MANIFOLD_TYPE(T) typedef gtsam::IsManifold _gtsam_IsManifold_##T; +#define GTSAM_CONCEPT_MANIFOLD_TYPE(T) using _gtsam_IsManifold_##T = gtsam::IsManifold; diff --git a/gtsam/base/Matrix.cpp b/gtsam/base/Matrix.cpp index 41a80629b..5b8a021d4 100644 --- a/gtsam/base/Matrix.cpp +++ b/gtsam/base/Matrix.cpp @@ -25,6 +25,7 @@ #include #include +#include #include #include diff --git a/gtsam/base/Matrix.h b/gtsam/base/Matrix.h index 013947bbd..cfedf6d8c 100644 --- a/gtsam/base/Matrix.h +++ b/gtsam/base/Matrix.h @@ -26,12 +26,9 @@ #include #include -#include - -#include -#include #include -#include + +#include /** * Matrix is a typedef in the gtsam namespace @@ -46,28 +43,28 @@ typedef Eigen::Matrix M // Create handy typedefs and constants for square-size matrices // MatrixMN, MatrixN = MatrixNN, I_NxN, and Z_NxN, for M,N=1..9 #define GTSAM_MAKE_MATRIX_DEFS(N) \ -typedef Eigen::Matrix Matrix##N; \ -typedef Eigen::Matrix Matrix1##N; \ -typedef Eigen::Matrix Matrix2##N; \ -typedef Eigen::Matrix Matrix3##N; \ -typedef Eigen::Matrix Matrix4##N; \ -typedef Eigen::Matrix Matrix5##N; \ -typedef Eigen::Matrix Matrix6##N; \ -typedef Eigen::Matrix Matrix7##N; \ -typedef Eigen::Matrix Matrix8##N; \ -typedef Eigen::Matrix Matrix9##N; \ +using Matrix##N = Eigen::Matrix; \ +using Matrix1##N = Eigen::Matrix; \ +using Matrix2##N = Eigen::Matrix; \ +using Matrix3##N = Eigen::Matrix; \ +using Matrix4##N = Eigen::Matrix; \ +using Matrix5##N = Eigen::Matrix; \ +using Matrix6##N = Eigen::Matrix; \ +using Matrix7##N = Eigen::Matrix; \ +using Matrix8##N = Eigen::Matrix; \ +using Matrix9##N = Eigen::Matrix; \ static const Eigen::MatrixBase::IdentityReturnType I_##N##x##N = Matrix##N::Identity(); \ static const Eigen::MatrixBase::ConstantReturnType Z_##N##x##N = Matrix##N::Zero(); -GTSAM_MAKE_MATRIX_DEFS(1); -GTSAM_MAKE_MATRIX_DEFS(2); -GTSAM_MAKE_MATRIX_DEFS(3); -GTSAM_MAKE_MATRIX_DEFS(4); -GTSAM_MAKE_MATRIX_DEFS(5); -GTSAM_MAKE_MATRIX_DEFS(6); -GTSAM_MAKE_MATRIX_DEFS(7); -GTSAM_MAKE_MATRIX_DEFS(8); -GTSAM_MAKE_MATRIX_DEFS(9); +GTSAM_MAKE_MATRIX_DEFS(1) +GTSAM_MAKE_MATRIX_DEFS(2) +GTSAM_MAKE_MATRIX_DEFS(3) +GTSAM_MAKE_MATRIX_DEFS(4) +GTSAM_MAKE_MATRIX_DEFS(5) +GTSAM_MAKE_MATRIX_DEFS(6) +GTSAM_MAKE_MATRIX_DEFS(7) +GTSAM_MAKE_MATRIX_DEFS(8) +GTSAM_MAKE_MATRIX_DEFS(9) // Matrix expressions for accessing parts of matrices typedef Eigen::Block SubMatrix; @@ -523,82 +520,4 @@ GTSAM_EXPORT Matrix LLt(const Matrix& A); GTSAM_EXPORT Matrix RtR(const Matrix& A); GTSAM_EXPORT Vector columnNormSquare(const Matrix &A); -} // namespace gtsam - -#include -#include -#include - -namespace boost { - namespace serialization { - - /** - * Ref. https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063 - * - * Eigen supports calling resize() on both static and dynamic matrices. - * This allows for a uniform API, with resize having no effect if the static matrix - * is already the correct size. - * https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing - * - * We use all the Matrix template parameters to ensure wide compatibility. - * - * eigen_typekit in ROS uses the same code - * http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html - */ - - // split version - sends sizes ahead - template - void save(Archive & ar, - const Eigen::Matrix & m, - const unsigned int /*version*/) { - const size_t rows = m.rows(), cols = m.cols(); - ar << BOOST_SERIALIZATION_NVP(rows); - ar << BOOST_SERIALIZATION_NVP(cols); - ar << make_nvp("data", make_array(m.data(), m.size())); - } - - template - void load(Archive & ar, - Eigen::Matrix & m, - const unsigned int /*version*/) { - size_t rows, cols; - ar >> BOOST_SERIALIZATION_NVP(rows); - ar >> BOOST_SERIALIZATION_NVP(cols); - m.resize(rows, cols); - ar >> make_nvp("data", make_array(m.data(), m.size())); - } - - // templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix); - template - void serialize(Archive & ar, - Eigen::Matrix & m, - const unsigned int version) { - split_free(ar, m, version); - } - - // specialized to Matrix for MATLAB wrapper - template - void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) { - split_free(ar, m, version); - } - - } // namespace serialization -} // namespace boost +} // namespace gtsam diff --git a/gtsam/base/MatrixSerialization.h b/gtsam/base/MatrixSerialization.h new file mode 100644 index 000000000..f79d7b27f --- /dev/null +++ b/gtsam/base/MatrixSerialization.h @@ -0,0 +1,89 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file MatrixSerialization.h + * @brief Serialization for matrices + * @author Frank Dellaert + * @date February 2022 + */ + +// \callgraph + +#pragma once + +#include + +#include +#include +#include + +namespace boost { +namespace serialization { + +/** + * Ref. + * https://stackoverflow.com/questions/18382457/eigen-and-boostserialize/22903063#22903063 + * + * Eigen supports calling resize() on both static and dynamic matrices. + * This allows for a uniform API, with resize having no effect if the static + * matrix is already the correct size. + * https://eigen.tuxfamily.org/dox/group__TutorialMatrixClass.html#TutorialMatrixSizesResizing + * + * We use all the Matrix template parameters to ensure wide compatibility. + * + * eigen_typekit in ROS uses the same code + * http://docs.ros.org/lunar/api/eigen_typekit/html/eigen__mqueue_8cpp_source.html + */ + +// split version - sends sizes ahead +template +void save( + Archive& ar, + const Eigen::Matrix& m, + const unsigned int /*version*/) { + const size_t rows = m.rows(), cols = m.cols(); + ar << BOOST_SERIALIZATION_NVP(rows); + ar << BOOST_SERIALIZATION_NVP(cols); + ar << make_nvp("data", make_array(m.data(), m.size())); +} + +template +void load(Archive& ar, + Eigen::Matrix& m, + const unsigned int /*version*/) { + size_t rows, cols; + ar >> BOOST_SERIALIZATION_NVP(rows); + ar >> BOOST_SERIALIZATION_NVP(cols); + m.resize(rows, cols); + ar >> make_nvp("data", make_array(m.data(), m.size())); +} + +// templated version of BOOST_SERIALIZATION_SPLIT_FREE(Eigen::Matrix); +template +void serialize( + Archive& ar, + Eigen::Matrix& m, + const unsigned int version) { + split_free(ar, m, version); +} + +// specialized to Matrix for MATLAB wrapper +template +void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) { + split_free(ar, m, version); +} + +} // namespace serialization +} // namespace boost diff --git a/gtsam/base/OptionalJacobian.h b/gtsam/base/OptionalJacobian.h index 4b580f82e..c9a960a89 100644 --- a/gtsam/base/OptionalJacobian.h +++ b/gtsam/base/OptionalJacobian.h @@ -20,6 +20,8 @@ #pragma once #include // Configuration from CMake #include +#include +#include #ifndef OPTIONALJACOBIAN_NOBOOST #include @@ -89,6 +91,31 @@ public: usurp(dynamic.data()); } + /// Constructor that will resize a dynamic matrix (unless already correct) + OptionalJacobian(Eigen::MatrixXd* dynamic) : + map_(nullptr) { + dynamic->resize(Rows, Cols); // no malloc if correct size + usurp(dynamic->data()); + } + + /** + * @brief Constructor from an Eigen::Ref *value*. Will not usurp if dimension is wrong + * @note This is important so we don't overwrite someone else's memory! + */ + template + OptionalJacobian(Eigen::Ref dynamic_ref) : + map_(nullptr) { + if (dynamic_ref.rows() == Rows && dynamic_ref.cols() == Cols && !dynamic_ref.IsRowMajor) { + usurp(dynamic_ref.data()); + } else { + throw std::invalid_argument( + std::string("OptionalJacobian called with wrong dimensions or " + "storage order.\n" + "Expected: ") + + "(" + std::to_string(Rows) + ", " + std::to_string(Cols) + ")"); + } + } + #ifndef OPTIONALJACOBIAN_NOBOOST /// Constructor with boost::none just makes empty diff --git a/gtsam/base/Testable.h b/gtsam/base/Testable.h index 6062c7ae1..d50d62c1f 100644 --- a/gtsam/base/Testable.h +++ b/gtsam/base/Testable.h @@ -173,4 +173,4 @@ namespace gtsam { * @deprecated please use BOOST_CONCEPT_ASSERT and */ #define GTSAM_CONCEPT_TESTABLE_INST(T) template class gtsam::IsTestable; -#define GTSAM_CONCEPT_TESTABLE_TYPE(T) typedef gtsam::IsTestable _gtsam_Testable_##T; +#define GTSAM_CONCEPT_TESTABLE_TYPE(T) using _gtsam_Testable_##T = gtsam::IsTestable; diff --git a/gtsam/base/TestableAssertions.h b/gtsam/base/TestableAssertions.h index 0e6e1c276..e5bd34d19 100644 --- a/gtsam/base/TestableAssertions.h +++ b/gtsam/base/TestableAssertions.h @@ -80,12 +80,13 @@ bool assert_equal(const V& expected, const boost::optional& actual, do return assert_equal(expected, *actual, tol); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * Version of assert_equals to work with vectors - * \deprecated: use container equals instead + * @deprecated: use container equals instead */ template -bool assert_equal(const std::vector& expected, const std::vector& actual, double tol = 1e-9) { +bool GTSAM_DEPRECATED assert_equal(const std::vector& expected, const std::vector& actual, double tol = 1e-9) { bool match = true; if (expected.size() != actual.size()) match = false; @@ -108,6 +109,7 @@ bool assert_equal(const std::vector& expected, const std::vector& actual, } return true; } +#endif /** * Function for comparing maps of testable->testable diff --git a/gtsam/base/Value.h b/gtsam/base/Value.h index a19fbe176..697c4f3be 100644 --- a/gtsam/base/Value.h +++ b/gtsam/base/Value.h @@ -21,6 +21,7 @@ #include // Configuration from CMake #include +#include #include #include diff --git a/gtsam/base/Vector.h b/gtsam/base/Vector.h index 9567d9980..f7923ff88 100644 --- a/gtsam/base/Vector.h +++ b/gtsam/base/Vector.h @@ -48,19 +48,19 @@ static const Eigen::MatrixBase::ConstantReturnType Z_3x1 = Vector3::Zer // Create handy typedefs and constants for vectors with N>3 // VectorN and Z_Nx1, for N=1..9 #define GTSAM_MAKE_VECTOR_DEFS(N) \ - typedef Eigen::Matrix Vector##N; \ + using Vector##N = Eigen::Matrix; \ static const Eigen::MatrixBase::ConstantReturnType Z_##N##x1 = Vector##N::Zero(); -GTSAM_MAKE_VECTOR_DEFS(4); -GTSAM_MAKE_VECTOR_DEFS(5); -GTSAM_MAKE_VECTOR_DEFS(6); -GTSAM_MAKE_VECTOR_DEFS(7); -GTSAM_MAKE_VECTOR_DEFS(8); -GTSAM_MAKE_VECTOR_DEFS(9); -GTSAM_MAKE_VECTOR_DEFS(10); -GTSAM_MAKE_VECTOR_DEFS(11); -GTSAM_MAKE_VECTOR_DEFS(12); -GTSAM_MAKE_VECTOR_DEFS(15); +GTSAM_MAKE_VECTOR_DEFS(4) +GTSAM_MAKE_VECTOR_DEFS(5) +GTSAM_MAKE_VECTOR_DEFS(6) +GTSAM_MAKE_VECTOR_DEFS(7) +GTSAM_MAKE_VECTOR_DEFS(8) +GTSAM_MAKE_VECTOR_DEFS(9) +GTSAM_MAKE_VECTOR_DEFS(10) +GTSAM_MAKE_VECTOR_DEFS(11) +GTSAM_MAKE_VECTOR_DEFS(12) +GTSAM_MAKE_VECTOR_DEFS(15) typedef Eigen::VectorBlock SubVector; typedef Eigen::VectorBlock ConstSubVector; @@ -204,18 +204,19 @@ inline double inner_prod(const V1 &a, const V2& b) { return a.dot(b); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * BLAS Level 1 scal: x <- alpha*x - * \deprecated: use operators instead + * @deprecated: use operators instead */ -inline void scal(double alpha, Vector& x) { x *= alpha; } +inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; } /** * BLAS Level 1 axpy: y <- alpha*x + y - * \deprecated: use operators instead + * @deprecated: use operators instead */ template -inline void axpy(double alpha, const V1& x, V2& y) { +inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) { assert (y.size()==x.size()); y += alpha * x; } @@ -223,6 +224,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) { assert (y.size()==x.size()); y += alpha * x; } +#endif /** * house(x,j) computes HouseHolder vector v and scaling factor beta @@ -263,46 +265,4 @@ GTSAM_EXPORT Vector concatVectors(const std::list& vs); * concatenate Vectors */ GTSAM_EXPORT Vector concatVectors(size_t nrVectors, ...); -} // namespace gtsam - -#include -#include -#include - -namespace boost { - namespace serialization { - - // split version - copies into an STL vector for serialization - template - void save(Archive & ar, const gtsam::Vector & v, unsigned int /*version*/) { - const size_t size = v.size(); - ar << BOOST_SERIALIZATION_NVP(size); - ar << make_nvp("data", make_array(v.data(), v.size())); - } - - template - void load(Archive & ar, gtsam::Vector & v, unsigned int /*version*/) { - size_t size; - ar >> BOOST_SERIALIZATION_NVP(size); - v.resize(size); - ar >> make_nvp("data", make_array(v.data(), v.size())); - } - - // split version - copies into an STL vector for serialization - template - void save(Archive & ar, const Eigen::Matrix & v, unsigned int /*version*/) { - ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime)); - } - - template - void load(Archive & ar, Eigen::Matrix & v, unsigned int /*version*/) { - ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime)); - } - - } // namespace serialization -} // namespace boost - -BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector) -BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2) -BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3) -BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6) +} // namespace gtsam diff --git a/gtsam/base/VectorSerialization.h b/gtsam/base/VectorSerialization.h new file mode 100644 index 000000000..97df02a75 --- /dev/null +++ b/gtsam/base/VectorSerialization.h @@ -0,0 +1,65 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file VectorSerialization.h + * @brief serialization for Vectors + * @author Frank Dellaert + * @date February 2022 + */ + +#pragma once + +#include + +#include +#include +#include + +namespace boost { +namespace serialization { + +// split version - copies into an STL vector for serialization +template +void save(Archive& ar, const gtsam::Vector& v, unsigned int /*version*/) { + const size_t size = v.size(); + ar << BOOST_SERIALIZATION_NVP(size); + ar << make_nvp("data", make_array(v.data(), v.size())); +} + +template +void load(Archive& ar, gtsam::Vector& v, unsigned int /*version*/) { + size_t size; + ar >> BOOST_SERIALIZATION_NVP(size); + v.resize(size); + ar >> make_nvp("data", make_array(v.data(), v.size())); +} + +// split version - copies into an STL vector for serialization +template +void save(Archive& ar, const Eigen::Matrix& v, + unsigned int /*version*/) { + ar << make_nvp("data", make_array(v.data(), v.RowsAtCompileTime)); +} + +template +void load(Archive& ar, Eigen::Matrix& v, + unsigned int /*version*/) { + ar >> make_nvp("data", make_array(v.data(), v.RowsAtCompileTime)); +} + +} // namespace serialization +} // namespace boost + +BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector) +BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector2) +BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector3) +BOOST_SERIALIZATION_SPLIT_FREE(gtsam::Vector6) diff --git a/gtsam/base/VerticalBlockMatrix.h b/gtsam/base/VerticalBlockMatrix.h index 92031db2b..0d8d69df8 100644 --- a/gtsam/base/VerticalBlockMatrix.h +++ b/gtsam/base/VerticalBlockMatrix.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include namespace gtsam { diff --git a/gtsam/base/base.i b/gtsam/base/base.i index d9c51fbe8..9b9f351ce 100644 --- a/gtsam/base/base.i +++ b/gtsam/base/base.i @@ -38,7 +38,7 @@ class DSFMap { DSFMap(); KEY find(const KEY& key) const; void merge(const KEY& x, const KEY& y); - std::map sets(); + std::map sets(); }; class IndexPairSet { @@ -82,6 +82,7 @@ class IndexPairSetMap { }; #include +#include bool linear_independent(Matrix A, Matrix B, double tol); #include diff --git a/gtsam/base/chartTesting.h b/gtsam/base/chartTesting.h index f63054a5b..8f5213f91 100644 --- a/gtsam/base/chartTesting.h +++ b/gtsam/base/chartTesting.h @@ -32,7 +32,7 @@ void testDefaultChart(TestResult& result_, const std::string& name_, const T& value) { - GTSAM_CONCEPT_TESTABLE_TYPE(T); + GTSAM_CONCEPT_TESTABLE_TYPE(T) typedef typename gtsam::DefaultChart Chart; typedef typename Chart::vector Vector; diff --git a/gtsam/base/cholesky.h b/gtsam/base/cholesky.h index 5e3276ff0..bf7d18a1d 100644 --- a/gtsam/base/cholesky.h +++ b/gtsam/base/cholesky.h @@ -18,7 +18,6 @@ #pragma once #include -#include namespace gtsam { diff --git a/gtsam/base/deprecated/LieMatrix.h b/gtsam/base/deprecated/LieMatrix.h deleted file mode 100644 index a3d0a4328..000000000 --- a/gtsam/base/deprecated/LieMatrix.h +++ /dev/null @@ -1,152 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief A wrapper around Matrix providing Lie compatibility - * @author Richard Roberts and Alex Cunningham - */ - -#pragma once - -#include - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieMatrix : public Matrix { - - /// @name Constructors - /// @{ - enum { dimension = Eigen::Dynamic }; - - /** default constructor - only for serialize */ - LieMatrix() {} - - /** initialize from a normal matrix */ - LieMatrix(const Matrix& v) : Matrix(v) {} - - template - LieMatrix(const M& v) : Matrix(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieMatrix(const Eigen::Matrix& v) : Matrix(v) {} -#endif - - /** constructor with size and initial data, row order ! */ - LieMatrix(size_t m, size_t n, const double* const data) : - Matrix(Eigen::Map(data, m, n)) {} - - /// @} - /// @name Testable interface - /// @{ - - /** print @param s optional string naming the object */ - void print(const std::string& name = "") const { - gtsam::print(matrix(), name); - } - /** equality up to tolerance */ - inline bool equals(const LieMatrix& expected, double tol=1e-5) const { - return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol); - } - - /// @} - /// @name Standard Interface - /// @{ - - /** get the underlying matrix */ - inline Matrix matrix() const { - return static_cast(*this); - } - - /// @} - - /// @name Group - /// @{ - LieMatrix compose(const LieMatrix& q) { return (*this)+q;} - LieMatrix between(const LieMatrix& q) { return q-(*this);} - LieMatrix inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieMatrix& q) { return between(q).vector();} - LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieMatrix& p) {return p.vector();} - static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** Returns dimensionality of the tangent space */ - inline size_t dim() const { return size(); } - - /** Convert to vector, is done row-wise - TODO why? */ - inline Vector vector() const { - Vector result(size()); - typedef Eigen::Matrix RowMajor; - Eigen::Map(&result(0), rows(), cols()) = *this; - return result; - } - - /** identity - NOTE: no known size at compile time - so zero length */ - inline static LieMatrix identity() { - throw std::runtime_error("LieMatrix::identity(): Don't use this function"); - return LieMatrix(); - } - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Matrix", - boost::serialization::base_object(*this)); - - } - -}; - - -template<> -struct traits : public internal::VectorSpace { - - // Override Retract, as the default version does not know how to initialize - static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v, - ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) { - if (H1) *H1 = Eye(origin); - if (H2) *H2 = Eye(origin); - typedef const Eigen::Matrix RowMajor; - return origin + Eigen::Map(&v(0), origin.rows(), origin.cols()); - } - -}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieScalar.h b/gtsam/base/deprecated/LieScalar.h deleted file mode 100644 index 6c9a5f766..000000000 --- a/gtsam/base/deprecated/LieScalar.h +++ /dev/null @@ -1,88 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieScalar.h - * @brief A wrapper around scalar providing Lie compatibility - * @author Kai Ni - */ - -#pragma once - -#include -#include -#include - -namespace gtsam { - - /** - * @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ - struct LieScalar { - - enum { dimension = 1 }; - - /** default constructor */ - LieScalar() : d_(0.0) {} - - /** wrap a double */ - /*explicit*/ LieScalar(double d) : d_(d) {} - - /** access the underlying value */ - double value() const { return d_; } - - /** Automatic conversion to underlying value */ - operator double() const { return d_; } - - /** convert vector */ - Vector1 vector() const { Vector1 v; v< - struct traits : public internal::ScalarTraits {}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieVector.h b/gtsam/base/deprecated/LieVector.h deleted file mode 100644 index 745189c3d..000000000 --- a/gtsam/base/deprecated/LieVector.h +++ /dev/null @@ -1,121 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief A wrapper around vector providing Lie compatibility - * @author Alex Cunningham - */ - -#pragma once - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieVector : public Vector { - - enum { dimension = Eigen::Dynamic }; - - /** default constructor - should be unnecessary */ - LieVector() {} - - /** initialize from a normal vector */ - LieVector(const Vector& v) : Vector(v) {} - - template - LieVector(const V& v) : Vector(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieVector(const Eigen::Matrix& v) : Vector(v) {} -#endif - - /** wrap a double */ - LieVector(double d) : Vector((Vector(1) << d).finished()) {} - - /** constructor with size and initial data, row order ! */ - LieVector(size_t m, const double* const data) : Vector(m) { - for (size_t i = 0; i < m; i++) (*this)(i) = data[i]; - } - - /// @name Testable - /// @{ - void print(const std::string& name="") const { - gtsam::print(vector(), name); - } - bool equals(const LieVector& expected, double tol=1e-5) const { - return gtsam::equal(vector(), expected.vector(), tol); - } - /// @} - - /// @name Group - /// @{ - LieVector compose(const LieVector& q) { return (*this)+q;} - LieVector between(const LieVector& q) { return q-(*this);} - LieVector inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieVector& q) { return between(q).vector();} - LieVector retract(const Vector& v) {return compose(LieVector(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieVector& p) {return p.vector();} - static LieVector Expmap(const Vector& v) { return LieVector(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** get the underlying vector */ - Vector vector() const { - return static_cast(*this); - } - - /** Returns dimensionality of the tangent space */ - size_t dim() const { return this->size(); } - - /** identity - NOTE: no known size at compile time - so zero length */ - static LieVector identity() { - throw std::runtime_error("LieVector::identity(): Don't use this function"); - return LieVector(); - } - - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Vector", - boost::serialization::base_object(*this)); - } -}; - - -template<> -struct traits : public internal::VectorSpace {}; - -} // \namespace gtsam diff --git a/gtsam/base/serialization.h b/gtsam/base/serialization.h index f589ecc5e..e615afe83 100644 --- a/gtsam/base/serialization.h +++ b/gtsam/base/serialization.h @@ -19,11 +19,13 @@ #pragma once -#include +#include #include +#include #include // includes for standard serialization types +#include #include #include #include @@ -40,6 +42,17 @@ #include #include +// Workaround a bug in GCC >= 7 and C++17 +// ref. https://gitlab.com/libeigen/eigen/-/issues/1676 +#ifdef __GNUC__ +#if __GNUC__ >= 7 && __cplusplus >= 201703L +namespace boost { namespace serialization { struct U; } } +namespace Eigen { namespace internal { +template<> struct traits {enum {Flags=0};}; +} } +#endif +#endif + namespace gtsam { /** @name Standard serialization diff --git a/gtsam/base/serializationTestHelpers.h b/gtsam/base/serializationTestHelpers.h index 5994a5e51..bb8574245 100644 --- a/gtsam/base/serializationTestHelpers.h +++ b/gtsam/base/serializationTestHelpers.h @@ -42,7 +42,7 @@ T create() { } // Creates or empties a folder in the build folder and returns the relative path -boost::filesystem::path resetFilesystem( +inline boost::filesystem::path resetFilesystem( boost::filesystem::path folder = "actual") { boost::filesystem::remove_all(folder); boost::filesystem::create_directory(folder); diff --git a/gtsam/base/tests/testLieMatrix.cpp b/gtsam/base/tests/testLieMatrix.cpp deleted file mode 100644 index 8c68bf8a0..000000000 --- a/gtsam/base/tests/testLieMatrix.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieMatrix.cpp - * @author Richard Roberts - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieMatrix) -GTSAM_CONCEPT_LIE_INST(LieMatrix) - -/* ************************************************************************* */ -TEST( LieMatrix, construction ) { - Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished(); - LieMatrix lie1(m), lie2(m); - - EXPECT(traits::GetDimension(m) == 4); - EXPECT(assert_equal(m, lie1.matrix())); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( LieMatrix, other_constructors ) { - Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished(); - LieMatrix exp(init); - double data[] = {10,30,20,40}; - LieMatrix b(2,2,data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -TEST(LieMatrix, retract) { - LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished()); - Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished(); - - LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished()); - LieMatrix actual = traits::Retract(init,update); - - EXPECT(assert_equal(expected, actual)); - - Vector expectedUpdate = update; - Vector actualUpdate = traits::Local(init,actual); - - EXPECT(assert_equal(expectedUpdate, actualUpdate)); - - Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished(); - Vector actualLogmap = traits::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished())); - EXPECT(assert_equal(expectedLogmap, actualLogmap)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ - - diff --git a/gtsam/base/tests/testLieScalar.cpp b/gtsam/base/tests/testLieScalar.cpp deleted file mode 100644 index 74f5e0d41..000000000 --- a/gtsam/base/tests/testLieScalar.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieScalar.cpp - * @author Kai Ni - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieScalar) -GTSAM_CONCEPT_LIE_INST(LieScalar) - -const double tol=1e-9; - -//****************************************************************************** -TEST(LieScalar , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieScalar , Invariants) { - LieScalar lie1(2), lie2(3); - CHECK(check_group_invariants(lie1, lie2)); - CHECK(check_manifold_invariants(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, construction ) { - double d = 2.; - LieScalar lie1(d), lie2(d); - - EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol); - EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol); - EXPECT(traits::dimension == 1); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, localCoordinates ) { - LieScalar lie1(1.), lie2(3.); - - Vector1 actual = traits::Local(lie1, lie2); - EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/base/tests/testLieVector.cpp b/gtsam/base/tests/testLieVector.cpp deleted file mode 100644 index 76c4fc490..000000000 --- a/gtsam/base/tests/testLieVector.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieVector.cpp - * @author Alex Cunningham - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieVector) -GTSAM_CONCEPT_LIE_INST(LieVector) - -//****************************************************************************** -TEST(LieVector , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieVector , Invariants) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - check_manifold_invariants(lie1, lie2); -} - -//****************************************************************************** -TEST( testLieVector, construction ) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - - EXPECT(lie1.dim() == 3); - EXPECT(assert_equal(v, lie1.vector())); - EXPECT(assert_equal(lie1, lie2)); -} - -//****************************************************************************** -TEST( testLieVector, other_constructors ) { - Vector init = Vector2(10.0, 20.0); - LieVector exp(init); - double data[] = { 10, 20 }; - LieVector b(2, data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ - diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index a7c218705..7802f27e1 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -173,7 +173,7 @@ TEST(Matrix, stack ) { Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished(); Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished(); - Matrix AB = stack(2, &A, &B); + Matrix AB = gtsam::stack(2, &A, &B); Matrix C(5, 2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -187,7 +187,7 @@ TEST(Matrix, stack ) std::vector matrices; matrices.push_back(A); matrices.push_back(B); - Matrix AB2 = stack(matrices); + Matrix AB2 = gtsam::stack(matrices); EQUALITY(C,AB2); } diff --git a/gtsam/base/tests/testOptionalJacobian.cpp b/gtsam/base/tests/testOptionalJacobian.cpp index 128576107..ae91642f4 100644 --- a/gtsam/base/tests/testOptionalJacobian.cpp +++ b/gtsam/base/tests/testOptionalJacobian.cpp @@ -24,40 +24,33 @@ using namespace std; using namespace gtsam; //****************************************************************************** +#define TEST_CONSTRUCTOR(DIM1, DIM2, X, TRUTHY) \ + { \ + OptionalJacobian H(X); \ + EXPECT(H == TRUTHY); \ + } TEST( OptionalJacobian, Constructors ) { Matrix23 fixed; - - OptionalJacobian<2, 3> H1; - EXPECT(!H1); - - OptionalJacobian<2, 3> H2(fixed); - EXPECT(H2); - - OptionalJacobian<2, 3> H3(&fixed); - EXPECT(H3); - Matrix dynamic; - OptionalJacobian<2, 3> H4(dynamic); - EXPECT(H4); - - OptionalJacobian<2, 3> H5(boost::none); - EXPECT(!H5); - boost::optional optional(dynamic); - OptionalJacobian<2, 3> H6(optional); - EXPECT(H6); + OptionalJacobian<2, 3> H; + EXPECT(!H); + + TEST_CONSTRUCTOR(2, 3, fixed, true); + TEST_CONSTRUCTOR(2, 3, &fixed, true); + TEST_CONSTRUCTOR(2, 3, dynamic, true); + TEST_CONSTRUCTOR(2, 3, &dynamic, true); + TEST_CONSTRUCTOR(2, 3, boost::none, false); + TEST_CONSTRUCTOR(2, 3, optional, true); + + // Test dynamic OptionalJacobian<-1, -1> H7; EXPECT(!H7); - OptionalJacobian<-1, -1> H8(dynamic); - EXPECT(H8); - - OptionalJacobian<-1, -1> H9(boost::none); - EXPECT(!H9); - - OptionalJacobian<-1, -1> H10(optional); - EXPECT(H10); + TEST_CONSTRUCTOR(-1, -1, dynamic, true); + TEST_CONSTRUCTOR(-1, -1, boost::none, false); + TEST_CONSTRUCTOR(-1, -1, optional, true); } //****************************************************************************** @@ -101,6 +94,25 @@ TEST( OptionalJacobian, Fixed) { dynamic2.setOnes(); test(dynamic2); EXPECT(assert_equal(kTestMatrix, dynamic2)); + + { // Dynamic pointer + // Passing in an empty matrix means we want it resized + Matrix dynamic0; + test(&dynamic0); + EXPECT(assert_equal(kTestMatrix, dynamic0)); + + // Dynamic wrong size + Matrix dynamic1(3, 5); + dynamic1.setOnes(); + test(&dynamic1); + EXPECT(assert_equal(kTestMatrix, dynamic1)); + + // Dynamic right size + Matrix dynamic2(2, 5); + dynamic2.setOnes(); + test(&dynamic2); + EXPECT(assert_equal(kTestMatrix, dynamic2)); + } } //****************************************************************************** diff --git a/gtsam/base/tests/testSerializationBase.cpp b/gtsam/base/tests/testSerializationBase.cpp index d863eaba3..f7aa97b31 100644 --- a/gtsam/base/tests/testSerializationBase.cpp +++ b/gtsam/base/tests/testSerializationBase.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include diff --git a/gtsam/base/tests/testTestableAssertions.cpp b/gtsam/base/tests/testTestableAssertions.cpp deleted file mode 100644 index 305aa7ca9..000000000 --- a/gtsam/base/tests/testTestableAssertions.cpp +++ /dev/null @@ -1,35 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testTestableAssertions - * @author Alex Cunningham - */ - -#include -#include -#include - -using namespace gtsam; - -/* ************************************************************************* */ -TEST( testTestableAssertions, optional ) { - typedef boost::optional OptionalScalar; - LieScalar x(1.0); - OptionalScalar ox(x), dummy = boost::none; - EXPECT(assert_equal(ox, ox)); - EXPECT(assert_equal(x, ox)); - EXPECT(assert_equal(dummy, dummy)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/base/tests/testVector.cpp b/gtsam/base/tests/testVector.cpp index bd715e3cb..c87732b09 100644 --- a/gtsam/base/tests/testVector.cpp +++ b/gtsam/base/tests/testVector.cpp @@ -220,8 +220,8 @@ TEST(Vector, axpy ) Vector x = Vector3(10., 20., 30.); Vector y0 = Vector3(2.0, 5.0, 6.0); Vector y1 = y0, y2 = y0; - axpy(0.1,x,y1); - axpy(0.1,x,y2.head(3)); + y1 += 0.1 * x; + y2.head(3) += 0.1 * x; Vector expected = Vector3(3.0, 7.0, 9.0); EXPECT(assert_equal(expected,y1)); EXPECT(assert_equal(expected,Vector(y2))); diff --git a/gtsam/base/treeTraversal-inst.h b/gtsam/base/treeTraversal-inst.h index 7a88f72eb..30cec3b9a 100644 --- a/gtsam/base/treeTraversal-inst.h +++ b/gtsam/base/treeTraversal-inst.h @@ -158,9 +158,8 @@ void DepthFirstForestParallel(FOREST& forest, DATA& rootData, // Typedefs typedef typename FOREST::Node Node; - tbb::task::spawn_root_and_wait( - internal::CreateRootTask(forest.roots(), rootData, visitorPre, - visitorPost, problemSizeThreshold)); + internal::CreateRootTask(forest.roots(), rootData, visitorPre, + visitorPost, problemSizeThreshold); #else DepthFirstForest(forest, rootData, visitorPre, visitorPost); #endif diff --git a/gtsam/base/treeTraversal/parallelTraversalTasks.h b/gtsam/base/treeTraversal/parallelTraversalTasks.h index 87d5b0d4c..dc1b45906 100644 --- a/gtsam/base/treeTraversal/parallelTraversalTasks.h +++ b/gtsam/base/treeTraversal/parallelTraversalTasks.h @@ -22,7 +22,7 @@ #include #ifdef GTSAM_USE_TBB -#include // tbb::task, tbb::task_list +#include // tbb::task_group #include // tbb::scalable_allocator namespace gtsam { @@ -34,7 +34,7 @@ namespace gtsam { /* ************************************************************************* */ template - class PreOrderTask : public tbb::task + class PreOrderTask { public: const boost::shared_ptr& treeNode; @@ -42,28 +42,30 @@ namespace gtsam { VISITOR_PRE& visitorPre; VISITOR_POST& visitorPost; int problemSizeThreshold; + tbb::task_group& tg; bool makeNewTasks; - bool isPostOrderPhase; + // Keep track of order phase across multiple calls to the same functor + mutable bool isPostOrderPhase; PreOrderTask(const boost::shared_ptr& treeNode, const boost::shared_ptr& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold, - bool makeNewTasks = true) + tbb::task_group& tg, bool makeNewTasks = true) : treeNode(treeNode), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost), problemSizeThreshold(problemSizeThreshold), + tg(tg), makeNewTasks(makeNewTasks), isPostOrderPhase(false) {} - tbb::task* execute() override + void operator()() const { if(isPostOrderPhase) { // Run the post-order visitor since this task was recycled to run the post-order visitor (void) visitorPost(treeNode, *myData); - return nullptr; } else { @@ -71,14 +73,10 @@ namespace gtsam { { if(!treeNode->children.empty()) { - // Allocate post-order task as a continuation - isPostOrderPhase = true; - recycle_as_continuation(); - bool overThreshold = (treeNode->problemSize() >= problemSizeThreshold); - tbb::task* firstChild = 0; - tbb::task_list childTasks; + // If we have child tasks, start subtasks and wait for them to complete + tbb::task_group ctg; for(const boost::shared_ptr& child: treeNode->children) { // Process child in a subtask. Important: Run visitorPre before calling @@ -86,37 +84,30 @@ namespace gtsam { // allocated an extra child, this causes a TBB error. boost::shared_ptr childData = boost::allocate_shared( tbb::scalable_allocator(), visitorPre(child, *myData)); - tbb::task* childTask = - new (allocate_child()) PreOrderTask(child, childData, visitorPre, visitorPost, - problemSizeThreshold, overThreshold); - if (firstChild) - childTasks.push_back(*childTask); - else - firstChild = childTask; + ctg.run(PreOrderTask(child, childData, visitorPre, visitorPost, + problemSizeThreshold, ctg, overThreshold)); } + ctg.wait(); - // If we have child tasks, start subtasks and wait for them to complete - set_ref_count((int)treeNode->children.size()); - spawn(childTasks); - return firstChild; + // Allocate post-order task as a continuation + isPostOrderPhase = true; + tg.run(*this); } else { // Run the post-order visitor in this task if we have no children (void) visitorPost(treeNode, *myData); - return nullptr; } } else { // Process this node and its children in this task processNodeRecursively(treeNode, *myData); - return nullptr; } } } - void processNodeRecursively(const boost::shared_ptr& node, DATA& myData) + void processNodeRecursively(const boost::shared_ptr& node, DATA& myData) const { for(const boost::shared_ptr& child: node->children) { @@ -131,7 +122,7 @@ namespace gtsam { /* ************************************************************************* */ template - class RootTask : public tbb::task + class RootTask { public: const ROOTS& roots; @@ -139,38 +130,31 @@ namespace gtsam { VISITOR_PRE& visitorPre; VISITOR_POST& visitorPost; int problemSizeThreshold; + tbb::task_group& tg; RootTask(const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, - int problemSizeThreshold) : + int problemSizeThreshold, tbb::task_group& tg) : roots(roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost), - problemSizeThreshold(problemSizeThreshold) {} + problemSizeThreshold(problemSizeThreshold), tg(tg) {} - tbb::task* execute() override + void operator()() const { typedef PreOrderTask PreOrderTask; // Create data and tasks for our children - tbb::task_list tasks; for(const boost::shared_ptr& root: roots) { boost::shared_ptr rootData = boost::allocate_shared(tbb::scalable_allocator(), visitorPre(root, myData)); - tasks.push_back(*new(allocate_child()) - PreOrderTask(root, rootData, visitorPre, visitorPost, problemSizeThreshold)); + tg.run(PreOrderTask(root, rootData, visitorPre, visitorPost, problemSizeThreshold, tg)); } - // Set TBB ref count - set_ref_count(1 + (int) roots.size()); - // Spawn tasks - spawn_and_wait_for_all(tasks); - // Return nullptr - return nullptr; } }; template - RootTask& - CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold) + void CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold) { typedef RootTask RootTask; - return *new(tbb::task::allocate_root()) RootTask(roots, rootData, visitorPre, visitorPost, problemSizeThreshold); - } + tbb::task_group tg; + tg.run_and_wait(RootTask(roots, rootData, visitorPre, visitorPost, problemSizeThreshold, tg)); + } } diff --git a/gtsam/base/types.h b/gtsam/base/types.h index aaada3cee..a0d24f1a6 100644 --- a/gtsam/base/types.h +++ b/gtsam/base/types.h @@ -34,6 +34,14 @@ #include #endif +#if defined(__GNUC__) || defined(__clang__) +#define GTSAM_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define GTSAM_DEPRECATED __declspec(deprecated) +#else +#define GTSAM_DEPRECATED +#endif + #ifdef GTSAM_USE_EIGEN_MKL_OPENMP #include #endif diff --git a/gtsam/base/utilities.cpp b/gtsam/base/utilities.cpp new file mode 100644 index 000000000..189156c91 --- /dev/null +++ b/gtsam/base/utilities.cpp @@ -0,0 +1,13 @@ +#include + +namespace gtsam { + +std::string RedirectCout::str() const { + return ssBuffer_.str(); +} + +RedirectCout::~RedirectCout() { + std::cout.rdbuf(coutBuffer_); +} + +} diff --git a/gtsam/base/utilities.h b/gtsam/base/utilities.h index 8eb5617a8..d9b92b8aa 100644 --- a/gtsam/base/utilities.h +++ b/gtsam/base/utilities.h @@ -1,5 +1,9 @@ #pragma once +#include +#include +#include + namespace gtsam { /** * For Python __str__(). @@ -12,14 +16,10 @@ struct RedirectCout { RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {} /// return the string - std::string str() const { - return ssBuffer_.str(); - } + std::string str() const; /// destructor -- redirect stdout buffer to its original buffer - ~RedirectCout() { - std::cout.rdbuf(coutBuffer_); - } + ~RedirectCout(); private: std::stringstream ssBuffer_; diff --git a/gtsam/basis/Basis.h b/gtsam/basis/Basis.h index d8bd28c1a..765a2f645 100644 --- a/gtsam/basis/Basis.h +++ b/gtsam/basis/Basis.h @@ -92,7 +92,7 @@ Matrix kroneckerProductIdentity(const Weights& w) { /// CRTP Base class for function bases template -class GTSAM_EXPORT Basis { +class Basis { public: /** * Calculate weights for all x in vector X. @@ -497,11 +497,6 @@ class GTSAM_EXPORT Basis { } }; - // Vector version for MATLAB :-( - static double Derivative(double x, const Vector& p, // - OptionalJacobian H = boost::none) { - return DerivativeFunctor(x)(p.transpose(), H); - } }; } // namespace gtsam diff --git a/gtsam/basis/BasisFactors.h b/gtsam/basis/BasisFactors.h index 0b3d4c1a0..648bcd510 100644 --- a/gtsam/basis/BasisFactors.h +++ b/gtsam/basis/BasisFactors.h @@ -29,9 +29,12 @@ namespace gtsam { * pseudo-spectral parameterization. * * @tparam BASIS The basis class to use e.g. Chebyshev2 + * + * Example, degree 8 Chebyshev polynomial measured at x=0.5: + * EvaluationFactor factor(key, measured, model, 8, 0.5); */ template -class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor { +class EvaluationFactor : public FunctorizedFactor { private: using Base = FunctorizedFactor; @@ -47,7 +50,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor { * @param N The degree of the polynomial. * @param x The point at which to evaluate the polynomial. */ - EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model, + EvaluationFactor(Key key, double z, const SharedNoiseModel &model, const size_t N, double x) : Base(key, z, model, typename BASIS::EvaluationFunctor(N, x)) {} @@ -62,7 +65,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor { * @param a Lower bound for the polynomial. * @param b Upper bound for the polynomial. */ - EvaluationFactor(Key key, const double &z, const SharedNoiseModel &model, + EvaluationFactor(Key key, double z, const SharedNoiseModel &model, const size_t N, double x, double a, double b) : Base(key, z, model, typename BASIS::EvaluationFunctor(N, x, a, b)) {} @@ -85,7 +88,7 @@ class GTSAM_EXPORT EvaluationFactor : public FunctorizedFactor { * @param M: Size of the evaluated state vector. */ template -class GTSAM_EXPORT VectorEvaluationFactor +class VectorEvaluationFactor : public FunctorizedFactor> { private: using Base = FunctorizedFactor>; @@ -148,7 +151,7 @@ class GTSAM_EXPORT VectorEvaluationFactor * where N is the degree and i is the component index. */ template -class GTSAM_EXPORT VectorComponentFactor +class VectorComponentFactor : public FunctorizedFactor> { private: using Base = FunctorizedFactor>; @@ -217,7 +220,7 @@ class GTSAM_EXPORT VectorComponentFactor * where `x` is the value (e.g. timestep) at which the rotation was evaluated. */ template -class GTSAM_EXPORT ManifoldEvaluationFactor +class ManifoldEvaluationFactor : public FunctorizedFactor::dimension>> { private: using Base = FunctorizedFactor::dimension>>; @@ -269,7 +272,7 @@ class GTSAM_EXPORT ManifoldEvaluationFactor * @param BASIS: The basis class to use e.g. Chebyshev2 */ template -class GTSAM_EXPORT DerivativeFactor +class DerivativeFactor : public FunctorizedFactor { private: using Base = FunctorizedFactor; @@ -318,7 +321,7 @@ class GTSAM_EXPORT DerivativeFactor * @param M: Size of the evaluated state vector derivative. */ template -class GTSAM_EXPORT VectorDerivativeFactor +class VectorDerivativeFactor : public FunctorizedFactor> { private: using Base = FunctorizedFactor>; @@ -371,7 +374,7 @@ class GTSAM_EXPORT VectorDerivativeFactor * @param P: Size of the control component derivative. */ template -class GTSAM_EXPORT ComponentDerivativeFactor +class ComponentDerivativeFactor : public FunctorizedFactor> { private: using Base = FunctorizedFactor>; diff --git a/gtsam/basis/Chebyshev.h b/gtsam/basis/Chebyshev.h index d16ccfaac..1c16c47bf 100644 --- a/gtsam/basis/Chebyshev.h +++ b/gtsam/basis/Chebyshev.h @@ -21,8 +21,6 @@ #include #include -#include - namespace gtsam { /** @@ -31,7 +29,7 @@ namespace gtsam { * These are typically denoted with the symbol T_n, where n is the degree. * The parameter N is the number of coefficients, i.e., N = n+1. */ -struct Chebyshev1Basis : Basis { +struct GTSAM_EXPORT Chebyshev1Basis : Basis { using Parameters = Eigen::Matrix; Parameters parameters_; @@ -79,7 +77,7 @@ struct Chebyshev1Basis : Basis { * functions. In this sense, they are like the sines and cosines of the Fourier * basis. */ -struct Chebyshev2Basis : Basis { +struct GTSAM_EXPORT Chebyshev2Basis : Basis { using Parameters = Eigen::Matrix; /** diff --git a/gtsam/basis/Chebyshev2.h b/gtsam/basis/Chebyshev2.h index 28590961d..e306c93d5 100644 --- a/gtsam/basis/Chebyshev2.h +++ b/gtsam/basis/Chebyshev2.h @@ -22,8 +22,7 @@ * * This is different from Chebyshev.h since it leverage ideas from * pseudo-spectral optimization, i.e. we don't decompose into basis functions, - * rather estimate function parameters that enforce function nodes at Chebyshev - * points. + * rather estimate function values at the Chebyshev points. * * Please refer to Agrawal21icra for more details. * diff --git a/gtsam/basis/Fourier.h b/gtsam/basis/Fourier.h index d264e182d..eb259bd8a 100644 --- a/gtsam/basis/Fourier.h +++ b/gtsam/basis/Fourier.h @@ -24,7 +24,7 @@ namespace gtsam { /// Fourier basis -class GTSAM_EXPORT FourierBasis : public Basis { +class FourierBasis : public Basis { public: using Parameters = Eigen::Matrix; using DiffMatrix = Eigen::Matrix; diff --git a/gtsam/basis/ParameterMatrix.h b/gtsam/basis/ParameterMatrix.h index df2d9f62e..eddcbfeae 100644 --- a/gtsam/basis/ParameterMatrix.h +++ b/gtsam/basis/ParameterMatrix.h @@ -153,7 +153,7 @@ class ParameterMatrix { return matrix_ * other; } - /// @name Vector Space requirements, following LieMatrix + /// @name Vector Space requirements /// @{ /** diff --git a/gtsam/basis/basis.i b/gtsam/basis/basis.i index 8f06fd2e1..a6c9d87ee 100644 --- a/gtsam/basis/basis.i +++ b/gtsam/basis/basis.i @@ -44,9 +44,6 @@ class Chebyshev2 { static Matrix DerivativeWeights(size_t N, double x, double a, double b); static Matrix IntegrationWeights(size_t N, double a, double b); static Matrix DifferentiationMatrix(size_t N, double a, double b); - - // TODO Needs OptionalJacobian - // static double Derivative(double x, Vector f); }; #include @@ -140,7 +137,7 @@ class FitBasis { static gtsam::GaussianFactorGraph::shared_ptr LinearGraph( const std::map& sequence, const gtsam::noiseModel::Base* model, size_t N); - Parameters parameters() const; + This::Parameters parameters() const; }; } // namespace gtsam diff --git a/gtsam/basis/tests/testBasisFactors.cpp b/gtsam/basis/tests/testBasisFactors.cpp new file mode 100644 index 000000000..18a389da5 --- /dev/null +++ b/gtsam/basis/tests/testBasisFactors.cpp @@ -0,0 +1,230 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------1------------------------------------------- + */ + +/** + * @file testBasisFactors.cpp + * @date May 31, 2020 + * @author Varun Agrawal + * @brief unit tests for factors in BasisFactors.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using gtsam::noiseModel::Isotropic; +using gtsam::Pose2; +using gtsam::Vector; +using gtsam::Values; +using gtsam::Chebyshev2; +using gtsam::ParameterMatrix; +using gtsam::LevenbergMarquardtParams; +using gtsam::LevenbergMarquardtOptimizer; +using gtsam::NonlinearFactorGraph; +using gtsam::NonlinearOptimizerParams; + +constexpr size_t N = 2; + +// Key used in all tests +const gtsam::Symbol key('X', 0); + +//****************************************************************************** +TEST(BasisFactors, EvaluationFactor) { + using gtsam::EvaluationFactor; + + double measured = 0; + + auto model = Isotropic::Sigma(1, 1.0); + EvaluationFactor factor(key, measured, model, N, 0); + + NonlinearFactorGraph graph; + graph.add(factor); + + Vector functionValues(N); + functionValues.setZero(); + + Values initial; + initial.insert(key, functionValues); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +//****************************************************************************** +TEST(BasisFactors, VectorEvaluationFactor) { + using gtsam::VectorEvaluationFactor; + const size_t M = 4; + + const Vector measured = Vector::Zero(M); + + auto model = Isotropic::Sigma(M, 1.0); + VectorEvaluationFactor factor(key, measured, model, N, 0); + + NonlinearFactorGraph graph; + graph.add(factor); + + ParameterMatrix stateMatrix(N); + + Values initial; + initial.insert>(key, stateMatrix); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +//****************************************************************************** +TEST(BasisFactors, Print) { + using gtsam::VectorEvaluationFactor; + const size_t M = 1; + + const Vector measured = Vector::Ones(M) * 42; + + auto model = Isotropic::Sigma(M, 1.0); + VectorEvaluationFactor factor(key, measured, model, N, 0); + + std::string expected = + " keys = { X0 }\n" + " noise model: unit (1) \n" + "FunctorizedFactor(X0)\n" + " measurement: [\n" + " 42\n" + "]\n" + " noise model sigmas: 1\n"; + + EXPECT(assert_print_equal(expected, factor)); +} + +//****************************************************************************** +TEST(BasisFactors, VectorComponentFactor) { + using gtsam::VectorComponentFactor; + const int P = 4; + const size_t i = 2; + const double measured = 0.0, t = 3.0, a = 2.0, b = 4.0; + auto model = Isotropic::Sigma(1, 1.0); + VectorComponentFactor factor(key, measured, model, N, i, + t, a, b); + + NonlinearFactorGraph graph; + graph.add(factor); + + ParameterMatrix

stateMatrix(N); + + Values initial; + initial.insert>(key, stateMatrix); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +//****************************************************************************** +TEST(BasisFactors, ManifoldEvaluationFactor) { + using gtsam::ManifoldEvaluationFactor; + const Pose2 measured; + const double t = 3.0, a = 2.0, b = 4.0; + auto model = Isotropic::Sigma(3, 1.0); + ManifoldEvaluationFactor factor(key, measured, model, N, + t, a, b); + + NonlinearFactorGraph graph; + graph.add(factor); + + ParameterMatrix<3> stateMatrix(N); + + Values initial; + initial.insert>(key, stateMatrix); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +//****************************************************************************** +TEST(BasisFactors, VecDerivativePrior) { + using gtsam::VectorDerivativeFactor; + const size_t M = 4; + + const Vector measured = Vector::Zero(M); + auto model = Isotropic::Sigma(M, 1.0); + VectorDerivativeFactor vecDPrior(key, measured, model, N, 0); + + NonlinearFactorGraph graph; + graph.add(vecDPrior); + + ParameterMatrix stateMatrix(N); + + Values initial; + initial.insert>(key, stateMatrix); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +//****************************************************************************** +TEST(BasisFactors, ComponentDerivativeFactor) { + using gtsam::ComponentDerivativeFactor; + const size_t M = 4; + + double measured = 0; + auto model = Isotropic::Sigma(1, 1.0); + ComponentDerivativeFactor controlDPrior(key, measured, model, + N, 0, 0); + + NonlinearFactorGraph graph; + graph.add(controlDPrior); + + Values initial; + ParameterMatrix stateMatrix(N); + initial.insert>(key, stateMatrix); + + LevenbergMarquardtParams parameters; + parameters.setMaxIterations(20); + Values result = + LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); + + EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/basis/tests/testChebyshev.cpp b/gtsam/basis/tests/testChebyshev.cpp index 64c925886..7d7f9323d 100644 --- a/gtsam/basis/tests/testChebyshev.cpp +++ b/gtsam/basis/tests/testChebyshev.cpp @@ -25,9 +25,10 @@ using namespace std; using namespace gtsam; +namespace { auto model = noiseModel::Unit::Create(1); - const size_t N = 3; +} // namespace //****************************************************************************** TEST(Chebyshev, Chebyshev1) { diff --git a/gtsam/basis/tests/testChebyshev2.cpp b/gtsam/basis/tests/testChebyshev2.cpp index 4cee70daf..9090757f4 100644 --- a/gtsam/basis/tests/testChebyshev2.cpp +++ b/gtsam/basis/tests/testChebyshev2.cpp @@ -10,26 +10,30 @@ * -------------------------------------------------------------------------- */ /** - * @file testChebyshev.cpp + * @file testChebyshev2.cpp * @date July 4, 2020 * @author Varun Agrawal * @brief Unit tests for Chebyshev Basis Decompositions via pseudo-spectral * methods */ -#include -#include #include #include +#include #include +#include + +#include using namespace std; using namespace gtsam; using namespace boost::placeholders; +namespace { noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1); const size_t N = 32; +} // namespace //****************************************************************************** TEST(Chebyshev2, Point) { @@ -121,12 +125,30 @@ TEST(Chebyshev2, InterpolateVector) { EXPECT(assert_equal(numericalH, actualH, 1e-9)); } +//****************************************************************************** +// Interpolating poses using the exponential map +TEST(Chebyshev2, InterpolatePose2) { + double t = 30, a = 0, b = 100; + + ParameterMatrix<3> X(N); + X.row(0) = Chebyshev2::Points(N, a, b); // slope 1 ramp + X.row(1) = Vector::Zero(N); + X.row(2) = 0.1 * Vector::Ones(N); + + Vector xi(3); + xi << t, 0, 0.1; + Chebyshev2::ManifoldEvaluationFunctor fx(N, t, a, b); + // We use xi as canonical coordinates via exponential map + Pose2 expected = Pose2::ChartAtOrigin::Retract(xi); + EXPECT(assert_equal(expected, fx(X))); +} + //****************************************************************************** TEST(Chebyshev2, Decomposition) { // Create example sequence Sequence sequence; for (size_t i = 0; i < 16; i++) { - double x = (double)i / 16. - 0.99, y = x; + double x = (1.0/ 16)*i - 0.99, y = x; sequence[x] = y; } @@ -144,11 +166,11 @@ TEST(Chebyshev2, DifferentiationMatrix3) { // Trefethen00book, p.55 const size_t N = 3; Matrix expected(N, N); - // Differentiation matrix computed from Chebfun + // Differentiation matrix computed from chebfun expected << 1.5000, -2.0000, 0.5000, // 0.5000, -0.0000, -0.5000, // -0.5000, 2.0000, -1.5000; - // multiply by -1 since the cheb points have a phase shift wrt Trefethen + // multiply by -1 since the chebyshev points have a phase shift wrt Trefethen // This was verified with chebfun expected = -expected; @@ -167,7 +189,7 @@ TEST(Chebyshev2, DerivativeMatrix6) { 0.3820, -0.8944, 1.6180, 0.1708, -2.0000, 0.7236, // -0.2764, 0.6180, -0.8944, 2.0000, 1.1708, -2.6180, // 0.5000, -1.1056, 1.5279, -2.8944, 10.4721, -8.5000; - // multiply by -1 since the cheb points have a phase shift wrt Trefethen + // multiply by -1 since the chebyshev points have a phase shift wrt Trefethen // This was verified with chebfun expected = -expected; @@ -252,7 +274,7 @@ TEST(Chebyshev2, DerivativeWeights2) { Weights dWeights2 = Chebyshev2::DerivativeWeights(N, x2, a, b); EXPECT_DOUBLES_EQUAL(fprime(x2), dWeights2 * fvals, 1e-8); - // test if derivative calculation and cheb point is correct + // test if derivative calculation and Chebyshev point is correct double x3 = Chebyshev2::Point(N, 3, a, b); Weights dWeights3 = Chebyshev2::DerivativeWeights(N, x3, a, b); EXPECT_DOUBLES_EQUAL(fprime(x3), dWeights3 * fvals, 1e-8); diff --git a/gtsam/config.h.in b/gtsam/config.h.in index 9f7106187..d47329a62 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -70,7 +70,7 @@ #cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION // Make sure dependent projects that want it can see deprecated functions -#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42 // Support Metis-based nested dissection #cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION @@ -80,3 +80,6 @@ // Whether to use the system installed Metis instead of the provided one #cmakedefine GTSAM_USE_SYSTEM_METIS + +// Toggle switch for BetweenFactor jacobian computation +#cmakedefine GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR diff --git a/gtsam/base/LieScalar.h b/gtsam/discrete/AlgebraicDecisionTree.cpp similarity index 53% rename from gtsam/base/LieScalar.h rename to gtsam/discrete/AlgebraicDecisionTree.cpp index e159ffa87..83ee4051a 100644 --- a/gtsam/base/LieScalar.h +++ b/gtsam/discrete/AlgebraicDecisionTree.cpp @@ -10,17 +10,19 @@ * -------------------------------------------------------------------------- */ /** - * @file LieScalar.h - * @brief External deprecation warning, see deprecated/LieScalar.h for details - * @author Kai Ni + * @file AlgebraicDecisionTree.cpp + * @date Feb 20, 2022 + * @author Mike Sheffler + * @author Duy-Nguyen Ta + * @author Frank Dellaert */ -#pragma once +#include "AlgebraicDecisionTree.h" -#ifdef _MSC_VER -#pragma message("LieScalar.h is deprecated. Please use double/float instead.") -#else - #warning "LieScalar.h is deprecated. Please use double/float instead." -#endif +#include -#include +namespace gtsam { + + template class AlgebraicDecisionTree; + +} // namespace gtsam diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..9769715a1 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -18,8 +18,13 @@ #pragma once +#include #include +#include +#include +#include +#include namespace gtsam { /** @@ -27,21 +32,28 @@ namespace gtsam { * Just has some nice constructors and some syntactic sugar * TODO: consider eliminating this class altogether? */ - template - class AlgebraicDecisionTree: public DecisionTree { + template + class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } - public: - - typedef DecisionTree Super; + public: + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { - static inline double zero() { - return 0.0; - } - static inline double one() { - return 1.0; - } + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } static inline double add(const double& a, const double& b) { return a + b; } @@ -54,63 +66,68 @@ namespace gtsam { static inline double div(const double& a, const double& b) { return a / b; } - static inline double id(const double& x) { - return x; - } + static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : - Super(1.0) { - } + AlgebraicDecisionTree() : Base(1.0) {} - AlgebraicDecisionTree(const Super& add) : - Super(add) { - } + // Explicitly non-explicit constructor + AlgebraicDecisionTree(const Base& add) : Base(add) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { - } + AlgebraicDecisionTree(const L& label, double y1, double y2) + : Base(label, y1, y2) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { - } + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, + double y2) + : Base(labelC, y1, y2) {} /** Create from keys and vector table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::vector& ys) { + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); std::copy(std::istream_iterator(iss), - std::istream_iterator(), std::back_inserter(ys)); + std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ - template - AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + template + AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) + : Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ - template + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ + template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + // Functor for label conversion so we can use `convertFrom`. + std::function L_of_M = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = DecisionTree::convertFrom(other.root_, L_of_M, op); } /** sum */ @@ -134,12 +151,31 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } - }; -// AlgebraicDecisionTree + /// print method customized to value type `double`. + void print(const std::string& s, + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { + auto valueFormatter = [](const double& v) { + return (boost::format("%4.8g") % v).str(); + }; + Base::print(s, labelFormatter, valueFormatter); + } -} -// namespace gtsam + /// Equality method customized to value type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto compare = [tol](double a, double b) { + return std::abs(a - b) < tol; + }; + return Base::equals(other, compare); + } + }; + +template +struct traits> + : public Testable> {}; +} // namespace gtsam diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 3665d6dfa..90e2dbdd8 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -19,32 +19,32 @@ #pragma once #include -#include #include - +#include +#include namespace gtsam { - /** - * An assignment from labels to value index (size_t). - * Assigns to each label a value. Implemented as a simple map. - * A discrete factor takes an Assignment and returns a value. - */ - template - class Assignment: public std::map { - public: - void print(const std::string& s = "Assignment: ") const { - std::cout << s << ": "; - for(const typename Assignment::value_type& keyValue: *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; - std::cout << std::endl; - } +/** + * An assignment from labels to value index (size_t). + * Assigns to each label a value. Implemented as a simple map. + * A discrete factor takes an Assignment and returns a value. + */ +template +class Assignment : public std::map { + public: + using std::map::operator=; - bool equals(const Assignment& other, double tol = 1e-9) const { - return (*this == other); - } - }; //Assignment + void print(const std::string& s = "Assignment: ") const { + std::cout << s << ": "; + for (const typename Assignment::value_type& keyValue : *this) + std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + std::cout << std::endl; + } + bool equals(const Assignment& other, double tol = 1e-9) const { + return (*this == other); + } /** * @brief Get Cartesian product consisting all possible configurations @@ -58,29 +58,28 @@ namespace gtsam { * variables with each having cardinalities 4, we get 4096 possible * configurations!! */ - template - std::vector > cartesianProduct( - const std::vector >& keys) { - std::vector > allPossValues; - Assignment values; + template > + static std::vector CartesianProduct( + const std::vector>& keys) { + std::vector allPossValues; + Derived values; typedef std::pair DiscreteKey; - for(const DiscreteKey& key: keys) - values[key.first] = 0; //Initialize from 0 + for (const DiscreteKey& key : keys) + values[key.first] = 0; // Initialize from 0 while (1) { allPossValues.push_back(values); size_t j = 0; for (j = 0; j < keys.size(); j++) { L idx = keys[j].first; values[idx]++; - if (values[idx] < keys[j].second) - break; - //Wrap condition + if (values[idx] < keys[j].second) break; + // Wrap condition values[idx] = 0; } - if (j == keys.size()) - break; + if (j == keys.size()) break; } return allPossValues; } +}; // Assignment -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebf..99f29b8e5 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,79 +20,93 @@ #pragma once #include -#include +#include +#include #include +#include +#include #include #include -#include -using boost::assign::operator+=; +#include #include -#include - -#include #include #include +#include +#include +#include #include +#include +#include + +using boost::assign::operator+=; namespace gtsam { - /*********************************************************************************/ + /****************************************************************************/ // Node - /*********************************************************************************/ + /****************************************************************************/ #ifdef DT_DEBUG_MEMORY template int DecisionTree::Node::nrNodes = 0; #endif - /*********************************************************************************/ + /****************************************************************************/ // Leaf - /*********************************************************************************/ - template - class DecisionTree::Leaf: public DecisionTree::Node { - + /****************************************************************************/ + template + struct DecisionTree::Leaf : public DecisionTree::Node { /** constant stored in this leaf */ Y constant_; - public: + /** The number of assignments contained within this leaf. + * Particularly useful when leaves have been pruned. + */ + size_t nrAssignments_; - /** Constructor from constant */ - Leaf(const Y& constant) : - constant_(constant) {} + /// Constructor from constant + Leaf(const Y& constant, size_t nrAssignments = 1) + : constant_(constant), nrAssignments_(nrAssignments) {} - /** return the constant */ + /// Return the constant const Y& constant() const { return constant_; } + /// Return the number of assignments contained within this leaf. + size_t nrAssignments() const { return nrAssignments_; } + /// Leaf-Leaf equality bool sameLeaf(const Leaf& q) const override { return constant_ == q.constant_; } - /// polymorphic equality: is q is a leaf, could be + /// polymorphic equality: is q a leaf and is it the same as this leaf? bool sameLeaf(const Node& q) const override { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Leaf* other = dynamic_cast (&q); + /// equality up to tolerance + bool equals(const Node& q, const CompareFunc& compare) const override { + const Leaf* other = dynamic_cast(&q); if (!other) return false; - return std::abs(double(this->constant_ - other->constant_)) < tol; + return compare(this->constant_, other->constant_); } - /** print */ - void print(const std::string& s) const override { - bool showZero = true; - if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + /// print + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } - /** to graphviz file */ - void dot(std::ostream& os, bool showZero) const override { - if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" - << boost::format("%4.2g") % constant_ - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + /** Write graphviz format to stream `os`. */ + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { + std::string value = valueFormatter(constant_); + if (showZero || value.compare("0")) + os << "\"" << this->id() << "\" [label=\"" << value + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; } /** evaluate */ @@ -102,7 +116,14 @@ namespace gtsam { /** apply unary operator */ NodePtr apply(const Unary& op) const override { - NodePtr f(new Leaf(op(constant_))); + NodePtr f(new Leaf(op(constant_), nrAssignments_)); + return f; + } + + /// Apply unary operator with assignment + NodePtr apply(const UnaryAssignment& op, + const Assignment& assignment) const override { + NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); return f; } @@ -117,58 +138,68 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + // fL op gL + NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_)); return h; } // If second argument is a Choice node, call it's apply with leaf as second NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - return fC.apply_fC_op_gL(*this, op); // operand order back to normal + return fC.apply_fC_op_gL(*this, op); // operand order back to normal } /** choose a branch, create new memory ! */ NodePtr choose(const L& label, size_t index) const override { - return NodePtr(new Leaf(constant())); + return NodePtr(new Leaf(constant(), nrAssignments())); } bool isLeaf() const override { return true; } + }; // Leaf - }; // Leaf - - /*********************************************************************************/ + /****************************************************************************/ // Choice - /*********************************************************************************/ + /****************************************************************************/ template - class DecisionTree::Choice: public DecisionTree::Node { - + struct DecisionTree::Choice: public DecisionTree::Node { /** the label of the variable on which we split */ L label_; /** The children of this Choice node. */ std::vector branches_; - private: - /** incremental allSame */ + private: + /** + * Incremental allSame. + * Records if all the branches are the same leaf. + */ size_t allSame_; - typedef boost::shared_ptr ChoicePtr; - - public: + using ChoicePtr = boost::shared_ptr; + public: ~Choice() override { #ifdef DT_DEBUG_MEMORY - std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; + std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() + << std::std::endl; #endif } - /** If all branches of a choice node f are the same, just return a branch */ + /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { -#ifndef DT_NO_PRUNING +#ifndef GTSAM_DT_NO_PRUNING if (f->allSame_) { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; - assert(f0->isLeaf()); - NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + + size_t nrAssignments = 0; + for(auto branch: f->branches()) { + assert(branch->isLeaf()); + nrAssignments += + boost::dynamic_pointer_cast(branch)->nrAssignments(); + } + NodePtr newLeaf( + new Leaf(boost::dynamic_pointer_cast(f0)->constant(), + nrAssignments)); return newLeaf; } else #endif @@ -177,18 +208,15 @@ namespace gtsam { bool isLeaf() const override { return false; } - /** Constructor, given choice label and mandatory expected branch count */ + /// Constructor, given choice label and mandatory expected branch count. Choice(const L& label, size_t count) : label_(label), allSame_(true) { branches_.reserve(count); } - /** - * Construct from applying binary op to two Choice nodes - */ + /// Construct from applying binary op to two Choice nodes. Choice(const Choice& f, const Choice& g, const Binary& op) : allSame_(true) { - // Choose what to do based on label if (f.label() > g.label()) { // f higher than g @@ -214,6 +242,7 @@ namespace gtsam { } } + /// Return the label of this choice node. const L& label() const { return label_; } @@ -235,33 +264,39 @@ namespace gtsam { branches_.push_back(node); } - /** print (as a tree) */ - void print(const std::string& s) const override { + /// print (as a tree). + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - // std::cout << this << ","; - std::cout << label_ << ") " << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str()); + branches_[i]->print((boost::format("%s %d") % s % i).str(), + labelFormatter, valueFormatter); } /** output to graphviz (as a a graph) */ - void dot(std::ostream& os, bool showZero) const override { + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { - const Leaf* leaf = dynamic_cast (branch.get()); - if (leaf && !leaf->constant()) continue; + const Leaf* leaf = dynamic_cast(branch.get()); + if (leaf && valueFormatter(leaf->constant()).compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; - branch->dot(os, showZero); + branch->dot(os, labelFormatter, valueFormatter, showZero); } } @@ -275,19 +310,20 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + /// equality + bool equals(const Node& q, const CompareFunc& compare) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), compare))) + return false; return true; } - /** evaluate */ + /// evaluate const Y& operator()(const Assignment& x) const override { #ifndef NDEBUG typename Assignment::const_iterator it = x.find(label_); @@ -302,20 +338,54 @@ namespace gtsam { return (*child)(x); } - /** - * Construct from applying unary op to a Choice node - */ + /// Construct from applying unary op to a Choice node. Choice(const L& label, const Choice& f, const Unary& op) : label_(label), allSame_(true) { - - branches_.reserve(f.branches_.size()); // reserve space - for (const NodePtr& branch: f.branches_) - push_back(branch->apply(op)); + branches_.reserve(f.branches_.size()); // reserve space + for (const NodePtr& branch : f.branches_) { + push_back(branch->apply(op)); + } } - /** apply unary operator */ + /** + * @brief Constructor which accepts a UnaryAssignment op and the + * corresponding assignment. + * + * @param label The label for this node. + * @param f The original choice node to apply the op on. + * @param op Function to apply on the choice node. Takes Assignment and + * value as arguments. + * @param assignment The Assignment that will go to op. + */ + Choice(const L& label, const Choice& f, const UnaryAssignment& op, + const Assignment& assignment) + : label_(label), allSame_(true) { + branches_.reserve(f.branches_.size()); // reserve space + + Assignment assignment_ = assignment; + + for (size_t i = 0; i < f.branches_.size(); i++) { + assignment_[label_] = i; // Set assignment for label to i + + const NodePtr branch = f.branches_[i]; + push_back(branch->apply(op, assignment_)); + + // Remove the assignment so we are backtracking + auto assignment_it = assignment_.find(label_); + assignment_.erase(assignment_it); + } + } + + /// apply unary operator. NodePtr apply(const Unary& op) const override { - boost::shared_ptr r(new Choice(label_, *this, op)); + auto r = boost::make_shared(label_, *this, op); + return Unique(r); + } + + /// Apply unary operator with assignment + NodePtr apply(const UnaryAssignment& op, + const Assignment& assignment) const override { + auto r = boost::make_shared(label_, *this, op, assignment); return Unique(r); } @@ -330,44 +400,42 @@ namespace gtsam { // If second argument of binary op is Leaf node, recurse on branches NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(NodePtr branch: branches_) - h->push_back(fL.apply_f_op_g(*branch, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); return Unique(h); } // If second argument of binary op is Choice, call constructor NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - boost::shared_ptr h(new Choice(fC, *this, op)); + auto h = boost::make_shared(fC, *this, op); return Unique(h); } // If second argument of binary op is Leaf template NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(const NodePtr& branch: branches_) - h->push_back(branch->apply_f_op_g(gL, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(branch->apply_f_op_g(gL, op)); return Unique(h); } /** choose a branch, recursively */ NodePtr choose(const L& label, size_t index) const override { - if (label_ == label) - return branches_[index]; // choose branch + if (label_ == label) return branches_[index]; // choose branch // second case, not label of interest, just recurse - boost::shared_ptr r(new Choice(label_, branches_.size())); - for(const NodePtr& branch: branches_) - r->push_back(branch->choose(label, index)); + auto r = boost::make_shared(label_, branches_.size()); + for (auto&& branch : branches_) + r->push_back(branch->choose(label, index)); return Unique(r); } + }; // Choice - }; // Choice - - /*********************************************************************************/ + /****************************************************************************/ // DecisionTree - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree() { } @@ -377,37 +445,36 @@ namespace gtsam { root_(root) { } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const Y& y) { root_ = NodePtr(new Leaf(y)); } - /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const L& label, const Y& y1, const Y& y2) { - boost::shared_ptr a(new Choice(label, 2)); + /****************************************************************************/ + template + DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { + auto a = boost::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); root_ = Choice::Unique(a); } - /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const LabelC& labelC, const Y& y1, const Y& y2) { + /****************************************************************************/ + template + DecisionTree::DecisionTree(const LabelC& labelC, const Y& y1, + const Y& y2) { if (labelC.second != 2) throw std::invalid_argument( "DecisionTree: binary constructor called with non-binary label"); - boost::shared_ptr a(new Choice(labelC.first, 2)); + auto a = boost::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); root_ = Choice::Unique(a); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::vector& ys) { @@ -415,29 +482,28 @@ namespace gtsam { root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::string& table) { - // Convert std::string to values of type Y std::vector ys; std::istringstream iss(table); copy(std::istream_iterator(iss), std::istream_iterator(), - back_inserter(ys)); + back_inserter(ys)); // now call recursive Create root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { root_ = compose(begin, end, label); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { @@ -446,24 +512,35 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } - /*********************************************************************************/ - template - template - DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); + /****************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + Func Y_of_X) { + // Define functor for identity mapping of node label. + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(other.root_, L_of_L, Y_of_X); } - /*********************************************************************************/ - // Called by two constructors above. - // Takes a label and a corresponding range of decision trees, and creates a new - // decision tree. However, the order of the labels needs to be respected, so we - // cannot just create a root Choice node on the label: if the label is not the - // highest label, we need to do a complicated and expensive recursive call. - template template - typename DecisionTree::NodePtr DecisionTree::compose(Iterator begin, - Iterator end, const L& label) const { + /****************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, Func Y_of_X) { + auto L_of_M = [&map](const M& label) -> L { return map.at(label); }; + root_ = convertFrom(other.root_, L_of_M, Y_of_X); + } + /****************************************************************************/ + // Called by two constructors above. + // Takes a label and a corresponding range of decision trees, and creates a + // new decision tree. However, the order of the labels needs to be respected, + // so we cannot just create a root Choice node on the label: if the label is + // not the highest label, we need a complicated/ expensive recursive call. + template + template + typename DecisionTree::NodePtr DecisionTree::compose( + Iterator begin, Iterator end, const L& label) const { // find highest label among branches boost::optional highestLabel; size_t nrChoices = 0; @@ -480,13 +557,14 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { - boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + auto choiceOnLabel = boost::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); return Choice::Unique(choiceOnLabel); } else { // Set up a new choice on the highest label - boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, nrChoices)); + auto choiceOnHighestLabel = + boost::make_shared(*highestLabel, nrChoices); // now, for all possible values of highestLabel for (size_t index = 0; index < nrChoices; index++) { // make a new set of functions for composing by iterating over the given @@ -505,7 +583,7 @@ namespace gtsam { } } - /*********************************************************************************/ + /****************************************************************************/ // "create" is a bit of a complicated thing, but very useful. // It takes a range of labels and a corresponding range of values, // and creates a decision tree, as follows: @@ -530,7 +608,6 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) const { - // get crucial counts size_t nrChoices = begin->second; size_t size = endY - beginY; @@ -542,10 +619,14 @@ namespace gtsam { // Create a simple choice node with values as leaves. if (size != nrChoices) { std::cout << "Trying to create DD on " << begin->first << std::endl; - std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; + std::cout << boost::format( + "DecisionTree::create: expected %d values but got %d " + "instead") % + nrChoices % size + << std::endl; throw std::invalid_argument("DecisionTree::create invalid argument"); } - boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + auto choice = boost::make_shared(begin->first, endY - beginY); for (ValueIt y = beginY; y != endY; y++) choice->push_back(NodePtr(new Leaf(*y))); return Choice::Unique(choice); @@ -558,56 +639,219 @@ namespace gtsam { size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = create(labelC, end, beginY, beginY + split); - functions += DecisionTree(f); + functions.emplace_back(f); } return compose(functions.begin(), functions.end(), begin->first); } - /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { + /****************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const { + using LY = DecisionTree; - typedef DecisionTree MX; - typedef typename MX::Leaf MXLeaf; - typedef typename MX::Choice MXChoice; - typedef typename MX::NodePtr MXNodePtr; - typedef DecisionTree LY; - - // ugliness below because apparently we can't have templated virtual functions - // If leaf, apply unary conversion "op" and create a unique leaf - const MXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + // Ugliness below because apparently we can't have templated virtual + // functions. + // If leaf, apply unary conversion "op" and create a unique leaf. + using MXLeaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(f)) { + return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); + } // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + using MXChoice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( - "DecisionTree::Convert: Invalid NodePtr"); + "DecisionTree::convertFrom: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); - functions += converted; + for (auto&& branch : choice->branches()) { + functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } return LY::compose(functions.begin(), functions.end(), newLabel); } - /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + /****************************************************************************/ + /** + * Functor performing depth-first visit to each leaf with the leaf value as + * the argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have less than 8 leaves. For example, if a tree has all assignment + * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 + * assignments. + */ + template + struct Visit { + using F = std::function; + explicit Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr"); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); } - template - void DecisionTree::print(const std::string& s) const { - root_->print(s); + /****************************************************************************/ + /** + * Functor performing depth-first visit to each leaf with the Leaf object + * passed as an argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + */ + template + struct VisitLeaf { + using F = std::function::Leaf&)>; + explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(*leaf); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr"); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visitLeaf(Func f) const { + VisitLeaf visit(f); + visit(root_); + } + + /****************************************************************************/ + /** + * Functor performing depth-first visit to each leaf with the leaf's + * `Assignment` and value passed as arguments. + * + * NOTE: Follows the same pruning semantics as `visit`. + */ + template + struct VisitWith { + using F = std::function&, const Y&)>; + explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. + Assignment assignment; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(assignment, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); + for (size_t i = 0; i < choice->nrChoices(); i++) { + assignment[choice->label()] = i; // Set assignment for label to i + + (*this)(choice->branches()[i]); // recurse! + + // Remove the choice so we are backtracking + auto choice_it = assignment.find(choice->label()); + assignment.erase(choice_it); + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /****************************************************************************/ + template + size_t DecisionTree::nrLeaves() const { + size_t total = 0; + visit([&total](const Y& node) { total += 1; }); + return total; + } + + /****************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /****************************************************************************/ + /** + * Get (partial) labels by performing a visit. + * + * This method performs a depth-first search to go to every leaf and records + * the keys assignment which leads to that leaf. Since the tree can be pruned, + * there might be a leaf at a lower depth which results in a partial + * assignment (i.e. not all keys are specified). + * + * E.g. given a tree with 3 keys, there may be a branch where the 3rd key has + * the same values for all the leaves. This leads to the branch being pruned + * so we get a leaf which is arrived at by just the first 2 keys and their + * assignments. + */ + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& assignment, const Y&) { + for (auto&& kv : assignment) { + unique.insert(kv.first); + } + }; + visitWith(f); + return unique; + } + +/****************************************************************************/ + template + bool DecisionTree::equals(const DecisionTree& other, + const CompareFunc& compare) const { + return root_->equals(*other.root_, compare); + } + + template + void DecisionTree::print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const { + root_->print(s, labelFormatter, valueFormatter); } template @@ -622,13 +866,36 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const Unary& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } return DecisionTree(root_->apply(op)); } - /*********************************************************************************/ + /// Apply unary operator with assignment + template + DecisionTree DecisionTree::apply( + const UnaryAssignment& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } + Assignment assignment; + return DecisionTree(root_->apply(op, assignment)); + } + + /****************************************************************************/ template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { + // It is unclear what should happen if either tree is empty: + if (empty() || g.empty()) { + throw std::runtime_error( + "DecisionTree::apply(binary op) undefined for empty trees."); + } // apply the operaton on the root of both diagrams NodePtr h = root_->apply_f_op_g(*g.root_, op); // create a new class with the resulting root "h" @@ -636,7 +903,7 @@ namespace gtsam { return result; } - /*********************************************************************************/ + /****************************************************************************/ // The way this works: // We have an ADT, picture it as a tree. // At a certain depth, we have a branch on "label". @@ -656,25 +923,40 @@ namespace gtsam { return result; } - /*********************************************************************************/ - template - void DecisionTree::dot(std::ostream& os, bool showZero) const { + /****************************************************************************/ + template + void DecisionTree::dot(std::ostream& os, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { os << "digraph G {\n"; - root_->dot(os, showZero); + root_->dot(os, labelFormatter, valueFormatter, showZero); os << " [ordering=out]}" << std::endl; } - template - void DecisionTree::dot(const std::string& name, bool showZero) const { + template + void DecisionTree::dot(const std::string& name, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::ofstream os((name + ".dot").c_str()); - dot(os, showZero); - int result = system( - ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); - if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + dot(os, labelFormatter, valueFormatter, showZero); + int result = + system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null") + .c_str()); + if (result == -1) + throw std::runtime_error("DecisionTree::dot system call failed"); + } -/*********************************************************************************/ - -} // namespace gtsam + template + std::string DecisionTree::dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { + std::stringstream ss; + dot(ss, labelFormatter, valueFormatter, showZero); + return ss.str(); + } +/******************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0..1f45d320b 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,12 +19,17 @@ #pragma once +#include #include #include #include #include #include +#include +#include +#include +#include #include namespace gtsam { @@ -36,24 +41,32 @@ namespace gtsam { */ template class DecisionTree { + protected: + /// Default method for comparison of two objects of type Y. + static bool DefaultCompare(const Y& a, const Y& b) { + return a == b; + } - public: + public: + using LabelFormatter = std::function; + using ValueFormatter = std::function; + using CompareFunc = std::function; /** Handy typedefs for unary and binary function types */ - typedef std::function Unary; - typedef std::function Binary; + using Unary = std::function; + using UnaryAssignment = std::function&, const Y&)>; + using Binary = std::function; /** A label annotated with cardinality */ - typedef std::pair LabelC; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ - class Leaf; - class Choice; + struct Leaf; + struct Choice; /** ------------------------ Node base class --------------------------- */ - class Node { - public: - typedef boost::shared_ptr Ptr; + struct Node { + using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -62,14 +75,16 @@ namespace gtsam { // Constructor Node() { #ifdef DT_DEBUG_MEMORY - std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); + std::cout << ++nrNodes << " constructed " << id() << std::endl; + std::cout.flush(); #endif } // Destructor virtual ~Node() { #ifdef DT_DEBUG_MEMORY - std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); + std::cout << --nrNodes << " destructed " << id() << std::endl; + std::cout.flush(); #endif } @@ -77,13 +92,20 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "") const = 0; - virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual void print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const = 0; + virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, const CompareFunc& compare = + &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; + virtual Ptr apply(const UnaryAssignment& op, + const Assignment& assignment) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; @@ -92,37 +114,46 @@ namespace gtsam { }; /** ------------------------ Node base class --------------------------- */ - public: - + public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; - /* a DecisionTree just contains the root */ + /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; - protected: - - /** Internal recursive function to create from keys, cardinalities, and Y values */ + protected: + /** Internal recursive function to create from keys, cardinalities, + * and Y values + */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - - /** Default constructor */ - DecisionTree(); - - public: + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param L_of_M Functor to convert from label type M to type L. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; + public: /// @name Standard Constructors /// @{ - /** Create a constant */ - DecisionTree(const Y& y); + /** Default constructor (for serialization) */ + DecisionTree(); - /** Create a new leaf function splitting on a variable */ + /** Create a constant */ + explicit DecisionTree(const Y& y); + + /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ @@ -139,31 +170,60 @@ namespace gtsam { DecisionTree(Iterator begin, Iterator end, const L& label); /** Create DecisionTree from two others */ - DecisionTree(const L& label, // - const DecisionTree& f0, const DecisionTree& f1); + DecisionTree(const L& label, const DecisionTree& f0, + const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + /** + * @brief Convert from a different value type. + * + * @tparam X The previous value type. + * @param other The DecisionTree to convert from. + * @param Y_of_X Functor to convert from value type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, Func Y_of_X); + + /** + * @brief Convert from a different value type X to value type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous value type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, const std::map& map, + Func Y_of_X); /// @} /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree") const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, + const CompareFunc& compare = &DefaultCompare) const; /// @} /// @name Standard Interface /// @{ - /** Make virtual */ - virtual ~DecisionTree() { - } + /// Make virtual + virtual ~DecisionTree() {} + + /// Check if tree is empty. + bool empty() const { return !root_; } /** equality */ bool operator==(const DecisionTree& q) const; @@ -171,9 +231,94 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f (side-effect) Function taking a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f (side-effect) Function taking the leaf node pointer. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitLeaf(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f (side-effect) Function taking an assignment and a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& assignment, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + + /// Return the number of leaves in the tree. + size_t nrLeaves() const; + + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + * @note Due to pruning, leaves might not exhaust choices. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all unique labels as a set. */ + std::set labels() const; + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; + /** + * @brief Apply Unary operation "op" to f while also providing the + * corresponding assignment. + * + * @param op Function which takes Assignment and Y as input and returns + * object of type Y. + * @return DecisionTree + */ + DecisionTree apply(const UnaryAssignment& op) const; + /** apply binary operation "op" to f and g */ DecisionTree apply(const DecisionTree& g, const Binary& op) const; @@ -185,7 +330,8 @@ namespace gtsam { } /** combine subtrees on key with binary operation "op" */ - DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; + DecisionTree combine(const L& label, size_t cardinality, + const Binary& op) const; /** combine with LabelC for convenience */ DecisionTree combine(const LabelC& labelC, const Binary& op) const { @@ -193,38 +339,68 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, bool showZero = true) const; + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, bool showZero = true) const; + void dot(const std::string& name, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero = true) const; /// @name Advanced Interface /// @{ // internal use only - DecisionTree(const NodePtr& root); + explicit DecisionTree(const NodePtr& root); // internal use only template NodePtr compose(Iterator begin, Iterator end, const L& label) const; /// @} - - }; // DecisionTree + }; // DecisionTree /** free versions of apply */ - template + /// Apply unary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } - template + /// Apply unary operator `op` with Assignment to DecisionTree `f`. + template + DecisionTree apply(const DecisionTree& f, + const typename DecisionTree::UnaryAssignment& op) { + return f.apply(op); + } + + /// Apply binary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, const typename DecisionTree::Binary& op) { return f.apply(g, op); } -} // namespace gtsam + /** + * @brief unzip a DecisionTree with `std::pair` values. + * + * @param input the DecisionTree with `(T1,T2)` values. + * @return a pair of DecisionTree on T1 and T2, respectively. + */ + template + std::pair, DecisionTree > unzip( + const DecisionTree >& input) { + return std::make_pair( + DecisionTree(input, [](std::pair i) { return i.first; }), + DecisionTree(input, + [](std::pair i) { return i.second; })); + } + +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index b7b9d7034..4e16fc689 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -17,74 +17,90 @@ * @author Frank Dellaert */ +#include #include #include -#include #include +#include +#include using namespace std; namespace gtsam { - /* ******************************************************************************** */ - DecisionTreeFactor::DecisionTreeFactor() { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor() {} - /* ******************************************************************************** */ + /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) : - DiscreteFactor(keys.indices()), Potentials(keys, potentials) { - } + const ADT& potentials) + : DiscreteFactor(keys.indices()), + ADT(potentials), + cardinalities_(keys.cardinalities()) {} - /* *************************************************************************/ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), Potentials(c) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) + : DiscreteFactor(c.keys()), + AlgebraicDecisionTree(c), + cardinalities_(c.cardinalities_) {} - /* ************************************************************************* */ - bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { - if(!dynamic_cast(&other)) { + /* ************************************************************************ */ + bool DecisionTreeFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { return false; - } - else { - const DecisionTreeFactor& f(static_cast(other)); - return Potentials::equals(f, tol); + } else { + const auto& f(static_cast(other)); + return ADT::equals(f, tol); } } - /* ************************************************************************* */ + /* ************************************************************************ */ + double DecisionTreeFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ************************************************************************ */ void DecisionTreeFactor::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { cout << s; - Potentials::print("Potentials:",formatter); + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + ADT::print("", formatter); } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { - map cs; // new cardinalities + ADT::Binary op) const { + map cs; // new cardinalities // make unique key-cardinality map - for(Key j: keys()) cs[j] = cardinality(j); - for(Key j: f.keys()) cs[j] = f.cardinality(j); + for (Key j : keys()) cs[j] = cardinality(j); + for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for(const std::pair& key: cs) - keys.push_back(key); + for (const std::pair& key : cs) keys.push_back(key); // apply operand ADT result = ADT::apply(f, op); // Make a new factor return DecisionTreeFactor(keys, result); } - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, - ADT::Binary op) const { - - if (nrFrontals > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % nrFrontals % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + size_t nrFrontals, ADT::Binary op) const { + if (nrFrontals > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + nrFrontals % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -98,20 +114,21 @@ namespace gtsam { DiscreteKeys dkeys; for (; i < keys().size(); i++) { Key j = keys()[i]; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, - ADT::Binary op) const { - - if (frontalKeys.size() > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % frontalKeys.size() % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + const Ordering& frontalKeys, ADT::Binary op) const { + if (frontalKeys.size() > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + frontalKeys.size() % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -122,17 +139,190 @@ namespace gtsam { } // create new factor, note we collect keys that are not in frontalKeys - // TODO: why do we need this??? result should contain correct keys!!! + // TODO(frank): why do we need this??? result should contain correct keys!!! DiscreteKeys dkeys; for (i = 0; i < keys().size(); i++) { Key j = keys()[i]; - // TODO: inefficient! - if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) + // TODO(frank): inefficient! + if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != + frontalKeys.end()) continue; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } -/* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************ */ + std::vector> DecisionTreeFactor::enumerate() + const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteKeys DecisionTreeFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + + /* ************************************************************************ */ + static std::string valueFormatter(const double& v) { + return (boost::format("%4.2g") % v).str(); + } + + /** output to graphviz format, stream version */ + void DecisionTreeFactor::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(os, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format, open a file */ + void DecisionTreeFactor::dot(const std::string& name, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(name, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format string */ + std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, + bool showZero) const { + return ADT::dot(keyFormatter, valueFormatter, showZero); + } + + // Print out header. + /* ************************************************************************ */ + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << "|"; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << kv.second << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************ */ + string DecisionTreeFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "

\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << " "; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << kv.second << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const vector& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} + + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const string& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} + + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the decision tree so we can threshold. + std::vector probabilities; + this->visitLeaf([&](const Leaf& leaf) { + size_t nrAssignments = leaf.nrAssignments(); + double prob = leaf.constant(); + probabilities.insert(probabilities.end(), nrAssignments, prob); + }); + + // The number of probabilities can be lower than max_leaves + if (probabilities.size() <= N) { + return *this; + } + + std::sort(probabilities.begin(), probabilities.end(), + std::greater{}); + + double threshold = probabilities[N - 1]; + + // Now threshold the decision tree + size_t total = 0; + auto thresholdFunc = [threshold, &total, N](const double& value) { + if (value < threshold || total >= N) { + return 0.0; + } else { + total += 1; + return value; + } + }; + DecisionTree thresholded(*this, thresholdFunc); + + // Create pruned decision tree factor and return. + return DecisionTreeFactor(this->discreteKeys(), thresholded); + } + + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index d1696a281..86fa44649 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,15 +18,18 @@ #pragma once +#include #include -#include +#include #include +#include #include - -#include -#include +#include #include +#include +#include +#include namespace gtsam { @@ -35,34 +38,46 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials { - - public: - + class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor, + public AlgebraicDecisionTree { + public: // typedefs needed to play nice with gtsam typedef DecisionTreeFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class + typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; + typedef AlgebraicDecisionTree ADT; - public: + protected: + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ /** Default constructor for I/O */ DecisionTreeFactor(); - /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from Indices and (string or doubles) */ - template - DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : - DiscreteFactor(keys.indices()), Potentials(keys, table) { - } + /** Constructor from doubles */ + DecisionTreeFactor(const DiscreteKeys& keys, + const std::vector& table); + + /** Constructor from string */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); + + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} /** Construct from a DiscreteConditional type */ - DecisionTreeFactor(const DiscreteConditional& c); + explicit DecisionTreeFactor(const DiscreteConditional& c); /// @} /// @name Testable @@ -72,7 +87,8 @@ namespace gtsam { bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; // print - void print(const std::string& s = "DecisionTreeFactor:\n", + void print( + const std::string& s = "DecisionTreeFactor:\n", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// @} @@ -80,8 +96,8 @@ namespace gtsam { /// @{ /// Value is just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { - return Potentials::operator()(values); + double operator()(const DiscreteValues& values) const override { + return ADT::operator()(values); } /// multiply two factors @@ -89,15 +105,17 @@ namespace gtsam { return apply(f, ADT::Ring::mul); } + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); } /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - return *this; - } + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values shared_ptr sum(size_t nrFrontals) const { @@ -109,11 +127,16 @@ namespace gtsam { return combine(keys, ADT::Ring::add); } - /// Create new factor by maximizing over all values with the same separator values + /// Create new factor by maximizing over all values with the same separator. shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, ADT::Ring::max); } + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, ADT::Ring::max); + } + /// @} /// @name Advanced Interface /// @{ @@ -121,14 +144,14 @@ namespace gtsam { /** * Apply binary operator (*this) "op" f * @param f the second argument for op - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree */ DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; /** * Combine frontal variables using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; @@ -136,37 +159,80 @@ namespace gtsam { /** * Combine frontal variables in an Ordering using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; -// /** -// * @brief Permutes the keys in Potentials and DiscreteFactor -// * -// * This re-implements the permuteWithInverse() in both Potentials -// * and DiscreteFactor by doing both of them together. -// */ -// -// void permuteWithInverse(const Permutation& inversePermutation){ -// DiscreteFactor::permuteWithInverse(inversePermutation); -// Potentials::permuteWithInverse(inversePermutation); -// } -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) { -// DiscreteFactor::reduceWithInverse(inverseReduction); -// Potentials::reduceWithInverse(inverseReduction); -// } + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the leaves to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return DecisionTreeFactor + */ + DecisionTreeFactor prune(size_t maxNrAssignments) const; /// @} -}; -// DecisionTreeFactor + /// @name Wrapper support + /// @{ + + /** output to graphviz format, stream version */ + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + }; // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 84a80c565..ccc52585e 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -25,51 +25,78 @@ namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool DiscreteBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - - /* ************************************************************************* */ - double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const { - // evaluate all conditionals and multiply - double result = 1.0; - for(DiscreteConditional::shared_ptr conditional: *this) - result *= (*conditional)(values); - return result; - } - - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const { - // solve each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->solveInPlace(*result); - return result; - } - - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::sample() const { - // sample each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->sampleInPlace(*result); - return result; - } +// Instantiate base class +template class FactorGraph; /* ************************************************************************* */ -} // namespace +bool DiscreteBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); +} + +/* ************************************************************************* */ +double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { + // evaluate all conditionals and multiply + double result = 1.0; + for (const DiscreteConditional::shared_ptr& conditional : *this) + result *= (*conditional)(values); + return result; +} + +/* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +DiscreteValues DiscreteBayesNet::optimize() const { + DiscreteValues result; + return optimize(result); +} + +DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { + // solve each node in turn in topological sort order (parents first) +#ifdef _MSC_VER +#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!") +#else +#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!" +#endif + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->solveInPlace(&result); + return result; +} +#endif + +/* ************************************************************************* */ +DiscreteValues DiscreteBayesNet::sample() const { + DiscreteValues result; + return sample(result); +} + +DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { + // sample each node in turn in topological sort order (parents first) + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->sampleInPlace(&result); + return result; +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->markdown(keyFormatter, names) << endl; + return ss.str(); +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteBayesNet of size " << size() << "

"; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->html(keyFormatter, names) << endl; + return ss.str(); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index d5ba30584..df94d6908 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -13,25 +13,31 @@ * @file DiscreteBayesNet.h * @date Feb 15, 2011 * @author Duy-Nguyen Ta + * @author Frank dellaert */ #pragma once -#include -#include -#include +#include +#include #include #include -#include + +#include +#include +#include +#include +#include namespace gtsam { -/** A Bayes net made from linear-Discrete densities */ - class GTSAM_EXPORT DiscreteBayesNet: public BayesNet - { - public: - - typedef FactorGraph Base; +/** + * A Bayes net made from discrete conditional distributions. + * @addtogroup discrete + */ +class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { + public: + typedef BayesNet Base; typedef DiscreteBayesNet This; typedef DiscreteConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -40,20 +46,24 @@ namespace gtsam { /// @name Standard Constructors /// @{ - /** Construct empty factor graph */ + /// Construct empty Bayes net. DiscreteBayesNet() {} /** Construct from iterator over conditionals */ - template - DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit DiscreteBayesNet(const CONTAINER& conditionals) + : Base(conditionals) {} - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - DiscreteBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + DiscreteBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~DiscreteBayesNet() {} @@ -71,26 +81,73 @@ namespace gtsam { /// @name Standard Interface /// @{ + // Add inherited versions of add. + using Base::add; + + /** Add a DiscreteDistribution using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } + /** Add a DiscreteCondtional */ - void add(const Signature& s); + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + //** evaluate for given DiscreteValues */ + double evaluate(const DiscreteValues & values) const; -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); - - //** evaluate for given Values */ - double evaluate(const DiscreteConditional::Values & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } /** - * Solve the DiscreteBayesNet by back-substitution - */ - DiscreteFactor::sharedValues optimize() const; + * @brief do ancestral sampling + * + * Assumes the Bayes net is reverse topologically sorted, i.e. last + * conditional will be sampled first. If the Bayes net resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return a sampled value for all variables. + */ + DiscreteValues sample() const; - /** Do ancestral sampling */ - DiscreteFactor::sharedValues sample() const; + /** + * @brief do ancestral sampling, given certain variables. + * + * Assumes the Bayes net is reverse topologically sorted *and* that the + * Bayes net does not contain any conditionals for the given values. + * + * @return given values extended with sampled value for all other variables. + */ + DiscreteValues sample(DiscreteValues given) const; + + ///@} + /// @name Wrapper support + /// @{ + + /// Render as markdown tables. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// Render as html tables. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; ///@} - private: +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + + DiscreteValues GTSAM_DEPRECATED optimize() const; + DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const; + /// @} +#endif + + private: /** Serialization function */ friend class boost::serialization::access; template diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 990d10dbe..139292eee 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -31,7 +31,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTreeClique::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { // evaluate all conditionals and multiply double result = (*conditional_)(values); for (const auto& child : children) { @@ -47,7 +47,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTree::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { double result = 1.0; for (const auto& root : roots_) { result *= root->evaluate(values); @@ -55,8 +55,40 @@ namespace gtsam { return result; } -} // \namespace gtsam - - + /* **************************************************************************/ + std::string DiscreteBayesTree::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + std::string DiscreteBayesTree::html( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteBayesTree of size " << nodes_.size() + << "

"; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << clique->conditional()->html(keyFormatter, names); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 29da5817e..809ce9c83 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique conditional_->printSignature(s, formatter); } - //** evaluate conditional probability of subtree for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate conditional probability of subtree for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; }; /* ************************************************************************* */ @@ -72,14 +72,35 @@ class GTSAM_EXPORT DiscreteBayesTree typedef DiscreteBayesTree This; typedef boost::shared_ptr shared_ptr; + /// @name Standard interface + /// @{ /** Default constructor, creates an empty Bayes tree */ DiscreteBayesTree() {} /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; - //** evaluate probability for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate probability for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } + + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown tables. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// Render as html tables. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index ac7c58405..0d6c5e976 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -16,57 +16,119 @@ * @author Frank Dellaert */ +#include +#include #include #include #include -#include -#include - -#include #include +#include #include +#include #include #include +#include #include using namespace std; - +using std::pair; +using std::stringstream; +using std::vector; namespace gtsam { // Instantiate base class -template class Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} + +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; @@ -79,122 +141,196 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; - Potentials::print(""); + cout << "):\n"; + ADT::print("", formatter); cout << endl; } -/* ******************************************************************************** */ +/* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) + double tol) const { + if (!dynamic_cast(&other)) { return false; - else { - const DecisionTreeFactor& f( - static_cast(other)); + } else { + const DecisionTreeFactor& f(static_cast(other)); return DecisionTreeFactor::equals(f, tol); } } -/* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const { - ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { +/* ************************************************************************** */ +DiscreteConditional::ADT DiscreteConditional::choose( + const DiscreteValues& given, bool forceComplete) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteConditional::ADT adt(*this); + size_t value; + for (Key j : parents()) { try { - j = (key); - value = parentsValues.at(j); - pFS = pFS.choose(j, value); - } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; - parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; + value = given.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + if (forceComplete) { + given.print("parentsValues: "); + throw runtime_error( + "DiscreteConditional::choose: parent value missing"); + } + } } - - return pFS; + return adt; } -/* ******************************************************************************** */ -void DiscreteConditional::solveInPlace(Values& values) const { - // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(values); // P(F|S=parentsValues) +/* ************************************************************************** */ +DiscreteConditional::shared_ptr DiscreteConditional::choose( + const DiscreteValues& given) const { + ADT adt = choose(given, false); // P(F|S=given) + + // Collect all keys not in given. + DiscreteKeys dKeys; + for (Key j : frontals()) { + dKeys.emplace_back(j, this->cardinality(j)); + } + for (size_t i = nrFrontals(); i < size(); i++) { + Key j = keys_[i]; + if (given.count(j) == 0) { + dKeys.emplace_back(j, this->cardinality(j)); + } + } + return boost::make_shared(nrFrontals(), dKeys, adt); +} + +/* ************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + } + } + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); +} + +/* ****************************************************************************/ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t frontal) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], frontal); + return likelihood(values); +} + +/* ************************************************************************** */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +void DiscreteConditional::solveInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize - Values mpe; + DiscreteValues mpe; double maxP = 0; - DiscreteKeys keys; - for(Key idx: frontals()) { - DiscreteKey dk(idx, cardinality(idx)); - keys & dk; - } // Get all Possible Configurations - vector allPosbValues = cartesianProduct(keys); + const auto allPosbValues = frontalAssignments(); - // Find the MPE - for(Values& frontalVals: allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better if (pValueS > maxP) { maxP = pValueS; mpe = frontalVals; } } - //set values (inPlace) to mpe - for(Key j: frontals()) { - values[j] = mpe[j]; + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; } } -/* ******************************************************************************** */ -void DiscreteConditional::sampleInPlace(Values& values) const { - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - size_t sampled = sample(values); // Sample variable - values[j] = sampled; // store result in partial solution -} - -/* ******************************************************************************** */ -size_t DiscreteConditional::solve(const Values& parentsValues) const { - - // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) +/* ************************************************************************** */ +size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining - // TODO, only works for one key now, seems horribly slow this way - size_t mpe = 0; - Values frontals; + size_t max = 0; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { frontals[j] = value; double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update solution if better + if (pValueS > maxP) { + maxP = pValueS; + max = value; + } + } + return max; +} +#endif + +/* ************************************************************************** */ +size_t DiscreteConditional::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + assert(nrParents() == 0); + DiscreteValues frontals; + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = (*this)(frontals); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; - mpe = value; + maxValue = value; } } - return mpe; + return maxValue; } -/* ******************************************************************************** */ -size_t DiscreteConditional::sample(const Values& parentsValues) const { +/* ************************************************************************** */ +void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + size_t sampled = sample(*values); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution +} + +/* ************************************************************************** */ +size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); - Values frontals; + DiscreteValues frontals; for (size_t value = 0; value < nj; value++) { frontals[key] = value; p[value] = pFS(frontals); // P(F=value|S=parentsValues) @@ -206,6 +342,174 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} -}// namespace +/* ************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + +/* ************************************************************************* */ +vector DiscreteConditional::frontalAssignments() const { + vector> pairs; + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************* */ +vector DiscreteConditional::allAssignments() const { + vector> pairs; + for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key)); + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************* */ +// Print out signature. +static void streamSignature(const DiscreteConditional& conditional, + const KeyFormatter& keyFormatter, + stringstream* ss) { + *ss << "P("; + bool first = true; + for (Key key : conditional.frontals()) { + if (!first) *ss << ","; + *ss << keyFormatter(key); + first = false; + } + if (conditional.nrParents() > 0) { + *ss << "|"; + bool first = true; + for (Key parent : conditional.parents()) { + if (!first) *ss << ","; + *ss << keyFormatter(parent); + first = false; + } + } + *ss << "):"; +} + +/* ************************************************************************* */ +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << " *"; + streamSignature(*this, keyFormatter, &ss); + ss << "*\n" << std::endl; + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << DecisionTreeFactor::markdown(keyFormatter, names); + return ss.str(); + } + + // Print out header. + ss << "|"; + for (Key parent : parents()) { + ss << "*" << keyFormatter(parent) << "*|"; + } + + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index); + } + ss << "|"; + } + ss << "\n"; + + // Print out separator with alignment hints. + ss << "|"; + size_t n = frontalAssignments.size(); + for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; + ss << "\n"; + + // Print out all rows. + size_t count = 0; + for (const auto& a : allAssignments()) { + if (count == 0) { + ss << "|"; + for (auto&& it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index) << "|"; + } + } + ss << operator()(a) << "|"; + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + return ss.str(); +} + +/* ************************************************************************ */ +string DiscreteConditional::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << "
\n

"; + streamSignature(*this, keyFormatter, &ss); + ss << "

\n"; + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << DecisionTreeFactor::html(keyFormatter, names); + return ss.str(); + } + + // Print out preamble. + ss << "\n \n"; + + // Print out header row. + ss << " "; + for (Key parent : parents()) { + ss << ""; + } + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Output all rows, one per assignment: + size_t count = 0, n = frontalAssignments.size(); + for (const auto& a : allAssignments()) { + if (count == 0) { + ss << " "; + for (auto&& it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << ""; + } + } + ss << ""; // value + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + + // Finish up + ss << " \n
" << keyFormatter(parent) << ""; + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index); + } + ss << "
" << DiscreteValues::Translate(names, *it, index) << "" << operator()(a) << "
\n
"; + return ss.str(); +} + +/* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8299fab2c..cff1b69a6 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -21,10 +21,11 @@ #include #include #include -#include -#include +#include +#include #include +#include namespace gtsam { @@ -32,59 +33,109 @@ namespace gtsam { * Discrete Conditional Density * Derives from DecisionTreeFactor */ -class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, - public Conditional { - -public: +class GTSAM_EXPORT DiscreteConditional + : public DecisionTreeFactor, + public Conditional { + public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class - typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class - typedef Conditional BaseConditional; ///< Typedef to our conditional base class + typedef DiscreteConditional This; ///< Typedef to this class + typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class + typedef Conditional + BaseConditional; ///< Typedef to our conditional base class - /** A map from keys to values.. - * TODO: Again, do we need this??? */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ - DiscreteConditional() { - } + /// Default constructor needed for serialization. + DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); - - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal); - - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys); + explicit DiscreteConditional(const Signature& signature); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the parents - * of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The string is parsed into a Signature::Table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + + /// No-parent specialization; can also use DiscreteDistribution. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal); + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys); + + /** + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; + + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; /// @} /// @name Testable /// @{ /// GTSAM-style print - void print(const std::string& s = "Discrete Conditional: ", + void print( + const std::string& s = "Discrete Conditional: ", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// GTSAM-style equals @@ -102,68 +153,95 @@ public: } /// Evaluate, just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { - return Potentials::operator()(values); + double operator()(const DiscreteValues& values) const override { + return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const Assignment& parentsValues) const; - /** - * solve a conditional - * @param parentsValues Known values of the parents - * @return MPE value of the child (1 frontal variable). + * @brief restrict to given *parent* values. + * + * Note: does not need be complete set. Examples: + * + * P(C|D,E) + . -> P(C|D,E) + * P(C|D,E) + E -> P(C|D) + * P(C|D,E) + D -> P(C|E) + * P(C|D,E) + D,E -> P(C) + * P(C|D,E) + C -> error! + * + * @return a shared_ptr to a new DiscreteConditional */ - size_t solve(const Values& parentsValues) const; + shared_ptr choose(const DiscreteValues& given) const; + + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const; /** * sample * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const Values& parentsValues) const; + size_t sample(const DiscreteValues& parentsValues) const; + + /// Single parent version. + size_t sample(size_t parent_value) const; + + /// Zero parent version. + size_t sample() const; + + /** + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). + */ + size_t argmax() const; /// @} /// @name Advanced Interface /// @{ - /// solve a conditional, in place - void solveInPlace(Values& parentsValues) const; - /// sample in place, stores result in partial solution - void sampleInPlace(Values& parentsValues) const; + void sampleInPlace(DiscreteValues* parentsValues) const; + + /// Return all assignments for frontal variables. + std::vector frontalAssignments() const; + + /// Return all assignments for frontal *and* parent variables. + std::vector allAssignments() const; + + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// Render as html table. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const; + void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; + /// @} +#endif + + protected: + /// Internal version of choose + DiscreteConditional::ADT choose(const DiscreteValues& given, + bool forceComplete) const; }; // DiscreteConditional // traits -template<> struct traits : public Testable {}; - -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - -} // gtsam +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.cpp b/gtsam/discrete/DiscreteDistribution.cpp new file mode 100644 index 000000000..739771470 --- /dev/null +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -0,0 +1,52 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteDistribution.cpp + * @date December 2021 + * @author Frank Dellaert + */ + +#include + +#include + +namespace gtsam { + +void DiscreteDistribution::print(const std::string& s, + const KeyFormatter& formatter) const { + Base::print(s, formatter); +} + +double DiscreteDistribution::operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); +} + +std::vector DiscreteDistribution::pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscreteDistribution::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h new file mode 100644 index 000000000..c5147dbc1 --- /dev/null +++ b/gtsam/discrete/DiscreteDistribution.h @@ -0,0 +1,107 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteDistribution.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscreteDistribution() {} + + /// Constructor from factor. + explicit DiscreteDistribution(const DecisionTreeFactor& f) + : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscreteDistribution P(D % "3/2"); + */ + explicit DiscreteDistribution(const Signature& s) : Base(s) {} + + /** + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). + * + * Example: DiscreteDistribution P(D, {0.4, 0.6}); + */ + DiscreteDistribution(const DiscreteKey& key, const std::vector& spec) + : DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {} + + /** + * Construct from key and a string specifying the probability mass function + * (PMF). + * + * Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9"); + */ + DiscreteDistribution(const DiscreteKey& key, const std::string& spec) + : DiscreteDistribution(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard interface + /// @{ + + /// Evaluate given a single value. + double operator()(size_t value) const; + + /// We also want to keep the Base version, taking DiscreteValues: + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); + + /// Return entire probability mass function. + std::vector pmf() const; + + /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); } + /// @} +#endif +}; +// DiscreteDistribution + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d2..08309e2e1 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -17,11 +17,59 @@ * @author Frank Dellaert */ +#include #include +#include +#include + using namespace std; namespace gtsam { /* ************************************************************************* */ -} // namespace gtsam +std::vector expNormalize(const std::vector& logProbs) { + double maxLogProb = -std::numeric_limits::infinity(); + for (size_t i = 0; i < logProbs.size(); i++) { + double logProb = logProbs[i]; + if ((logProb != std::numeric_limits::infinity()) && + logProb > maxLogProb) { + maxLogProb = logProb; + } + } + + // After computing the max = "Z" of the log probabilities L_i, we compute + // the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z). + double total = 0.0; + for (size_t i = 0; i < logProbs.size(); i++) { + double probPrime = exp(logProbs[i] - maxLogProb); + total += probPrime; + } + double logTotal = log(total); + + // Now we compute the (normalized) probability (for each i): + // p_i = exp(L_i - Z - log S) + double checkNormalization = 0.0; + std::vector probs; + for (size_t i = 0; i < logProbs.size(); i++) { + double prob = exp(logProbs[i] - maxLogProb - logTotal); + probs.push_back(prob); + checkNormalization += prob; + } + + // Numerical tolerance for floating point comparisons + double tol = 1e-9; + + if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) { + std::string errMsg = + std::string("expNormalize failed to normalize probabilities. ") + + std::string("Expected normalization constant = 1.0. Got value: ") + + std::to_string(checkNormalization) + + std::string( + "\n This could have resulted from numerical overflow/underflow."); + throw std::logic_error(errMsg); + } + return probs; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6b0919507..212ade8cf 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -18,10 +18,11 @@ #pragma once -#include +#include #include #include +#include namespace gtsam { class DecisionTreeFactor; @@ -40,18 +41,7 @@ public: typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class typedef Factor Base; ///< Our base class - /** A map from keys to values - * TODO: Do we need this? Should we just use gtsam::Values? - * We just need another special DiscreteValue to represent labels, - * However, all other Lie's operators are undefined in this class. - * The good thing is we can have a Hybrid graph of discrete/continuous variables - * together.. - * Another good thing is we don't need to have the special DiscreteKey which stores - * cardinality of a Discrete variable. It should be handled naturally in - * the new class DiscreteValue, as the varible's type (domain) - */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility public: @@ -84,27 +74,72 @@ public: Base::print(s, formatter); } - /** Test whether the factor is empty */ - virtual bool empty() const { return size() == 0; } - /// @} /// @name Standard Interface /// @{ /// Find value for given assignment of values to variables - virtual double operator()(const Values&) const = 0; + virtual double operator()(const DiscreteValues&) const = 0; /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = DiscreteValues::Names; + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + virtual std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + virtual std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; + /// @} }; // DiscreteFactor // traits template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; + + +/** + * @brief Normalize a set of log probabilities. + * + * Normalizing a set of log probabilities in a numerically stable way is + * tricky. To avoid overflow/underflow issues, we compute the largest + * (finite) log probability and subtract it from each log probability before + * normalizing. This comes from the observation that if: + * p_i = exp(L_i) / ( sum_j exp(L_j) ), + * Then, + * p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)), + * = exp(L_i - Z) / ( sum_j exp(L_j - Z) ) + * + * Setting Z = max_j L_j, we can avoid numerical issues that arise when all + * of the (unnormalized) log probabilities are either very large or very + * small. + */ +std::vector expNormalize(const std::vector &logProbs); + }// namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e41968d6b..ebcac445c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,18 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include +#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -41,11 +44,25 @@ namespace gtsam { /* ************************************************************************* */ KeySet DiscreteFactorGraph::keys() const { KeySet keys; - for(const sharedFactor& factor: *this) - if (factor) keys.insert(factor->begin(), factor->end()); + for (const sharedFactor& factor : *this) { + if (factor) keys.insert(factor->begin(), factor->end()); + } return keys; } + /* ************************************************************************* */ + DiscreteKeys DiscreteFactorGraph::discreteKeys() const { + DiscreteKeys result; + for (auto&& factor : *this) { + if (auto p = boost::dynamic_pointer_cast(factor)) { + DiscreteKeys factor_keys = p->discreteKeys(); + result.insert(result.end(), factor_keys.begin(), factor_keys.end()); + } + } + + return result; + } + /* ************************************************************************* */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; @@ -56,7 +73,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteFactorGraph::operator()( - const DiscreteFactor::Values &values) const { + const DiscreteValues &values) const { double product = 1.0; for( const sharedFactor& factor: factors_ ) product *= (*factor)(values); @@ -64,7 +81,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -93,22 +110,99 @@ namespace gtsam { // } // } - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const - { - gttic(DiscreteFactorGraph_optimize); - return BaseEliminateable::eliminateSequential()->optimize(); - } - - /* ************************************************************************* */ + /* ************************************************************************ */ + // Alternate eliminate function for MPE std::pair // - EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; - for(const DiscreteFactor::shared_ptr& factor: factors) - product = (*factor) * product; + for (auto&& factor : factors) product = (*factor) * product; + gttoc(product); + + // max out frontals, this is the factor on the separator + gttic(max); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + gttoc(max); + + // Ordering keys for the conditional so that frontalKeys are really in front + DiscreteKeys orderedKeys; + for (auto&& key : frontalKeys) + orderedKeys.emplace_back(key, product.cardinality(key)); + for (auto&& key : max->keys()) + orderedKeys.emplace_back(key, product.cardinality(key)); + + // Make lookup with product + gttic(lookup); + size_t nrFrontals = frontalKeys.size(); + auto lookup = boost::make_shared(nrFrontals, + orderedKeys, product); + gttoc(lookup); + + return std::make_pair( + boost::dynamic_pointer_cast(lookup), max); + } + + /* ************************************************************************ */ + // sumProduct is just an alias for regular eliminateSequential. + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(orderingType); + return *bayesNet; + } + + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(ordering); + return *bayesNet; + } + + /* ************************************************************************ */ + // The max-product solution below is a bit clunky: the elimination machinery + // does not allow for differently *typed* versions of elimination, so we + // eliminate into a Bayes Net using the special eliminate function above, and + // then create the DiscreteLookupDAG after the fact, in linear time. + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = eliminateSequential(orderingType, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + /* ************************************************************************ */ + DiscreteValues DiscreteFactorGraph::optimize( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(orderingType); + return dag.argmax(); + } + + DiscreteValues DiscreteFactorGraph::optimize( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(ordering); + return dag.argmax(); + } + + /* ************************************************************************ */ + std::pair // + EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product; + for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator @@ -118,17 +212,46 @@ namespace gtsam { // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), + sum->keys().end()); // now divide product/sum to get conditional gttic(divide); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); + auto conditional = + boost::make_shared(product, *sum, orderedKeys); gttoc(divide); - return std::make_pair(cond, sum); + return std::make_pair(conditional, sum); } -/* ************************************************************************* */ -} // namespace + /* ************************************************************************ */ + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "factor " << i << ":\n"; + ss << factors_[i]->markdown(keyFormatter, names) << endl; + } + return ss.str(); + } + /* ************************************************************************ */ + string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteFactorGraph of size " << size() << "

"; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "

factor " << i << ":

"; + ss << factors_[i]->html(keyFormatter, names) << endl; + } + return ss.str(); + } + + /* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index f39adc9a8..f962b1802 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -18,19 +18,22 @@ #pragma once -#include -#include -#include #include -#include +#include +#include +#include +#include #include + #include +#include +#include +#include namespace gtsam { // Forward declarations class DiscreteFactorGraph; -class DiscreteFactor; class DiscreteConditional; class DiscreteBayesNet; class DiscreteEliminationTree; @@ -62,33 +65,35 @@ template<> struct EliminationTraits * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * Factor == DiscreteFactor */ -class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph, -public EliminateableFactorGraph { -public: +class GTSAM_EXPORT DiscreteFactorGraph + : public FactorGraph, + public EliminateableFactorGraph { + public: + using This = DiscreteFactorGraph; ///< this class + using Base = FactorGraph; ///< base factor graph type + using BaseEliminateable = + EliminateableFactorGraph; ///< for elimination + using shared_ptr = boost::shared_ptr; ///< shared_ptr to This - typedef DiscreteFactorGraph This; ///< Typedef to this class - typedef FactorGraph Base; ///< Typedef to base factor graph type - typedef EliminateableFactorGraph BaseEliminateable; ///< Typedef to base elimination class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + using Values = DiscreteValues; ///< backwards compatibility - /** A map from keys to values */ - typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Indices = KeyVector; ///> map from keys to values /** Default constructor */ DiscreteFactorGraph() {} /** Construct from iterator over factors */ - template - DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} + template + DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) + : Base(firstFactor, lastFactor) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template + template explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template + /** Implicit copy/downcast constructor to override explicit template container + * constructor */ + template DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} /// Destructor @@ -101,57 +106,111 @@ public: /// @} - template - void add(const DiscreteKey& j, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j); - push_back(boost::make_shared(keys, table)); - } - - template - void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j1); - keys.push_back(j2); - push_back(boost::make_shared(keys, table)); - } - - /** add shared discreteFactor immediately from arguments */ - template - void add(const DiscreteKeys& keys, SOURCE table) { - push_back(boost::make_shared(keys, table)); + /** Add a decision-tree factor */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); } /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; + /// Return the DiscreteKeys in this factor graph. + DiscreteKeys discreteKeys() const; + /** return product of all factors as a single factor */ DecisionTreeFactor product() const; - /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ - double operator()(const DiscreteFactor::Values & values) const; + /** + * Evaluates the factor graph given values, returns the joint probability of + * the factor graph given specific instantiation of values + */ + double operator()(const DiscreteValues& values) const; /// print void print( const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** Solve the factor graph by performing variable elimination in COLAMD order using - * the dense elimination function specified in \c function, - * followed by back-substitution resulting from elimination. Is equivalent - * to calling graph.eliminateSequential()->optimize(). */ - DiscreteFactor::sharedValues optimize() const; + /** + * @brief Implement the sum-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct( + OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Implement the sum-product algorithm + * + * @param ordering + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct(const Ordering& ordering) const; -// /** Permute the variables in the factors */ -// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); -// -// /** Apply a reduction, which is a remapping of variable indices. */ -// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /** + * @brief Implement the max-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteLookupDAG DAG with lookup tables + */ + DiscreteLookupDAG maxProduct( + OptionalOrderingType orderingType = boost::none) const; -}; // \ DiscreteFactorGraph + /** + * @brief Implement the max-product algorithm + * + * @param ordering + * @return DiscreteLookupDAG `DAG with lookup tables + */ + DiscreteLookupDAG maxProduct(const Ordering& ordering) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param orderingType + * @return DiscreteValues : MPE + */ + DiscreteValues optimize( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param ordering + * @return DiscreteValues : MPE + */ + DiscreteValues optimize(const Ordering& ordering) const; + + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown tables + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /** + * @brief Render as html tables + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} +}; // \ DiscreteFactorGraph /// traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -} // \ namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f5bc9be1d..c74ad3cc2 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -16,6 +16,8 @@ * @author Richard Roberts */ +#pragma once + #include #include #include diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 5ddad22b0..121d61103 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -33,16 +33,13 @@ namespace gtsam { KeyVector DiscreteKeys::indices() const { KeyVector js; - for(const DiscreteKey& key: *this) - js.push_back(key.first); + for (const DiscreteKey& key : *this) js.push_back(key.first); return js; } - map DiscreteKeys::cardinalities() const { - map cs; - cs.insert(begin(),end()); -// for(const DiscreteKey& key: *this) -// cs.insert(key); + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(), end()); return cs; } diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8e..dea00074d 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -28,21 +28,26 @@ namespace gtsam { /** - * Key type for discrete conditionals - * Includes name and cardinality + * Key type for discrete variables. + * Includes Key and cardinality. */ - typedef std::pair DiscreteKey; + using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator - struct DiscreteKeys: public std::vector { + struct GTSAM_EXPORT DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; + + /// Constructor for serialization + DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { - push_back(key); + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } + + /// Construct from cardinalities. + explicit DiscreteKeys(std::map cardinalities) { + for (auto&& kv : cardinalities) emplace_back(kv); } /// Construct from a vector of keys @@ -51,13 +56,13 @@ namespace gtsam { } /// Construct from cardinalities with default names - GTSAM_EXPORT DiscreteKeys(const std::vector& cs); + DiscreteKeys(const std::vector& cs); /// Return a vector of indices - GTSAM_EXPORT KeyVector indices() const; + KeyVector indices() const; /// Return a map from index to cardinality - GTSAM_EXPORT std::map cardinalities() const; + std::map cardinalities() const; /// Add a key (non-const!) DiscreteKeys& operator&(const DiscreteKey& key) { diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp new file mode 100644 index 000000000..d96b38b0e --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -0,0 +1,127 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.cpp + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + double maxP = 0; + DiscreteValues frontals; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + +/* ************************************************************************** */ +DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( + const DiscreteBayesNet& bayesNet) { + DiscreteLookupDAG dag; + for (auto&& conditional : bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; +} + +DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h new file mode 100644 index 000000000..15169a1dc --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.h + * @date January, 2022 + * @author Frank dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +class DiscreteBayesNet; + +/** + * @brief DiscreteLookupTable table for max-product + * + * Inherits from discrete conditional for convenience, but is not normalized. + * Is used in the max-product algorithm. + */ +class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a sorted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; +}; + +/** A DAG made from lookup tables, as defined above. */ +class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = DiscreteLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + DiscreteLookupDAG() {} + + /// Create from BayesNet with LookupTables + static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); + + /// Destructor + virtual ~DiscreteLookupDAG() {} + + /// @} + + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This& bn, double tol = 1e-9) const; + + /// @} + + /// @name Standard Interface + /// @{ + + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + /** + * @brief argmax by back-substitution, optionally given certain variables. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index c2a188e08..dc87f665e 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -29,7 +29,7 @@ namespace gtsam { /** * A class for computing marginals of variables in a DiscreteFactorGraph */ - class DiscreteMarginals { +class DiscreteMarginals { protected: @@ -37,6 +37,8 @@ namespace gtsam { public: + DiscreteMarginals() {} + /** Construct a marginals class. * @param graph The factor graph defining the full joint density on all variables. */ @@ -64,7 +66,7 @@ namespace gtsam { //Create result Vector vResult(key.second); for (size_t state = 0; state < key.second ; ++ state) { - DiscreteFactor::Values values; + DiscreteValues values; values[key.first] = state; vResult(state) = (*marginalFactor)(values); } diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp new file mode 100644 index 000000000..5d0c8dd3d --- /dev/null +++ b/gtsam/discrete/DiscreteValues.cpp @@ -0,0 +1,97 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.cpp + * @date January, 2022 + * @author Frank Dellaert + */ + +#include + +#include + +using std::cout; +using std::endl; +using std::string; +using std::stringstream; + +namespace gtsam { + +void DiscreteValues::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << s << ": "; + for (auto&& kv : *this) + cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + cout << endl; +} + +string DiscreteValues::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +string DiscreteValues::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header and separator with alignment hints. + ss << "|Variable|value|\n|:-:|:-:|\n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << "|" << keyFormatter(kv.first) << "|" + << Translate(names, kv.first, kv.second) << "|\n"; + } + + return ss.str(); +} + +string DiscreteValues::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " \n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << " "; + ss << ""; + ss << "\n"; + } + ss << " \n
Variablevalue
" << keyFormatter(kv.first) << "" + << Translate(names, kv.first, kv.second) << "
\n
"; + return ss.str(); +} + +string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter, + const DiscreteValues::Names& names) { + return values.markdown(keyFormatter, names); +} + +string html(const DiscreteValues& values, const KeyFormatter& keyFormatter, + const DiscreteValues::Names& names) { + return values.html(keyFormatter, names); +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h new file mode 100644 index 000000000..cb17bf833 --- /dev/null +++ b/gtsam/discrete/DiscreteValues.h @@ -0,0 +1,106 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.h + * @date Dec 13, 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** A map from keys to values + * TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues? + * We just need another special DiscreteValue to represent labels, + * However, all other Lie's operators are undefined in this class. + * The good thing is we can have a Hybrid graph of discrete/continuous variables + * together.. + * Another good thing is we don't need to have the special DiscreteKey which + * stores cardinality of a Discrete variable. It should be handled naturally in + * the new class DiscreteValue, as the variable's type (domain) + */ +class GTSAM_EXPORT DiscreteValues : public Assignment { + public: + using Base = Assignment; // base class + + using Assignment::Assignment; // all constructors + + // Define the implicit default constructor. + DiscreteValues() = default; + + // Construct from assignment. + explicit DiscreteValues(const Base& a) : Base(a) {} + + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + static std::vector CartesianProduct( + const DiscreteKeys& keys) { + return Base::CartesianProduct(keys); + } + + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = std::map>; + + /// Translate an integer index value for given key to a string. + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Output as a markdown table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string markdown output. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const; + + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string html output. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const; + + /// @} +}; + +/// Free version of markdown. +std::string markdown(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); + +/// Free version of html. +std::string html(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp deleted file mode 100644 index 331a76c13..000000000 --- a/gtsam/discrete/Potentials.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.cpp - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#include -#include - -#include - -#include - -using namespace std; - -namespace gtsam { - -// explicit instantiation -template class DecisionTree; -template class AlgebraicDecisionTree; - -/* ************************************************************************* */ -double Potentials::safe_div(const double& a, const double& b) { - // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); -} - -/* ******************************************************************************** - */ -Potentials::Potentials() : ADT(1.0) {} - -/* ******************************************************************************** - */ -Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) - : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} - -/* ************************************************************************* */ -bool Potentials::equals(const Potentials& other, double tol) const { - return ADT::equals(other, tol); -} - -/* ************************************************************************* */ -void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: {"; - for (const std::pair& key : cardinalities_) - cout << formatter(key.first) << ":" << key.second << ", "; - cout << "}" << endl; - ADT::print(" "); -} -// -// /* ************************************************************************* */ -// template -// void Potentials::remapIndices(const P& remapping) { -// // Permute the _cardinalities (TODO: Inefficient Consider Improving) -// DiscreteKeys keys; -// map ordering; -// -// // Get the original keys from cardinalities_ -// for(const DiscreteKey& key: cardinalities_) -// keys & key; -// -// // Perform Permutation -// for(DiscreteKey& key: keys) { -// ordering[key.first] = remapping[key.first]; -// key.first = ordering[key.first]; -// } -// -// // Change *this -// AlgebraicDecisionTree permuted((*this), ordering); -// *this = permuted; -// cardinalities_ = keys.cardinalities(); -// } -// -// /* ************************************************************************* */ -// void Potentials::permuteWithInverse(const Permutation& inversePermutation) { -// remapIndices(inversePermutation); -// } -// -// /* ************************************************************************* */ -// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) { -// remapIndices(inverseReduction); -// } - - /* ************************************************************************* */ - -} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h deleted file mode 100644 index 1078b4c61..000000000 --- a/gtsam/discrete/Potentials.h +++ /dev/null @@ -1,97 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.h - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace gtsam { - - /** - * A base class for both DiscreteFactor and DiscreteConditional - */ - class Potentials: public AlgebraicDecisionTree { - - public: - - typedef AlgebraicDecisionTree ADT; - - protected: - - /// Cardinality for each key, used in combine - std::map cardinalities_; - - /** Constructor from ColumnIndex, and ADT */ - Potentials(const ADT& potentials) : - ADT(potentials) { - } - - // Safe division for probabilities - GTSAM_EXPORT static double safe_div(const double& a, const double& b); - -// // Apply either a permutation or a reduction -// template -// void remapIndices(const P& remapping); - - public: - - /** Default constructor for I/O */ - GTSAM_EXPORT Potentials(); - - /** Constructor from Indices and ADT */ - GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree); - - /** Constructor from Indices and (string or doubles) */ - template - Potentials(const DiscreteKeys& keys, SOURCE table) : - ADT(keys, table), cardinalities_(keys.cardinalities()) { - } - - // Testable - GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const; - GTSAM_EXPORT void print(const std::string& s = "Potentials: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const; - - size_t cardinality(Key j) const { return cardinalities_.at(j);} - -// /** -// * @brief Permutes the keys in Potentials -// * -// * This permutes the Indices and performs necessary re-ordering of ADD. -// * This is virtual so that derived types e.g. DecisionTreeFactor can -// * re-implement it. -// */ -// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation); -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction); - - }; // Potentials - -// traits -template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; - - -} // namespace gtsam diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 94b160a29..146555898 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -38,19 +38,7 @@ namespace gtsam { using boost::phoenix::push_back; // Special rows, true and false - Signature::Row createF() { - Signature::Row r(2); - r[0] = 1; - r[1] = 0; - return r; - } - Signature::Row createT() { - Signature::Row r(2); - r[0] = 0; - r[1] = 1; - return r; - } - Signature::Row T = createT(), F = createF(); + Signature::Row F{1, 0}, T{0, 1}; // Special tables (inefficient, but do we care for user input?) Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { @@ -69,40 +57,13 @@ namespace gtsam { table = or_ | and_ | rows; or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + rows = +(row | true_ | false_); row = qi::double_ >> +("/" >> qi::double_); true_ = qi::lit("T")[qi::_val = T]; false_ = qi::lit("F")[qi::_val = F]; } } grammar; - // Create simpler parsing function to avoid the issue of only parsing a single row - bool parse_table(const string& spec, Signature::Table& table) { - // check for OR, AND on whole phrase - It f = spec.begin(), l = spec.end(); - if (qi::parse(f, l, - qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) || - qi::parse(f, l, - qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)])) - return true; - - // tokenize into separate rows - istringstream iss(spec); - string token; - while (iss >> token) { - Signature::Row values; - It tf = token.begin(), tl = token.end(); - bool r = qi::parse(tf, tl, - qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) | - qi::lit("T")[ph::ref(values) = T] | - qi::lit("F")[ph::ref(values) = F] ); - if (!r) - return false; - table.push_back(values); - } - - return true; - } } // \namespace parser ostream& operator <<(ostream &os, const Signature::Row &row) { @@ -118,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } @@ -166,14 +139,11 @@ namespace gtsam { Signature& Signature::operator=(const string& spec) { spec_.reset(spec); Table table; - // NOTE: using simpler parse function to ensure boost back compatibility -// parser::It f = spec.begin(), l = spec.end(); - bool success = // -// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar - parser::parse_table(spec, table); + parser::It f = spec.begin(), l = spec.end(); + bool success = + qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); if (success) { - for(Row& row: table) - normalize(row); + for (Row& row : table) normalize(row); table_.reset(table); } return *this; diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 6c59b5bff..ff83caa53 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,73 @@ namespace gtsam { boost::optional table_; public: + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } + /** the variable key */ + const DiscreteKey& key() const { return key_; } - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } - /** All key indices, with variable key first */ - KeyVector indices() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } + /** All key indices, with variable key first */ + KeyVector indices() const; - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; - /** Add the CPT spec - Fails in boost 1.40 */ - Signature& operator=(const std::string& spec); + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** @@ -122,7 +150,6 @@ namespace gtsam { /** * Helper function to create Signature objects * example: Signature s(D % "99/1"); - * Uses string parser, which requires BOOST 1.42 or higher */ GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i new file mode 100644 index 000000000..b3a12a8d5 --- /dev/null +++ b/gtsam/discrete/discrete.i @@ -0,0 +1,302 @@ +//************************************************************************* +// discrete +//************************************************************************* + +namespace gtsam { + + +#include +class DiscreteKey {}; + +class DiscreteKeys { + DiscreteKeys(); + size_t size() const; + bool empty() const; + gtsam::DiscreteKey at(size_t n) const; + void push_back(const gtsam::DiscreteKey& point_pair); +}; + +// DiscreteValues is added in specializations/discrete.h as a std::map +string markdown( + const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +string markdown(const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter, + std::map> names); +string html( + const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +string html(const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter, + std::map> names); + +#include +class DiscreteFactor { + void print(string s = "DiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + double operator()(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { + DecisionTreeFactor(); + + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + + void print(string s = "DecisionTreeFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + bool showZero = true) const; + std::vector> enumerate() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; + gtsam::DiscreteConditional marginal(gtsam::Key key) const; + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + gtsam::Key firstFrontalKey() const; + size_t nrFrontals() const; + size_t nrParents() const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(size_t value) const; + size_t sample() const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +virtual class DiscreteDistribution : gtsam::DiscreteConditional { + DiscreteDistribution(); + DiscreteDistribution(const gtsam::DecisionTreeFactor& f); + DiscreteDistribution(const gtsam::DiscreteKey& key, string spec); + DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector spec); + void print(string s = "Discrete Prior\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; + size_t argmax() const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); + void add(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues sample() const; + gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +class DiscreteBayesTreeClique { + DiscreteBayesTreeClique(); + DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); + const gtsam::DiscreteConditional* conditional() const; + bool isRoot() const; + void printSignature( + const string& s = "Clique: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const DiscreteBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +class DiscreteLookupDAG { + DiscreteLookupDAG(); + void push_back(const gtsam::DiscreteLookupTable* table); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteLookupTable* at(size_t i) const; + void print(string s = "DiscreteLookupDAG\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + gtsam::DiscreteValues argmax() const; + gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const; +}; + +#include +class DiscreteFactorGraph { + DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + + // Building the graph + void push_back(const gtsam::DiscreteFactor* factor); + void push_back(const gtsam::DiscreteConditional* conditional); + void push_back(const gtsam::DiscreteFactorGraph& graph); + void push_back(const gtsam::DiscreteBayesNet& bayesNet); + void push_back(const gtsam::DiscreteBayesTree& bayesTree); + void add(const gtsam::DiscreteKey& j, string spec); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKeys& keys, string spec); + void add(const std::vector& keys, string spec); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; + + gtsam::DecisionTreeFactor product() const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet sumProduct(); + gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); + + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); + + gtsam::DiscreteBayesNet* eliminateSequential(); + gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); + pair + eliminatePartialSequential(const gtsam::Ordering& ordering); + + gtsam::DiscreteBayesTree* eliminateMultifrontal(); + gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); + pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +} // namespace gtsam diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index be720dbca..6a3fb2388 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -17,37 +17,39 @@ */ #include -#include // make sure we have traits +#include // make sure we have traits +#include // headers first to make sure no missing headers -//#define DT_NO_PRUNING +//#define GTSAM_DT_NO_PRUNING #include -#include // for convert only +#include // for convert only #define DISABLE_TIMING -#include #include #include +#include using namespace boost::assign; #include -#include #include +#include using namespace std; using namespace gtsam; -/* ******************************************************************************** */ +/* ************************************************************************** */ typedef AlgebraicDecisionTree ADT; // traits namespace gtsam { -template<> struct traits : public Testable {}; -} +template <> +struct traits : public Testable {}; +} // namespace gtsam #define DISABLE_DOT -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -62,8 +64,8 @@ void dot(const T&f, const string& filename) { // If second argument of binary op is Leaf template - typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( - Cache& cache, const Leaf& gL, Mul op) const { + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const { Ptr h(new Choice(label(), cardinality())); for(const NodePtr& branch: branches_) h->push_back(branch->apply_f_op_g(cache, gL, op)); @@ -71,9 +73,9 @@ void dot(const T&f, const string& filename) { } */ -/* ******************************************************************************** */ +/* ************************************************************************** */ // instrumented operators -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t muls = 0, adds = 0; double elapsed; void resetCounts() { @@ -82,8 +84,9 @@ void resetCounts() { } void printCounts(const string& s) { #ifndef DISABLE_TIMING - cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds - % (1000 * elapsed) << endl; + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds % + (1000 * elapsed) + << endl; #endif resetCounts(); } @@ -96,12 +99,11 @@ double add_(const double& a, const double& b) { return a + b; } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test ADT -TEST(ADT, example3) -{ +TEST(ADT, example3) { // Create labels - DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2); // Literals ADT a(A, 0.5, 0.5); @@ -113,38 +115,37 @@ TEST(ADT, example3) ADT cnotb = c * notb; dot(cnotb, "ADT-cnotb"); -// a.print("a: "); -// cnotb.print("cnotb: "); + // a.print("a: "); + // cnotb.print("cnotb: "); ADT acnotb = a * cnotb; -// acnotb.print("acnotb: "); -// acnotb.printCache("acnotb Cache:"); + // acnotb.print("acnotb: "); + // acnotb.printCache("acnotb Cache:"); dot(acnotb, "ADT-acnotb"); - ADT big = apply(apply(d, note, &mul), acnotb, &add_); dot(big, "ADT-big"); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Asia Bayes Network -/* ******************************************************************************** */ +/* ************************************************************************** */ /** Convert Signature into CPT */ ADT create(const Signature& signature) { ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); - string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - dot(p, dotfile); + string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, DOTfile); return p; } /* ************************************************************************* */ // test Asia Joint -TEST(ADT, joint) -{ - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); +TEST(ADT, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); resetCounts(); gttic_(asiaCPTs); @@ -203,10 +204,9 @@ TEST(ADT, joint) /* ************************************************************************* */ // test Inference with joint -TEST(ADT, inference) -{ - DiscreteKey A(0,2), D(1,2),// - B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); +TEST(ADT, inference) { + DiscreteKey A(0, 2), D(1, 2), // + B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); resetCounts(); gttic_(infCPTs); @@ -243,7 +243,7 @@ TEST(ADT, inference) dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Joint-Product-ASTLBEXD"); - EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering + EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); elapsed = asiaProdNode->secs() + asiaProdNode->wall(); @@ -270,9 +270,8 @@ TEST(ADT, inference) } /* ************************************************************************* */ -TEST(ADT, factor_graph) -{ - DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); +TEST(ADT, factor_graph) { + DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); resetCounts(); gttic_(createCPTs); @@ -319,7 +318,7 @@ TEST(ADT, factor_graph) dot(fg, "Marginalized-3E"); fg = fg.combine(L, &add_); dot(fg, "Marginalized-2L"); - EXPECT(adds = 54); + LONGS_EQUAL(49, adds); gttoc_(marg); tictoc_getNode(margNode, marg); elapsed = margNode->secs() + margNode->wall(); @@ -402,50 +401,49 @@ TEST(ADT, factor_graph) /* ************************************************************************* */ // test equality -TEST(ADT, equality_noparser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_noparser) { + DiscreteKey A(0, 2), B(1, 2); Signature::Table tableA, tableB; Signature::Row rA, rB; - rA += 80, 20; rB += 60, 40; - tableA += rA; tableB += rB; + rA += 80, 20; + rB += 60, 40; + tableA += rA; + tableB += rB; // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ************************************************************************* */ // test equality -TEST(ADT, equality_parser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_parser) { + DiscreteKey A(0, 2), B(1, 2); // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Factor graph construction // test constructor from strings -TEST(ADT, constructor) -{ - DiscreteKey v0(0,2), v1(1,3); - Assignment x00, x01, x02, x10, x11, x12; +TEST(ADT, constructor) { + DiscreteKey v0(0, 2), v1(1, 3); + DiscreteValues x00, x01, x02, x10, x11, x12; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x02[0] = 0, x02[1] = 2; @@ -469,13 +467,12 @@ TEST(ADT, constructor) EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); - DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2); vector table(5 * 4 * 3 * 2); double x = 0; - for(double& t: table) - t = x++; + for (double& t : table) t = x++; ADT f3(z0 & z1 & z2 & z3, table); - Assignment assignment; + DiscreteValues assignment; assignment[0] = 0; assignment[1] = 0; assignment[2] = 0; @@ -486,9 +483,8 @@ TEST(ADT, constructor) /* ************************************************************************* */ // test conversion to integer indices // Only works if DiscreteKeys are binary, as size_t has binary cardinality! -TEST(ADT, conversion) -{ - DiscreteKey X(0,2), Y(1,2); +TEST(ADT, conversion) { + DiscreteKey X(0, 2), Y(1, 2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); dot(fDiscreteKey, "conversion-f1"); @@ -501,7 +497,7 @@ TEST(ADT, conversion) // f2.print("f2"); dot(fIndexKey, "conversion-f2"); - Assignment x00, x01, x02, x10, x11, x12; + DiscreteValues x00, x01, x02, x10, x11, x12; x00[5] = 0, x00[2] = 0; x01[5] = 0, x01[2] = 1; x10[5] = 1, x10[2] = 0; @@ -512,11 +508,10 @@ TEST(ADT, conversion) EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test operations in elimination -TEST(ADT, elimination) -{ - DiscreteKey A(0,2), B(1,3), C(2,2); +TEST(ADT, elimination) { + DiscreteKey A(0, 2), B(1, 3), C(2, 2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); dot(f1, "elimination-f1"); @@ -524,60 +519,58 @@ TEST(ADT, elimination) // sum out lower key ADT actualSum = f1.sum(C); ADT expectedSum(A & B, "3 7 11 9 6 10"); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // - 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } { // sum out lower 2 keys ADT actualSum = f1.sum(C).sum(B); ADT expectedSum(A, 21, 25); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // - 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test non-commutative op -TEST(ADT, div) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, div) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 8, 16); ADT b(B, 2, 4); - ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 - ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_b_div_a, b / a)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test zero shortcut -TEST(ADT, zero) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, zero) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 0, 1); ADT notb(B, 1, 0); ADT anotb = a * notb; // GTSAM_PRINT(anotb); - Assignment x00, x01, x10, x11; + DiscreteValues x00, x01, x10, x11; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x10[0] = 1, x10[1] = 0; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..5ccbcf916 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -17,67 +17,108 @@ * @date Jan 30, 2012 */ -#include -using namespace boost::assign; +// #define DT_DEBUG_MEMORY +// #define GTSAM_DT_NO_PRUNING +#define DISABLE_DOT +#include -#include #include #include -//#define DT_DEBUG_MEMORY -//#define DT_NO_PRUNING -#define DISABLE_DOT -#include +#include + +#include +using namespace boost::assign; + using namespace std; using namespace gtsam; -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } -#define DOT(x)(dot(x,#x)) +#define DOT(x) (dot(x, #x)) -struct Crazy { int a; double b; }; -typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) +struct Crazy { + int a; + double b; +}; -// traits -namespace gtsam { -template<> struct traits : public Testable {}; -} - -/* ******************************************************************************** */ -// Test string labels and int range -/* ******************************************************************************** */ - -typedef DecisionTree DT; - -// traits -namespace gtsam { -template<> struct traits
: public Testable
{}; -} - -struct Ring { - static inline int zero() { - return 0; +struct CrazyDecisionTree : public DecisionTree { + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const Crazy& v) { + return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); } - static inline int one() { - return 1; - } - static inline int add(const int& a, const int& b) { - return a + b; - } - static inline int mul(const int& a, const int& b) { - return a * b; + /// Equality method customized to Crazy node type + bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + auto compare = [tol](const Crazy& v, const Crazy& w) { + return v.a == w.a && std::abs(v.b - w.b) < tol; + }; + return DecisionTree::equals(other, compare); } }; -/* ******************************************************************************** */ +// traits +namespace gtsam { +template <> +struct traits : public Testable {}; +} // namespace gtsam + +GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) + +/* ************************************************************************** */ +// Test string labels and int range +/* ************************************************************************** */ + +struct DT : public DecisionTree { + using Base = DecisionTree; + using DecisionTree::DecisionTree; + DT() = default; + + DT(const Base& dt) : Base(dt) {} + + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const int& v) { + return (boost::format("%d") % v).str(); + }; + std::cout << s; + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); + } +}; + +// traits +namespace gtsam { +template <> +struct traits
: public Testable
{}; +} // namespace gtsam + +GTSAM_CONCEPT_TESTABLE_INST(DT) + +struct Ring { + static inline int zero() { return 0; } + static inline int one() { return 1; } + static inline int id(const int& a) { return a; } + static inline int add(const int& a, const int& b) { return a + b; } + static inline int mul(const int& a, const int& b) { return a * b; } +}; + +/* ************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -88,54 +129,62 @@ TEST(DT, example) x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; + // empty + DT empty; + // A DT a(A, 0, 5); - LONGS_EQUAL(0,a(x00)) - LONGS_EQUAL(5,a(x10)) + LONGS_EQUAL(0, a(x00)) + LONGS_EQUAL(5, a(x10)) DOT(a); // pruned DT p(A, 2, 2); - LONGS_EQUAL(2,p(x00)) - LONGS_EQUAL(2,p(x10)) + LONGS_EQUAL(2, p(x00)) + LONGS_EQUAL(2, p(x10)) DOT(p); // \neg B DT notb(B, 5, 0); - LONGS_EQUAL(5,notb(x00)) - LONGS_EQUAL(5,notb(x10)) + LONGS_EQUAL(5, notb(x00)) + LONGS_EQUAL(5, notb(x10)) DOT(notb); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(gtsam::apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(gtsam::apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(gtsam::apply(a, empty, &Ring::mul), std::runtime_error); + // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,anotb(x00)) - LONGS_EQUAL(0,anotb(x01)) - LONGS_EQUAL(25,anotb(x10)) - LONGS_EQUAL(0,anotb(x11)) + LONGS_EQUAL(0, anotb(x00)) + LONGS_EQUAL(0, anotb(x01)) + LONGS_EQUAL(25, anotb(x10)) + LONGS_EQUAL(0, anotb(x11)) DOT(anotb); // check pruning DT pnotb = apply(p, notb, &Ring::mul); - LONGS_EQUAL(10,pnotb(x00)) - LONGS_EQUAL( 0,pnotb(x01)) - LONGS_EQUAL(10,pnotb(x10)) - LONGS_EQUAL( 0,pnotb(x11)) + LONGS_EQUAL(10, pnotb(x00)) + LONGS_EQUAL(0, pnotb(x01)) + LONGS_EQUAL(10, pnotb(x10)) + LONGS_EQUAL(0, pnotb(x11)) DOT(pnotb); // check pruning DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); - LONGS_EQUAL(0,zeros(x00)) - LONGS_EQUAL(0,zeros(x01)) - LONGS_EQUAL(0,zeros(x10)) - LONGS_EQUAL(0,zeros(x11)) + LONGS_EQUAL(0, zeros(x00)) + LONGS_EQUAL(0, zeros(x01)) + LONGS_EQUAL(0, zeros(x10)) + LONGS_EQUAL(0, zeros(x11)) DOT(zeros); // apply, two nodes, in switched order DT notba = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,notba(x00)) - LONGS_EQUAL(0,notba(x01)) - LONGS_EQUAL(25,notba(x10)) - LONGS_EQUAL(0,notba(x11)) + LONGS_EQUAL(0, notba(x00)) + LONGS_EQUAL(0, notba(x01)) + LONGS_EQUAL(25, notba(x10)) + LONGS_EQUAL(0, notba(x11)) DOT(notba); // Test choose 0 @@ -150,10 +199,10 @@ TEST(DT, example) // apply, two nodes at same level DT a_and_a = apply(a, a, &Ring::mul); - LONGS_EQUAL(0,a_and_a(x00)) - LONGS_EQUAL(0,a_and_a(x01)) - LONGS_EQUAL(25,a_and_a(x10)) - LONGS_EQUAL(25,a_and_a(x11)) + LONGS_EQUAL(0, a_and_a(x00)) + LONGS_EQUAL(0, a_and_a(x01)) + LONGS_EQUAL(25, a_and_a(x10)) + LONGS_EQUAL(25, a_and_a(x11)) DOT(a_and_a); // create a function on C @@ -165,27 +214,42 @@ TEST(DT, example) // mul notba with C DT notbac = apply(notba, c, &Ring::mul); - LONGS_EQUAL(125,notbac(x101)) + LONGS_EQUAL(125, notbac(x101)) DOT(notbac); // mul now in different order DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); - LONGS_EQUAL(125,acnotb(x101)) + LONGS_EQUAL(125, acnotb(x101)) DOT(acnotb); } -/* ******************************************************************************** */ -// test Conversion -enum Label { - U, V, X, Y, Z -}; -typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; +/* ************************************************************************** */ +// test Conversion of values +bool bool_of_int(const int& y) { return y != 0; }; +typedef DecisionTree StringBoolTree; + +TEST(DecisionTree, ConvertValuesOnly) { + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + StringBoolTree f2(f1, bool_of_int); + + // Check a value + Assignment x00; + x00["A"] = 0, x00["B"] = 0; + EXPECT(!f2(x00)); } -TEST(DT, conversion) -{ +/* ************************************************************************** */ +// test Conversion of both values and labels. +enum Label { U, V, X, Y, Z }; +typedef DecisionTree LabelBoolTree; + +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -196,12 +260,9 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); - // f1.print("f1"); - // f2.print("f2"); + LabelBoolTree f2(f1, ordering, &bool_of_int); - // create a value + // Check some values Assignment
\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); } /* ************************************************************************* */ @@ -104,4 +229,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 2b440e5a0..19af676f7 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -38,21 +38,26 @@ using namespace boost::assign; using namespace std; using namespace gtsam; +static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), + LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); + +using ADT = AlgebraicDecisionTree; + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); auto prior = boost::make_shared(Parent % "6/4"); - CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), - (Potentials::ADT)*prior)); + CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), + (ADT)*prior)); bayesNet.push_back(prior); auto conditional = boost::make_shared(Child | Parent = "7/3 8/2"); EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); - Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); - CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); + ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); + CHECK(assert_equal(expected, (ADT)*conditional)); bayesNet.push_back(conditional); DiscreteFactorGraph fg(bayesNet); @@ -71,11 +76,9 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), - Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - asia.add(Asia % "99/1"); - asia.add(Smoking % "50/50"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); @@ -103,39 +106,26 @@ TEST(DiscreteBayesNet, Asia) { DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); - // solve - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 0); - EXPECT(assert_equal(expectedMPE, *actualMPE)); - // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); fg.add(Dyspnea, "0 1"); // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); - DiscreteFactor::Values expectedMPE2; - insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 1); - EXPECT(assert_equal(expectedMPE2, *actualMPE2)); + EXPECT(assert_equal(expected2, *chordal->back())); // now sample from it - DiscreteFactor::Values expectedSample; + DiscreteValues expectedSample; SETDEBUG("DiscreteConditional::sample", false); insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( LungCancer.first, 1)(Bronchitis.first, 0); - DiscreteFactor::sharedValues actualSample = chordal2->sample(); - EXPECT(assert_equal(expectedSample, *actualSample)); + auto actualSample = chordal2->sample(); + EXPECT(assert_equal(expectedSample, actualSample)); } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Sugar) { +TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); DiscreteBayesNet bn; @@ -149,6 +139,60 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) { bn.add(C | S = "1/1/2 5/2/3"); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Dot) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking % "50/50"); + + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + string actual = fragment.dot(); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var3[label=\"3\"];\n" + " var4[label=\"4\"];\n" + " var5[label=\"5\"];\n" + " var6[label=\"6\"];\n" + "\n" + " var3->var5\n" + " var6->var5\n" + " var4->var6\n" + " var0->var3\n" + "}"); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteBayesNet, markdown) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking | Asia = "8/2 7/3"); + + string expected = + "`DiscreteBayesNet` of size 2\n" + "\n" + " *P(Asia):*\n\n" + "|Asia|value|\n" + "|:-:|:-:|\n" + "|0|0.99|\n" + "|1|0.01|\n" + "\n" + " *P(Smoking|Asia):*\n\n" + "|*Asia*|0|1|\n" + "|:-:|:-:|:-:|\n" + "|0|0.8|0.2|\n" + "|1|0.7|0.3|\n\n"; + auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + string actual = fragment.markdown(formatter); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index ecf485036..6635633a2 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -26,88 +26,101 @@ using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; - -static bool debug = false; +static constexpr bool debug = false; /* ************************************************************************* */ - -TEST_UNSAFE(DiscreteBayesTree, ThinTree) { - const int nrNodes = 15; - const size_t nrStates = 2; - - // define variables - vector key; - for (int i = 0; i < nrNodes; i++) { - DiscreteKey key_i(i, nrStates); - key.push_back(key_i); - } - - // create a thin-tree Bayesnet, a la Jean-Guillaume +struct TestFixture { + vector keys; DiscreteBayesNet bayesNet; - bayesNet.add(key[14] % "1/3"); + boost::shared_ptr bayesTree; - bayesNet.add(key[13] | key[14] = "1/3 3/1"); - bayesNet.add(key[12] | key[14] = "3/1 3/1"); + /** + * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student), + * and then create the Bayes tree from it. + */ + TestFixture() { + // Define variables. + for (int i = 0; i < 15; i++) { + DiscreteKey key_i(i, 2); + keys.push_back(key_i); + } - bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); - bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); + // Create thin-tree Bayesnet. + bayesNet.add(keys[14] % "1/3"); - bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); + bayesNet.add(keys[13] | keys[14] = "1/3 3/1"); + bayesNet.add(keys[12] | keys[14] = "3/1 3/1"); - bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); - bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); + bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4"); + bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1"); + + bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1"); + + bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1"); + bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1"); + + // Create a BayesTree out of the Bayes net. + bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); + } +}; + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, ThinTree) { + const TestFixture self; + const auto& keys = self.keys; if (debug) { - GTSAM_PRINT(bayesNet); - bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); + GTSAM_PRINT(self.bayesNet); + self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } // create a BayesTree out of a Bayes net - auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); if (debug) { - GTSAM_PRINT(*bayesTree); - bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); + GTSAM_PRINT(*self.bayesTree); + self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } // Check frontals and parents for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { - auto clique_i = (*bayesTree)[i]; + auto clique_i = (*self.bayesTree)[i]; EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = bayesTree->roots().front(); + auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & - key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); + auto allPosbValues = DiscreteValues::CartesianProduct( + keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] & + keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] & + keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double expected = bayesNet.evaluate(x); - double actual = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double expected = self.bayesNet.evaluate(x); + double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } - // Calculate all some marginals for Values==all1 + // Calculate all some marginals for DiscreteValues==all1 Vector marginals = Vector::Zero(15); double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double px = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double px = self.bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; if (x[12] && x[14]) { @@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { } } } - DiscreteFactor::Values all1 = allPosbValues.back(); + DiscreteValues all1 = allPosbValues.back(); // check separator marginal P(S0) - auto clique = (*bayesTree)[0]; + auto clique = (*self.bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - clique = (*bayesTree)[11]; + clique = (*self.bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - clique = (*bayesTree)[8]; + clique = (*self.bayesTree)[8]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - clique = (*bayesTree)[2]; + clique = (*self.bayesTree)[2]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - clique = (*bayesTree)[0]; + clique = (*self.bayesTree)[0]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); for (auto clique : cliques) { DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { @@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } @@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) - actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); // Check joint P(1, 2) - actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); // Check joint P(2, 4) - actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 5) - actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 6) - actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 11) - actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, Dot) { + const TestFixture self; + string actual = self.bayesTree->dot(); + EXPECT(actual == + "digraph G{\n" + "0[label=\"13, 11, 6, 7\"];\n" + "0->1\n" + "1[label=\"14 : 11, 13\"];\n" + "1->2\n" + "2[label=\"9, 12 : 14\"];\n" + "2->3\n" + "3[label=\"3 : 9, 12\"];\n" + "2->4\n" + "4[label=\"2 : 9, 12\"];\n" + "2->5\n" + "5[label=\"8 : 12, 14\"];\n" + "5->6\n" + "6[label=\"1 : 8, 12\"];\n" + "5->7\n" + "7[label=\"0 : 8, 12\"];\n" + "1->8\n" + "8[label=\"10 : 13, 14\"];\n" + "8->9\n" + "9[label=\"5 : 10, 13\"];\n" + "8->10\n" + "10[label=\"4 : 10, 13\"];\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9e..13a34dd19 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -10,10 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * @file testDecisionTreeFactor.cpp + * @file testDiscreteConditional.cpp * @brief unit tests for DiscreteConditional * @author Duy-Nguyen Ta - * @date Feb 14, 2011 + * @author Frank dellaert + * @date Feb 14, 2011 */ #include @@ -24,31 +25,30 @@ using namespace boost::assign; #include #include #include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditional, constructors) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! +TEST(DiscreteConditional, constructors) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! + + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = "1/1 2/3 1/4"); - EXPECT(expected1); - EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); - EXPECT(expected1->endParents() == expected1->end()); - EXPECT(expected1->endFrontals() == expected1->beginParents()); - DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,50 +61,314 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { - DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional actual(2, factor); - auto expected = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(*expected, actual, 1e-5)); +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + // And for good measure: + EXPECT(assert_equal(expected, actual)); + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT(actualA.frontals() == KeyVector{1}); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT(actualB.frontals() == KeyVector{0}); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); +} + +/* ************************************************************************* */ +// Check calculation of marginals in case branches are pruned +TEST(DiscreteConditional, marginals2) { + DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen! + DiscreteConditional conditional(A | B = "2/2 3/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + GTSAM_PRINT(pAB); + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "8/4"); + EXPECT(assert_equal(pA, actualA)); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); +} + +/* ************************************************************************* */ +TEST(DiscreteConditional, likelihood) { + DiscreteKey X(0, 2), Y(1, 3); + DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); + + auto actual0 = conditional.likelihood(0); + DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); + EXPECT(assert_equal(expected0, *actual0, 1e-9)); + + auto actual1 = conditional.likelihood(1); + DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); +} + +/* ************************************************************************* */ +// Check choose on P(C|D,E) +TEST(DiscreteConditional, choose) { + DiscreteKey C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // Case 1: no given values: no-op + DiscreteValues given; + auto actual1 = C_given_DE.choose(given); + EXPECT(assert_equal(C_given_DE, *actual1, 1e-9)); + + // Case 2: 1 given value + given[D.first] = 1; + auto actual2 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual2->nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual2->nrParents()); + DiscreteConditional expected2(C | E = "1/1 1/4"); + EXPECT(assert_equal(expected2, *actual2, 1e-9)); + + // Case 2: 2 given values + given[E.first] = 0; + auto actual3 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual3->nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual3->nrParents()); + DiscreteConditional expected3(C % "1/1"); + EXPECT(assert_equal(expected3, *actual3, 1e-9)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents. +TEST(DiscreteConditional, markdown_prior) { + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1):*\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|0|0.2|\n" + "|1|0.4|\n" + "|2|0.4|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents + names. +TEST(DiscreteConditional, markdown_prior_names) { + Symbol x1('x', 1); + DiscreteKey A(x1, 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1):*\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|A0|0.2|\n" + "|A1|0.4|\n" + "|A2|0.4|\n"; + DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}}; + string actual = conditional.markdown(DefaultKeyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, multivalued. +TEST(DiscreteConditional, markdown_multivalued) { + DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5); + DiscreteConditional conditional( + A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); + string expected = + " *P(a1|b1):*\n\n" + "|*b1*|0|1|2|\n" + "|:-:|:-:|:-:|:-:|\n" + "|0|0.02|0.88|0.1|\n" + "|1|0.02|0.2|0.78|\n" + "|2|0.33|0.33|0.34|\n" + "|3|0.33|0.33|0.34|\n" + "|4|0.95|0.02|0.03|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, two parents + names. +TEST(DiscreteConditional, markdown) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + " *P(A|B,C):*\n\n" + "|*B*|*C*|T|F|\n" + "|:-:|:-:|:-:|:-:|\n" + "|-|Zero|0|1|\n" + "|-|One|0.25|0.75|\n" + "|-|Two|0.5|0.5|\n" + "|+|Zero|0.75|0.25|\n" + "|+|One|0|1|\n" + "|+|Two|1|0|\n"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.markdown(formatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation looks as expected, two parents + names. +TEST(DiscreteConditional, html) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + "
\n" + "

P(A|B,C):

\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
BCTF
-Zero01
-One0.250.75
-Two0.50.5
+Zero0.750.25
+One01
+Two10
\n" + "
"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.html(formatter, names); + EXPECT(actual == expected); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp new file mode 100644 index 000000000..d88b510f8 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -0,0 +1,88 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testDiscreteDistribution.cpp + * @brief unit tests for DiscreteDistribution + * @author Frank dellaert + * @date December 2021 + */ + +#include +#include +#include + +using namespace gtsam; + +static const DiscreteKey X(0, 2); + +/* ************************************************************************* */ +TEST(DiscreteDistribution, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscreteDistribution expected(f); + + DiscreteDistribution actual(X % "2/3"); + EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual.nrParents()); + EXPECT(assert_equal(expected, actual, 1e-9)); + + const std::vector pmf{0.4, 0.6}; + DiscreteDistribution actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteDistribution prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, operator) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); + EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, pmf) { + DiscreteDistribution prior(X % "2/3"); + std::vector expected{0.4, 0.6}; + EXPECT(prior.pmf() == expected); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, sample) { + DiscreteDistribution prior(X % "2/3"); + prior.sample(); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, argmax) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_LONGS_EQUAL(prior.argmax(), 1); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 1defd5acf..0a7d869ec 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -30,8 +30,8 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { - DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); +TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { + DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteFactorGraph graph; graph.add(AI, "1 0 0 1"); @@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); -// graph.print("Graph: "); - DecisionTreeFactor product = graph.product(); - DecisionTreeFactor::shared_ptr sum = product.sum(1); -// sum->print("Debug SUM: "); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); - -// cond->print("marginal:"); - -// pair result = EliminateDiscrete(graph, 1); -// result.first->print("BayesNet: "); -// result.second->print("New factor: "); -// - Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3); - DiscreteEliminationTree eliminationTree(graph, ordering); -// eliminationTree.print("Elimination tree: "); - eliminationTree.eliminate(EliminateDiscrete); -// solver.optimize(); -// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate(); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); + EXPECT(assert_equal(mpe, actualMPE)); } /* ************************************************************************* */ @@ -81,8 +67,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { graph.add(P2, "0.9 0.6"); graph.add(P1 & P2, "4 1 10 4"); - // Instantiate Values - DiscreteFactor::Values values; + // Instantiate DiscreteValues + DiscreteValues values; values[0] = 1; values[1] = 1; @@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, test) -{ +TEST(DiscreteFactorGraph, test) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); // A simple factor graph (A)-fAC-(C)-fBC-(B) // with smoothness priors @@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test) graph.add(C & B, "3 1 1 3"); // Test EliminateDiscrete - // FIXME: apparently Eliminate returns a conditional rather than a net Ordering frontalKeys; frontalKeys += Key(0); DiscreteConditional::shared_ptr conditional; DecisionTreeFactor::shared_ptr newFactor; boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); - // Check Bayes net + // Check Conditional CHECK(conditional); - DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - // cout << signature << endl; DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); - expected.add(signature); // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); EXPECT(assert_equal(expectedFactor, *newFactor)); - // add conditionals to complete expected Bayes net - expected.add(B | A = "5/3 3/5"); - expected.add(A % "1/1"); - // GTSAM_PRINT(expected); - - // Test elimination tree + // Test using elimination tree Ordering ordering; ordering += Key(0), Key(1), Key(2); DiscreteEliminationTree etree(graph, ordering); DiscreteBayesNet::shared_ptr actual; DiscreteFactorGraph::shared_ptr remainingGraph; boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); - EXPECT(assert_equal(expected, *actual)); -// // Test solver -// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); -// EXPECT(assert_equal(expected, *actual2)); + // Check Bayes net + DiscreteBayesNet expectedBayesNet; + expectedBayesNet.add(signature); + expectedBayesNet.add(B | A = "5/3 3/5"); + expectedBayesNet.add(A % "1/1"); + EXPECT(assert_equal(expectedBayesNet, *actual)); - // Test optimization - DiscreteFactor::Values expectedValues; - insert(expectedValues)(0, 0)(1, 0)(2, 0); - DiscreteFactor::sharedValues actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, *actualValues)); + // Test eliminateSequential + DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); + EXPECT(assert_equal(expectedBayesNet, *actual2)); + + // Test mpe + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Test sumProduct alias with all orderings: + auto mpeProbability = expectedBayesNet(mpe); + EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + + // Using custom ordering + DiscreteBayesNet bayesNet = graph.sumProduct(ordering); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + auto bayesNet = graph.sumProduct(orderingType); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + } } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE) -{ +TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) { // Declare a bunch of keys - DiscreteKey C(0,2), A(1,2), B(2,2); + DiscreteKey C(0, 2), A(1, 2), B(2, 2); // Create Factor graph DiscreteFactorGraph graph; graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // graph.product().print(); - // DiscreteSequentialSolver(graph).eliminate()->print(); - DiscreteFactor::sharedValues actualMPE = graph.optimize(); + // Created expected MPE + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 1)(2, 1); - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, *actualMPE)); + // Do max-product with different orderings + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + DiscreteLookupDAG dag = graph.maxProduct(orderingType); + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); + auto actualMPE2 = graph.optimize(); // all in one + EXPECT(assert_equal(mpe, actualMPE2)); + } } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -{ +TEST(DiscreteFactorGraph, marginalIsNotMPE) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + // Optimize on BayesNet maximizes marginal, then the conditional marginals: + auto notOptimal = bayesNet.optimize(); + EXPECT(graph(notOptimal) < graph(mpe)); + EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +#endif +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { // The factor graph in Darwiche09book, page 244 - DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); + DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2); // Create Factor graph DiscreteFactorGraph graph; @@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); - graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) - //graph.product().print("Darwiche-product"); - // graph.product().potentials().dot("Darwiche-product"); - // DiscreteSequentialSolver(graph).eliminate()->print(); + graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); + DiscreteValues mpe; + insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1); + EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression + // You can check visually by printing product: + // graph.product().print("Darwiche-product"); - // Use the solver machinery. - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, *actualMPE)); -// DiscreteConditional::shared_ptr root = chordal->back(); -// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); - - // Let us create the Bayes tree here, just for fun, because we don't use it now -// typedef JunctionTreeOrdered JT; -// GenericMultifrontalSolver solver(graph); -// BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); -//// bayesTree->print("Bayes Tree"); -// EXPECT_LONGS_EQUAL(2,bayesTree->size()); + // Check MPE. + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4); - DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); - // bayesTree->print("Bayes Tree"); - EXPECT_LONGS_EQUAL(2,bayesTree->size()); - -#ifdef OLD -// Create the elimination tree manually -VariableIndexOrdered structure(graph); -typedef EliminationTreeOrdered ETree; -ETree::shared_ptr eTree = ETree::Create(graph, structure); -//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<"); - -// eliminate normally and check solution -DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); -// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, *actualMPE)); - -// Approximate and check solution -// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); -// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<"); -// EXPECT(assert_equal(expectedMPE, *actualMPE)); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4); + auto chordal = graph.eliminateSequential(ordering); + EXPECT_LONGS_EQUAL(5, chordal->size()); +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + auto notOptimal = chordal->optimize(); // not MPE ! + EXPECT(graph(notOptimal) < graph(mpe)); #endif + + // Let us create the Bayes tree here, just for fun, because we don't use it + DiscreteBayesTree::shared_ptr bayesTree = + graph.eliminateMultifrontal(ordering); + // bayesTree->print("Bayes Tree"); + EXPECT_LONGS_EQUAL(2, bayesTree->size()); } + #ifdef OLD /* ************************************************************************* */ @@ -359,6 +373,100 @@ cout << unicorns; } #endif +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, Dot) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string actual = graph.dot(); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var1[label=\"1\"];\n" + " var2[label=\"2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var0--factor0;\n" + " var1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " var0--factor1;\n" + " var2--factor1;\n" + "}\n"; + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.dot(formatter); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varC[label=\"C\"];\n" + " varA[label=\"A\"];\n" + " varB[label=\"B\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varC--factor0;\n" + " varA--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varC--factor1;\n" + " varB--factor1;\n" + "}\n"; + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteFactorGraph, markdown) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string expected = + "`DiscreteFactorGraph` of size 2\n" + "\n" + "factor 0:\n" + "|C|A|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.2|\n" + "|0|1|0.8|\n" + "|1|0|0.3|\n" + "|1|1|0.7|\n" + "\n" + "factor 1:\n" + "|C|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.1|\n" + "|0|1|0.9|\n" + "|1|0|0.4|\n" + "|1|1|0.6|\n\n"; + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.markdown(formatter); + EXPECT(actual == expected); + + // Make sure values are correctly displayed. + DiscreteValues values; + values[0] = 1; + values[1] = 0; + EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteLookupDAG.cpp b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp new file mode 100644 index 000000000..04b859780 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteLookupDAG.cpp + * + * @date January, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using namespace gtsam; +using namespace boost::assign; + +/* ************************************************************************* */ +TEST(DiscreteLookupDAG, argmax) { + using ADT = AlgebraicDecisionTree; + + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create lookup table corresponding to "marginalIsNotMPE" in testDFG. + DiscreteLookupDAG dag; + + ADT adtB(DiscreteKeys{B, A}, std::vector{0.5, 1. / 3, 0.5, 2. / 3}); + dag.add(1, DiscreteKeys{B, A}, adtB); + + ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19)); + dag.add(1, DiscreteKeys{A}, adtA); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // check: + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index e1eb92af3..3208f81c5 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(Cathy.first); - DiscreteFactor::Values values; + DiscreteValues values; values[Cathy.first] = 0; EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6); @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(key[2].first); - DiscreteFactor::Values values; + DiscreteValues values; values[key[2].first] = 0; EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4); @@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) { graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); // Calculate the marginals by brute force - vector allPosbValues = - cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); + auto allPosbValues = DiscreteValues::CartesianProduct( + key[0] & key[1] & key[2] & key[3] & key[4]); Vector T = Z_5x1, F = Z_5x1; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; + DiscreteValues x = allPosbValues[i]; double px = graph(x); for (size_t j = 0; j < 5; j++) if (x[j]) diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp new file mode 100644 index 000000000..c8a1fa168 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -0,0 +1,76 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteValues.cpp + * + * @date Jan, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +#include +using namespace boost::assign; + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DiscreteValues, markdownWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "|Variable|value|\n" + "|:-:|:-:|\n" + "|B|-|\n" + "|A|One|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(DiscreteValues, htmlWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
Variablevalue
B-
AOne
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f7..737bd8aef 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} + +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); +} + +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ diff --git a/gtsam/geometry/BearingRange.h b/gtsam/geometry/BearingRange.h index 8db7abffe..95b0232f0 100644 --- a/gtsam/geometry/BearingRange.h +++ b/gtsam/geometry/BearingRange.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace gtsam { diff --git a/gtsam/geometry/Cal3.h b/gtsam/geometry/Cal3.h index 08ce4c1e6..1690615dd 100644 --- a/gtsam/geometry/Cal3.h +++ b/gtsam/geometry/Cal3.h @@ -170,9 +170,9 @@ class GTSAM_EXPORT Cal3 { return K; } -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** @deprecated The following function has been deprecated, use K above */ - Matrix3 matrix() const { return K(); } + Matrix3 GTSAM_DEPRECATED matrix() const { return K(); } #endif /// Return inverted calibration matrix inv(K) diff --git a/gtsam/geometry/Cal3Bundler.h b/gtsam/geometry/Cal3Bundler.h index 0d7c1be9d..82b5ec91d 100644 --- a/gtsam/geometry/Cal3Bundler.h +++ b/gtsam/geometry/Cal3Bundler.h @@ -41,6 +41,9 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 { public: enum { dimension = 3 }; + ///< shared pointer to stereo calibration object + using shared_ptr = boost::shared_ptr; + /// @name Standard Constructors /// @{ @@ -97,12 +100,12 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 { Vector3 vector() const; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// get parameter u0 - inline double u0() const { return u0_; } + inline double GTSAM_DEPRECATED u0() const { return u0_; } /// get parameter v0 - inline double v0() const { return v0_; } + inline double GTSAM_DEPRECATED v0() const { return v0_; } #endif /** diff --git a/gtsam/geometry/Cal3DS2.h b/gtsam/geometry/Cal3DS2.h index f756cba5e..039497cc9 100644 --- a/gtsam/geometry/Cal3DS2.h +++ b/gtsam/geometry/Cal3DS2.h @@ -21,6 +21,7 @@ #pragma once #include +#include namespace gtsam { @@ -37,6 +38,9 @@ class GTSAM_EXPORT Cal3DS2 : public Cal3DS2_Base { public: enum { dimension = 9 }; + ///< shared pointer to stereo calibration object + using shared_ptr = boost::shared_ptr; + /// @name Standard Constructors /// @{ diff --git a/gtsam/geometry/Cal3DS2_Base.h b/gtsam/geometry/Cal3DS2_Base.h index a9b09cf46..1b2291e07 100644 --- a/gtsam/geometry/Cal3DS2_Base.h +++ b/gtsam/geometry/Cal3DS2_Base.h @@ -21,6 +21,7 @@ #include #include +#include namespace gtsam { @@ -47,6 +48,9 @@ class GTSAM_EXPORT Cal3DS2_Base : public Cal3 { public: enum { dimension = 9 }; + ///< shared pointer to stereo calibration object + using shared_ptr = boost::shared_ptr; + /// @name Standard Constructors /// @{ diff --git a/gtsam/geometry/Cal3Fisheye.cpp b/gtsam/geometry/Cal3Fisheye.cpp index 52d475d5d..fd2c7ab65 100644 --- a/gtsam/geometry/Cal3Fisheye.cpp +++ b/gtsam/geometry/Cal3Fisheye.cpp @@ -46,9 +46,9 @@ double Cal3Fisheye::Scaling(double r) { /* ************************************************************************* */ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1, OptionalJacobian<2, 2> H2) const { - const double xi = p.x(), yi = p.y(); + const double xi = p.x(), yi = p.y(), zi = 1; const double r2 = xi * xi + yi * yi, r = sqrt(r2); - const double t = atan(r); + const double t = atan2(r, zi); const double t2 = t * t, t4 = t2 * t2, t6 = t2 * t4, t8 = t4 * t4; Vector5 K, T; K << 1, k1_, k2_, k3_, k4_; @@ -76,28 +76,32 @@ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1, // Derivative for points in intrinsic coords (2 by 2) if (H2) { - const double dtd_dt = - 1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8; - const double dt_dr = 1 / (1 + r2); - const double rinv = 1 / r; - const double dr_dxi = xi * rinv; - const double dr_dyi = yi * rinv; - const double dtd_dxi = dtd_dt * dt_dr * dr_dxi; - const double dtd_dyi = dtd_dt * dt_dr * dr_dyi; + if (r2==0) { + *H2 = DK; + } else { + const double dtd_dt = + 1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8; + const double R2 = r2 + zi*zi; + const double dt_dr = zi / R2; + const double rinv = 1 / r; + const double dr_dxi = xi * rinv; + const double dr_dyi = yi * rinv; + const double dtd_dr = dtd_dt * dt_dr; + + const double c2 = dr_dxi * dr_dxi; + const double s2 = dr_dyi * dr_dyi; + const double cs = dr_dxi * dr_dyi; - const double td = t * K.dot(T); - const double rrinv = 1 / r2; - const double dxd_dxi = - dtd_dxi * dr_dxi + td * rinv - td * xi * rrinv * dr_dxi; - const double dxd_dyi = dtd_dyi * dr_dxi - td * xi * rrinv * dr_dyi; - const double dyd_dxi = dtd_dxi * dr_dyi - td * yi * rrinv * dr_dxi; - const double dyd_dyi = - dtd_dyi * dr_dyi + td * rinv - td * yi * rrinv * dr_dyi; + const double dxd_dxi = dtd_dr * c2 + s * (1 - c2); + const double dxd_dyi = (dtd_dr - s) * cs; + const double dyd_dxi = dxd_dyi; + const double dyd_dyi = dtd_dr * s2 + s * (1 - s2); - Matrix2 DR; - DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi; + Matrix2 DR; + DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi; - *H2 = DK * DR; + *H2 = DK * DR; + } } return uv; diff --git a/gtsam/geometry/Cal3Fisheye.h b/gtsam/geometry/Cal3Fisheye.h index a8c9fa182..c0caecaa1 100644 --- a/gtsam/geometry/Cal3Fisheye.h +++ b/gtsam/geometry/Cal3Fisheye.h @@ -22,6 +22,8 @@ #include #include +#include + #include namespace gtsam { diff --git a/gtsam/geometry/Cal3Unified.h b/gtsam/geometry/Cal3Unified.h index f07ca0a54..e93d313c8 100644 --- a/gtsam/geometry/Cal3Unified.h +++ b/gtsam/geometry/Cal3Unified.h @@ -52,6 +52,9 @@ class GTSAM_EXPORT Cal3Unified : public Cal3DS2_Base { public: enum { dimension = 10 }; + ///< shared pointer to stereo calibration object + using shared_ptr = boost::shared_ptr; + /// @name Standard Constructors /// @{ diff --git a/gtsam/geometry/Cyclic.h b/gtsam/geometry/Cyclic.h index 35578ffe0..065cd0140 100644 --- a/gtsam/geometry/Cyclic.h +++ b/gtsam/geometry/Cyclic.h @@ -15,6 +15,8 @@ * @author Frank Dellaert **/ +#pragma once + #include #include diff --git a/gtsam/geometry/Line3.cpp b/gtsam/geometry/Line3.cpp index e3b4841e0..9e7b2e13e 100644 --- a/gtsam/geometry/Line3.cpp +++ b/gtsam/geometry/Line3.cpp @@ -117,4 +117,4 @@ Line3 transformTo(const Pose3 &wTc, const Line3 &wL, return Line3(cRl, c_ab[0], c_ab[1]); } -} \ No newline at end of file +} // namespace gtsam diff --git a/gtsam/geometry/Line3.h b/gtsam/geometry/Line3.h index f70e13ca7..78827274a 100644 --- a/gtsam/geometry/Line3.h +++ b/gtsam/geometry/Line3.h @@ -21,12 +21,27 @@ namespace gtsam { +class Line3; + +/** + * Transform a line from world to camera frame + * @param wTc - Pose3 of camera in world frame + * @param wL - Line3 in world frame + * @param Dpose - OptionalJacobian of transformed line with respect to p + * @param Dline - OptionalJacobian of transformed line with respect to l + * @return Transformed line in camera frame + */ +GTSAM_EXPORT Line3 transformTo(const Pose3 &wTc, const Line3 &wL, + OptionalJacobian<4, 6> Dpose = boost::none, + OptionalJacobian<4, 4> Dline = boost::none); + + /** * A 3D line (R,a,b) : (Rot3,Scalar,Scalar) * @addtogroup geometry * \nosubgrouping */ -class Line3 { +class GTSAM_EXPORT Line3 { private: Rot3 R_; // Rotation of line about x and y in world frame double a_, b_; // Intersection of line with the world x-y plane rotated by R_ @@ -136,18 +151,6 @@ class Line3 { OptionalJacobian<4, 4> Dline); }; -/** - * Transform a line from world to camera frame - * @param wTc - Pose3 of camera in world frame - * @param wL - Line3 in world frame - * @param Dpose - OptionalJacobian of transformed line with respect to p - * @param Dline - OptionalJacobian of transformed line with respect to l - * @return Transformed line in camera frame - */ -Line3 transformTo(const Pose3 &wTc, const Line3 &wL, - OptionalJacobian<4, 6> Dpose = boost::none, - OptionalJacobian<4, 4> Dline = boost::none); - template<> struct traits : public internal::Manifold {}; diff --git a/gtsam/geometry/PinholeCamera.h b/gtsam/geometry/PinholeCamera.h index c1f0b6b3f..c20e90819 100644 --- a/gtsam/geometry/PinholeCamera.h +++ b/gtsam/geometry/PinholeCamera.h @@ -30,7 +30,7 @@ namespace gtsam { * \nosubgrouping */ template -class GTSAM_EXPORT PinholeCamera: public PinholeBaseK { +class PinholeCamera: public PinholeBaseK { public: @@ -230,13 +230,15 @@ public: Point2 _project2(const POINT& pw, OptionalJacobian<2, dimension> Dcamera, OptionalJacobian<2, FixedDimension::value> Dpoint) const { // We just call 3-derivative version in Base - Matrix26 Dpose; - Eigen::Matrix Dcal; - Point2 pi = Base::project(pw, Dcamera ? &Dpose : 0, Dpoint, - Dcamera ? &Dcal : 0); - if (Dcamera) + if (Dcamera){ + Matrix26 Dpose; + Eigen::Matrix Dcal; + const Point2 pi = Base::project(pw, Dpose, Dpoint, Dcal); *Dcamera << Dpose, Dcal; - return pi; + return pi; + } else { + return Base::project(pw, boost::none, Dpoint, boost::none); + } } /// project a 3D point from world coordinates into the image @@ -312,6 +314,16 @@ public: return range(camera.pose(), Dcamera, Dother); } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return K_.K() * PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_.fx());; + } + private: /** Serialization function */ diff --git a/gtsam/geometry/PinholePose.h b/gtsam/geometry/PinholePose.h index cc6435a58..7b92be5d5 100644 --- a/gtsam/geometry/PinholePose.h +++ b/gtsam/geometry/PinholePose.h @@ -31,11 +31,11 @@ namespace gtsam { * \nosubgrouping */ template -class GTSAM_EXPORT PinholeBaseK: public PinholeBase { +class PinholeBaseK: public PinholeBase { private: - GTSAM_CONCEPT_MANIFOLD_TYPE(CALIBRATION); + GTSAM_CONCEPT_MANIFOLD_TYPE(CALIBRATION) // Get dimensions of calibration type at compile time static const int DimK = FixedDimension::value; @@ -121,6 +121,13 @@ public: return _project(pw, Dpose, Dpoint, Dcal); } + /// project a 3D point from world coordinates into the image + Point2 reprojectionError(const Point3& pw, const Point2& measured, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none, + OptionalJacobian<2, DimK> Dcal = boost::none) const { + return Point2(_project(pw, Dpose, Dpoint, Dcal) - measured); + } + /// project a point at infinity from world coordinates into the image Point2 project(const Unit3& pw, OptionalJacobian<2, 6> Dpose = boost::none, OptionalJacobian<2, 2> Dpoint = boost::none, @@ -159,7 +166,6 @@ public: return result; } - /// backproject a 2-dimensional point to a 3-dimensional point at infinity Unit3 backprojectPointAtInfinity(const Point2& p) const { const Point2 pn = calibration().calibrate(p); @@ -410,6 +416,16 @@ public: return PinholePose(); // assumes that the default constructor is valid } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + Matrix34 P = Matrix34(PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4)); + return K_->K() * P; + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } /// @} private: diff --git a/gtsam/geometry/Point2.cpp b/gtsam/geometry/Point2.cpp index d8060cfcf..06c32526b 100644 --- a/gtsam/geometry/Point2.cpp +++ b/gtsam/geometry/Point2.cpp @@ -113,6 +113,18 @@ list circleCircleIntersection(Point2 c1, double r1, Point2 c2, return circleCircleIntersection(c1, c2, fh); } +Point2Pair means(const std::vector &abPointPairs) { + const size_t n = abPointPairs.size(); + if (n == 0) throw std::invalid_argument("Point2::mean input Point2Pair vector is empty"); + Point2 aSum(0, 0), bSum(0, 0); + for (const Point2Pair &abPair : abPointPairs) { + aSum += abPair.first; + bSum += abPair.second; + } + const double f = 1.0 / n; + return {aSum * f, bSum * f}; +} + /* ************************************************************************* */ ostream &operator<<(ostream &os, const gtsam::Point2Pair &p) { os << p.first << " <-> " << p.second; diff --git a/gtsam/geometry/Point2.h b/gtsam/geometry/Point2.h index cdb9f4480..d8b6daca8 100644 --- a/gtsam/geometry/Point2.h +++ b/gtsam/geometry/Point2.h @@ -71,6 +71,9 @@ GTSAM_EXPORT boost::optional circleCircleIntersection(double R_d, double * @return list of solutions (0,1, or 2). Identical circles will return empty list, as well. */ GTSAM_EXPORT std::list circleCircleIntersection(Point2 c1, Point2 c2, boost::optional fh); + +/// Calculate the two means of a set of Point2 pairs +GTSAM_EXPORT Point2Pair means(const std::vector &abPointPairs); /** * @brief Intersect 2 circles diff --git a/gtsam/geometry/Point3.cpp b/gtsam/geometry/Point3.cpp index a565ac140..ef91108eb 100644 --- a/gtsam/geometry/Point3.cpp +++ b/gtsam/geometry/Point3.cpp @@ -17,6 +17,7 @@ #include #include #include +#include using namespace std; diff --git a/gtsam/geometry/Pose2.cpp b/gtsam/geometry/Pose2.cpp index 8dafffee8..b37674b92 100644 --- a/gtsam/geometry/Pose2.cpp +++ b/gtsam/geometry/Pose2.cpp @@ -28,7 +28,7 @@ using namespace std; namespace gtsam { /** instantiate concept checks */ -GTSAM_CONCEPT_POSE_INST(Pose2); +GTSAM_CONCEPT_POSE_INST(Pose2) static const Rot2 R_PI_2(Rot2::fromCosSin(0., 1.)); @@ -213,6 +213,14 @@ Point2 Pose2::transformTo(const Point2& point, return q; } +Matrix Pose2::transformTo(const Matrix& points) const { + if (points.rows() != 2) { + throw std::invalid_argument("Pose2:transformTo expects 2*N matrix."); + } + const Matrix2 Rt = rotation().transpose(); + return Rt * (points.colwise() - t_); // Eigen broadcasting! +} + /* ************************************************************************* */ // see doc/math.lyx, SE(2) section Point2 Pose2::transformFrom(const Point2& point, @@ -224,6 +232,15 @@ Point2 Pose2::transformFrom(const Point2& point, return q + t_; } + +Matrix Pose2::transformFrom(const Matrix& points) const { + if (points.rows() != 2) { + throw std::invalid_argument("Pose2:transformFrom expects 2*N matrix."); + } + const Matrix2 R = rotation().matrix(); + return (R * points).colwise() + t_; // Eigen broadcasting! +} + /* ************************************************************************* */ Rot2 Pose2::bearing(const Point2& point, OptionalJacobian<1, 3> Hpose, OptionalJacobian<1, 2> Hpoint) const { @@ -292,54 +309,77 @@ double Pose2::range(const Pose2& pose, } /* ************************************************************************* - * New explanation, from scan.ml - * It finds the angle using a linear method: - * q = Pose2::transformFrom(p) = t + R*p + * Align finds the angle using a linear method: + * a = Pose2::transformFrom(b) = t + R*b * We need to remove the centroids from the data to find the rotation - * using dp=[dpx;dpy] and q=[dqx;dqy] we have - * |dqx| |c -s| |dpx| |dpx -dpy| |c| + * using db=[dbx;dby] and a=[dax;day] we have + * |dax| |c -s| |dbx| |dbx -dby| |c| * | | = | | * | | = | | * | | = H_i*cs - * |dqy| |s c| |dpy| |dpy dpx| |s| + * |day| |s c| |dby| |dby dbx| |s| * where the Hi are the 2*2 matrices. Then we will minimize the criterion - * J = \sum_i norm(q_i - H_i * cs) + * J = \sum_i norm(a_i - H_i * cs) * Taking the derivative with respect to cs and setting to zero we have - * cs = (\sum_i H_i' * q_i)/(\sum H_i'*H_i) + * cs = (\sum_i H_i' * a_i)/(\sum H_i'*H_i) * The hessian is diagonal and just divides by a constant, but this * normalization constant is irrelevant, since we take atan2. - * i.e., cos ~ sum(dpx*dqx + dpy*dqy) and sin ~ sum(-dpy*dqx + dpx*dqy) + * i.e., cos ~ sum(dbx*dax + dby*day) and sin ~ sum(-dby*dax + dbx*day) * The translation is then found from the centroids - * as they also satisfy cq = t + R*cp, hence t = cq - R*cp + * as they also satisfy ca = t + R*cb, hence t = ca - R*cb */ -boost::optional align(const vector& pairs) { - - size_t n = pairs.size(); - if (n<2) return boost::none; // we need at least two pairs +boost::optional Pose2::Align(const Point2Pairs &ab_pairs) { + const size_t n = ab_pairs.size(); + if (n < 2) { + return boost::none; // we need at least 2 pairs + } // calculate centroids - Point2 cp(0,0), cq(0,0); - for(const Point2Pair& pair: pairs) { - cp += pair.first; - cq += pair.second; + Point2 ca(0, 0), cb(0, 0); + for (const Point2Pair& pair : ab_pairs) { + ca += pair.first; + cb += pair.second; } - double f = 1.0/n; - cp *= f; cq *= f; + const double f = 1.0/n; + ca *= f; + cb *= f; // calculate cos and sin - double c=0,s=0; - for(const Point2Pair& pair: pairs) { - Point2 dp = pair.first - cp; - Point2 dq = pair.second - cq; - c += dp.x() * dq.x() + dp.y() * dq.y(); - s += -dp.y() * dq.x() + dp.x() * dq.y(); + double c = 0, s = 0; + for (const Point2Pair& pair : ab_pairs) { + Point2 da = pair.first - ca; + Point2 db = pair.second - cb; + c += db.x() * da.x() + db.y() * da.y(); + s += -db.y() * da.x() + db.x() * da.y(); } // calculate angle and translation - double theta = atan2(s,c); - Rot2 R = Rot2::fromAngle(theta); - Point2 t = cq - R*cp; + const double theta = atan2(s, c); + const Rot2 R = Rot2::fromAngle(theta); + const Point2 t = ca - R*cb; return Pose2(R, t); } +boost::optional Pose2::Align(const Matrix& a, const Matrix& b) { + if (a.rows() != 2 || b.rows() != 2 || a.cols() != b.cols()) { + throw std::invalid_argument( + "Pose2:Align expects 2*N matrices of equal shape."); + } + Point2Pairs ab_pairs; + for (Eigen::Index j = 0; j < a.cols(); j++) { + ab_pairs.emplace_back(a.col(j), b.col(j)); + } + return Pose2::Align(ab_pairs); +} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +boost::optional align(const Point2Pairs& ba_pairs) { + Point2Pairs ab_pairs; + for (const Point2Pair &baPair : ba_pairs) { + ab_pairs.emplace_back(baPair.second, baPair.first); + } + return Pose2::Align(ab_pairs); +} +#endif + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/geometry/Pose2.h b/gtsam/geometry/Pose2.h index a54951728..466c5a42a 100644 --- a/gtsam/geometry/Pose2.h +++ b/gtsam/geometry/Pose2.h @@ -92,6 +92,18 @@ public: *this = Expmap(v); } + /** + * Create Pose2 by aligning two point pairs + * A pose aTb is estimated between pairs (a_point, b_point) such that + * a_point = aTb * b_point + * Note this allows for noise on the points but in that case the mapping + * will not be exact. + */ + static boost::optional Align(const Point2Pairs& abPointPairs); + + // Version of Pose2::Align that takes 2 matrices. + static boost::optional Align(const Matrix& a, const Matrix& b); + /// @} /// @name Testable /// @{ @@ -199,13 +211,29 @@ public: OptionalJacobian<2, 3> Dpose = boost::none, OptionalJacobian<2, 2> Dpoint = boost::none) const; + /** + * @brief transform many points in world coordinates and transform to Pose. + * @param points 2*N matrix in world coordinates + * @return points in Pose coordinates, as 2*N Matrix + */ + Matrix transformTo(const Matrix& points) const; + /** Return point coordinates in global frame */ GTSAM_EXPORT Point2 transformFrom(const Point2& point, OptionalJacobian<2, 3> Dpose = boost::none, OptionalJacobian<2, 2> Dpoint = boost::none) const; + /** + * @brief transform many points in Pose coordinates and transform to world. + * @param points 2*N matrix in Pose coordinates + * @return points in world coordinates, as 2*N Matrix + */ + Matrix transformFrom(const Matrix& points) const; + /** syntactic sugar for transformFrom */ - inline Point2 operator*(const Point2& point) const { return transformFrom(point);} + inline Point2 operator*(const Point2& point) const { + return transformFrom(point); + } /// @} /// @name Standard Interface @@ -315,12 +343,19 @@ inline Matrix wedge(const Vector& xi) { return Matrix(Pose2::wedge(xi(0),xi(1),xi(2))).eval(); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** + * @deprecated Use static constructor (with reversed pairs!) * Calculate pose between a vector of 2D point correspondences (p,q) * where q = Pose2::transformFrom(p) = t + R*p */ -typedef std::pair Point2Pair; -GTSAM_EXPORT boost::optional align(const std::vector& pairs); +GTSAM_EXPORT boost::optional +GTSAM_DEPRECATED align(const Point2Pairs& pairs); +#endif + +// Convenience typedef +using Pose2Pair = std::pair; +using Pose2Pairs = std::vector; template <> struct traits : public internal::LieGroup {}; diff --git a/gtsam/geometry/Pose3.cpp b/gtsam/geometry/Pose3.cpp index c183e32ed..2da51a625 100644 --- a/gtsam/geometry/Pose3.cpp +++ b/gtsam/geometry/Pose3.cpp @@ -30,7 +30,7 @@ using std::vector; using Point3Pairs = vector; /** instantiate concept checks */ -GTSAM_CONCEPT_POSE_INST(Pose3); +GTSAM_CONCEPT_POSE_INST(Pose3) /* ************************************************************************* */ Pose3::Pose3(const Pose2& pose2) : @@ -59,10 +59,50 @@ Matrix6 Pose3::AdjointMap() const { const Matrix3 R = R_.matrix(); Matrix3 A = skewSymmetric(t_.x(), t_.y(), t_.z()) * R; Matrix6 adj; - adj << R, Z_3x3, A, R; + adj << R, Z_3x3, A, R; // Gives [R 0; A R] return adj; } +/* ************************************************************************* */ +// Calculate AdjointMap applied to xi_b, with Jacobians +Vector6 Pose3::Adjoint(const Vector6& xi_b, OptionalJacobian<6, 6> H_pose, + OptionalJacobian<6, 6> H_xib) const { + const Matrix6 Ad = AdjointMap(); + + // Jacobians + // D1 Ad_T(xi_b) = D1 Ad_T Ad_I(xi_b) = Ad_T * D1 Ad_I(xi_b) = Ad_T * ad_xi_b + // D2 Ad_T(xi_b) = Ad_T + // See docs/math.pdf for more details. + // In D1 calculation, we could be more efficient by writing it out, but do not + // for readability + if (H_pose) *H_pose = -Ad * adjointMap(xi_b); + if (H_xib) *H_xib = Ad; + + return Ad * xi_b; +} + +/* ************************************************************************* */ +/// The dual version of Adjoint +Vector6 Pose3::AdjointTranspose(const Vector6& x, OptionalJacobian<6, 6> H_pose, + OptionalJacobian<6, 6> H_x) const { + const Matrix6 Ad = AdjointMap(); + const Vector6 AdTx = Ad.transpose() * x; + + // Jacobians + // See docs/math.pdf for more details. + if (H_pose) { + const auto w_T_hat = skewSymmetric(AdTx.head<3>()), + v_T_hat = skewSymmetric(AdTx.tail<3>()); + *H_pose << w_T_hat, v_T_hat, // + /* */ v_T_hat, Z_3x3; + } + if (H_x) { + *H_x = Ad.transpose(); + } + + return AdTx; +} + /* ************************************************************************* */ Matrix6 Pose3::adjointMap(const Vector6& xi) { Matrix3 w_hat = skewSymmetric(xi(0), xi(1), xi(2)); @@ -75,7 +115,7 @@ Matrix6 Pose3::adjointMap(const Vector6& xi) { /* ************************************************************************* */ Vector6 Pose3::adjoint(const Vector6& xi, const Vector6& y, - OptionalJacobian<6, 6> Hxi) { + OptionalJacobian<6, 6> Hxi, OptionalJacobian<6, 6> H_y) { if (Hxi) { Hxi->setZero(); for (int i = 0; i < 6; ++i) { @@ -86,12 +126,14 @@ Vector6 Pose3::adjoint(const Vector6& xi, const Vector6& y, Hxi->col(i) = Gi * y; } } - return adjointMap(xi) * y; + const Matrix6& ad_xi = adjointMap(xi); + if (H_y) *H_y = ad_xi; + return ad_xi * y; } /* ************************************************************************* */ Vector6 Pose3::adjointTranspose(const Vector6& xi, const Vector6& y, - OptionalJacobian<6, 6> Hxi) { + OptionalJacobian<6, 6> Hxi, OptionalJacobian<6, 6> H_y) { if (Hxi) { Hxi->setZero(); for (int i = 0; i < 6; ++i) { @@ -102,7 +144,9 @@ Vector6 Pose3::adjointTranspose(const Vector6& xi, const Vector6& y, Hxi->col(i) = GTi * y; } } - return adjointMap(xi).transpose() * y; + const Matrix6& adT_xi = adjointMap(xi).transpose(); + if (H_y) *H_y = adT_xi; + return adT_xi * y; } /* ************************************************************************* */ @@ -310,6 +354,14 @@ Point3 Pose3::transformFrom(const Point3& point, OptionalJacobian<3, 6> Hself, return R_ * point + t_; } +Matrix Pose3::transformFrom(const Matrix& points) const { + if (points.rows() != 3) { + throw std::invalid_argument("Pose3:transformFrom expects 3*N matrix."); + } + const Matrix3 R = R_.matrix(); + return (R * points).colwise() + t_; // Eigen broadcasting! +} + /* ************************************************************************* */ Point3 Pose3::transformTo(const Point3& point, OptionalJacobian<3, 6> Hself, OptionalJacobian<3, 3> Hpoint) const { @@ -330,6 +382,14 @@ Point3 Pose3::transformTo(const Point3& point, OptionalJacobian<3, 6> Hself, return q; } +Matrix Pose3::transformTo(const Matrix& points) const { + if (points.rows() != 3) { + throw std::invalid_argument("Pose3:transformTo expects 3*N matrix."); + } + const Matrix3 Rt = R_.transpose(); + return Rt * (points.colwise() - t_); // Eigen broadcasting! +} + /* ************************************************************************* */ double Pose3::range(const Point3& point, OptionalJacobian<1, 6> Hself, OptionalJacobian<1, 3> Hpoint) const { @@ -387,7 +447,7 @@ Unit3 Pose3::bearing(const Pose3& pose, OptionalJacobian<2, 6> Hself, boost::optional Pose3::Align(const Point3Pairs &abPointPairs) { const size_t n = abPointPairs.size(); if (n < 3) { - return boost::none; // we need at least three pairs + return boost::none; // we need at least three pairs } // calculate centroids @@ -407,6 +467,19 @@ boost::optional Pose3::Align(const Point3Pairs &abPointPairs) { return Pose3(aRb, aTb); } +boost::optional Pose3::Align(const Matrix& a, const Matrix& b) { + if (a.rows() != 3 || b.rows() != 3 || a.cols() != b.cols()) { + throw std::invalid_argument( + "Pose3:Align expects 3*N matrices of equal shape."); + } + Point3Pairs abPointPairs; + for (Eigen::Index j = 0; j < a.cols(); j++) { + abPointPairs.emplace_back(a.col(j), b.col(j)); + } + return Pose3::Align(abPointPairs); +} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 boost::optional align(const Point3Pairs &baPointPairs) { Point3Pairs abPointPairs; for (const Point3Pair &baPair : baPointPairs) { @@ -414,6 +487,7 @@ boost::optional align(const Point3Pairs &baPointPairs) { } return Pose3::Align(abPointPairs); } +#endif /* ************************************************************************* */ std::ostream &operator<<(std::ostream &os, const Pose3& pose) { diff --git a/gtsam/geometry/Pose3.h b/gtsam/geometry/Pose3.h index d76e1b41a..c36047349 100644 --- a/gtsam/geometry/Pose3.h +++ b/gtsam/geometry/Pose3.h @@ -85,6 +85,9 @@ public: */ static boost::optional Align(const std::vector& abPointPairs); + // Version of Pose3::Align that takes 2 matrices. + static boost::optional Align(const Matrix& a, const Matrix& b); + /// @} /// @name Testable /// @{ @@ -145,15 +148,22 @@ public: * Calculate Adjoint map, transforming a twist in this pose's (i.e, body) frame to the world spatial frame * Ad_pose is 6*6 matrix that when applied to twist xi \f$ [R_x,R_y,R_z,T_x,T_y,T_z] \f$, returns Ad_pose(xi) */ - Matrix6 AdjointMap() const; /// FIXME Not tested - marked as incorrect + Matrix6 AdjointMap() const; /** - * Apply this pose's AdjointMap Ad_g to a twist \f$ \xi_b \f$, i.e. a body-fixed velocity, transforming it to the spatial frame + * Apply this pose's AdjointMap Ad_g to a twist \f$ \xi_b \f$, i.e. a + * body-fixed velocity, transforming it to the spatial frame * \f$ \xi^s = g*\xi^b*g^{-1} = Ad_g * \xi^b \f$ + * Note that H_xib = AdjointMap() */ - Vector6 Adjoint(const Vector6& xi_b) const { - return AdjointMap() * xi_b; - } /// FIXME Not tested - marked as incorrect + Vector6 Adjoint(const Vector6& xi_b, + OptionalJacobian<6, 6> H_this = boost::none, + OptionalJacobian<6, 6> H_xib = boost::none) const; + + /// The dual version of Adjoint + Vector6 AdjointTranspose(const Vector6& x, + OptionalJacobian<6, 6> H_this = boost::none, + OptionalJacobian<6, 6> H_x = boost::none) const; /** * Compute the [ad(w,v)] operator as defined in [Kobilarov09siggraph], pg 11 @@ -170,13 +180,14 @@ public: * and its inverse transpose in the discrete Euler Poincare' (DEP) operator. * */ - static Matrix6 adjointMap(const Vector6 &xi); + static Matrix6 adjointMap(const Vector6& xi); /** * Action of the adjointMap on a Lie-algebra vector y, with optional derivatives */ - static Vector6 adjoint(const Vector6 &xi, const Vector6 &y, - OptionalJacobian<6, 6> Hxi = boost::none); + static Vector6 adjoint(const Vector6& xi, const Vector6& y, + OptionalJacobian<6, 6> Hxi = boost::none, + OptionalJacobian<6, 6> H_y = boost::none); // temporary fix for wrappers until case issue is resolved static Matrix6 adjointMap_(const Vector6 &xi) { return adjointMap(xi);} @@ -186,7 +197,8 @@ public: * The dual version of adjoint action, acting on the dual space of the Lie-algebra vector space. */ static Vector6 adjointTranspose(const Vector6& xi, const Vector6& y, - OptionalJacobian<6, 6> Hxi = boost::none); + OptionalJacobian<6, 6> Hxi = boost::none, + OptionalJacobian<6, 6> H_y = boost::none); /// Derivative of Expmap static Matrix6 ExpmapDerivative(const Vector6& xi); @@ -240,6 +252,13 @@ public: Point3 transformFrom(const Point3& point, OptionalJacobian<3, 6> Hself = boost::none, OptionalJacobian<3, 3> Hpoint = boost::none) const; + /** + * @brief transform many points in Pose coordinates and transform to world. + * @param points 3*N matrix in Pose coordinates + * @return points in world coordinates, as 3*N Matrix + */ + Matrix transformFrom(const Matrix& points) const; + /** syntactic sugar for transformFrom */ inline Point3 operator*(const Point3& point) const { return transformFrom(point); @@ -255,6 +274,13 @@ public: Point3 transformTo(const Point3& point, OptionalJacobian<3, 6> Hself = boost::none, OptionalJacobian<3, 3> Hpoint = boost::none) const; + /** + * @brief transform many points in world coordinates and transform to Pose. + * @param points 3*N matrix in world coordinates + * @return points in Pose coordinates, as 3*N Matrix + */ + Matrix transformTo(const Matrix& points) const; + /// @} /// @name Standard Interface /// @{ diff --git a/gtsam/geometry/Quaternion.h b/gtsam/geometry/Quaternion.h index 1557a09db..2ef47d58e 100644 --- a/gtsam/geometry/Quaternion.h +++ b/gtsam/geometry/Quaternion.h @@ -117,13 +117,23 @@ struct traits { omega = (-8. / 3. - 2. / 3. * qw) * q.vec(); } else { // Normal, away from zero case - _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); - // Important: convert to [-pi,pi] to keep error continuous - if (angle > M_PI) - angle -= twoPi; - else if (angle < -M_PI) - angle += twoPi; - omega = (angle / s) * q.vec(); + if (qw > 0) { + _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); + // Important: convert to [-pi,pi] to keep error continuous + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * q.vec(); + } else { + // Make sure that we are using a canonical quaternion with w > 0 + _Scalar angle = 2 * acos(-qw), s = sqrt(1 - qw * qw); + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * -q.vec(); + } } if(H) *H = SO3::LogmapDerivative(omega.template cast()); diff --git a/gtsam/geometry/Rot2.cpp b/gtsam/geometry/Rot2.cpp index 283147e4c..9bf631e50 100644 --- a/gtsam/geometry/Rot2.cpp +++ b/gtsam/geometry/Rot2.cpp @@ -129,6 +129,19 @@ Rot2 Rot2::relativeBearing(const Point2& d, OptionalJacobian<1, 2> H) { } } +/* ************************************************************************* */ +Rot2 Rot2::ClosestTo(const Matrix2& M) { + Eigen::JacobiSVD svd(M, Eigen::ComputeFullU | Eigen::ComputeFullV); + const Matrix2& U = svd.matrixU(); + const Matrix2& V = svd.matrixV(); + const double det = (U * V.transpose()).determinant(); + Matrix2 M_prime = (U * Vector2(1, det).asDiagonal() * V.transpose()); + + double c = M_prime(0, 0); + double s = M_prime(1, 0); + return Rot2::fromCosSin(c, s); +} + /* ************************************************************************* */ } // gtsam diff --git a/gtsam/geometry/Rot2.h b/gtsam/geometry/Rot2.h index ec30c6657..2690ca248 100644 --- a/gtsam/geometry/Rot2.h +++ b/gtsam/geometry/Rot2.h @@ -14,6 +14,7 @@ * @brief 2D rotation * @date Dec 9, 2009 * @author Frank Dellaert + * @author John Lambert */ #pragma once @@ -209,6 +210,9 @@ namespace gtsam { /** return 2*2 transpose (inverse) rotation matrix */ Matrix2 transpose() const; + /** Find closest valid rotation matrix, given a 2x2 matrix */ + static Rot2 ClosestTo(const Matrix2& M); + private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/geometry/Rot3.h b/gtsam/geometry/Rot3.h index abd74e063..18bd88b52 100644 --- a/gtsam/geometry/Rot3.h +++ b/gtsam/geometry/Rot3.h @@ -49,16 +49,14 @@ namespace gtsam { - /** - * @brief A 3D rotation represented as a rotation matrix if the preprocessor - * symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it - * is defined. - * @addtogroup geometry - * \nosubgrouping - */ - class GTSAM_EXPORT Rot3 : public LieGroup { - - private: +/** + * @brief Rot3 is a 3D rotation represented as a rotation matrix if the + * preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion + * if it is defined. + * @addtogroup geometry + */ +class GTSAM_EXPORT Rot3 : public LieGroup { + private: #ifdef GTSAM_USE_QUATERNIONS /** Internal Eigen Quaternion */ @@ -67,8 +65,7 @@ namespace gtsam { SO3 rot_; #endif - public: - + public: /// @name Constructors and named constructors /// @{ @@ -83,7 +80,7 @@ namespace gtsam { */ Rot3(const Point3& col1, const Point3& col2, const Point3& col3); - /** constructor from a rotation matrix, as doubles in *row-major* order !!! */ + /// Construct from a rotation matrix, as doubles in *row-major* order !!! Rot3(double R11, double R12, double R13, double R21, double R22, double R23, double R31, double R32, double R33); @@ -567,6 +564,9 @@ namespace gtsam { #endif }; + /// std::vector of Rot3s, mainly for wrapper + using Rot3Vector = std::vector >; + /** * [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R * and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx' @@ -585,5 +585,6 @@ namespace gtsam { template<> struct traits : public internal::LieGroup {}; -} + +} // namespace gtsam diff --git a/gtsam/geometry/SO3.cpp b/gtsam/geometry/SO3.cpp index 80c0171ad..2585c37be 100644 --- a/gtsam/geometry/SO3.cpp +++ b/gtsam/geometry/SO3.cpp @@ -261,25 +261,59 @@ Vector3 SO3::Logmap(const SO3& Q, ChartJacobian H) { // when trace == -1, i.e., when theta = +-pi, +-3pi, +-5pi, etc. // we do something special - if (tr + 1.0 < 1e-10) { - if (std::abs(R33 + 1.0) > 1e-5) - omega = (M_PI / sqrt(2.0 + 2.0 * R33)) * Vector3(R13, R23, 1.0 + R33); - else if (std::abs(R22 + 1.0) > 1e-5) - omega = (M_PI / sqrt(2.0 + 2.0 * R22)) * Vector3(R12, 1.0 + R22, R32); - else - // if(std::abs(R.r1_.x()+1.0) > 1e-5) This is implicit - omega = (M_PI / sqrt(2.0 + 2.0 * R11)) * Vector3(1.0 + R11, R21, R31); + if (tr + 1.0 < 1e-3) { + if (R33 > R22 && R33 > R11) { + // R33 is the largest diagonal, a=3, b=1, c=2 + const double W = R21 - R12; + const double Q1 = 2.0 + 2.0 * R33; + const double Q2 = R31 + R13; + const double Q3 = R23 + R32; + const double r = sqrt(Q1); + const double one_over_r = 1 / r; + const double norm = sqrt(Q1*Q1 + Q2*Q2 + Q3*Q3 + W*W); + const double sgn_w = W < 0 ? -1.0 : 1.0; + const double mag = M_PI - (2 * sgn_w * W) / norm; + const double scale = 0.5 * one_over_r * mag; + omega = sgn_w * scale * Vector3(Q2, Q3, Q1); + } else if (R22 > R11) { + // R22 is the largest diagonal, a=2, b=3, c=1 + const double W = R13 - R31; + const double Q1 = 2.0 + 2.0 * R22; + const double Q2 = R23 + R32; + const double Q3 = R12 + R21; + const double r = sqrt(Q1); + const double one_over_r = 1 / r; + const double norm = sqrt(Q1*Q1 + Q2*Q2 + Q3*Q3 + W*W); + const double sgn_w = W < 0 ? -1.0 : 1.0; + const double mag = M_PI - (2 * sgn_w * W) / norm; + const double scale = 0.5 * one_over_r * mag; + omega = sgn_w * scale * Vector3(Q3, Q1, Q2); + } else { + // R11 is the largest diagonal, a=1, b=2, c=3 + const double W = R32 - R23; + const double Q1 = 2.0 + 2.0 * R11; + const double Q2 = R12 + R21; + const double Q3 = R31 + R13; + const double r = sqrt(Q1); + const double one_over_r = 1 / r; + const double norm = sqrt(Q1*Q1 + Q2*Q2 + Q3*Q3 + W*W); + const double sgn_w = W < 0 ? -1.0 : 1.0; + const double mag = M_PI - (2 * sgn_w * W) / norm; + const double scale = 0.5 * one_over_r * mag; + omega = sgn_w * scale * Vector3(Q1, Q2, Q3); + } } else { double magnitude; - const double tr_3 = tr - 3.0; // always negative - if (tr_3 < -1e-7) { + const double tr_3 = tr - 3.0; // could be non-negative if the matrix is off orthogonal + if (tr_3 < -1e-6) { + // this is the normal case -1 < trace < 3 double theta = acos((tr - 1.0) / 2.0); magnitude = theta / (2.0 * sin(theta)); } else { // when theta near 0, +-2pi, +-4pi, etc. (trace near 3.0) // use Taylor expansion: theta \approx 1/2-(t-3)/12 + O((t-3)^2) // see https://github.com/borglab/gtsam/issues/746 for details - magnitude = 0.5 - tr_3 / 12.0; + magnitude = 0.5 - tr_3 / 12.0 + tr_3*tr_3/60.0; } omega = magnitude * Vector3(R32 - R23, R13 - R31, R21 - R12); } diff --git a/gtsam/geometry/SOn.cpp b/gtsam/geometry/SOn.cpp index c6cff4214..7088513d5 100644 --- a/gtsam/geometry/SOn.cpp +++ b/gtsam/geometry/SOn.cpp @@ -22,7 +22,7 @@ namespace gtsam { template <> -GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref X) { +void SOn::Hat(const Vector &xi, Eigen::Ref X) { size_t n = AmbientDim(xi.size()); if (n < 2) throw std::invalid_argument("SO::Hat: n<2 not supported"); @@ -48,7 +48,7 @@ GTSAM_EXPORT void SOn::Hat(const Vector &xi, Eigen::Ref X) { } } -template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) { +template <> Matrix SOn::Hat(const Vector &xi) { size_t n = AmbientDim(xi.size()); Matrix X(n, n); // allocate space for n*n skew-symmetric matrix SOn::Hat(xi, X); @@ -56,7 +56,6 @@ template <> GTSAM_EXPORT Matrix SOn::Hat(const Vector &xi) { } template <> -GTSAM_EXPORT Vector SOn::Vee(const Matrix& X) { const size_t n = X.rows(); if (n < 2) throw std::invalid_argument("SO::Hat: n<2 not supported"); @@ -104,7 +103,9 @@ SOn LieGroup::between(const SOn& g, DynamicJacobian H1, } // Dynamic version of vec -template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const { +template <> +typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const +{ const size_t n = rows(), n2 = n * n; // Vectorize diff --git a/gtsam/geometry/SOn.h b/gtsam/geometry/SOn.h index 86b6019e1..af0e7a3cf 100644 --- a/gtsam/geometry/SOn.h +++ b/gtsam/geometry/SOn.h @@ -24,6 +24,8 @@ #include #include +#include + #include // TODO(frank): how to avoid? #include #include @@ -356,17 +358,21 @@ Vector SOn::Vee(const Matrix& X); using DynamicJacobian = OptionalJacobian; template <> +GTSAM_EXPORT SOn LieGroup::compose(const SOn& g, DynamicJacobian H1, DynamicJacobian H2) const; template <> +GTSAM_EXPORT SOn LieGroup::between(const SOn& g, DynamicJacobian H1, DynamicJacobian H2) const; /* * Specialize dynamic vec. */ -template <> typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const; +template <> +GTSAM_EXPORT +typename SOn::VectorN2 SOn::vec(DynamicJacobian H) const; /** Serialization function */ template diff --git a/gtsam/geometry/Similarity2.cpp b/gtsam/geometry/Similarity2.cpp new file mode 100644 index 000000000..4ed3351f8 --- /dev/null +++ b/gtsam/geometry/Similarity2.cpp @@ -0,0 +1,242 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file Similarity2.cpp + * @brief Implementation of Similarity2 transform + * @author John Lambert, Varun Agrawal + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + +using std::vector; + +namespace internal { + +/// Subtract centroids from point pairs. +static Point2Pairs SubtractCentroids(const Point2Pairs& abPointPairs, + const Point2Pair& centroids) { + Point2Pairs d_abPointPairs; + for (const Point2Pair& abPair : abPointPairs) { + Point2 da = abPair.first - centroids.first; + Point2 db = abPair.second - centroids.second; + d_abPointPairs.emplace_back(da, db); + } + return d_abPointPairs; +} + +/// Form inner products x and y and calculate scale. +static double CalculateScale(const Point2Pairs& d_abPointPairs, + const Rot2& aRb) { + double x = 0, y = 0; + Point2 da, db; + + for (const Point2Pair& d_abPair : d_abPointPairs) { + std::tie(da, db) = d_abPair; + const Vector2 da_prime = aRb * db; + y += da.transpose() * da_prime; + x += da_prime.transpose() * da_prime; + } + const double s = y / x; + return s; +} + +/// Form outer product H. +static Matrix2 CalculateH(const Point2Pairs& d_abPointPairs) { + Matrix2 H = Z_2x2; + for (const Point2Pair& d_abPair : d_abPointPairs) { + H += d_abPair.first * d_abPair.second.transpose(); + } + return H; +} + +/** + * @brief This method estimates the similarity transform from differences point + * pairs, given a known or estimated rotation and point centroids. + * + * @param d_abPointPairs + * @param aRb + * @param centroids + * @return Similarity2 + */ +static Similarity2 Align(const Point2Pairs& d_abPointPairs, const Rot2& aRb, + const Point2Pair& centroids) { + const double s = CalculateScale(d_abPointPairs, aRb); + // dividing aTb by s is required because the registration cost function + // minimizes ||a - sRb - t||, whereas Sim(2) computes s(Rb + t) + const Point2 aTb = (centroids.first - s * (aRb * centroids.second)) / s; + return Similarity2(aRb, aTb, s); +} + +/** + * @brief This method estimates the similarity transform from point pairs, + * given a known or estimated rotation. + * Refer to: + * http://www5.informatik.uni-erlangen.de/Forschung/Publikationen/2005/Zinsser05-PSR.pdf + * Chapter 3 + * + * @param abPointPairs + * @param aRb + * @return Similarity2 + */ +static Similarity2 AlignGivenR(const Point2Pairs& abPointPairs, + const Rot2& aRb) { + auto centroids = means(abPointPairs); + auto d_abPointPairs = internal::SubtractCentroids(abPointPairs, centroids); + return internal::Align(d_abPointPairs, aRb, centroids); +} +} // namespace internal + +Similarity2::Similarity2() : t_(0, 0), s_(1) {} + +Similarity2::Similarity2(double s) : t_(0, 0), s_(s) {} + +Similarity2::Similarity2(const Rot2& R, const Point2& t, double s) + : R_(R), t_(t), s_(s) {} + +Similarity2::Similarity2(const Matrix2& R, const Vector2& t, double s) + : R_(Rot2::ClosestTo(R)), t_(t), s_(s) {} + +Similarity2::Similarity2(const Matrix3& T) + : R_(Rot2::ClosestTo(T.topLeftCorner<2, 2>())), + t_(T.topRightCorner<2, 1>()), + s_(1.0 / T(2, 2)) {} + +bool Similarity2::equals(const Similarity2& other, double tol) const { + return R_.equals(other.R_, tol) && + traits::Equals(t_, other.t_, tol) && s_ < (other.s_ + tol) && + s_ > (other.s_ - tol); +} + +bool Similarity2::operator==(const Similarity2& other) const { + return R_.matrix() == other.R_.matrix() && t_ == other.t_ && s_ == other.s_; +} + +void Similarity2::print(const std::string& s) const { + std::cout << std::endl; + std::cout << s; + rotation().print("\nR:\n"); + std::cout << "t: " << translation().transpose() << " s: " << scale() + << std::endl; +} + +Similarity2 Similarity2::identity() { return Similarity2(); } + +Similarity2 Similarity2::operator*(const Similarity2& S) const { + return Similarity2(R_ * S.R_, ((1.0 / S.s_) * t_) + R_ * S.t_, s_ * S.s_); +} + +Similarity2 Similarity2::inverse() const { + const Rot2 Rt = R_.inverse(); + const Point2 sRt = Rt * (-s_ * t_); + return Similarity2(Rt, sRt, 1.0 / s_); +} + +Point2 Similarity2::transformFrom(const Point2& p) const { + const Point2 q = R_ * p + t_; + return s_ * q; +} + +Pose2 Similarity2::transformFrom(const Pose2& T) const { + Rot2 R = R_.compose(T.rotation()); + Point2 t = Point2(s_ * (R_ * T.translation() + t_)); + return Pose2(R, t); +} + +Point2 Similarity2::operator*(const Point2& p) const { + return transformFrom(p); +} + +Similarity2 Similarity2::Align(const Point2Pairs& abPointPairs) { + // Refer to Chapter 3 of + // http://www5.informatik.uni-erlangen.de/Forschung/Publikationen/2005/Zinsser05-PSR.pdf + if (abPointPairs.size() < 2) + throw std::runtime_error("input should have at least 2 pairs of points"); + auto centroids = means(abPointPairs); + auto d_abPointPairs = internal::SubtractCentroids(abPointPairs, centroids); + Matrix2 H = internal::CalculateH(d_abPointPairs); + // ClosestTo finds rotation matrix closest to H in Frobenius sense + Rot2 aRb = Rot2::ClosestTo(H); + return internal::Align(d_abPointPairs, aRb, centroids); +} + +Similarity2 Similarity2::Align(const Pose2Pairs& abPosePairs) { + const size_t n = abPosePairs.size(); + if (n < 2) + throw std::runtime_error("input should have at least 2 pairs of poses"); + + // calculate rotation + vector rotations; + Point2Pairs abPointPairs; + rotations.reserve(n); + abPointPairs.reserve(n); + // Below denotes the pose of the i'th object/camera/etc + // in frame "a" or frame "b". + Pose2 aTi, bTi; + for (const Pose2Pair& abPair : abPosePairs) { + std::tie(aTi, bTi) = abPair; + const Rot2 aRb = aTi.rotation().compose(bTi.rotation().inverse()); + rotations.emplace_back(aRb); + abPointPairs.emplace_back(aTi.translation(), bTi.translation()); + } + const Rot2 aRb_estimate = FindKarcherMean(rotations); + + return internal::AlignGivenR(abPointPairs, aRb_estimate); +} + +Vector4 Similarity2::Logmap(const Similarity2& S, // + OptionalJacobian<4, 4> Hm) { + const Vector2 u = S.t_; + const Vector1 w = Rot2::Logmap(S.R_); + const double s = log(S.s_); + Vector4 result; + result << u, w, s; + if (Hm) { + throw std::runtime_error("Similarity2::Logmap: derivative not implemented"); + } + return result; +} + +Similarity2 Similarity2::Expmap(const Vector4& v, // + OptionalJacobian<4, 4> Hm) { + const Vector2 t = v.head<2>(); + const Rot2 R = Rot2::Expmap(v.segment<1>(2)); + const double s = v[3]; + if (Hm) { + throw std::runtime_error("Similarity2::Expmap: derivative not implemented"); + } + return Similarity2(R, t, s); +} + +Matrix4 Similarity2::AdjointMap() const { + throw std::runtime_error("Similarity2::AdjointMap not implemented"); +} + +std::ostream& operator<<(std::ostream& os, const Similarity2& p) { + os << "[" << p.rotation().theta() << " " << p.translation().transpose() << " " + << p.scale() << "]\';"; + return os; +} + +Matrix3 Similarity2::matrix() const { + Matrix3 T; + T.topRows<2>() << R_.matrix(), t_; + T.bottomRows<1>() << 0, 0, 1.0 / s_; + return T; +} + +} // namespace gtsam diff --git a/gtsam/geometry/Similarity2.h b/gtsam/geometry/Similarity2.h new file mode 100644 index 000000000..05f10d149 --- /dev/null +++ b/gtsam/geometry/Similarity2.h @@ -0,0 +1,200 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file Similarity2.h + * @brief Implementation of Similarity2 transform + * @author John Lambert, Varun Agrawal + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace gtsam { + +// Forward declarations +class Pose2; + +/** + * 2D similarity transform + */ +class GTSAM_EXPORT Similarity2 : public LieGroup { + /// @name Pose Concept + /// @{ + typedef Rot2 Rotation; + typedef Point2 Translation; + /// @} + + private: + Rot2 R_; + Point2 t_; + double s_; + + public: + /// @name Constructors + /// @{ + + /// Default constructor + Similarity2(); + + /// Construct pure scaling + Similarity2(double s); + + /// Construct from GTSAM types + Similarity2(const Rot2& R, const Point2& t, double s); + + /// Construct from Eigen types + Similarity2(const Matrix2& R, const Vector2& t, double s); + + /// Construct from matrix [R t; 0 s^-1] + Similarity2(const Matrix3& T); + + /// @} + /// @name Testable + /// @{ + + /// Compare with tolerance + bool equals(const Similarity2& sim, double tol) const; + + /// Exact equality + bool operator==(const Similarity2& other) const; + + /// Print with optional string + void print(const std::string& s) const; + + friend std::ostream& operator<<(std::ostream& os, const Similarity2& p); + + /// @} + /// @name Group + /// @{ + + /// Return an identity transform + static Similarity2 identity(); + + /// Composition + Similarity2 operator*(const Similarity2& S) const; + + /// Return the inverse + Similarity2 inverse() const; + + /// @} + /// @name Group action on Point2 + /// @{ + + /// Action on a point p is s*(R*p+t) + Point2 transformFrom(const Point2& p) const; + + /** + * Action on a pose T. + * |Rs ts| |R t| |Rs*R Rs*t+ts| + * |0 1/s| * |0 1| = | 0 1/s |, the result is still a Sim2 object. + * To retrieve a Pose2, we normalized the scale value into 1. + * |Rs*R Rs*t+ts| |Rs*R s(Rs*t+ts)| + * | 0 1/s | = | 0 1 | + * + * This group action satisfies the compatibility condition. + * For more details, refer to: https://en.wikipedia.org/wiki/Group_action + */ + Pose2 transformFrom(const Pose2& T) const; + + /* syntactic sugar for transformFrom */ + Point2 operator*(const Point2& p) const; + + /** + * Create Similarity2 by aligning at least two point pairs + */ + static Similarity2 Align(const Point2Pairs& abPointPairs); + + /** + * Create the Similarity2 object that aligns at least two pose pairs. + * Each pair is of the form (aTi, bTi). + * Given a list of pairs in frame a, and a list of pairs in frame b, + Align() + * will compute the best-fit Similarity2 aSb transformation to align them. + * First, the rotation aRb will be computed as the average (Karcher mean) + of + * many estimates aRb (from each pair). Afterwards, the scale factor will + be computed + * using the algorithm described here: + * http://www5.informatik.uni-erlangen.de/Forschung/Publikationen/2005/Zinsser05-PSR.pdf + */ + static Similarity2 Align(const std::vector& abPosePairs); + + /// @} + /// @name Lie Group + /// @{ + + /** + * Log map at the identity + * \f$ [t_x, t_y, \delta, \lambda] \f$ + */ + static Vector4 Logmap(const Similarity2& S, // + OptionalJacobian<4, 4> Hm = boost::none); + + /// Exponential map at the identity + static Similarity2 Expmap(const Vector4& v, // + OptionalJacobian<4, 4> Hm = boost::none); + + /// Chart at the origin + struct ChartAtOrigin { + static Similarity2 Retract(const Vector4& v, + ChartJacobian H = boost::none) { + return Similarity2::Expmap(v, H); + } + static Vector4 Local(const Similarity2& other, + ChartJacobian H = boost::none) { + return Similarity2::Logmap(other, H); + } + }; + + /// Project from one tangent space to another + Matrix4 AdjointMap() const; + + using LieGroup::inverse; + + /// @} + /// @name Standard interface + /// @{ + + /// Calculate 4*4 matrix group equivalent + Matrix3 matrix() const; + + /// Return a GTSAM rotation + Rot2 rotation() const { return R_; } + + /// Return a GTSAM translation + Point2 translation() const { return t_; } + + /// Return the scale + double scale() const { return s_; } + + /// Dimensionality of tangent space = 4 DOF - used to autodetect sizes + inline static size_t Dim() { return 4; } + + /// Dimensionality of tangent space = 4 DOF + inline size_t dim() const { return 4; } + + /// @} +}; + +template <> +struct traits : public internal::LieGroup {}; + +template <> +struct traits : public internal::LieGroup {}; + +} // namespace gtsam diff --git a/gtsam/geometry/Similarity3.cpp b/gtsam/geometry/Similarity3.cpp index fcaf0c874..7fde974c5 100644 --- a/gtsam/geometry/Similarity3.cpp +++ b/gtsam/geometry/Similarity3.cpp @@ -26,7 +26,7 @@ namespace gtsam { using std::vector; -namespace { +namespace internal { /// Subtract centroids from point pairs. static Point3Pairs subtractCentroids(const Point3Pairs &abPointPairs, const Point3Pair ¢roids) { @@ -40,8 +40,10 @@ static Point3Pairs subtractCentroids(const Point3Pairs &abPointPairs, } /// Form inner products x and y and calculate scale. -static const double calculateScale(const Point3Pairs &d_abPointPairs, - const Rot3 &aRb) { +// We force the scale to be a non-negative quantity +// (see Section 10.1 of https://ethaneade.com/lie_groups.pdf) +static double calculateScale(const Point3Pairs &d_abPointPairs, + const Rot3 &aRb) { double x = 0, y = 0; Point3 da, db; for (const Point3Pair& d_abPair : d_abPointPairs) { @@ -50,7 +52,7 @@ static const double calculateScale(const Point3Pairs &d_abPointPairs, y += da.transpose() * da_prime; x += da_prime.transpose() * da_prime; } - const double s = y / x; + const double s = std::fabs(y / x); return s; } @@ -79,10 +81,10 @@ static Similarity3 align(const Point3Pairs &d_abPointPairs, const Rot3 &aRb, static Similarity3 alignGivenR(const Point3Pairs &abPointPairs, const Rot3 &aRb) { auto centroids = means(abPointPairs); - auto d_abPointPairs = subtractCentroids(abPointPairs, centroids); + auto d_abPointPairs = internal::subtractCentroids(abPointPairs, centroids); return align(d_abPointPairs, aRb, centroids); } -} // namespace +} // namespace internal Similarity3::Similarity3() : t_(0,0,0), s_(1) { @@ -163,11 +165,11 @@ Similarity3 Similarity3::Align(const Point3Pairs &abPointPairs) { if (abPointPairs.size() < 3) throw std::runtime_error("input should have at least 3 pairs of points"); auto centroids = means(abPointPairs); - auto d_abPointPairs = subtractCentroids(abPointPairs, centroids); - Matrix3 H = calculateH(d_abPointPairs); + auto d_abPointPairs = internal::subtractCentroids(abPointPairs, centroids); + Matrix3 H = internal::calculateH(d_abPointPairs); // ClosestTo finds rotation matrix closest to H in Frobenius sense Rot3 aRb = Rot3::ClosestTo(H); - return align(d_abPointPairs, aRb, centroids); + return internal::align(d_abPointPairs, aRb, centroids); } Similarity3 Similarity3::Align(const vector &abPosePairs) { @@ -190,7 +192,7 @@ Similarity3 Similarity3::Align(const vector &abPosePairs) { } const Rot3 aRb_estimate = FindKarcherMean(rotations); - return alignGivenR(abPointPairs, aRb_estimate); + return internal::alignGivenR(abPointPairs, aRb_estimate); } Matrix4 Similarity3::wedge(const Vector7 &xi) { @@ -281,15 +283,11 @@ std::ostream &operator<<(std::ostream &os, const Similarity3& p) { return os; } -const Matrix4 Similarity3::matrix() const { +Matrix4 Similarity3::matrix() const { Matrix4 T; T.topRows<3>() << R_.matrix(), t_; T.bottomRows<1>() << 0, 0, 0, 1.0 / s_; return T; } -Similarity3::operator Pose3() const { - return Pose3(R_, s_ * t_); -} - } // namespace gtsam diff --git a/gtsam/geometry/Similarity3.h b/gtsam/geometry/Similarity3.h index 0ef787b05..845d4c810 100644 --- a/gtsam/geometry/Similarity3.h +++ b/gtsam/geometry/Similarity3.h @@ -18,13 +18,12 @@ #pragma once -#include -#include -#include #include #include #include - +#include +#include +#include namespace gtsam { @@ -34,108 +33,106 @@ class Pose3; /** * 3D similarity transform */ -class Similarity3: public LieGroup { - +class GTSAM_EXPORT Similarity3 : public LieGroup { /// @name Pose Concept /// @{ typedef Rot3 Rotation; typedef Point3 Translation; /// @} -private: + private: Rot3 R_; Point3 t_; double s_; -public: - + public: /// @name Constructors /// @{ /// Default constructor - GTSAM_EXPORT Similarity3(); + Similarity3(); /// Construct pure scaling - GTSAM_EXPORT Similarity3(double s); + Similarity3(double s); /// Construct from GTSAM types - GTSAM_EXPORT Similarity3(const Rot3& R, const Point3& t, double s); + Similarity3(const Rot3& R, const Point3& t, double s); /// Construct from Eigen types - GTSAM_EXPORT Similarity3(const Matrix3& R, const Vector3& t, double s); + Similarity3(const Matrix3& R, const Vector3& t, double s); /// Construct from matrix [R t; 0 s^-1] - GTSAM_EXPORT Similarity3(const Matrix4& T); + Similarity3(const Matrix4& T); /// @} /// @name Testable /// @{ /// Compare with tolerance - GTSAM_EXPORT bool equals(const Similarity3& sim, double tol) const; + bool equals(const Similarity3& sim, double tol) const; /// Exact equality - GTSAM_EXPORT bool operator==(const Similarity3& other) const; + bool operator==(const Similarity3& other) const; /// Print with optional string - GTSAM_EXPORT void print(const std::string& s) const; + void print(const std::string& s) const; - GTSAM_EXPORT friend std::ostream &operator<<(std::ostream &os, const Similarity3& p); + friend std::ostream& operator<<(std::ostream& os, const Similarity3& p); /// @} /// @name Group /// @{ /// Return an identity transform - GTSAM_EXPORT static Similarity3 identity(); + static Similarity3 identity(); /// Composition - GTSAM_EXPORT Similarity3 operator*(const Similarity3& S) const; + Similarity3 operator*(const Similarity3& S) const; /// Return the inverse - GTSAM_EXPORT Similarity3 inverse() const; + Similarity3 inverse() const; /// @} /// @name Group action on Point3 /// @{ /// Action on a point p is s*(R*p+t) - GTSAM_EXPORT Point3 transformFrom(const Point3& p, // - OptionalJacobian<3, 7> H1 = boost::none, // - OptionalJacobian<3, 3> H2 = boost::none) const; + Point3 transformFrom(const Point3& p, // + OptionalJacobian<3, 7> H1 = boost::none, // + OptionalJacobian<3, 3> H2 = boost::none) const; - /** + /** * Action on a pose T. - * |Rs ts| |R t| |Rs*R Rs*t+ts| + * |Rs ts| |R t| |Rs*R Rs*t+ts| * |0 1/s| * |0 1| = | 0 1/s |, the result is still a Sim3 object. * To retrieve a Pose3, we normalized the scale value into 1. * |Rs*R Rs*t+ts| |Rs*R s(Rs*t+ts)| * | 0 1/s | = | 0 1 | - * - * This group action satisfies the compatibility condition. + * + * This group action satisfies the compatibility condition. * For more details, refer to: https://en.wikipedia.org/wiki/Group_action */ - GTSAM_EXPORT Pose3 transformFrom(const Pose3& T) const; + Pose3 transformFrom(const Pose3& T) const; /** syntactic sugar for transformFrom */ - GTSAM_EXPORT Point3 operator*(const Point3& p) const; + Point3 operator*(const Point3& p) const; /** * Create Similarity3 by aligning at least three point pairs */ - GTSAM_EXPORT static Similarity3 Align(const std::vector& abPointPairs); - + static Similarity3 Align(const std::vector& abPointPairs); + /** * Create the Similarity3 object that aligns at least two pose pairs. * Each pair is of the form (aTi, bTi). * Given a list of pairs in frame a, and a list of pairs in frame b, Align() * will compute the best-fit Similarity3 aSb transformation to align them. * First, the rotation aRb will be computed as the average (Karcher mean) of - * many estimates aRb (from each pair). Afterwards, the scale factor will be computed - * using the algorithm described here: + * many estimates aRb (from each pair). Afterwards, the scale factor will be + * computed using the algorithm described here: * http://www5.informatik.uni-erlangen.de/Forschung/Publikationen/2005/Zinsser05-PSR.pdf */ - GTSAM_EXPORT static Similarity3 Align(const std::vector& abPosePairs); + static Similarity3 Align(const std::vector& abPosePairs); /// @} /// @name Lie Group @@ -144,20 +141,22 @@ public: /** Log map at the identity * \f$ [R_x,R_y,R_z, t_x, t_y, t_z, \lambda] \f$ */ - GTSAM_EXPORT static Vector7 Logmap(const Similarity3& s, // - OptionalJacobian<7, 7> Hm = boost::none); + static Vector7 Logmap(const Similarity3& s, // + OptionalJacobian<7, 7> Hm = boost::none); /** Exponential map at the identity */ - GTSAM_EXPORT static Similarity3 Expmap(const Vector7& v, // - OptionalJacobian<7, 7> Hm = boost::none); + static Similarity3 Expmap(const Vector7& v, // + OptionalJacobian<7, 7> Hm = boost::none); /// Chart at the origin struct ChartAtOrigin { - static Similarity3 Retract(const Vector7& v, ChartJacobian H = boost::none) { + static Similarity3 Retract(const Vector7& v, + ChartJacobian H = boost::none) { return Similarity3::Expmap(v, H); } - static Vector7 Local(const Similarity3& other, ChartJacobian H = boost::none) { + static Vector7 Local(const Similarity3& other, + ChartJacobian H = boost::none) { return Similarity3::Logmap(other, H); } }; @@ -170,67 +169,53 @@ public: * @return 4*4 element of Lie algebra that can be exponentiated * TODO(frank): rename to Hat, make part of traits */ - GTSAM_EXPORT static Matrix4 wedge(const Vector7& xi); + static Matrix4 wedge(const Vector7& xi); /// Project from one tangent space to another - GTSAM_EXPORT Matrix7 AdjointMap() const; + Matrix7 AdjointMap() const; /// @} /// @name Standard interface /// @{ /// Calculate 4*4 matrix group equivalent - GTSAM_EXPORT const Matrix4 matrix() const; + Matrix4 matrix() const; /// Return a GTSAM rotation - const Rot3& rotation() const { - return R_; - } + Rot3 rotation() const { return R_; } /// Return a GTSAM translation - const Point3& translation() const { - return t_; - } + Point3 translation() const { return t_; } /// Return the scale - double scale() const { - return s_; - } - - /// Convert to a rigid body pose (R, s*t) - /// TODO(frank): why is this here? Red flag! Definitely don't have it as a cast. - GTSAM_EXPORT operator Pose3() const; + double scale() const { return s_; } /// Dimensionality of tangent space = 7 DOF - used to autodetect sizes - inline static size_t Dim() { - return 7; - } + inline static size_t Dim() { return 7; } /// Dimensionality of tangent space = 7 DOF - inline size_t dim() const { - return 7; - } + inline size_t dim() const { return 7; } /// @} /// @name Helper functions /// @{ -private: + private: /// Calculate expmap and logmap coefficients. static Matrix3 GetV(Vector3 w, double lambda); /// @} }; -template<> +template <> inline Matrix wedge(const Vector& xi) { return Similarity3::wedge(xi); } -template<> +template <> struct traits : public internal::LieGroup {}; -template<> +template <> struct traits : public internal::LieGroup {}; -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/geometry/SimpleCamera.cpp b/gtsam/geometry/SimpleCamera.cpp index d1a5ed330..be6a010b2 100644 --- a/gtsam/geometry/SimpleCamera.cpp +++ b/gtsam/geometry/SimpleCamera.cpp @@ -21,8 +21,8 @@ namespace gtsam { -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 - SimpleCamera simpleCamera(const Matrix34& P) { +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + SimpleCamera GTSAM_DEPRECATED simpleCamera(const Matrix34& P) { // P = [A|a] = s K cRw [I|-T], with s the unknown scale Matrix3 A = P.topLeftCorner(3, 3); diff --git a/gtsam/geometry/SimpleCamera.h b/gtsam/geometry/SimpleCamera.h index 5ff6b9816..f0776c2e2 100644 --- a/gtsam/geometry/SimpleCamera.h +++ b/gtsam/geometry/SimpleCamera.h @@ -37,7 +37,7 @@ namespace gtsam { using PinholeCameraCal3Unified = gtsam::PinholeCamera; using PinholeCameraCal3Fisheye = gtsam::PinholeCamera; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * @deprecated: SimpleCamera for backwards compatability with GTSAM 3.x * Use PinholeCameraCal3_S2 instead diff --git a/gtsam/geometry/SphericalCamera.cpp b/gtsam/geometry/SphericalCamera.cpp new file mode 100644 index 000000000..58a29dc09 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.cpp @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include + +using namespace std; + +namespace gtsam { + +/* ************************************************************************* */ +bool SphericalCamera::equals(const SphericalCamera& camera, double tol) const { + return pose_.equals(camera.pose(), tol); +} + +/* ************************************************************************* */ +void SphericalCamera::print(const string& s) const { pose_.print(s + ".pose"); } + +/* ************************************************************************* */ +pair SphericalCamera::projectSafe(const Point3& pw) const { + const Point3 pc = pose().transformTo(pw); + Unit3 pu = Unit3::FromPoint3(pc); + return make_pair(pu, pc.norm() > 1e-8); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Point3& pw, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + Matrix36 Dtf_pose; + Matrix3 Dtf_point; // calculated by transformTo if needed + const Point3 pc = + pose().transformTo(pw, Dpose ? &Dtf_pose : 0, Dpoint ? &Dtf_point : 0); + + if (pc.norm() <= 1e-8) throw("point cannot be at the center of the camera"); + + Matrix23 Dunit; // calculated by FromPoint3 if needed + Unit3 pu = Unit3::FromPoint3(Point3(pc), Dpoint ? &Dunit : 0); + + if (Dpose) *Dpose = Dunit * Dtf_pose; // 2x3 * 3x6 = 2x6 + if (Dpoint) *Dpoint = Dunit * Dtf_point; // 2x3 * 3x3 = 2x3 + return pu; +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 2> Dpoint) const { + Matrix23 Dtf_rot; + Matrix2 Dtf_point; // calculated by transformTo if needed + const Unit3 pu = pose().rotation().unrotate(pwu, Dpose ? &Dtf_rot : 0, + Dpoint ? &Dtf_point : 0); + + if (Dpose) + *Dpose << Dtf_rot, Matrix::Zero(2, 3); // 2x6 (translation part is zero) + if (Dpoint) *Dpoint = Dtf_point; // 2x2 + return pu; +} + +/* ************************************************************************* */ +Point3 SphericalCamera::backproject(const Unit3& pu, const double depth) const { + return pose().transformFrom(depth * pu); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::backprojectPointAtInfinity(const Unit3& p) const { + return pose().rotation().rotate(p); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project(const Point3& point, + OptionalJacobian<2, 6> Dcamera, + OptionalJacobian<2, 3> Dpoint) const { + return project2(point, Dcamera, Dpoint); +} + +/* ************************************************************************* */ +Vector2 SphericalCamera::reprojectionError( + const Point3& point, const Unit3& measured, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + // project point + if (Dpose || Dpoint) { + Matrix26 H_project_pose; + Matrix23 H_project_point; + Matrix22 H_error; + Unit3 projected = project2(point, H_project_pose, H_project_point); + Vector2 error = measured.errorVector(projected, boost::none, H_error); + if (Dpose) *Dpose = H_error * H_project_pose; + if (Dpoint) *Dpoint = H_error * H_project_point; + return error; + } else { + return measured.errorVector(project2(point, Dpose, Dpoint)); + } +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/geometry/SphericalCamera.h b/gtsam/geometry/SphericalCamera.h new file mode 100644 index 000000000..4880423d3 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.h @@ -0,0 +1,241 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +/** + * Empty calibration. Only needed to play well with other cameras + * (e.g., when templating functions wrt cameras), since other cameras + * have constuctors in the form ‘camera(pose,calibration)’ + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT EmptyCal { + public: + enum { dimension = 0 }; + EmptyCal() {} + virtual ~EmptyCal() = default; + using shared_ptr = boost::shared_ptr; + + /// return DOF, dimensionality of tangent space + inline static size_t Dim() { return dimension; } + + void print(const std::string& s) const { + std::cout << "empty calibration: " << s << std::endl; + } + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "EmptyCal", boost::serialization::base_object(*this)); + } +}; + +/** + * A spherical camera class that has a Pose3 and measures bearing vectors. + * The camera has an ‘Empty’ calibration and the only 6 dof are the pose + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT SphericalCamera { + public: + enum { dimension = 6 }; + + using Measurement = Unit3; + using MeasurementVector = std::vector; + using CalibrationType = EmptyCal; + + private: + Pose3 pose_; ///< 3D pose of camera + + protected: + EmptyCal::shared_ptr emptyCal_; + + public: + /// @} + /// @name Standard Constructors + /// @{ + + /// Default constructor + SphericalCamera() + : pose_(Pose3::identity()), emptyCal_(boost::make_shared()) {} + + /// Constructor with pose + explicit SphericalCamera(const Pose3& pose) + : pose_(pose), emptyCal_(boost::make_shared()) {} + + /// Constructor with empty intrinsics (needed for smart factors) + explicit SphericalCamera(const Pose3& pose, + const EmptyCal::shared_ptr& cal) + : pose_(pose), emptyCal_(cal) {} + + /// @} + /// @name Advanced Constructors + /// @{ + explicit SphericalCamera(const Vector& v) : pose_(Pose3::Expmap(v)) {} + + /// Default destructor + virtual ~SphericalCamera() = default; + + /// return shared pointer to calibration + const EmptyCal::shared_ptr& sharedCalibration() const { + return emptyCal_; + } + + /// return calibration + const EmptyCal& calibration() const { return *emptyCal_; } + + /// @} + /// @name Testable + /// @{ + + /// assert equality up to a tolerance + bool equals(const SphericalCamera& camera, double tol = 1e-9) const; + + /// print + virtual void print(const std::string& s = "SphericalCamera") const; + + /// @} + /// @name Standard Interface + /// @{ + + /// return pose, constant version + const Pose3& pose() const { return pose_; } + + /// get rotation + const Rot3& rotation() const { return pose_.rotation(); } + + /// get translation + const Point3& translation() const { return pose_.translation(); } + + // /// return pose, with derivative + // const Pose3& getPose(OptionalJacobian<6, 6> H) const; + + /// @} + /// @name Transformations and measurement functions + /// @{ + + /// Project a point into the image and check depth + std::pair projectSafe(const Point3& pw) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Point3& pw, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D direction in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 2> Dpoint = boost::none) const; + + /// backproject a 2-dimensional point to a 3-dimensional point at given depth + Point3 backproject(const Unit3& p, const double depth) const; + + /// backproject point at infinity + Unit3 backprojectPointAtInfinity(const Unit3& p) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project(const Point3& point, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Compute reprojection error for a given 3D point in world coordinates + * @param point 3D point in world coordinates + * @return the tangent space error between the projection and the measurement + */ + Vector2 reprojectionError(const Point3& point, const Unit3& measured, + OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + /// @} + + /// move a cameras according to d + SphericalCamera retract(const Vector6& d) const { + return SphericalCamera(pose().retract(d)); + } + + /// return canonical coordinate + Vector6 localCoordinates(const SphericalCamera& p) const { + return pose().localCoordinates(p.pose()); + } + + /// for Canonical + static SphericalCamera identity() { + return SphericalCamera( + Pose3::identity()); // assumes that the default constructor is valid + } + + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return Matrix34(pose_.inverse().matrix().block(0, 0, 3, 4)); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension, 1>::Constant(0.0); + } + + /// @deprecated + size_t dim() const { return 6; } + + /// @deprecated + static size_t Dim() { return 6; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(pose_); + } + + public: + GTSAM_MAKE_ALIGNED_OPERATOR_NEW +}; +// end of class SphericalCamera + +template <> +struct traits : public internal::LieGroup {}; + +template <> +struct traits : public internal::LieGroup {}; + +} // namespace gtsam diff --git a/gtsam/geometry/StereoCamera.h b/gtsam/geometry/StereoCamera.h index 3b5bdaefc..c53fc11c9 100644 --- a/gtsam/geometry/StereoCamera.h +++ b/gtsam/geometry/StereoCamera.h @@ -170,6 +170,11 @@ public: OptionalJacobian<3, 3> H2 = boost::none, OptionalJacobian<3, 0> H3 = boost::none) const; + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } + /// @} private: diff --git a/gtsam/geometry/Unit3.h b/gtsam/geometry/Unit3.h index 27d41a014..ebd5c63c9 100644 --- a/gtsam/geometry/Unit3.h +++ b/gtsam/geometry/Unit3.h @@ -23,11 +23,12 @@ #include #include #include +#include +#include #include #include #include -#include #include #include @@ -39,7 +40,7 @@ namespace gtsam { /// Represents a 3D point on a unit sphere. -class Unit3 { +class GTSAM_EXPORT Unit3 { private: @@ -96,7 +97,7 @@ public: } /// Named constructor from Point3 with optional Jacobian - GTSAM_EXPORT static Unit3 FromPoint3(const Point3& point, // + static Unit3 FromPoint3(const Point3& point, // OptionalJacobian<2, 3> H = boost::none); /** @@ -105,7 +106,7 @@ public: * std::mt19937 engine(42); * Unit3 unit = Unit3::Random(engine); */ - GTSAM_EXPORT static Unit3 Random(std::mt19937 & rng); + static Unit3 Random(std::mt19937 & rng); /// @} @@ -115,7 +116,7 @@ public: friend std::ostream& operator<<(std::ostream& os, const Unit3& pair); /// The print fuction - GTSAM_EXPORT void print(const std::string& s = std::string()) const; + void print(const std::string& s = std::string()) const; /// The equals function with tolerance bool equals(const Unit3& s, double tol = 1e-9) const { @@ -132,16 +133,16 @@ public: * tangent to the sphere at the current direction. * Provides derivatives of the basis with the two basis vectors stacked up as a 6x1. */ - GTSAM_EXPORT const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const; + const Matrix32& basis(OptionalJacobian<6, 2> H = boost::none) const; /// Return skew-symmetric associated with 3D point on unit sphere - GTSAM_EXPORT Matrix3 skew() const; + Matrix3 skew() const; /// Return unit-norm Point3 - GTSAM_EXPORT Point3 point3(OptionalJacobian<3, 2> H = boost::none) const; + Point3 point3(OptionalJacobian<3, 2> H = boost::none) const; /// Return unit-norm Vector - GTSAM_EXPORT Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const; + Vector3 unitVector(OptionalJacobian<3, 2> H = boost::none) const; /// Return scaled direction as Point3 friend Point3 operator*(double s, const Unit3& d) { @@ -149,20 +150,20 @@ public: } /// Return dot product with q - GTSAM_EXPORT double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, // + double dot(const Unit3& q, OptionalJacobian<1,2> H1 = boost::none, // OptionalJacobian<1,2> H2 = boost::none) const; /// Signed, vector-valued error between two directions /// @deprecated, errorVector has the proper derivatives, this confusingly has only the second. - GTSAM_EXPORT Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const; + Vector2 error(const Unit3& q, OptionalJacobian<2, 2> H_q = boost::none) const; /// Signed, vector-valued error between two directions /// NOTE(hayk): This method has zero derivatives if this (p) and q are orthogonal. - GTSAM_EXPORT Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, // + Vector2 errorVector(const Unit3& q, OptionalJacobian<2, 2> H_p = boost::none, // OptionalJacobian<2, 2> H_q = boost::none) const; /// Distance between two directions - GTSAM_EXPORT double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const; + double distance(const Unit3& q, OptionalJacobian<1, 2> H = boost::none) const; /// Cross-product between two Unit3s Unit3 cross(const Unit3& q) const { @@ -195,10 +196,10 @@ public: }; /// The retract function - GTSAM_EXPORT Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const; + Unit3 retract(const Vector2& v, OptionalJacobian<2,2> H = boost::none) const; /// The local coordinates function - GTSAM_EXPORT Vector2 localCoordinates(const Unit3& s) const; + Vector2 localCoordinates(const Unit3& s) const; /// @} diff --git a/gtsam/geometry/concepts.h b/gtsam/geometry/concepts.h index 207b48f56..bafb62418 100644 --- a/gtsam/geometry/concepts.h +++ b/gtsam/geometry/concepts.h @@ -72,5 +72,5 @@ private: /** Pose Concept macros */ #define GTSAM_CONCEPT_POSE_INST(T) template class gtsam::PoseConcept; -#define GTSAM_CONCEPT_POSE_TYPE(T) typedef gtsam::PoseConcept _gtsam_PoseConcept##T; +#define GTSAM_CONCEPT_POSE_TYPE(T) using _gtsam_PoseConcept##T = gtsam::PoseConcept; diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index 9baa49e8e..415aa0dc4 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -27,9 +27,6 @@ class Point2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point2Pairs { @@ -104,9 +101,6 @@ class StereoPoint2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -131,9 +125,6 @@ class Point3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point3Pairs { @@ -191,9 +182,6 @@ class Rot2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -372,9 +360,6 @@ class Rot3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -387,6 +372,9 @@ class Pose2 { Pose2(const gtsam::Rot2& r, const gtsam::Point2& t); Pose2(Vector v); + static boost::optional Align(const gtsam::Point2Pairs& abPointPairs); + static boost::optional Align(const gtsam::Matrix& a, const gtsam::Matrix& b); + // Testable void print(string s = "") const; bool equals(const gtsam::Pose2& pose, double tol) const; @@ -421,6 +409,10 @@ class Pose2 { gtsam::Point2 transformFrom(const gtsam::Point2& p) const; gtsam::Point2 transformTo(const gtsam::Point2& p) const; + // Matrix versions + Matrix transformFrom(const Matrix& points) const; + Matrix transformTo(const Matrix& points) const; + // Standard Interface double x() const; double y() const; @@ -433,13 +425,8 @@ class Pose2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; -boost::optional align(const gtsam::Point2Pairs& pairs); - #include class Pose3 { // Standard Constructors @@ -449,6 +436,9 @@ class Pose3 { Pose3(const gtsam::Pose2& pose2); Pose3(Matrix mat); + static boost::optional Align(const gtsam::Point3Pairs& abPointPairs); + static boost::optional Align(const gtsam::Matrix& a, const gtsam::Matrix& b); + // Testable void print(string s = "") const; bool equals(const gtsam::Pose3& pose, double tol) const; @@ -473,6 +463,9 @@ class Pose3 { Vector logmap(const gtsam::Pose3& pose); Matrix AdjointMap() const; Vector Adjoint(Vector xi) const; + Vector AdjointTranspose(Vector xi) const; + static Matrix adjointMap(Vector xi); + static Vector adjoint(Vector xi, Vector y); static Matrix adjointMap_(Vector xi); static Vector adjoint_(Vector xi, Vector y); static Vector adjointTranspose(Vector xi, Vector y); @@ -485,6 +478,10 @@ class Pose3 { gtsam::Point3 transformFrom(const gtsam::Point3& point) const; gtsam::Point3 transformTo(const gtsam::Point3& point) const; + // Matrix versions + Matrix transformFrom(const Matrix& points) const; + Matrix transformTo(const Matrix& points) const; + // Standard Interface gtsam::Rot3 rotation() const; gtsam::Point3 translation() const; @@ -499,9 +496,6 @@ class Pose3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Pose3Pairs { @@ -544,9 +538,6 @@ class Unit3 { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::Unit3& expected, double tol) const; }; @@ -593,7 +584,13 @@ class Cal3_S2 { // Action on Point2 gtsam::Point2 calibrate(const gtsam::Point2& p) const; + gtsam::Point2 calibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; gtsam::Point2 uncalibrate(const gtsam::Point2& p) const; + gtsam::Point2 uncalibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; // Standard Interface double fx() const; @@ -608,9 +605,6 @@ class Cal3_S2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -635,13 +629,16 @@ virtual class Cal3DS2_Base { // Action on Point2 gtsam::Point2 uncalibrate(const gtsam::Point2& p) const; + gtsam::Point2 uncalibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; gtsam::Point2 calibrate(const gtsam::Point2& p) const; + gtsam::Point2 calibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -665,9 +662,6 @@ virtual class Cal3DS2 : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -698,13 +692,16 @@ virtual class Cal3Unified : gtsam::Cal3DS2_Base { // Note: the signature of this functions differ from the functions // with equal name in the base class. gtsam::Point2 calibrate(const gtsam::Point2& p) const; + gtsam::Point2 calibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; gtsam::Point2 uncalibrate(const gtsam::Point2& p) const; + gtsam::Point2 uncalibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -727,7 +724,13 @@ class Cal3Fisheye { // Action on Point2 gtsam::Point2 calibrate(const gtsam::Point2& p) const; + gtsam::Point2 calibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; gtsam::Point2 uncalibrate(const gtsam::Point2& p) const; + gtsam::Point2 uncalibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; // Standard Interface double fx() const; @@ -747,9 +750,6 @@ class Cal3Fisheye { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -793,7 +793,13 @@ class Cal3Bundler { // Action on Point2 gtsam::Point2 calibrate(const gtsam::Point2& p) const; + gtsam::Point2 calibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; gtsam::Point2 uncalibrate(const gtsam::Point2& p) const; + gtsam::Point2 uncalibrate(const gtsam::Point2& p, + Eigen::Ref Dcal, + Eigen::Ref Dp) const; // Standard Interface double fx() const; @@ -808,9 +814,6 @@ class Cal3Bundler { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -834,19 +837,29 @@ class CalibratedCamera { // Action on Point3 gtsam::Point2 project(const gtsam::Point3& point) const; + gtsam::Point2 project(const gtsam::Point3& point, + Eigen::Ref Dcamera, + Eigen::Ref Dpoint); + gtsam::Point3 backproject(const gtsam::Point2& p, double depth) const; + gtsam::Point3 backproject(const gtsam::Point2& p, double depth, + Eigen::Ref Dresult_dpose, + Eigen::Ref Dresult_dp, + Eigen::Ref Dresult_ddepth); + static gtsam::Point2 Project(const gtsam::Point3& cameraPoint); // Standard Interface gtsam::Pose3 pose() const; double range(const gtsam::Point3& point) const; + double range(const gtsam::Point3& point, Eigen::Ref Dcamera, + Eigen::Ref Dpoint); double range(const gtsam::Pose3& pose) const; + double range(const gtsam::Pose3& point, Eigen::Ref Dcamera, + Eigen::Ref Dpose); double range(const gtsam::CalibratedCamera& camera) const; // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -854,6 +867,7 @@ template class PinholeCamera { // Standard Constructors and Named Constructors PinholeCamera(); + PinholeCamera(const gtsam::PinholeCamera other); PinholeCamera(const gtsam::Pose3& pose); PinholeCamera(const gtsam::Pose3& pose, const CALIBRATION& K); static This Level(const CALIBRATION& K, const gtsam::Pose2& pose, @@ -880,15 +894,48 @@ class PinholeCamera { static gtsam::Point2 Project(const gtsam::Point3& cameraPoint); pair projectSafe(const gtsam::Point3& pw) const; gtsam::Point2 project(const gtsam::Point3& point); + gtsam::Point2 project(const gtsam::Point3& point, + Eigen::Ref Dpose, + Eigen::Ref Dpoint, + Eigen::Ref Dcal); gtsam::Point3 backproject(const gtsam::Point2& p, double depth) const; + gtsam::Point3 backproject(const gtsam::Point2& p, double depth, + Eigen::Ref Dresult_dpose, + Eigen::Ref Dresult_dp, + Eigen::Ref Dresult_ddepth, + Eigen::Ref Dresult_dcal); double range(const gtsam::Point3& point); + double range(const gtsam::Point3& point, Eigen::Ref Dcamera, + Eigen::Ref Dpoint); double range(const gtsam::Pose3& pose); + double range(const gtsam::Pose3& pose, Eigen::Ref Dcamera, + Eigen::Ref Dpose); // enabling serialization functionality void serialize() const; +}; + +#include +class Similarity2 { + // Standard Constructors + Similarity2(); + Similarity2(double s); + Similarity2(const gtsam::Rot2& R, const gtsam::Point2& t, double s); + Similarity2(const Matrix& R, const Vector& t, double s); + Similarity2(const Matrix& T); - // enable pickling in python - void pickle() const; + gtsam::Point2 transformFrom(const gtsam::Point2& p) const; + gtsam::Pose2 transformFrom(const gtsam::Pose2& T); + + static gtsam::Similarity2 Align(const gtsam::Point2Pairs& abPointPairs); + static gtsam::Similarity2 Align(const gtsam::Pose2Pairs& abPosePairs); + + // Standard Interface + bool equals(const gtsam::Similarity2& sim, double tol) const; + Matrix matrix() const; + gtsam::Rot2& rotation(); + gtsam::Point2& translation(); + double scale() const; }; #include @@ -907,9 +954,10 @@ class Similarity3 { static gtsam::Similarity3 Align(const gtsam::Pose3Pairs& abPosePairs); // Standard Interface - const Matrix matrix() const; - const gtsam::Rot3& rotation(); - const gtsam::Point3& translation(); + bool equals(const gtsam::Similarity3& sim, double tol) const; + Matrix matrix() const; + gtsam::Rot3& rotation(); + gtsam::Point3& translation(); double scale() const; }; @@ -954,14 +1002,20 @@ class StereoCamera { static size_t Dim(); // Transformations and measurement functions - gtsam::StereoPoint2 project(const gtsam::Point3& point); + gtsam::StereoPoint2 project(const gtsam::Point3& point) const; gtsam::Point3 backproject(const gtsam::StereoPoint2& p) const; + // project with Jacobian + gtsam::StereoPoint2 project2(const gtsam::Point3& point, + Eigen::Ref H1, + Eigen::Ref H2) const; + + gtsam::Point3 backproject2(const gtsam::StereoPoint2& p, + Eigen::Ref H1, + Eigen::Ref H2) const; + // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -971,27 +1025,34 @@ class StereoCamera { gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3_S2* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3DS2* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3Bundler* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses, gtsam::Cal3_S2* sharedCal, const gtsam::Point2Vector& measurements, diff --git a/gtsam/geometry/tests/testCal3Bundler.cpp b/gtsam/geometry/tests/testCal3Bundler.cpp index cd576f900..020cab2f9 100644 --- a/gtsam/geometry/tests/testCal3Bundler.cpp +++ b/gtsam/geometry/tests/testCal3Bundler.cpp @@ -160,7 +160,7 @@ TEST(Cal3Bundler, retract) { } /* ************************************************************************* */ -TEST(Cal3_S2, Print) { +TEST(Cal3Bundler, Print) { Cal3Bundler cal(1, 2, 3, 4, 5); std::stringstream os; os << "f: " << cal.fx() << ", k1: " << cal.k1() << ", k2: " << cal.k2() diff --git a/gtsam/geometry/tests/testPose2.cpp b/gtsam/geometry/tests/testPose2.cpp index d17dc7689..de779cc75 100644 --- a/gtsam/geometry/tests/testPose2.cpp +++ b/gtsam/geometry/tests/testPose2.cpp @@ -717,123 +717,112 @@ TEST( Pose2, range_pose ) /* ************************************************************************* */ TEST(Pose2, align_1) { - Pose2 expected(Rot2::fromAngle(0), Point2(10,10)); - - vector correspondences; - Point2Pair pq1(make_pair(Point2(0,0), Point2(10,10))); - Point2Pair pq2(make_pair(Point2(20,10), Point2(30,20))); - correspondences += pq1, pq2; - - boost::optional actual = align(correspondences); - EXPECT(assert_equal(expected, *actual)); + Pose2 expected(Rot2::fromAngle(0), Point2(10, 10)); + Point2Pairs ab_pairs {{Point2(10, 10), Point2(0, 0)}, + {Point2(30, 20), Point2(20, 10)}}; + boost::optional aTb = Pose2::Align(ab_pairs); + EXPECT(assert_equal(expected, *aTb)); } TEST(Pose2, align_2) { - Point2 t(20,10); + Point2 t(20, 10); Rot2 R = Rot2::fromAngle(M_PI/2.0); Pose2 expected(R, t); - vector correspondences; - Point2 p1(0,0), p2(10,0); - Point2 q1 = expected.transformFrom(p1), q2 = expected.transformFrom(p2); - EXPECT(assert_equal(Point2(20,10),q1)); - EXPECT(assert_equal(Point2(20,20),q2)); - Point2Pair pq1(make_pair(p1, q1)); - Point2Pair pq2(make_pair(p2, q2)); - correspondences += pq1, pq2; + Point2 b1(0, 0), b2(10, 0); + Point2Pairs ab_pairs {{expected.transformFrom(b1), b1}, + {expected.transformFrom(b2), b2}}; - boost::optional actual = align(correspondences); - EXPECT(assert_equal(expected, *actual)); + boost::optional aTb = Pose2::Align(ab_pairs); + EXPECT(assert_equal(expected, *aTb)); } namespace align_3 { - Point2 t(10,10); + Point2 t(10, 10); Pose2 expected(Rot2::fromAngle(2*M_PI/3), t); - Point2 p1(0,0), p2(10,0), p3(10,10); - Point2 q1 = expected.transformFrom(p1), q2 = expected.transformFrom(p2), q3 = expected.transformFrom(p3); + Point2 b1(0, 0), b2(10, 0), b3(10, 10); + Point2 a1 = expected.transformFrom(b1), + a2 = expected.transformFrom(b2), + a3 = expected.transformFrom(b3); } TEST(Pose2, align_3) { using namespace align_3; - vector correspondences; - Point2Pair pq1(make_pair(p1, q1)); - Point2Pair pq2(make_pair(p2, q2)); - Point2Pair pq3(make_pair(p3, q3)); - correspondences += pq1, pq2, pq3; + Point2Pairs ab_pairs; + Point2Pair ab1(make_pair(a1, b1)); + Point2Pair ab2(make_pair(a2, b2)); + Point2Pair ab3(make_pair(a3, b3)); + ab_pairs += ab1, ab2, ab3; - boost::optional actual = align(correspondences); - EXPECT(assert_equal(expected, *actual)); + boost::optional aTb = Pose2::Align(ab_pairs); + EXPECT(assert_equal(expected, *aTb)); } namespace { /* ************************************************************************* */ // Prototype code to align two triangles using a rigid transform /* ************************************************************************* */ - struct Triangle { size_t i_,j_,k_;}; + struct Triangle { size_t i_, j_, k_;}; - boost::optional align2(const Point2Vector& ps, const Point2Vector& qs, + boost::optional align2(const Point2Vector& as, const Point2Vector& bs, const pair& trianglePair) { const Triangle& t1 = trianglePair.first, t2 = trianglePair.second; - vector correspondences; - correspondences += make_pair(ps[t1.i_],qs[t2.i_]), make_pair(ps[t1.j_],qs[t2.j_]), make_pair(ps[t1.k_],qs[t2.k_]); - return align(correspondences); + Point2Pairs ab_pairs = {{as[t1.i_], bs[t2.i_]}, + {as[t1.j_], bs[t2.j_]}, + {as[t1.k_], bs[t2.k_]}}; + return Pose2::Align(ab_pairs); } } TEST(Pose2, align_4) { using namespace align_3; - Point2Vector ps,qs; - ps += p1, p2, p3; - qs += q3, q1, q2; // note in 3,1,2 order ! + Point2Vector as, bs; + as += a1, a2, a3; + bs += b3, b1, b2; // note in 3,1,2 order ! Triangle t1; t1.i_=0; t1.j_=1; t1.k_=2; Triangle t2; t2.i_=1; t2.j_=2; t2.k_=0; - boost::optional actual = align2(ps, qs, make_pair(t1,t2)); + boost::optional actual = align2(as, bs, {t1, t2}); EXPECT(assert_equal(expected, *actual)); } //****************************************************************************** +namespace { +Pose2 id; Pose2 T1(M_PI / 4.0, Point2(sqrt(0.5), sqrt(0.5))); Pose2 T2(M_PI / 2.0, Point2(0.0, 2.0)); +} // namespace //****************************************************************************** -TEST(Pose2 , Invariants) { - Pose2 id; - - EXPECT(check_group_invariants(id,id)); - EXPECT(check_group_invariants(id,T1)); - EXPECT(check_group_invariants(T2,id)); - EXPECT(check_group_invariants(T2,T1)); - - EXPECT(check_manifold_invariants(id,id)); - EXPECT(check_manifold_invariants(id,T1)); - EXPECT(check_manifold_invariants(T2,id)); - EXPECT(check_manifold_invariants(T2,T1)); +TEST(Pose2, Invariants) { + EXPECT(check_group_invariants(id, id)); + EXPECT(check_group_invariants(id, T1)); + EXPECT(check_group_invariants(T2, id)); + EXPECT(check_group_invariants(T2, T1)); + EXPECT(check_manifold_invariants(id, id)); + EXPECT(check_manifold_invariants(id, T1)); + EXPECT(check_manifold_invariants(T2, id)); + EXPECT(check_manifold_invariants(T2, T1)); } //****************************************************************************** -TEST(Pose2 , LieGroupDerivatives) { - Pose2 id; - - CHECK_LIE_GROUP_DERIVATIVES(id,id); - CHECK_LIE_GROUP_DERIVATIVES(id,T2); - CHECK_LIE_GROUP_DERIVATIVES(T2,id); - CHECK_LIE_GROUP_DERIVATIVES(T2,T1); - +TEST(Pose2, LieGroupDerivatives) { + CHECK_LIE_GROUP_DERIVATIVES(id, id); + CHECK_LIE_GROUP_DERIVATIVES(id, T2); + CHECK_LIE_GROUP_DERIVATIVES(T2, id); + CHECK_LIE_GROUP_DERIVATIVES(T2, T1); } //****************************************************************************** -TEST(Pose2 , ChartDerivatives) { - Pose2 id; - - CHECK_CHART_DERIVATIVES(id,id); - CHECK_CHART_DERIVATIVES(id,T2); - CHECK_CHART_DERIVATIVES(T2,id); - CHECK_CHART_DERIVATIVES(T2,T1); +TEST(Pose2, ChartDerivatives) { + CHECK_CHART_DERIVATIVES(id, id); + CHECK_CHART_DERIVATIVES(id, T2); + CHECK_CHART_DERIVATIVES(T2, id); + CHECK_CHART_DERIVATIVES(T2, T1); } //****************************************************************************** diff --git a/gtsam/geometry/tests/testPose3.cpp b/gtsam/geometry/tests/testPose3.cpp index 7c1fa81e6..e1d3d5ea2 100644 --- a/gtsam/geometry/tests/testPose3.cpp +++ b/gtsam/geometry/tests/testPose3.cpp @@ -145,6 +145,81 @@ TEST(Pose3, Adjoint_full) EXPECT(assert_equal(expected3, Pose3::Expmap(xiprime3), 1e-6)); } +/* ************************************************************************* */ +// Check Adjoint numerical derivatives +TEST(Pose3, Adjoint_jacobians) +{ + Vector6 xi = (Vector6() << 0.1, 1.2, 2.3, 3.1, 1.4, 4.5).finished(); + + // Check evaluation sanity check + EQUALITY(static_cast(T.AdjointMap() * xi), T.Adjoint(xi)); + EQUALITY(static_cast(T2.AdjointMap() * xi), T2.Adjoint(xi)); + EQUALITY(static_cast(T3.AdjointMap() * xi), T3.Adjoint(xi)); + + // Check jacobians + Matrix6 actualH1, actualH2, expectedH1, expectedH2; + std::function Adjoint_proxy = + [&](const Pose3& T, const Vector6& xi) { return T.Adjoint(xi); }; + + T.Adjoint(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(Adjoint_proxy, T, xi); + expectedH2 = numericalDerivative22(Adjoint_proxy, T, xi); + EXPECT(assert_equal(expectedH1, actualH1)); + EXPECT(assert_equal(expectedH2, actualH2)); + + T2.Adjoint(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(Adjoint_proxy, T2, xi); + expectedH2 = numericalDerivative22(Adjoint_proxy, T2, xi); + EXPECT(assert_equal(expectedH1, actualH1)); + EXPECT(assert_equal(expectedH2, actualH2)); + + T3.Adjoint(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(Adjoint_proxy, T3, xi); + expectedH2 = numericalDerivative22(Adjoint_proxy, T3, xi); + EXPECT(assert_equal(expectedH1, actualH1)); + EXPECT(assert_equal(expectedH2, actualH2)); +} + +/* ************************************************************************* */ +// Check AdjointTranspose and jacobians +TEST(Pose3, AdjointTranspose) +{ + Vector6 xi = (Vector6() << 0.1, 1.2, 2.3, 3.1, 1.4, 4.5).finished(); + + // Check evaluation + EQUALITY(static_cast(T.AdjointMap().transpose() * xi), + T.AdjointTranspose(xi)); + EQUALITY(static_cast(T2.AdjointMap().transpose() * xi), + T2.AdjointTranspose(xi)); + EQUALITY(static_cast(T3.AdjointMap().transpose() * xi), + T3.AdjointTranspose(xi)); + + // Check jacobians + Matrix6 actualH1, actualH2, expectedH1, expectedH2; + std::function AdjointTranspose_proxy = + [&](const Pose3& T, const Vector6& xi) { + return T.AdjointTranspose(xi); + }; + + T.AdjointTranspose(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(AdjointTranspose_proxy, T, xi); + expectedH2 = numericalDerivative22(AdjointTranspose_proxy, T, xi); + EXPECT(assert_equal(expectedH1, actualH1, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2)); + + T2.AdjointTranspose(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(AdjointTranspose_proxy, T2, xi); + expectedH2 = numericalDerivative22(AdjointTranspose_proxy, T2, xi); + EXPECT(assert_equal(expectedH1, actualH1, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2)); + + T3.AdjointTranspose(xi, actualH1, actualH2); + expectedH1 = numericalDerivative21(AdjointTranspose_proxy, T3, xi); + expectedH2 = numericalDerivative22(AdjointTranspose_proxy, T3, xi); + EXPECT(assert_equal(expectedH1, actualH1, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2)); +} + /* ************************************************************************* */ // assert that T*wedge(xi)*T^-1 is equal to wedge(Ad_T(xi)) TEST(Pose3, Adjoint_hat) @@ -837,16 +912,20 @@ Vector6 testDerivAdjoint(const Vector6& xi, const Vector6& v) { } TEST( Pose3, adjoint) { - Vector expected = testDerivAdjoint(screwPose3::xi, screwPose3::xi); + Vector6 v = (Vector6() << 1, 2, 3, 4, 5, 6).finished(); + Vector expected = testDerivAdjoint(screwPose3::xi, v); - Matrix actualH; - Vector actual = Pose3::adjoint(screwPose3::xi, screwPose3::xi, actualH); + Matrix actualH1, actualH2; + Vector actual = Pose3::adjoint(screwPose3::xi, v, actualH1, actualH2); - Matrix numericalH = numericalDerivative21( - testDerivAdjoint, screwPose3::xi, screwPose3::xi, 1e-5); + Matrix numericalH1 = numericalDerivative21( + testDerivAdjoint, screwPose3::xi, v, 1e-5); + Matrix numericalH2 = numericalDerivative22( + testDerivAdjoint, screwPose3::xi, v, 1e-5); EXPECT(assert_equal(expected,actual,1e-5)); - EXPECT(assert_equal(numericalH,actualH,1e-5)); + EXPECT(assert_equal(numericalH1,actualH1,1e-5)); + EXPECT(assert_equal(numericalH2,actualH2,1e-5)); } /* ************************************************************************* */ @@ -859,14 +938,17 @@ TEST( Pose3, adjointTranspose) { Vector v = (Vector(6) << 0.04, 0.05, 0.06, 4.0, 5.0, 6.0).finished(); Vector expected = testDerivAdjointTranspose(xi, v); - Matrix actualH; - Vector actual = Pose3::adjointTranspose(xi, v, actualH); + Matrix actualH1, actualH2; + Vector actual = Pose3::adjointTranspose(xi, v, actualH1, actualH2); - Matrix numericalH = numericalDerivative21( + Matrix numericalH1 = numericalDerivative21( + testDerivAdjointTranspose, xi, v, 1e-5); + Matrix numericalH2 = numericalDerivative22( testDerivAdjointTranspose, xi, v, 1e-5); EXPECT(assert_equal(expected,actual,1e-15)); - EXPECT(assert_equal(numericalH,actualH,1e-5)); + EXPECT(assert_equal(numericalH1,actualH1,1e-5)); + EXPECT(assert_equal(numericalH2,actualH2,1e-5)); } /* ************************************************************************* */ diff --git a/gtsam/geometry/tests/testQuaternion.cpp b/gtsam/geometry/tests/testQuaternion.cpp index e862b94ad..281c71f7c 100644 --- a/gtsam/geometry/tests/testQuaternion.cpp +++ b/gtsam/geometry/tests/testQuaternion.cpp @@ -80,12 +80,6 @@ TEST(Quaternion , Compose) { EXPECT(traits::Equals(expected, actual)); } -//****************************************************************************** -Vector3 Q_z_axis(0, 0, 1); -Q id(Eigen::AngleAxisd(0, Q_z_axis)); -Q R1(Eigen::AngleAxisd(1, Q_z_axis)); -Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0))); - //****************************************************************************** TEST(Quaternion , Between) { Vector3 z_axis(0, 0, 1); @@ -108,7 +102,15 @@ TEST(Quaternion , Inverse) { } //****************************************************************************** -TEST(Quaternion , Invariants) { +namespace { +Vector3 Q_z_axis(0, 0, 1); +Q id(Eigen::AngleAxisd(0, Q_z_axis)); +Q R1(Eigen::AngleAxisd(1, Q_z_axis)); +Q R2(Eigen::AngleAxisd(2, Vector3(0, 1, 0))); +} // namespace + +//****************************************************************************** +TEST(Quaternion, Invariants) { EXPECT(check_group_invariants(id, id)); EXPECT(check_group_invariants(id, R1)); EXPECT(check_group_invariants(R2, id)); @@ -121,7 +123,7 @@ TEST(Quaternion , Invariants) { } //****************************************************************************** -TEST(Quaternion , LieGroupDerivatives) { +TEST(Quaternion, LieGroupDerivatives) { CHECK_LIE_GROUP_DERIVATIVES(id, id); CHECK_LIE_GROUP_DERIVATIVES(id, R2); CHECK_LIE_GROUP_DERIVATIVES(R2, id); @@ -129,7 +131,7 @@ TEST(Quaternion , LieGroupDerivatives) { } //****************************************************************************** -TEST(Quaternion , ChartDerivatives) { +TEST(Quaternion, ChartDerivatives) { CHECK_CHART_DERIVATIVES(id, id); CHECK_CHART_DERIVATIVES(id, R2); CHECK_CHART_DERIVATIVES(R2, id); diff --git a/gtsam/geometry/tests/testRot2.cpp b/gtsam/geometry/tests/testRot2.cpp index 7cd27a9da..5a087edcd 100644 --- a/gtsam/geometry/tests/testRot2.cpp +++ b/gtsam/geometry/tests/testRot2.cpp @@ -156,44 +156,39 @@ TEST( Rot2, relativeBearing ) } //****************************************************************************** +namespace { +Rot2 id; Rot2 T1(0.1); Rot2 T2(0.2); +} // namespace //****************************************************************************** -TEST(Rot2 , Invariants) { - Rot2 id; - - EXPECT(check_group_invariants(id,id)); - EXPECT(check_group_invariants(id,T1)); - EXPECT(check_group_invariants(T2,id)); - EXPECT(check_group_invariants(T2,T1)); - - EXPECT(check_manifold_invariants(id,id)); - EXPECT(check_manifold_invariants(id,T1)); - EXPECT(check_manifold_invariants(T2,id)); - EXPECT(check_manifold_invariants(T2,T1)); +TEST(Rot2, Invariants) { + EXPECT(check_group_invariants(id, id)); + EXPECT(check_group_invariants(id, T1)); + EXPECT(check_group_invariants(T2, id)); + EXPECT(check_group_invariants(T2, T1)); + EXPECT(check_manifold_invariants(id, id)); + EXPECT(check_manifold_invariants(id, T1)); + EXPECT(check_manifold_invariants(T2, id)); + EXPECT(check_manifold_invariants(T2, T1)); } //****************************************************************************** -TEST(Rot2 , LieGroupDerivatives) { - Rot2 id; - - CHECK_LIE_GROUP_DERIVATIVES(id,id); - CHECK_LIE_GROUP_DERIVATIVES(id,T2); - CHECK_LIE_GROUP_DERIVATIVES(T2,id); - CHECK_LIE_GROUP_DERIVATIVES(T2,T1); - +TEST(Rot2, LieGroupDerivatives) { + CHECK_LIE_GROUP_DERIVATIVES(id, id); + CHECK_LIE_GROUP_DERIVATIVES(id, T2); + CHECK_LIE_GROUP_DERIVATIVES(T2, id); + CHECK_LIE_GROUP_DERIVATIVES(T2, T1); } //****************************************************************************** -TEST(Rot2 , ChartDerivatives) { - Rot2 id; - - CHECK_CHART_DERIVATIVES(id,id); - CHECK_CHART_DERIVATIVES(id,T2); - CHECK_CHART_DERIVATIVES(T2,id); - CHECK_CHART_DERIVATIVES(T2,T1); +TEST(Rot2, ChartDerivatives) { + CHECK_CHART_DERIVATIVES(id, id); + CHECK_CHART_DERIVATIVES(id, T2); + CHECK_CHART_DERIVATIVES(T2, id); + CHECK_CHART_DERIVATIVES(T2, T1); } /* ************************************************************************* */ diff --git a/gtsam/geometry/tests/testRot3.cpp b/gtsam/geometry/tests/testRot3.cpp index 34f90c8cc..1df342d57 100644 --- a/gtsam/geometry/tests/testRot3.cpp +++ b/gtsam/geometry/tests/testRot3.cpp @@ -122,6 +122,21 @@ TEST( Rot3, AxisAngle) CHECK(assert_equal(expected,actual3,1e-5)); } +/* ************************************************************************* */ +TEST( Rot3, AxisAngle2) +{ + // constructor from a rotation matrix, as doubles in *row-major* order. + Rot3 R1(-0.999957, 0.00922903, 0.00203116, 0.00926964, 0.999739, 0.0208927, -0.0018374, 0.0209105, -0.999781); + + Unit3 actualAxis; + double actualAngle; + // convert Rot3 to quaternion using GTSAM + std::tie(actualAxis, actualAngle) = R1.axisAngle(); + + double expectedAngle = 3.1396582; + CHECK(assert_equal(expectedAngle, actualAngle, 1e-5)); +} + /* ************************************************************************* */ TEST( Rot3, Rodrigues) { @@ -181,13 +196,13 @@ TEST( Rot3, retract) } /* ************************************************************************* */ -TEST(Rot3, log) { +TEST( Rot3, log) { static const double PI = boost::math::constants::pi(); Vector w; Rot3 R; #define CHECK_OMEGA(X, Y, Z) \ - w = (Vector(3) << X, Y, Z).finished(); \ + w = (Vector(3) << (X), (Y), (Z)).finished(); \ R = Rot3::Rodrigues(w); \ EXPECT(assert_equal(w, Rot3::Logmap(R), 1e-12)); @@ -219,17 +234,17 @@ TEST(Rot3, log) { CHECK_OMEGA(0, 0, PI) // Windows and Linux have flipped sign in quaternion mode -#if !defined(__APPLE__) && defined(GTSAM_USE_QUATERNIONS) +//#if !defined(__APPLE__) && defined(GTSAM_USE_QUATERNIONS) w = (Vector(3) << x * PI, y * PI, z * PI).finished(); R = Rot3::Rodrigues(w); EXPECT(assert_equal(Vector(-w), Rot3::Logmap(R), 1e-12)); -#else - CHECK_OMEGA(x * PI, y * PI, z * PI) -#endif +//#else +// CHECK_OMEGA(x * PI, y * PI, z * PI) +//#endif // Check 360 degree rotations #define CHECK_OMEGA_ZERO(X, Y, Z) \ - w = (Vector(3) << X, Y, Z).finished(); \ + w = (Vector(3) << (X), (Y), (Z)).finished(); \ R = Rot3::Rodrigues(w); \ EXPECT(assert_equal((Vector)Z_3x1, Rot3::Logmap(R))); @@ -247,15 +262,15 @@ TEST(Rot3, log) { // Rot3's Logmap returns different, but equivalent compacted // axis-angle vectors depending on whether Rot3 is implemented // by Quaternions or SO3. - #if defined(GTSAM_USE_QUATERNIONS) - // Quaternion bounds angle to [-pi, pi] resulting in ~179.9 degrees - EXPECT(assert_equal(Vector3(0.264451979, -0.742197651, -3.04098211), +#if defined(GTSAM_USE_QUATERNIONS) + // Quaternion bounds angle to [-pi, pi] resulting in ~179.9 degrees + EXPECT(assert_equal(Vector3(0.264451979, -0.742197651, -3.04098211), + (Vector)Rot3::Logmap(Rlund), 1e-8)); +#else + // SO3 will be approximate because of the non-orthogonality + EXPECT(assert_equal(Vector3(0.264452, -0.742197708, -3.04098184), (Vector)Rot3::Logmap(Rlund), 1e-8)); - #else - // SO3 does not bound angle resulting in ~180.1 degrees - EXPECT(assert_equal(Vector3(-0.264544406, 0.742217405, 3.04117314), - (Vector)Rot3::Logmap(Rlund), 1e-8)); - #endif +#endif } /* ************************************************************************* */ @@ -625,46 +640,44 @@ TEST( Rot3, slerp) } //****************************************************************************** +namespace { +Rot3 id; Rot3 T1(Rot3::AxisAngle(Vector3(0, 0, 1), 1)); Rot3 T2(Rot3::AxisAngle(Vector3(0, 1, 0), 2)); +} // namespace //****************************************************************************** -TEST(Rot3 , Invariants) { - Rot3 id; +TEST(Rot3, Invariants) { + EXPECT(check_group_invariants(id, id)); + EXPECT(check_group_invariants(id, T1)); + EXPECT(check_group_invariants(T2, id)); + EXPECT(check_group_invariants(T2, T1)); + EXPECT(check_group_invariants(T1, T2)); - EXPECT(check_group_invariants(id,id)); - EXPECT(check_group_invariants(id,T1)); - EXPECT(check_group_invariants(T2,id)); - EXPECT(check_group_invariants(T2,T1)); - EXPECT(check_group_invariants(T1,T2)); - - EXPECT(check_manifold_invariants(id,id)); - EXPECT(check_manifold_invariants(id,T1)); - EXPECT(check_manifold_invariants(T2,id)); - EXPECT(check_manifold_invariants(T2,T1)); - EXPECT(check_manifold_invariants(T1,T2)); + EXPECT(check_manifold_invariants(id, id)); + EXPECT(check_manifold_invariants(id, T1)); + EXPECT(check_manifold_invariants(T2, id)); + EXPECT(check_manifold_invariants(T2, T1)); + EXPECT(check_manifold_invariants(T1, T2)); } //****************************************************************************** -TEST(Rot3 , LieGroupDerivatives) { - Rot3 id; - - CHECK_LIE_GROUP_DERIVATIVES(id,id); - CHECK_LIE_GROUP_DERIVATIVES(id,T2); - CHECK_LIE_GROUP_DERIVATIVES(T2,id); - CHECK_LIE_GROUP_DERIVATIVES(T1,T2); - CHECK_LIE_GROUP_DERIVATIVES(T2,T1); +TEST(Rot3, LieGroupDerivatives) { + CHECK_LIE_GROUP_DERIVATIVES(id, id); + CHECK_LIE_GROUP_DERIVATIVES(id, T2); + CHECK_LIE_GROUP_DERIVATIVES(T2, id); + CHECK_LIE_GROUP_DERIVATIVES(T1, T2); + CHECK_LIE_GROUP_DERIVATIVES(T2, T1); } //****************************************************************************** -TEST(Rot3 , ChartDerivatives) { - Rot3 id; +TEST(Rot3, ChartDerivatives) { if (ROT3_DEFAULT_COORDINATES_MODE == Rot3::EXPMAP) { - CHECK_CHART_DERIVATIVES(id,id); - CHECK_CHART_DERIVATIVES(id,T2); - CHECK_CHART_DERIVATIVES(T2,id); - CHECK_CHART_DERIVATIVES(T1,T2); - CHECK_CHART_DERIVATIVES(T2,T1); + CHECK_CHART_DERIVATIVES(id, id); + CHECK_CHART_DERIVATIVES(id, T2); + CHECK_CHART_DERIVATIVES(T2, id); + CHECK_CHART_DERIVATIVES(T1, T2); + CHECK_CHART_DERIVATIVES(T2, T1); } } diff --git a/gtsam/geometry/tests/testSO3.cpp b/gtsam/geometry/tests/testSO3.cpp index 910d482b0..96c1cce32 100644 --- a/gtsam/geometry/tests/testSO3.cpp +++ b/gtsam/geometry/tests/testSO3.cpp @@ -67,10 +67,12 @@ TEST(SO3, ClosestTo) { } //****************************************************************************** +namespace { SO3 id; Vector3 z_axis(0, 0, 1), v2(1, 2, 0), v3(1, 2, 3); SO3 R1(Eigen::AngleAxisd(0.1, z_axis)); SO3 R2(Eigen::AngleAxisd(0.2, z_axis)); +} // namespace /* ************************************************************************* */ TEST(SO3, ChordalMean) { @@ -79,16 +81,16 @@ TEST(SO3, ChordalMean) { } //****************************************************************************** +// Check that Hat specialization is equal to dynamic version TEST(SO3, Hat) { - // Check that Hat specialization is equal to dynamic version EXPECT(assert_equal(SO3::Hat(z_axis), SOn::Hat(z_axis))); EXPECT(assert_equal(SO3::Hat(v2), SOn::Hat(v2))); EXPECT(assert_equal(SO3::Hat(v3), SOn::Hat(v3))); } //****************************************************************************** +// Check that Hat specialization is equal to dynamic version TEST(SO3, Vee) { - // Check that Hat specialization is equal to dynamic version auto X1 = SOn::Hat(z_axis), X2 = SOn::Hat(v2), X3 = SOn::Hat(v3); EXPECT(assert_equal(SO3::Vee(X1), SOn::Vee(X1))); EXPECT(assert_equal(SO3::Vee(X2), SOn::Vee(X2))); diff --git a/gtsam/geometry/tests/testSO4.cpp b/gtsam/geometry/tests/testSO4.cpp index 5486755f7..fa550723a 100644 --- a/gtsam/geometry/tests/testSO4.cpp +++ b/gtsam/geometry/tests/testSO4.cpp @@ -48,6 +48,14 @@ TEST(SO4, Concept) { } //****************************************************************************** +TEST(SO4, Random) { + std::mt19937 rng(42); + auto Q = SO4::Random(rng); + EXPECT_LONGS_EQUAL(4, Q.matrix().rows()); +} + +//****************************************************************************** +namespace { SO4 id; Vector6 v1 = (Vector(6) << 0, 0, 0, 0.1, 0, 0).finished(); SO4 Q1 = SO4::Expmap(v1); @@ -55,13 +63,8 @@ Vector6 v2 = (Vector(6) << 0.00, 0.00, 0.00, 0.01, 0.02, 0.03).finished(); SO4 Q2 = SO4::Expmap(v2); Vector6 v3 = (Vector(6) << 1, 2, 3, 4, 5, 6).finished(); SO4 Q3 = SO4::Expmap(v3); +} // namespace -//****************************************************************************** -TEST(SO4, Random) { - std::mt19937 rng(42); - auto Q = SO4::Random(rng); - EXPECT_LONGS_EQUAL(4, Q.matrix().rows()); -} //****************************************************************************** TEST(SO4, Expmap) { // If we do exponential map in SO(3) subgroup, topleft should be equal to R1. @@ -84,16 +87,16 @@ TEST(SO4, Expmap) { } //****************************************************************************** +// Check that Hat specialization is equal to dynamic version TEST(SO4, Hat) { - // Check that Hat specialization is equal to dynamic version EXPECT(assert_equal(SO4::Hat(v1), SOn::Hat(v1))); EXPECT(assert_equal(SO4::Hat(v2), SOn::Hat(v2))); EXPECT(assert_equal(SO4::Hat(v3), SOn::Hat(v3))); } //****************************************************************************** +// Check that Hat specialization is equal to dynamic version TEST(SO4, Vee) { - // Check that Hat specialization is equal to dynamic version auto X1 = SOn::Hat(v1), X2 = SOn::Hat(v2), X3 = SOn::Hat(v3); EXPECT(assert_equal(SO4::Vee(X1), SOn::Vee(X1))); EXPECT(assert_equal(SO4::Vee(X2), SOn::Vee(X2))); @@ -116,8 +119,8 @@ TEST(SO4, Retract) { } //****************************************************************************** +// Check that Cayley is identical to dynamic version TEST(SO4, Local) { - // Check that Cayley is identical to dynamic version EXPECT( assert_equal(id.localCoordinates(Q1), SOn(4).localCoordinates(SOn(Q1)))); EXPECT( @@ -166,9 +169,7 @@ TEST(SO4, vec) { Matrix actualH; const Vector16 actual = Q2.vec(actualH); EXPECT(assert_equal(expected, actual)); - std::function f = [](const SO4& Q) { - return Q.vec(); - }; + std::function f = [](const SO4& Q) { return Q.vec(); }; const Matrix numericalH = numericalDerivative11(f, Q2, 1e-5); EXPECT(assert_equal(numericalH, actualH)); } diff --git a/gtsam/geometry/tests/testSimilarity2.cpp b/gtsam/geometry/tests/testSimilarity2.cpp new file mode 100644 index 000000000..dd4fd0efd --- /dev/null +++ b/gtsam/geometry/tests/testSimilarity2.cpp @@ -0,0 +1,66 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSimilarity2.cpp + * @brief Unit tests for Similarity2 class + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include + +#include + +using namespace std::placeholders; +using namespace gtsam; +using namespace std; + +GTSAM_CONCEPT_TESTABLE_INST(Similarity2) + +static const Point2 P(0.2, 0.7); +static const Rot2 R = Rot2::fromAngle(0.3); +static const double s = 4; + +const double degree = M_PI / 180; + +//****************************************************************************** +TEST(Similarity2, Concepts) { + BOOST_CONCEPT_ASSERT((IsGroup)); + BOOST_CONCEPT_ASSERT((IsManifold)); + BOOST_CONCEPT_ASSERT((IsLieGroup)); +} + +//****************************************************************************** +TEST(Similarity2, Constructors) { + Similarity2 sim2_Construct1; + Similarity2 sim2_Construct2(s); + Similarity2 sim2_Construct3(R, P, s); + Similarity2 sim2_Construct4(R.matrix(), P, s); +} + +//****************************************************************************** +TEST(Similarity2, Getters) { + Similarity2 sim2_default; + EXPECT(assert_equal(Rot2(), sim2_default.rotation())); + EXPECT(assert_equal(Point2(0, 0), sim2_default.translation())); + EXPECT_DOUBLES_EQUAL(1.0, sim2_default.scale(), 1e-9); +} + +//****************************************************************************** +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +//****************************************************************************** diff --git a/gtsam/geometry/tests/testSimilarity3.cpp b/gtsam/geometry/tests/testSimilarity3.cpp index 428422072..7a134f6ef 100644 --- a/gtsam/geometry/tests/testSimilarity3.cpp +++ b/gtsam/geometry/tests/testSimilarity3.cpp @@ -458,18 +458,18 @@ TEST(Similarity3, Optimization2) { Values result; result = LevenbergMarquardtOptimizer(graph, initial).optimize(); //result.print("Optimized Estimate\n"); - Pose3 p1, p2, p3, p4, p5; - p1 = Pose3(result.at(X(1))); - p2 = Pose3(result.at(X(2))); - p3 = Pose3(result.at(X(3))); - p4 = Pose3(result.at(X(4))); - p5 = Pose3(result.at(X(5))); + Similarity3 p1, p2, p3, p4, p5; + p1 = result.at(X(1)); + p2 = result.at(X(2)); + p3 = result.at(X(3)); + p4 = result.at(X(4)); + p5 = result.at(X(5)); - //p1.print("Pose1"); - //p2.print("Pose2"); - //p3.print("Pose3"); - //p4.print("Pose4"); - //p5.print("Pose5"); + //p1.print("Similarity1"); + //p2.print("Similarity2"); + //p3.print("Similarity3"); + //p4.print("Similarity4"); + //p5.print("Similarity5"); Similarity3 expected(0.7); EXPECT(assert_equal(expected, result.at(X(5)), 0.4)); diff --git a/gtsam/geometry/tests/testSimpleCamera.cpp b/gtsam/geometry/tests/testSimpleCamera.cpp index 18a25c553..173ccf05b 100644 --- a/gtsam/geometry/tests/testSimpleCamera.cpp +++ b/gtsam/geometry/tests/testSimpleCamera.cpp @@ -26,7 +26,7 @@ using namespace std; using namespace gtsam; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 static const Cal3_S2 K(625, 625, 0, 0, 0); diff --git a/gtsam/geometry/tests/testSphericalCamera.cpp b/gtsam/geometry/tests/testSphericalCamera.cpp new file mode 100644 index 000000000..4bc851f35 --- /dev/null +++ b/gtsam/geometry/tests/testSphericalCamera.cpp @@ -0,0 +1,163 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include +#include +#include +#include + +#include +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +typedef SphericalCamera Camera; + +// static const Cal3_S2 K(625, 625, 0, 0, 0); +// +static const Pose3 pose(Rot3(Vector3(1, -1, -1).asDiagonal()), + Point3(0, 0, 0.5)); +static const Camera camera(pose); +// +static const Pose3 pose1(Rot3(), Point3(0, 1, 0.5)); +static const Camera camera1(pose1); + +static const Point3 point1(-0.08, -0.08, 0.0); +static const Point3 point2(-0.08, 0.08, 0.0); +static const Point3 point3(0.08, 0.08, 0.0); +static const Point3 point4(0.08, -0.08, 0.0); + +// manually computed in matlab +static const Unit3 bearing1(-0.156054862928174, 0.156054862928174, + 0.975342893301088); +static const Unit3 bearing2(-0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing3(0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing4(0.156054862928174, 0.156054862928174, + 0.975342893301088); + +static double depth = 0.512640224719052; +/* ************************************************************************* */ +TEST(SphericalCamera, constructor) { + EXPECT(assert_equal(pose, camera.pose())); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, project) { + // expected from manual calculation in Matlab + EXPECT(assert_equal(camera.project(point1), bearing1)); + EXPECT(assert_equal(camera.project(point2), bearing2)); + EXPECT(assert_equal(camera.project(point3), bearing3)); + EXPECT(assert_equal(camera.project(point4), bearing4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject) { + EXPECT(assert_equal(camera.backproject(bearing1, depth), point1)); + EXPECT(assert_equal(camera.backproject(bearing2, depth), point2)); + EXPECT(assert_equal(camera.backproject(bearing3, depth), point3)); + EXPECT(assert_equal(camera.backproject(bearing4, depth), point4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject2) { + Point3 origin(0, 0, 0); + Rot3 rot(1., 0., 0., 0., 0., 1., 0., -1., 0.); // a camera1 looking down + Camera camera(Pose3(rot, origin)); + + Point3 actual = camera.backproject(Unit3(0, 0, 1), 1.); + Point3 expected(0., 1., 0.); + pair x = camera.projectSafe(expected); + + EXPECT(assert_equal(expected, actual)); + EXPECT(assert_equal(Unit3(0, 0, 1), x.first)); + EXPECT(x.second); +} + +/* ************************************************************************* */ +static Unit3 project3(const Pose3& pose, const Point3& point) { + return Camera(pose).project(point); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, Dproject) { + Matrix Dpose, Dpoint; + Unit3 result = camera.project(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose, point1); + Matrix numerical_point = numericalDerivative22(project3, pose, point1); + EXPECT(assert_equal(bearing1, result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +static Vector2 reprojectionError2(const Pose3& pose, const Point3& point, + const Unit3& measured) { + return Camera(pose).reprojectionError(point, measured); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError) { + Matrix Dpose, Dpoint; + Vector2 result = camera.reprojectionError(point1, bearing1, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing1); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing1); + EXPECT(assert_equal(Vector2(0.0, 0.0), result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError_noisy) { + Matrix Dpose, Dpoint; + Unit3 bearing_noisy = bearing1.retract(Vector2(0.01, 0.05)); + Vector2 result = + camera.reprojectionError(point1, bearing_noisy, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing_noisy); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing_noisy); + EXPECT(assert_equal(Vector2(-0.050282, 0.00833482), result, 1e-5)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +// Add a test with more arbitrary rotation +TEST(SphericalCamera, Dproject2) { + static const Pose3 pose1(Rot3::Ypr(0.1, -0.1, 0.4), Point3(0, 0, -10)); + static const Camera camera(pose1); + Matrix Dpose, Dpoint; + camera.project2(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose1, point1); + Matrix numerical_point = numericalDerivative22(project3, pose1, point1); + CHECK(assert_equal(numerical_pose, Dpose, 1e-7)); + CHECK(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/geometry/tests/testTriangulation.cpp b/gtsam/geometry/tests/testTriangulation.cpp index 4f71a48da..fb66fb6a2 100644 --- a/gtsam/geometry/tests/testTriangulation.cpp +++ b/gtsam/geometry/tests/testTriangulation.cpp @@ -10,22 +10,24 @@ * -------------------------------------------------------------------------- */ /** - * testTriangulation.cpp - * - * Created on: July 30th, 2013 - * Author: cbeall3 + * @file testTriangulation.cpp + * @brief triangulation utilities + * @date July 30th, 2013 + * @author Chris Beall (cbeall3) + * @author Luca Carlone */ -#include -#include -#include -#include -#include -#include -#include -#include #include - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -36,8 +38,8 @@ using namespace boost::assign; // Some common constants -static const boost::shared_ptr sharedCal = // - boost::make_shared(1500, 1200, 0, 640, 480); +static const boost::shared_ptr sharedCal = // + boost::make_shared(1500, 1200, 0.1, 640, 480); // Looking along X-axis, 1 meter above ground plane (x-y) static const Rot3 upright = Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2); @@ -57,8 +59,7 @@ Point2 z2 = camera2.project(landmark); //****************************************************************************** // Simple test with a well-behaved two camera situation -TEST( triangulation, twoPoses) { - +TEST(triangulation, twoPoses) { vector poses; Point2Vector measurements; @@ -69,37 +70,149 @@ TEST( triangulation, twoPoses) { // 1. Test simple DLT, perfect in no noise situation bool optimize = false; - boost::optional actual1 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual1 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual1, 1e-7)); // 2. test with optimization on, same answer optimize = true; - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual2, 1e-7)); - // 3. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 3. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); optimize = false; - boost::optional actual3 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual3 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4)); // 4. Now with optimization on optimize = true; - boost::optional actual4 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual4 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4)); } //****************************************************************************** -// Similar, but now with Bundler calibration -TEST( triangulation, twoPosesBundler) { +// Simple test with a well-behaved two camera situation with Cal3DS2 calibration. +TEST(triangulation, twoPosesCal3DS2) { + static const boost::shared_ptr sharedDistortedCal = // + boost::make_shared(1500, 1200, 0, 640, 480, -.3, 0.1, 0.0001, + -0.0003); - boost::shared_ptr bundlerCal = // - boost::make_shared(1500, 0, 0, 640, 480); + PinholeCamera camera1Distorted(pose1, *sharedDistortedCal); + + PinholeCamera camera2Distorted(pose2, *sharedDistortedCal); + + // 0. Project two landmarks into two cameras and triangulate + Point2 z1Distorted = camera1Distorted.project(landmark); + Point2 z2Distorted = camera2Distorted.project(landmark); + + vector poses; + Point2Vector measurements; + + poses += pose1, pose2; + measurements += z1Distorted, z2Distorted; + + double rank_tol = 1e-9; + + // 1. Test simple DLT, perfect in no noise situation + bool optimize = false; + boost::optional actual1 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(landmark, *actual1, 1e-7)); + + // 2. test with optimization on, same answer + optimize = true; + boost::optional actual2 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(landmark, *actual2, 1e-7)); + + // 3. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) + measurements.at(0) += Point2(0.1, 0.5); + measurements.at(1) += Point2(-0.2, 0.3); + optimize = false; + boost::optional actual3 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3)); + + // 4. Now with optimization on + optimize = true; + boost::optional actual4 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3)); +} + +//****************************************************************************** +// Simple test with a well-behaved two camera situation with Fisheye +// calibration. +TEST(triangulation, twoPosesFisheye) { + using Calibration = Cal3Fisheye; + static const boost::shared_ptr sharedDistortedCal = // + boost::make_shared(1500, 1200, .1, 640, 480, -.3, 0.1, + 0.0001, -0.0003); + + PinholeCamera camera1Distorted(pose1, *sharedDistortedCal); + + PinholeCamera camera2Distorted(pose2, *sharedDistortedCal); + + // 0. Project two landmarks into two cameras and triangulate + Point2 z1Distorted = camera1Distorted.project(landmark); + Point2 z2Distorted = camera2Distorted.project(landmark); + + vector poses; + Point2Vector measurements; + + poses += pose1, pose2; + measurements += z1Distorted, z2Distorted; + + double rank_tol = 1e-9; + + // 1. Test simple DLT, perfect in no noise situation + bool optimize = false; + boost::optional actual1 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(landmark, *actual1, 1e-7)); + + // 2. test with optimization on, same answer + optimize = true; + boost::optional actual2 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(landmark, *actual2, 1e-7)); + + // 3. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) + measurements.at(0) += Point2(0.1, 0.5); + measurements.at(1) += Point2(-0.2, 0.3); + optimize = false; + boost::optional actual3 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-3)); + + // 4. Now with optimization on + optimize = true; + boost::optional actual4 = // + triangulatePoint3(poses, sharedDistortedCal, measurements, + rank_tol, optimize); + EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-3)); +} + +//****************************************************************************** +// Similar, but now with Bundler calibration +TEST(triangulation, twoPosesBundler) { + boost::shared_ptr bundlerCal = // + boost::make_shared(1500, 0.1, 0.2, 640, 480); PinholeCamera camera1(pose1, *bundlerCal); PinholeCamera camera2(pose2, *bundlerCal); @@ -116,37 +229,40 @@ TEST( triangulation, twoPosesBundler) { bool optimize = true; double rank_tol = 1e-9; - boost::optional actual = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); + boost::optional actual = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, + optimize); EXPECT(assert_equal(landmark, *actual, 1e-7)); // Add some noise and try again measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); - EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-4)); + boost::optional actual2 = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-3)); } //****************************************************************************** -TEST( triangulation, fourPoses) { +TEST(triangulation, fourPoses) { vector poses; Point2Vector measurements; poses += pose1, pose2; measurements += z1, z2; - boost::optional actual = triangulatePoint3(poses, sharedCal, - measurements); + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -157,13 +273,13 @@ TEST( triangulation, fourPoses) { poses += pose3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(poses, - sharedCal, measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -176,13 +292,101 @@ TEST( triangulation, fourPoses) { poses += pose4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, fourPoses_distinct_Ks) { +TEST(triangulation, threePoses_robustNoiseModel) { + + Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); + PinholeCamera camera3(pose3, *sharedCal); + Point2 z3 = camera3.project(landmark); + + vector poses; + Point2Vector measurements; + poses += pose1, pose2, pose3; + measurements += z1, z2, z3; + + // noise free, so should give exactly the landmark + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); + EXPECT(assert_equal(landmark, *actual, 1e-2)); + + // Add outlier + measurements.at(0) += Point2(100, 120); // very large pixel noise! + + // now estimate does not match landmark + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); + // DLT is surprisingly robust, but still off (actual error is around 0.26m): + EXPECT( (landmark - *actual2).norm() >= 0.2); + EXPECT( (landmark - *actual2).norm() <= 0.5); + + // Again with nonlinear optimization + boost::optional actual3 = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); + // result from nonlinear (but non-robust optimization) is close to DLT and still off + EXPECT(assert_equal(*actual2, *actual3, 0.1)); + + // Again with nonlinear optimization, this time with robust loss + auto model = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2)); + boost::optional actual4 = triangulatePoint3( + poses, sharedCal, measurements, 1e-9, true, model); + // using the Huber loss we now have a quite small error!! nice! + EXPECT(assert_equal(landmark, *actual4, 0.05)); +} + +//****************************************************************************** +TEST(triangulation, fourPoses_robustNoiseModel) { + + Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); + PinholeCamera camera3(pose3, *sharedCal); + Point2 z3 = camera3.project(landmark); + + vector poses; + Point2Vector measurements; + poses += pose1, pose1, pose2, pose3; // 2 measurements from pose 1 + measurements += z1, z1, z2, z3; + + // noise free, so should give exactly the landmark + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); + EXPECT(assert_equal(landmark, *actual, 1e-2)); + + // Add outlier + measurements.at(0) += Point2(100, 120); // very large pixel noise! + // add noise on other measurements: + measurements.at(1) += Point2(0.1, 0.2); // small noise + measurements.at(2) += Point2(0.2, 0.2); + measurements.at(3) += Point2(0.3, 0.1); + + // now estimate does not match landmark + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); + // DLT is surprisingly robust, but still off (actual error is around 0.17m): + EXPECT( (landmark - *actual2).norm() >= 0.1); + EXPECT( (landmark - *actual2).norm() <= 0.5); + + // Again with nonlinear optimization + boost::optional actual3 = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); + // result from nonlinear (but non-robust optimization) is close to DLT and still off + EXPECT(assert_equal(*actual2, *actual3, 0.1)); + + // Again with nonlinear optimization, this time with robust loss + auto model = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2)); + boost::optional actual4 = triangulatePoint3( + poses, sharedCal, measurements, 1e-9, true, model); + // using the Huber loss we now have a quite small error!! nice! + EXPECT(assert_equal(landmark, *actual4, 0.05)); +} + +//****************************************************************************** +TEST(triangulation, fourPoses_distinct_Ks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -195,22 +399,23 @@ TEST( triangulation, fourPoses_distinct_Ks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - boost::optional actual = // - triangulatePoint3(cameras, measurements); + boost::optional actual = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(cameras, measurements); + boost::optional actual2 = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -222,13 +427,13 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(cameras, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(cameras, - measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(cameras, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -241,13 +446,38 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, outliersAndFarLandmarks) { +TEST(triangulation, fourPoses_distinct_Ks_distortion) { + Cal3DS2 K1(1500, 1200, 0, 640, 480, -.3, 0.1, 0.0001, -0.0003); + // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) + PinholeCamera camera1(pose1, K1); + + // create second camera 1 meter to the right of first camera + Cal3DS2 K2(1600, 1300, 0, 650, 440, -.2, 0.05, 0.0002, -0.0001); + PinholeCamera camera2(pose2, K2); + + // 1. Project two landmarks into two cameras and triangulate + Point2 z1 = camera1.project(landmark); + Point2 z2 = camera2.project(landmark); + + CameraSet> cameras; + Point2Vector measurements; + + cameras += camera1, camera2; + measurements += z1, z2; + + boost::optional actual = // + triangulatePoint3(cameras, measurements); + EXPECT(assert_equal(landmark, *actual, 1e-2)); +} + +//****************************************************************************** +TEST(triangulation, outliersAndFarLandmarks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -260,24 +490,29 @@ TEST( triangulation, outliersAndFarLandmarks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - double landmarkDistanceThreshold = 10; // landmark is closer than that - TriangulationParameters params(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - TriangulationResult actual = triangulateSafe(cameras,measurements,params); + double landmarkDistanceThreshold = 10; // landmark is closer than that + TriangulationParameters params( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + TriangulationResult actual = triangulateSafe(cameras, measurements, params); EXPECT(assert_equal(landmark, *actual, 1e-2)); EXPECT(actual.valid()); - landmarkDistanceThreshold = 4; // landmark is farther than that - TriangulationParameters params2(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - actual = triangulateSafe(cameras,measurements,params2); + landmarkDistanceThreshold = 4; // landmark is farther than that + TriangulationParameters params2( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + actual = triangulateSafe(cameras, measurements, params2); EXPECT(actual.farPoint()); - // 3. Add a slightly rotated third camera above with a wrong measurement (OUTLIER) + // 3. Add a slightly rotated third camera above with a wrong measurement + // (OUTLIER) Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); Cal3_S2 K3(700, 500, 0, 640, 480); PinholeCamera camera3(pose3, K3); @@ -286,21 +521,23 @@ TEST( triangulation, outliersAndFarLandmarks) { cameras += camera3; measurements += z3 + Point2(10, -10); - landmarkDistanceThreshold = 10; // landmark is closer than that - double outlierThreshold = 100; // loose, the outlier is going to pass - TriangulationParameters params3(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params3); + landmarkDistanceThreshold = 10; // landmark is closer than that + double outlierThreshold = 100; // loose, the outlier is going to pass + TriangulationParameters params3(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params3); EXPECT(actual.valid()); // now set stricter threshold for outlier rejection - outlierThreshold = 5; // tighter, the outlier is not going to pass - TriangulationParameters params4(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params4); + outlierThreshold = 5; // tighter, the outlier is not going to pass + TriangulationParameters params4(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params4); EXPECT(actual.outlier()); } //****************************************************************************** -TEST( triangulation, twoIdenticalPoses) { +TEST(triangulation, twoIdenticalPoses) { // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, *sharedCal); @@ -313,12 +550,12 @@ TEST( triangulation, twoIdenticalPoses) { poses += pose1, pose1; measurements += z1, z1; - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, onePose) { +TEST(triangulation, onePose) { // we expect this test to fail with a TriangulationUnderconstrainedException // because there's only one camera observation @@ -326,28 +563,26 @@ TEST( triangulation, onePose) { Point2Vector measurements; poses += Pose3(); - measurements += Point2(0,0); + measurements += Point2(0, 0); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, StereotriangulateNonlinear ) { - - auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, 508.835, 0.0699612); +TEST(triangulation, StereotriangulateNonlinear) { + auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, + 508.835, 0.0699612); // two camera poses m1, m2 Matrix4 m1, m2; - m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, - 0.592783835, -0.77156583, 0.230856632, 66.2186159, - 0.116517574, -0.201470143, -0.9725393, -4.28382528, - 0, 0, 0, 1; + m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835, + -0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143, + -0.9725393, -4.28382528, 0, 0, 0, 1; - m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, - -0.29277519, 0.947083213, 0.131587097, 65.843136, - -0.0206094928, 0.131334858, -0.991123524, -4.3525033, - 0, 0, 0, 1; + m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, -0.29277519, + 0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858, + -0.991123524, -4.3525033, 0, 0, 0, 1; typedef CameraSet Cameras; Cameras cameras; @@ -358,18 +593,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { measurements += StereoPoint2(226.936, 175.212, 424.469); measurements += StereoPoint2(339.571, 285.547, 669.973); - Point3 initial = Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 + Point3 initial = + Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 - Point3 actual = triangulateNonlinear(cameras, measurements, initial); + Point3 actual = triangulateNonlinear(cameras, measurements, initial); - Point3 expected(46.0484569, 66.4710686, -6.55046613); // error: 0.763510644187 + Point3 expected(46.0484569, 66.4710686, + -6.55046613); // error: 0.763510644187 EXPECT(assert_equal(expected, actual, 1e-4)); - // regular stereo factor comparison - expect very similar result as above { - typedef GenericStereoFactor StereoFactor; + typedef GenericStereoFactor StereoFactor; Values values; values.insert(Symbol('x', 1), Pose3(m1)); @@ -378,17 +614,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared(measurements[0], unit, Symbol('x',1), Symbol('l',1), stereoK); - graph.emplace_shared(measurements[1], unit, Symbol('x',2), Symbol('l',1), stereoK); + graph.emplace_shared(measurements[0], unit, Symbol('x', 1), + Symbol('l', 1), stereoK); + graph.emplace_shared(measurements[1], unit, Symbol('x', 2), + Symbol('l', 1), stereoK); const SharedDiagonal posePrior = noiseModel::Isotropic::Sigma(6, 1e-9); - graph.addPrior(Symbol('x',1), Pose3(m1), posePrior); - graph.addPrior(Symbol('x',2), Pose3(m2), posePrior); + graph.addPrior(Symbol('x', 1), Pose3(m1), posePrior); + graph.addPrior(Symbol('x', 2), Pose3(m2), posePrior); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use Triangulation Factor directly - expect same result as above @@ -399,13 +637,15 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared >(cameras[0], measurements[0], unit, Symbol('l',1)); - graph.emplace_shared >(cameras[1], measurements[1], unit, Symbol('l',1)); + graph.emplace_shared>( + cameras[0], measurements[0], unit, Symbol('l', 1)); + graph.emplace_shared>( + cameras[1], measurements[1], unit, Symbol('l', 1)); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use ExpressionFactor - expect same result as above @@ -416,11 +656,13 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - Expression point_(Symbol('l',1)); + Expression point_(Symbol('l', 1)); Expression camera0_(cameras[0]); Expression camera1_(cameras[1]); - Expression project0_(camera0_, &StereoCamera::project2, point_); - Expression project1_(camera1_, &StereoCamera::project2, point_); + Expression project0_(camera0_, &StereoCamera::project2, + point_); + Expression project1_(camera1_, &StereoCamera::project2, + point_); graph.addExpressionFactor(unit, measurements[0], project0_); graph.addExpressionFactor(unit, measurements[1], project1_); @@ -428,10 +670,172 @@ TEST( triangulation, StereotriangulateNonlinear ) { LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } } +//****************************************************************************** +// Simple test with a well-behaved two camera situation +TEST(triangulation, twoPoses_sphericalCamera) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + SphericalCamera cam1(pose1); + SphericalCamera cam2(pose2); + Unit3 u1 = cam1.project(landmark); + Unit3 u2 = cam2.project(landmark); + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + // 1. Test linear triangulation via DLT + auto projection_matrices = projectionMatricesFromCameras(cameras); + Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 2. Test nonlinear triangulation + point = triangulateNonlinear(cameras, measurements, point); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 3. Test simple DLT, now within triangulatePoint3 + bool optimize = false; + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual1, 1e-7)); + + // 4. test with optimization on, same answer + optimize = true; + boost::optional actual2 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual2, 1e-7)); + + // 5. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) + measurements.at(0) = + u1.retract(Vector2(0.01, 0.05)); // note: perturbation smaller for Unit3 + measurements.at(1) = u2.retract(Vector2(-0.02, 0.03)); + optimize = false; + boost::optional actual3 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654319, 1.48192), *actual3, 1e-3)); + + // 6. Now with optimization on + optimize = true; + boost::optional actual4 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654334, 1.48192), *actual4, 1e-3)); +} + +//****************************************************************************** +TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(2.0, 0.0, 0.0)); // 2m in front of poseA + Point3 landmarkL( + 1.0, -1.0, + 0.0); // 1m to the right of both cameras, in front of poseA, behind poseB + SphericalCamera cam1(poseA); + SphericalCamera cam2(poseB); + Unit3 u1 = cam1.project(landmarkL); + Unit3 u2 = cam2.project(landmarkL); + + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, 1.0)), u1, + 1e-7)); // in front and to the right of PoseA + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2, + 1e-7)); // behind and to the right of PoseB + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + { + // 1. Test simple DLT, when 1 point is behind spherical camera + bool optimize = false; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } + { + // 2. test with optimization on, same answer + bool optimize = true; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } +} + +//****************************************************************************** +TEST(triangulation, reprojectionError_cameraComparison) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Point3 landmarkL(5.0, 0.0, 0.0); // 1m in front of poseA + SphericalCamera sphericalCamera(poseA); + Unit3 u = sphericalCamera.project(landmarkL); + + static Cal3_S2::shared_ptr sharedK(new Cal3_S2(60, 640, 480)); + PinholePose pinholeCamera(poseA, sharedK); + Vector2 px = pinholeCamera.project(landmarkL); + + // add perturbation and compare error in both cameras + Vector2 px_noise(1.0, 2.0); // px perturbation vertically and horizontally + Vector2 measured_px = px + px_noise; + Vector2 measured_px_calibrated = sharedK->calibrate(measured_px); + Unit3 measured_u = + Unit3(measured_px_calibrated[0], measured_px_calibrated[1], 1.0); + Unit3 expected_measured_u = + Unit3(px_noise[0] / sharedK->fx(), px_noise[1] / sharedK->fy(), 1.0); + EXPECT(assert_equal(expected_measured_u, measured_u, 1e-7)); + + Vector2 actualErrorPinhole = + pinholeCamera.reprojectionError(landmarkL, measured_px); + Vector2 expectedErrorPinhole = Vector2(-px_noise[0], -px_noise[1]); + EXPECT(assert_equal(expectedErrorPinhole, actualErrorPinhole, + 1e-7)); //- sign due to definition of error + + Vector2 actualErrorSpherical = + sphericalCamera.reprojectionError(landmarkL, measured_u); + // expectedError: not easy to calculate, since it involves the unit3 basis + Vector2 expectedErrorSpherical(-0.00360842, 0.00180419); + EXPECT(assert_equal(expectedErrorSpherical, actualErrorSpherical, 1e-7)); +} + //****************************************************************************** int main() { TestResult tr; diff --git a/gtsam/geometry/triangulation.cpp b/gtsam/geometry/triangulation.cpp index a5d2e04cd..026afef24 100644 --- a/gtsam/geometry/triangulation.cpp +++ b/gtsam/geometry/triangulation.cpp @@ -53,15 +53,57 @@ Vector4 triangulateHomogeneousDLT( return v; } -Point3 triangulateDLT(const std::vector>& projection_matrices, +Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // number of cameras + size_t m = projection_matrices.size(); + + // Allocate DLT matrix + Matrix A = Matrix::Zero(m * 2, 4); + + for (size_t i = 0; i < m; i++) { + size_t row = i * 2; + const Matrix34& projection = projection_matrices.at(i); + const Point3& p = measurements.at(i).point3(); // to get access to x,y,z of the bearing vector + + // build system of equations + A.row(row) = p.x() * projection.row(2) - p.z() * projection.row(0); + A.row(row + 1) = p.y() * projection.row(2) - p.z() * projection.row(1); + } + int rank; + double error; + Vector v; + boost::tie(rank, error, v) = DLT(A, rank_tol); + + if (rank < 3) + throw(TriangulationUnderconstrainedException()); + + return v; +} + +Point3 triangulateDLT( + const std::vector>& projection_matrices, const Point2Vector& measurements, double rank_tol) { - Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, rank_tol); - + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); // Create 3D point from homogeneous coordinates return Point3(v.head<3>() / v[3]); } +Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // contrary to previous triangulateDLT, this is now taking Unit3 inputs + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); + // Create 3D point from homogeneous coordinates + return Point3(v.head<3>() / v[3]); +} + /// /** * Optimize for triangulation @@ -71,7 +113,7 @@ Point3 triangulateDLT(const std::vector #include #include +#include #include #include #include @@ -59,6 +61,18 @@ GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( const std::vector>& projection_matrices, const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * Same math as Hartley and Zisserman, 2nd Ed., page 312, but with unit-norm bearing vectors + * (contrarily to pinhole projection, the z entry is not assumed to be 1 as in Hartley and Zisserman) + * @param projection_matrices Projection matrices (K*P^-1) + * @param measurements Unit3 bearing measurements + * @param rank_tol SVD rank tolerance + * @return Triangulated point, in homogeneous coordinates + */ +GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol = 1e-9); + /** * DLT triangulation: See Hartley and Zisserman, 2nd Ed., page 312 * @param projection_matrices Projection matrices (K*P^-1) @@ -71,6 +85,14 @@ GTSAM_EXPORT Point3 triangulateDLT( const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * overload of previous function to work with Unit3 (projected to canonical camera) + */ +GTSAM_EXPORT Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, + double rank_tol = 1e-9); + /** * Create a factor graph with projection factors from poses and one calibration * @param poses Camera poses @@ -84,18 +106,18 @@ template std::pair triangulationGraph( const std::vector& poses, boost::shared_ptr sharedCal, const Point2Vector& measurements, Key landmarkKey, - const Point3& initialEstimate) { + const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { Values values; values.insert(landmarkKey, initialEstimate); // Initial landmark value NonlinearFactorGraph graph; static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); - static SharedNoiseModel prior_model(noiseModel::Isotropic::Sigma(6, 1e-6)); for (size_t i = 0; i < measurements.size(); i++) { const Pose3& pose_i = poses[i]; typedef PinholePose Camera; Camera camera_i(pose_i, sharedCal); graph.emplace_shared > // - (camera_i, measurements[i], unit2, landmarkKey); + (camera_i, measurements[i], model? model : unit2, landmarkKey); } return std::make_pair(graph, values); } @@ -113,7 +135,8 @@ template std::pair triangulationGraph( const CameraSet& cameras, const typename CAMERA::MeasurementVector& measurements, Key landmarkKey, - const Point3& initialEstimate) { + const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { Values values; values.insert(landmarkKey, initialEstimate); // Initial landmark value NonlinearFactorGraph graph; @@ -122,7 +145,7 @@ std::pair triangulationGraph( for (size_t i = 0; i < measurements.size(); i++) { const CAMERA& camera_i = cameras[i]; graph.emplace_shared > // - (camera_i, measurements[i], unit, landmarkKey); + (camera_i, measurements[i], model? model : unit, landmarkKey); } return std::make_pair(graph, values); } @@ -148,13 +171,14 @@ GTSAM_EXPORT Point3 optimize(const NonlinearFactorGraph& graph, template Point3 triangulateNonlinear(const std::vector& poses, boost::shared_ptr sharedCal, - const Point2Vector& measurements, const Point3& initialEstimate) { + const Point2Vector& measurements, const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { // Create a factor graph and initial values Values values; NonlinearFactorGraph graph; boost::tie(graph, values) = triangulationGraph // - (poses, sharedCal, measurements, Symbol('p', 0), initialEstimate); + (poses, sharedCal, measurements, Symbol('p', 0), initialEstimate, model); return optimize(graph, values, Symbol('p', 0)); } @@ -169,37 +193,142 @@ Point3 triangulateNonlinear(const std::vector& poses, template Point3 triangulateNonlinear( const CameraSet& cameras, - const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate) { + const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { // Create a factor graph and initial values Values values; NonlinearFactorGraph graph; boost::tie(graph, values) = triangulationGraph // - (cameras, measurements, Symbol('p', 0), initialEstimate); + (cameras, measurements, Symbol('p', 0), initialEstimate, model); return optimize(graph, values, Symbol('p', 0)); } -/** - * Create a 3*4 camera projection matrix from calibration and pose. - * Functor for partial application on calibration - * @param pose The camera pose - * @param cal The calibration - * @return Returns a Matrix34 - */ +template +std::vector> +projectionMatricesFromCameras(const CameraSet &cameras) { + std::vector> projection_matrices; + for (const CAMERA &camera: cameras) { + projection_matrices.push_back(camera.cameraProjectionMatrix()); + } + return projection_matrices; +} + +// overload, assuming pinholePose template -struct CameraProjectionMatrix { - CameraProjectionMatrix(const CALIBRATION& calibration) : - K_(calibration.K()) { +std::vector> projectionMatricesFromPoses( + const std::vector &poses, boost::shared_ptr sharedCal) { + std::vector> projection_matrices; + for (size_t i = 0; i < poses.size(); i++) { + PinholePose camera(poses.at(i), sharedCal); + projection_matrices.push_back(camera.cameraProjectionMatrix()); } - Matrix34 operator()(const Pose3& pose) const { - return K_ * (pose.inverse().matrix()).block<3, 4>(0, 0); + return projection_matrices; +} + +/** Create a pinhole calibration from a different Cal3 object, removing + * distortion. + * + * @tparam CALIBRATION Original calibration object. + * @param cal Input calibration object. + * @return Cal3_S2 with only the pinhole elements of cal. + */ +template +Cal3_S2 createPinholeCalibration(const CALIBRATION& cal) { + const auto& K = cal.K(); + return Cal3_S2(K(0, 0), K(1, 1), K(0, 1), K(0, 2), K(1, 2)); +} + +/** Internal undistortMeasurement to be used by undistortMeasurement and + * undistortMeasurements */ +template +MEASUREMENT undistortMeasurementInternal( + const CALIBRATION& cal, const MEASUREMENT& measurement, + boost::optional pinholeCal = boost::none) { + if (!pinholeCal) { + pinholeCal = createPinholeCalibration(cal); } -private: - const Matrix3 K_; -public: - GTSAM_MAKE_ALIGNED_OPERATOR_NEW -}; + return pinholeCal->uncalibrate(cal.calibrate(measurement)); +} + +/** Remove distortion for measurements so as if the measurements came from a + * pinhole camera. + * + * Removes distortion but maintains the K matrix of the initial cal. Operates by + * calibrating using full calibration and uncalibrating with only the pinhole + * component of the calibration. + * @tparam CALIBRATION Calibration type to use. + * @param cal Calibration with which measurements were taken. + * @param measurements Vector of measurements to undistort. + * @return measurements with the effect of the distortion of sharedCal removed. + */ +template +Point2Vector undistortMeasurements(const CALIBRATION& cal, + const Point2Vector& measurements) { + Cal3_S2 pinholeCalibration = createPinholeCalibration(cal); + Point2Vector undistortedMeasurements; + // Calibrate with cal and uncalibrate with pinhole version of cal so that + // measurements are undistorted. + std::transform(measurements.begin(), measurements.end(), + std::back_inserter(undistortedMeasurements), + [&cal, &pinholeCalibration](const Point2& measurement) { + return undistortMeasurementInternal( + cal, measurement, pinholeCalibration); + }); + return undistortedMeasurements; +} + +/** Specialization for Cal3_S2 as it doesn't need to be undistorted. */ +template <> +inline Point2Vector undistortMeasurements(const Cal3_S2& cal, + const Point2Vector& measurements) { + return measurements; +} + +/** Remove distortion for measurements so as if the measurements came from a + * pinhole camera. + * + * Removes distortion but maintains the K matrix of the initial calibrations. + * Operates by calibrating using full calibration and uncalibrating with only + * the pinhole component of the calibration. + * @tparam CAMERA Camera type to use. + * @param cameras Cameras corresponding to each measurement. + * @param measurements Vector of measurements to undistort. + * @return measurements with the effect of the distortion of the camera removed. + */ +template +typename CAMERA::MeasurementVector undistortMeasurements( + const CameraSet& cameras, + const typename CAMERA::MeasurementVector& measurements) { + const size_t num_meas = cameras.size(); + assert(num_meas == measurements.size()); + typename CAMERA::MeasurementVector undistortedMeasurements(num_meas); + for (size_t ii = 0; ii < num_meas; ++ii) { + // Calibrate with cal and uncalibrate with pinhole version of cal so that + // measurements are undistorted. + undistortedMeasurements[ii] = + undistortMeasurementInternal( + cameras[ii].calibration(), measurements[ii]); + } + return undistortedMeasurements; +} + +/** Specialize for Cal3_S2 to do nothing. */ +template > +inline PinholeCamera::MeasurementVector undistortMeasurements( + const CameraSet>& cameras, + const PinholeCamera::MeasurementVector& measurements) { + return measurements; +} + +/** Specialize for SphericalCamera to do nothing. */ +template +inline SphericalCamera::MeasurementVector undistortMeasurements( + const CameraSet& cameras, + const SphericalCamera::MeasurementVector& measurements) { + return measurements; +} /** * Function to triangulate 3D landmark point from an arbitrary number @@ -217,25 +346,28 @@ template Point3 triangulatePoint3(const std::vector& poses, boost::shared_ptr sharedCal, const Point2Vector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { assert(poses.size() == measurements.size()); if (poses.size() < 2) throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - CameraProjectionMatrix createP(*sharedCal); // partially apply - for(const Pose3& pose: poses) - projection_matrices.push_back(createP(pose)); + auto projection_matrices = projectionMatricesFromPoses(poses, sharedCal); + + // Undistort the measurements, leaving only the pinhole elements in effect. + auto undistortedMeasurements = + undistortMeasurements(*sharedCal, measurements); // Triangulate linearly - Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); + Point3 point = + triangulateDLT(projection_matrices, undistortedMeasurements, rank_tol); // Then refine using non-linear optimization if (optimize) point = triangulateNonlinear // - (poses, sharedCal, measurements, point); + (poses, sharedCal, measurements, point, model); #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION // verify that the triangulated point lies in front of all cameras @@ -265,7 +397,8 @@ template Point3 triangulatePoint3( const CameraSet& cameras, const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { size_t m = cameras.size(); assert(measurements.size() == m); @@ -274,16 +407,18 @@ Point3 triangulatePoint3( throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - for(const CAMERA& camera: cameras) - projection_matrices.push_back( - CameraProjectionMatrix(camera.calibration())( - camera.pose())); - Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); + auto projection_matrices = projectionMatricesFromCameras(cameras); + + // Undistort the measurements, leaving only the pinhole elements in effect. + auto undistortedMeasurements = + undistortMeasurements(cameras, measurements); + + Point3 point = + triangulateDLT(projection_matrices, undistortedMeasurements, rank_tol); // The n refine using non-linear optimization if (optimize) - point = triangulateNonlinear(cameras, measurements, point); + point = triangulateNonlinear(cameras, measurements, point, model); #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION // verify that the triangulated point lies in front of all cameras @@ -302,9 +437,10 @@ template Point3 triangulatePoint3( const CameraSet >& cameras, const Point2Vector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { return triangulatePoint3 > // - (cameras, measurements, rank_tol, optimize); + (cameras, measurements, rank_tol, optimize, model); } struct GTSAM_EXPORT TriangulationParameters { @@ -326,20 +462,25 @@ struct GTSAM_EXPORT TriangulationParameters { */ double dynamicOutlierRejectionThreshold; + SharedNoiseModel noiseModel; ///< used in the nonlinear triangulation + /** * Constructor * @param rankTol tolerance used to check if point triangulation is degenerate * @param enableEPI if true refine triangulation with embedded LM iterations * @param landmarkDistanceThreshold flag as degenerate if point further than this * @param dynamicOutlierRejectionThreshold or if average error larger than this + * @param noiseModel noise model to use during nonlinear triangulation * */ TriangulationParameters(const double _rankTolerance = 1.0, const bool _enableEPI = false, double _landmarkDistanceThreshold = -1, - double _dynamicOutlierRejectionThreshold = -1) : + double _dynamicOutlierRejectionThreshold = -1, + const SharedNoiseModel& _noiseModel = nullptr) : rankTolerance(_rankTolerance), enableEPI(_enableEPI), // landmarkDistanceThreshold(_landmarkDistanceThreshold), // - dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold) { + dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold), + noiseModel(_noiseModel){ } // stream to output @@ -351,6 +492,7 @@ struct GTSAM_EXPORT TriangulationParameters { << std::endl; os << "dynamicOutlierRejectionThreshold = " << p.dynamicOutlierRejectionThreshold << std::endl; + os << "noise model" << std::endl; return os; } @@ -453,8 +595,9 @@ TriangulationResult triangulateSafe(const CameraSet& cameras, else // We triangulate the 3D position of the landmark try { - Point3 point = triangulatePoint3(cameras, measured, - params.rankTolerance, params.enableEPI); + Point3 point = + triangulatePoint3(cameras, measured, params.rankTolerance, + params.enableEPI, params.noiseModel); // Check landmark distance and re-projection errors to avoid outliers size_t i = 0; @@ -474,8 +617,8 @@ TriangulationResult triangulateSafe(const CameraSet& cameras, #endif // Check reprojection error if (params.dynamicOutlierRejectionThreshold > 0) { - const Point2& zi = measured.at(i); - Point2 reprojectionError(camera.project(point) - zi); + const typename CAMERA::Measurement& zi = measured.at(i); + Point2 reprojectionError = camera.reprojectionError(point, zi); maxReprojError = std::max(maxReprojError, reprojectionError.norm()); } i += 1; @@ -503,6 +646,6 @@ using CameraSetCal3Bundler = CameraSet>; using CameraSetCal3_S2 = CameraSet>; using CameraSetCal3Fisheye = CameraSet>; using CameraSetCal3Unified = CameraSet>; - +using CameraSetSpherical = CameraSet; } // \namespace gtsam diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 67c3278a3..d4e959c3d 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -39,9 +39,6 @@ class KeyList { void remove(size_t key); void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastSet @@ -67,9 +64,6 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a vector @@ -91,9 +85,6 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastMap @@ -165,6 +156,7 @@ gtsam::Values allPose2s(gtsam::Values& values); Matrix extractPose2(const gtsam::Values& values); gtsam::Values allPose3s(gtsam::Values& values); Matrix extractPose3(const gtsam::Values& values); +Matrix extractVectors(const gtsam::Values& values, char c); void perturbPoint2(gtsam::Values& values, double sigma, int seed = 42u); void perturbPose2(gtsam::Values& values, double sigmaT, double sigmaR, int seed = 42u); diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index a73762258..afde5498d 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -10,46 +10,76 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include #include +#include #include #include +#include namespace gtsam { /* ************************************************************************* */ template -void BayesNet::print( - const std::string& s, const KeyFormatter& formatter) const { +void BayesNet::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } /* ************************************************************************* */ template -void BayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; +void BayesNet::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.digraphPreamble(&os); + // Create nodes for each variable in the graph + for (Key key : this->keys()) { + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); + } + os << "\n"; + + // Reverse order as typically Bayes nets stored in reverse topological sort. for (auto conditional : boost::adaptors::reverse(*this)) { - typename CONDITIONAL::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename CONDITIONAL::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; + auto frontals = conditional->frontals(); + const Key me = frontals.front(); + auto parents = conditional->parents(); + for (const Key& p : parents) + os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n"; } - of << "}"; + os << "}"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string BayesNet::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::stringstream ss; + dot(ss, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +template +void BayesNet::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter, writer); of.close(); } +/* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 938278d5a..219864c54 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -10,67 +10,79 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include - #include +#include +#include + namespace gtsam { - /** - * A BayesNet is a tree of conditionals, stored in elimination order. - * - * todo: how to handle Bayes nets with an optimize function? Currently using global functions. - * \nosubgrouping - */ - template - class BayesNet : public FactorGraph { +/** + * A BayesNet is a tree of conditionals, stored in elimination order. + * @addtogroup inference + */ +template +class BayesNet : public FactorGraph { + private: + typedef FactorGraph Base; - private: + public: + typedef typename boost::shared_ptr + sharedConditional; ///< A shared pointer to a conditional - typedef FactorGraph Base; + protected: + /// @name Standard Constructors + /// @{ - public: - typedef typename boost::shared_ptr sharedConditional; ///< A shared pointer to a conditional + /** Default constructor as an empty BayesNet */ + BayesNet() {} - protected: - /// @name Standard Constructors - /// @{ + /** Construct from iterator over conditionals */ + template + BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} - /** Default constructor as an empty BayesNet */ - BayesNet() {}; + /// @} - /** Construct from iterator over conditionals */ - template - BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + public: + /// @name Testable + /// @{ - /// @} + /** print out graph */ + void print( + const std::string& s = "BayesNet", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - public: - /// @name Testable - /// @{ + /// @} - /** print out graph */ - void print( - const std::string& s = "BayesNet", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /// @name Graph Display + /// @{ - /// @} + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// @name Standard Interface - /// @{ + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - void saveGraph(const std::string& s, - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - }; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; -} + /// @} +}; + +} // namespace gtsam #include diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 5b53a5719..b341c1d5a 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -63,20 +63,40 @@ namespace gtsam { } /* ************************************************************************* */ - template - void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { - if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); - std::ofstream of(s.c_str()); - of<< "digraph G{\n"; - for(const sharedClique& root: roots_) - saveGraph(of, root, keyFormatter); - of<<"}"; + template + void BayesTree::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + if (roots_.empty()) + throw std::invalid_argument( + "the root of Bayes tree has not been initialized!"); + os << "digraph G{\n"; + for (const sharedClique& root : roots_) dot(os, root, keyFormatter); + os << "}"; + std::flush(os); + } + + /* ************************************************************************* */ + template + std::string BayesTree::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } /* ************************************************************************* */ - template - void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { + template + void BayesTree::dot(std::ostream& s, sharedClique clique, + const KeyFormatter& keyFormatter, + int parentnum) const { static int num = 0; bool first = true; std::stringstream out; @@ -84,10 +104,10 @@ namespace gtsam { std::string parent = out.str(); parent += "[label=\""; - for (Key index : clique->conditional_->frontals()) { - if (!first) parent += ","; + for (Key key : clique->conditional_->frontals()) { + if (!first) parent += ", "; first = false; - parent += indexFormatter(index); + parent += keyFormatter(key); } if (clique->parent()) { @@ -96,10 +116,10 @@ namespace gtsam { } first = true; - for (Key sep : clique->conditional_->parents()) { - if (!first) parent += ","; + for (Key parentKey : clique->conditional_->parents()) { + if (!first) parent += ", "; first = false; - parent += indexFormatter(sep); + parent += keyFormatter(parentKey); } parent += "\"];\n"; s << parent; @@ -107,7 +127,7 @@ namespace gtsam { for (sharedClique c : clique->children) { num++; - saveGraph(s, c, indexFormatter, parentnum); + dot(s, c, keyFormatter, parentnum); } } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 7199da0ad..5b053ebee 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -19,6 +19,8 @@ #pragma once +#include + #include #include #include @@ -137,11 +139,11 @@ namespace gtsam { return nodes_.empty(); } - /** return nodes */ + /** Return nodes. Each node is a clique of variables obtained after elimination. */ const Nodes& nodes() const { return nodes_; } /** Access node by variable */ - const sharedNode operator[](Key j) const { return nodes_.at(j); } + sharedClique operator[](Key j) const { return nodes_.at(j); } /** return root cliques */ const Roots& roots() const { return roots_; } @@ -180,13 +182,20 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /** - * Read only with side effects - */ + /// @name Graph Display + /// @{ - /** saves the Tree to a text file in GraphViz format */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ @@ -234,8 +243,8 @@ namespace gtsam { protected: /** private helper method for saving the Tree to a text file in GraphViz format */ - void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, - int parentnum = 0) const; + void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, + int parentnum = 0) const; /** Gather data on a single clique */ void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const; @@ -247,7 +256,7 @@ namespace gtsam { void fillNodesIndex(const sharedClique& subtree); // Friend JunctionTree because it directly fills roots and nodes index. - template friend class EliminatableClusterTree; + template friend class EliminatableClusterTree; private: /** Serialization function */ diff --git a/gtsam/inference/ClusterTree-inst.h b/gtsam/inference/ClusterTree-inst.h index b042c0c8e..9bc141955 100644 --- a/gtsam/inference/ClusterTree-inst.h +++ b/gtsam/inference/ClusterTree-inst.h @@ -15,6 +15,10 @@ #include #include +#ifdef GTSAM_USE_TBB +#include +#endif + namespace gtsam { /* ************************************************************************* */ @@ -120,12 +124,25 @@ struct EliminationData { size_t myIndexInParent; FastVector childFactors; boost::shared_ptr bayesTreeNode; +#ifdef GTSAM_USE_TBB + boost::shared_ptr writeLock; +#endif EliminationData(EliminationData* _parentData, size_t nChildren) : - parentData(_parentData), bayesTreeNode(boost::make_shared()) { + parentData(_parentData), bayesTreeNode(boost::make_shared()) +#ifdef GTSAM_USE_TBB + , writeLock(boost::make_shared()) +#endif + { if (parentData) { +#ifdef GTSAM_USE_TBB + parentData->writeLock->lock(); +#endif myIndexInParent = parentData->childFactors.size(); parentData->childFactors.push_back(sharedFactor()); +#ifdef GTSAM_USE_TBB + parentData->writeLock->unlock(); +#endif } else { myIndexInParent = 0; } @@ -196,8 +213,15 @@ struct EliminationData { nodesIndex_.insert(std::make_pair(j, myData.bayesTreeNode)); // Store remaining factor in parent's gathered factors - if (!eliminationResult.second->empty()) + if (!eliminationResult.second->empty()) { +#ifdef GTSAM_USE_TBB + myData.parentData->writeLock->lock(); +#endif myData.parentData->childFactors[myData.myIndexInParent] = eliminationResult.second; +#ifdef GTSAM_USE_TBB + myData.parentData->writeLock->unlock(); +#endif + } } }; }; diff --git a/gtsam/inference/ClusterTree.h b/gtsam/inference/ClusterTree.h index e225bac5f..7dd414193 100644 --- a/gtsam/inference/ClusterTree.h +++ b/gtsam/inference/ClusterTree.h @@ -110,7 +110,7 @@ class ClusterTree { typedef sharedCluster sharedNode; /** concept check */ - GTSAM_CONCEPT_TESTABLE_TYPE(FactorType); + GTSAM_CONCEPT_TESTABLE_TYPE(FactorType) protected: FastVector roots_; diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 295122879..7594da78d 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -25,15 +25,12 @@ namespace gtsam { /** - * TODO: Update comments. The following comments are out of date!!! - * - * Base class for conditional densities, templated on KEY type. This class - * provides storage for the keys involved in a conditional, and iterators and + * Base class for conditional densities. This class iterators and * access to the frontal and separator keys. * * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * to the associated factor type and shared_ptr type of the derived class. See - * IndexConditional and GaussianConditional for examples. + * SymbolicConditional and GaussianConditional for examples. * \nosubgrouping */ template diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp new file mode 100644 index 000000000..ad5330575 --- /dev/null +++ b/gtsam/inference/DotWriter.cpp @@ -0,0 +1,129 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.cpp + * @brief Graphviz formatting for factor graphs. + * @author Frank Dellaert + * @date December, 2021 + */ + +#include + +#include +#include + +#include + +using namespace std; + +namespace gtsam { + +void DotWriter::graphPreamble(ostream* os) const { + *os << "graph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::digraphPreamble(ostream* os) const { + *os << "digraph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) const { + // Label the node with the label from the KeyFormatter + *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) + << "\""; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + if (boxes.count(key)) { + *os << ", shape=box"; + } + *os << "];\n"; +} + +void DotWriter::DrawFactor(size_t i, const boost::optional& position, + ostream* os) { + *os << " factor" << i << "[label=\"\", shape=point"; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +static void ConnectVariables(Key key1, Key key2, + const KeyFormatter& keyFormatter, ostream* os) { + *os << " var" << keyFormatter(key1) << "--" + << "var" << keyFormatter(key2) << ";\n"; +} + +static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, + size_t i, ostream* os) { + *os << " var" << keyFormatter(key) << "--" + << "factor" << i << ";\n"; +} + +/// Return variable position or none +boost::optional DotWriter::variablePos(Key key) const { + boost::optional result = boost::none; + + // Check position hint + Symbol symbol(key); + auto hint = positionHints.find(symbol.chr()); + if (hint != positionHints.end()) + result.reset(Vector2(symbol.index(), hint->second)); + + // Override with explicit position, if given. + auto pos = variablePositions.find(key); + if (pos != variablePositions.end()) + result.reset(pos->second); + + return result; +} + +void DotWriter::processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) const { + if (plotFactorPoints) { + if (binaryEdges && keys.size() == 2) { + ConnectVariables(keys[0], keys[1], keyFormatter, os); + } else { + // Create dot for the factor. + if (!position && factorPositions.count(i)) + DrawFactor(i, factorPositions.at(i), os); + else + DrawFactor(i, position, os); + + // Make factor-variable connections + if (connectKeysToFactor) { + for (Key key : keys) { + ConnectVariableFactor(key, keyFormatter, i, os); + } + } + } + } else { + // just connect variables in a clique + for (Key key1 : keys) { + for (Key key2 : keys) { + if (key2 > key1) { + ConnectVariables(key1, key2, keyFormatter, os); + } + } + } + } +} + +} // namespace gtsam diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h new file mode 100644 index 000000000..23302ee60 --- /dev/null +++ b/gtsam/inference/DotWriter.h @@ -0,0 +1,100 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.h + * @brief Graphviz formatter + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * @brief DotWriter is a helper class for writing graphviz .dot files. + * @addtogroup inference + */ +struct GTSAM_EXPORT DotWriter { + double figureWidthInches; ///< The figure width on paper in inches + double figureHeightInches; ///< The figure height on paper in inches + bool plotFactorPoints; ///< Plots each factor as a dot between the variables + bool connectKeysToFactor; ///< Draw a line from each key within a factor to + ///< the dot of the factor + bool binaryEdges; ///< just use non-dotted edges for binary factors + + /** + * Variable positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map variablePositions; + + /** + * The position hints allow one to use symbol character and index to specify + * position. Unless variable positions are specified, if a hint is present for + * a given symbol, it will be used to calculate the positions as (index,hint). + */ + std::map positionHints; + + /** A set of keys that will be displayed as a box */ + std::set boxes; + + /** + * Factor positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map factorPositions; + + explicit DotWriter(double figureWidthInches = 5, + double figureHeightInches = 5, + bool plotFactorPoints = true, + bool connectKeysToFactor = true, bool binaryEdges = false) + : figureWidthInches(figureWidthInches), + figureHeightInches(figureHeightInches), + plotFactorPoints(plotFactorPoints), + connectKeysToFactor(connectKeysToFactor), + binaryEdges(binaryEdges) {} + + /// Write out preamble for graph, including size. + void graphPreamble(std::ostream* os) const; + + /// Write out preamble for digraph, including size. + void digraphPreamble(std::ostream* os) const; + + /// Create a variable dot fragment. + void drawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os) const; + + /// Create factor dot. + static void DrawFactor(size_t i, const boost::optional& position, + std::ostream* os); + + /// Return variable position or none + boost::optional variablePos(Key key) const; + + /// Draw a single factor, specified by its index i and its variable keys. + void processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os) const; +}; + +} // namespace gtsam diff --git a/gtsam/inference/EliminateableFactorGraph-inst.h b/gtsam/inference/EliminateableFactorGraph-inst.h index 81f4047a1..35e7505c9 100644 --- a/gtsam/inference/EliminateableFactorGraph-inst.h +++ b/gtsam/inference/EliminateableFactorGraph-inst.h @@ -36,17 +36,17 @@ namespace gtsam { // no Ordering is provided. When removing optional from VariableIndex, create VariableIndex // before creating ordering. VariableIndex computedVariableIndex(asDerived()); - return eliminateSequential(function, computedVariableIndex, orderingType); + return eliminateSequential(orderingType, function, computedVariableIndex); } else { // Compute an ordering and call this function again. We are guaranteed to have a // VariableIndex already here because we computed one if needed in the previous 'if' block. if (orderingType == Ordering::METIS) { Ordering computedOrdering = Ordering::Metis(asDerived()); - return eliminateSequential(computedOrdering, function, variableIndex, orderingType); + return eliminateSequential(computedOrdering, function, variableIndex); } else { Ordering computedOrdering = Ordering::Colamd(*variableIndex); - return eliminateSequential(computedOrdering, function, variableIndex, orderingType); + return eliminateSequential(computedOrdering, function, variableIndex); } } } @@ -78,29 +78,31 @@ namespace gtsam { } /* ************************************************************************* */ - template - boost::shared_ptr::BayesTreeType> - EliminateableFactorGraph::eliminateMultifrontal( - OptionalOrderingType orderingType, const Eliminate& function, - OptionalVariableIndex variableIndex) const - { - if(!variableIndex) { - // If no VariableIndex provided, compute one and call this function again IMPORTANT: we check - // for no variable index first so that it's always computed if we need to call COLAMD because - // no Ordering is provided. When removing optional from VariableIndex, create VariableIndex - // before creating ordering. + template + boost::shared_ptr< + typename EliminateableFactorGraph::BayesTreeType> + EliminateableFactorGraph::eliminateMultifrontal( + OptionalOrderingType orderingType, const Eliminate& function, + OptionalVariableIndex variableIndex) const { + if (!variableIndex) { + // If no VariableIndex provided, compute one and call this function again + // IMPORTANT: we check for no variable index first so that it's always + // computed if we need to call COLAMD because no Ordering is provided. + // When removing optional from VariableIndex, create VariableIndex before + // creating ordering. VariableIndex computedVariableIndex(asDerived()); - return eliminateMultifrontal(function, computedVariableIndex, orderingType); - } - else { - // Compute an ordering and call this function again. We are guaranteed to have a - // VariableIndex already here because we computed one if needed in the previous 'if' block. + return eliminateMultifrontal(orderingType, function, + computedVariableIndex); + } else { + // Compute an ordering and call this function again. We are guaranteed to + // have a VariableIndex already here because we computed one if needed in + // the previous 'if' block. if (orderingType == Ordering::METIS) { Ordering computedOrdering = Ordering::Metis(asDerived()); - return eliminateMultifrontal(computedOrdering, function, variableIndex, orderingType); + return eliminateMultifrontal(computedOrdering, function, variableIndex); } else { Ordering computedOrdering = Ordering::Colamd(*variableIndex); - return eliminateMultifrontal(computedOrdering, function, variableIndex, orderingType); + return eliminateMultifrontal(computedOrdering, function, variableIndex); } } } @@ -273,7 +275,7 @@ namespace gtsam { else { // No ordering was provided for the unmarginalized variables, so order them with COLAMD. - return factorGraph->eliminateSequential(function); + return factorGraph->eliminateSequential(Ordering::COLAMD, function); } } } @@ -340,7 +342,7 @@ namespace gtsam { else { // No ordering was provided for the unmarginalized variables, so order them with COLAMD. - return factorGraph->eliminateMultifrontal(function); + return factorGraph->eliminateMultifrontal(Ordering::COLAMD, function); } } } diff --git a/gtsam/inference/EliminateableFactorGraph.h b/gtsam/inference/EliminateableFactorGraph.h index edc4883e7..c904d2f7f 100644 --- a/gtsam/inference/EliminateableFactorGraph.h +++ b/gtsam/inference/EliminateableFactorGraph.h @@ -288,8 +288,9 @@ namespace gtsam { FactorGraphType& asDerived() { return static_cast(*this); } public: - /** \deprecated ordering and orderingType shouldn't both be specified */ - boost::shared_ptr eliminateSequential( + #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated ordering and orderingType shouldn't both be specified */ + boost::shared_ptr GTSAM_DEPRECATED eliminateSequential( const Ordering& ordering, const Eliminate& function, OptionalVariableIndex variableIndex, @@ -297,16 +298,16 @@ namespace gtsam { return eliminateSequential(ordering, function, variableIndex); } - /** \deprecated orderingType specified first for consistency */ - boost::shared_ptr eliminateSequential( + /** @deprecated orderingType specified first for consistency */ + boost::shared_ptr GTSAM_DEPRECATED eliminateSequential( const Eliminate& function, OptionalVariableIndex variableIndex = boost::none, OptionalOrderingType orderingType = boost::none) const { return eliminateSequential(orderingType, function, variableIndex); } - /** \deprecated ordering and orderingType shouldn't both be specified */ - boost::shared_ptr eliminateMultifrontal( + /** @deprecated ordering and orderingType shouldn't both be specified */ + boost::shared_ptr GTSAM_DEPRECATED eliminateMultifrontal( const Ordering& ordering, const Eliminate& function, OptionalVariableIndex variableIndex, @@ -314,16 +315,16 @@ namespace gtsam { return eliminateMultifrontal(ordering, function, variableIndex); } - /** \deprecated orderingType specified first for consistency */ - boost::shared_ptr eliminateMultifrontal( + /** @deprecated orderingType specified first for consistency */ + boost::shared_ptr GTSAM_DEPRECATED eliminateMultifrontal( const Eliminate& function, OptionalVariableIndex variableIndex = boost::none, OptionalOrderingType orderingType = boost::none) const { return eliminateMultifrontal(orderingType, function, variableIndex); } - /** \deprecated */ - boost::shared_ptr marginalMultifrontalBayesNet( + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED marginalMultifrontalBayesNet( boost::variant variables, boost::none_t, const Eliminate& function = EliminationTraitsType::DefaultEliminate, @@ -331,14 +332,15 @@ namespace gtsam { return marginalMultifrontalBayesNet(variables, function, variableIndex); } - /** \deprecated */ - boost::shared_ptr marginalMultifrontalBayesTree( + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED marginalMultifrontalBayesTree( boost::variant variables, boost::none_t, const Eliminate& function = EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex = boost::none) const { return marginalMultifrontalBayesTree(variables, function, variableIndex); } + #endif }; } diff --git a/gtsam/inference/EliminationTree.h b/gtsam/inference/EliminationTree.h index e4a64c589..70e10b3bd 100644 --- a/gtsam/inference/EliminationTree.h +++ b/gtsam/inference/EliminationTree.h @@ -81,7 +81,7 @@ namespace gtsam { protected: /** concept check */ - GTSAM_CONCEPT_TESTABLE_TYPE(FactorType); + GTSAM_CONCEPT_TESTABLE_TYPE(FactorType) FastVector roots_; FastVector remainingFactors_; diff --git a/gtsam/inference/Factor.h b/gtsam/inference/Factor.h index 6ea81030a..27b85ef67 100644 --- a/gtsam/inference/Factor.h +++ b/gtsam/inference/Factor.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include #include @@ -111,6 +112,9 @@ typedef FastSet FactorIndexSet; /// @name Standard Interface /// @{ + /// Whether the factor is empty (involves zero variables). + bool empty() const { return keys_.empty(); } + /// First key Key front() const { return keys_.front(); } @@ -149,13 +153,11 @@ typedef FastSet FactorIndexSet; const std::string& s = "Factor", const KeyFormatter& formatter = DefaultKeyFormatter) const; - protected: /// check equality bool equals(const This& other, double tol = 1e-9) const; /// @} - public: /// @name Advanced Interface /// @{ diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 166ae41f4..a2ae07101 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -26,6 +26,7 @@ #include #include #include // for cout :-( +#include #include #include @@ -125,4 +126,50 @@ FactorIndices FactorGraph::add_factors(const CONTAINER& factors, return newFactorIndices; } +/* ************************************************************************* */ +template +void FactorGraph::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.graphPreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : keys()) { + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); + } + os << "\n"; + + // Create factors and variable connections + for (size_t i = 0; i < size(); ++i) { + const auto& factor = at(i); + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os); + } + } + + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string FactorGraph::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::stringstream ss; + dot(ss, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +template +void FactorGraph::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter, writer); + of.close(); +} + } // namespace gtsam diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index e337e3249..afea63da8 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -22,9 +22,10 @@ #pragma once +#include +#include #include #include -#include #include // for Eigen::aligned_allocator @@ -36,6 +37,7 @@ #include #include #include +#include namespace gtsam { /// Define collection type: @@ -126,6 +128,11 @@ class FactorGraph { /** Collection of factors */ FastVector factors_; + /// Check exact equality of the factor pointers. Useful for derived ==. + bool isEqual(const FactorGraph& other) const { + return factors_ == other.factors_; + } + /// @name Standard Constructors /// @{ @@ -288,11 +295,11 @@ class FactorGraph { /// @name Testable /// @{ - /// print out graph + /// Print out graph to std::cout, with optional key formatter. virtual void print(const std::string& s = "FactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const; - /** Check equality */ + /// Check equality up to tolerance. bool equals(const This& fg, double tol = 1e-9) const; /// @} @@ -371,6 +378,24 @@ class FactorGraph { return factors_.erase(first, last); } + /// @} + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/inference/MetisIndex-inl.h b/gtsam/inference/MetisIndex-inl.h index eb9670254..646523372 100644 --- a/gtsam/inference/MetisIndex-inl.h +++ b/gtsam/inference/MetisIndex-inl.h @@ -23,8 +23,8 @@ namespace gtsam { /* ************************************************************************* */ -template -void MetisIndex::augment(const FactorGraph& factors) { +template +void MetisIndex::augment(const FACTORGRAPH& factors) { std::map > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first std::map >::iterator iAdjMapIt; std::set keySet; diff --git a/gtsam/inference/MetisIndex.h b/gtsam/inference/MetisIndex.h index 7ec435caa..7431bff4c 100644 --- a/gtsam/inference/MetisIndex.h +++ b/gtsam/inference/MetisIndex.h @@ -62,8 +62,8 @@ public: nKeys_(0) { } - template - MetisIndex(const FG& factorGraph) : + template + MetisIndex(const FACTORGRAPH& factorGraph) : nKeys_(0) { augment(factorGraph); } @@ -78,8 +78,8 @@ public: * Augment the variable index with new factors. This can be used when * solving problems incrementally. */ - template - void augment(const FactorGraph& factors); + template + void augment(const FACTORGRAPH& factors); const std::vector& xadj() const { return xadj_; diff --git a/gtsam/inference/Ordering.cpp b/gtsam/inference/Ordering.cpp index 440d2b828..2ac2c0dde 100644 --- a/gtsam/inference/Ordering.cpp +++ b/gtsam/inference/Ordering.cpp @@ -25,11 +25,7 @@ #include #ifdef GTSAM_SUPPORT_NESTED_DISSECTION -#ifdef GTSAM_USE_SYSTEM_METIS #include -#else -#include -#endif #endif using namespace std; diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i new file mode 100644 index 000000000..e7b074ec4 --- /dev/null +++ b/gtsam/inference/inference.i @@ -0,0 +1,199 @@ +//************************************************************************* +// inference +//************************************************************************* + +namespace gtsam { + +// Headers for overloaded methods below, break hierarchy :-/ +#include +#include +#include +#include + +#include + +// Default keyformatter +void PrintKeyList( + const gtsam::KeyList& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +void PrintKeyVector( + const gtsam::KeyVector& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +void PrintKeySet( + const gtsam::KeySet& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); + +#include +class Symbol { + Symbol(); + Symbol(char c, uint64_t j); + Symbol(size_t key); + + size_t key() const; + void print(const string& s = "") const; + bool equals(const gtsam::Symbol& expected, double tol) const; + + char chr() const; + uint64_t index() const; + string string() const; +}; + +size_t symbol(char chr, size_t index); +char symbolChr(size_t key); +size_t symbolIndex(size_t key); + +namespace symbol_shorthand { +size_t A(size_t j); +size_t B(size_t j); +size_t C(size_t j); +size_t D(size_t j); +size_t E(size_t j); +size_t F(size_t j); +size_t G(size_t j); +size_t H(size_t j); +size_t I(size_t j); +size_t J(size_t j); +size_t K(size_t j); +size_t L(size_t j); +size_t M(size_t j); +size_t N(size_t j); +size_t O(size_t j); +size_t P(size_t j); +size_t Q(size_t j); +size_t R(size_t j); +size_t S(size_t j); +size_t T(size_t j); +size_t U(size_t j); +size_t V(size_t j); +size_t W(size_t j); +size_t X(size_t j); +size_t Y(size_t j); +size_t Z(size_t j); +} // namespace symbol_shorthand + +#include +class LabeledSymbol { + LabeledSymbol(size_t full_key); + LabeledSymbol(const gtsam::LabeledSymbol& key); + LabeledSymbol(unsigned char valType, unsigned char label, size_t j); + + size_t key() const; + unsigned char label() const; + unsigned char chr() const; + size_t index() const; + + gtsam::LabeledSymbol upper() const; + gtsam::LabeledSymbol lower() const; + gtsam::LabeledSymbol newChr(unsigned char c) const; + gtsam::LabeledSymbol newLabel(unsigned char label) const; + + void print(string s = "") const; +}; + +size_t mrsymbol(unsigned char c, unsigned char label, size_t j); +unsigned char mrsymbolChr(size_t key); +unsigned char mrsymbolLabel(size_t key); +size_t mrsymbolIndex(size_t key); + +#include +class Ordering { + /// Type of ordering to use + enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM }; + + // Standard Constructors and Named Constructors + Ordering(); + Ordering(const gtsam::Ordering& other); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering ColamdConstrainedLast( + const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainLast, + bool forceOrder = false); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering ColamdConstrainedFirst( + const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainFirst, + bool forceOrder = false); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering Natural(const FACTOR_GRAPH& graph); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering Metis(const FACTOR_GRAPH& graph); + + template < + FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + static gtsam::Ordering Create(gtsam::Ordering::OrderingType orderingType, + const FACTOR_GRAPH& graph); + + // Testable + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::Ordering& ord, double tol) const; + + // Standard interface + size_t size() const; + size_t at(size_t key) const; + void push_back(size_t key); + + // enabling serialization functionality + void serialize() const; +}; + +#include +class DotWriter { + DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, + bool plotFactorPoints = true, bool connectKeysToFactor = true, + bool binaryEdges = true); + + double figureWidthInches; + double figureHeightInches; + bool plotFactorPoints; + bool connectKeysToFactor; + bool binaryEdges; + + std::map variablePositions; + std::map positionHints; + std::set boxes; + std::map factorPositions; +}; + +#include +class VariableIndex { + // Standard Constructors and Named Constructors + VariableIndex(); + // TODO: Templetize constructor when wrap supports it + // template + // VariableIndex(const T& factorGraph, size_t nVariables); + // VariableIndex(const T& factorGraph); + VariableIndex(const gtsam::SymbolicFactorGraph& sfg); + VariableIndex(const gtsam::GaussianFactorGraph& gfg); + VariableIndex(const gtsam::NonlinearFactorGraph& fg); + VariableIndex(const gtsam::VariableIndex& other); + + // Testable + bool equals(const gtsam::VariableIndex& other, double tol) const; + void print(string s = "VariableIndex: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + // Standard interface + size_t size() const; + size_t nFactors() const; + size_t nEntries() const; +}; + +} // namespace gtsam diff --git a/gtsam/inference/tests/testOrdering.cpp b/gtsam/inference/tests/testOrdering.cpp index 0305218af..6fdca0d89 100644 --- a/gtsam/inference/tests/testOrdering.cpp +++ b/gtsam/inference/tests/testOrdering.cpp @@ -270,17 +270,7 @@ TEST(Ordering, MetisLoop) { symbolicGraph.push_factor(0, 5); // METIS -#if !defined(__APPLE__) - { - Ordering actual = Ordering::Create(Ordering::METIS, symbolicGraph); - // - P( 0 4 1) - // | - P( 2 | 4 1) - // | | - P( 3 | 4 2) - // | - P( 5 | 0 1) - Ordering expected = Ordering(list_of(3)(2)(5)(0)(4)(1)); - EXPECT(assert_equal(expected, actual)); - } -#else +#if defined(__APPLE__) { Ordering actual = Ordering::Create(Ordering::METIS, symbolicGraph); // - P( 1 0 3) @@ -290,6 +280,26 @@ TEST(Ordering, MetisLoop) { Ordering expected = Ordering(list_of(5)(4)(2)(1)(0)(3)); EXPECT(assert_equal(expected, actual)); } +#elif defined(_WIN32) + { + Ordering actual = Ordering::Create(Ordering::METIS, symbolicGraph); + // - P( 0 5 2) + // | - P( 3 | 5 2) + // | | - P( 4 | 5 3) + // | - P( 1 | 0 2) + Ordering expected = Ordering(list_of(4)(3)(1)(0)(5)(2)); + EXPECT(assert_equal(expected, actual)); + } +#else + { + Ordering actual = Ordering::Create(Ordering::METIS, symbolicGraph); + // - P( 0 4 1) + // | - P( 2 | 4 1) + // | | - P( 3 | 4 2) + // | - P( 5 | 0 1) + Ordering expected = Ordering(list_of(3)(2)(5)(0)(4)(1)); + EXPECT(assert_equal(expected, actual)); + } #endif } #endif diff --git a/gtsam/linear/Errors.cpp b/gtsam/linear/Errors.cpp index 3fe2f3307..41c6c3d09 100644 --- a/gtsam/linear/Errors.cpp +++ b/gtsam/linear/Errors.cpp @@ -110,11 +110,10 @@ double dot(const Errors& a, const Errors& b) { } /* ************************************************************************* */ -template<> -void axpy(double alpha, const Errors& x, Errors& y) { +void axpy(double alpha, const Errors& x, Errors& y) { Errors::const_iterator it = x.begin(); for(Vector& yi: y) - axpy(alpha,*(it++),yi); + yi += alpha * (*(it++)); } /* ************************************************************************* */ diff --git a/gtsam/linear/Errors.h b/gtsam/linear/Errors.h index eb844e04d..f6e147084 100644 --- a/gtsam/linear/Errors.h +++ b/gtsam/linear/Errors.h @@ -65,8 +65,7 @@ namespace gtsam { /** * BLAS level 2 style */ - template <> - GTSAM_EXPORT void axpy(double alpha, const Errors& x, Errors& y); + GTSAM_EXPORT void axpy(double alpha, const Errors& x, Errors& y); /** print with optional string */ GTSAM_EXPORT void print(const Errors& a, const std::string& s = "Error"); diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 1e790d0f1..6dcf662a9 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -26,6 +26,9 @@ using namespace std; using namespace gtsam; +// In Wrappers we have no access to this so have a default ready +static std::mt19937_64 kRandomNumberGenerator(42); + namespace gtsam { // Instantiate base class @@ -37,28 +40,50 @@ namespace gtsam { return Base::equals(bn, tol); } - /* ************************************************************************* */ - VectorValues GaussianBayesNet::optimize() const - { - VectorValues soln; // no missing variables -> just create an empty vector - return optimize(soln); + /* ************************************************************************ */ + VectorValues GaussianBayesNet::optimize() const { + VectorValues solution; // no missing variables -> create an empty vector + return optimize(solution); } - /* ************************************************************************* */ - VectorValues GaussianBayesNet::optimize( - const VectorValues& solutionForMissing) const { - VectorValues soln(solutionForMissing); // possibly empty + VectorValues GaussianBayesNet::optimize(VectorValues solution) const { // (R*x)./sigmas = y by solving x=inv(R)*(y.*sigmas) - /** solve each node in turn in topological sort order (parents first)*/ - for (auto cg: boost::adaptors::reverse(*this)) { + // solve each node in reverse topological sort order (parents first) + for (auto cg : boost::adaptors::reverse(*this)) { // i^th part of R*x=y, x=inv(R)*y - // (Rii*xi + R_i*x(i+1:))./si = yi <-> xi = inv(Rii)*(yi.*si - R_i*x(i+1:)) - soln.insert(cg->solve(soln)); + // (Rii*xi + R_i*x(i+1:))./si = yi => + // xi = inv(Rii)*(yi.*si - R_i*x(i+1:)) + solution.insert(cg->solve(solution)); } - return soln; + return solution; } - /* ************************************************************************* */ + /* ************************************************************************ */ + VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const { + VectorValues result; // no missing variables -> create an empty vector + return sample(result, rng); + } + + VectorValues GaussianBayesNet::sample(VectorValues result, + std::mt19937_64* rng) const { + // sample each node in reverse topological sort order (parents first) + for (auto cg : boost::adaptors::reverse(*this)) { + const VectorValues sampled = cg->sample(result, rng); + result.insert(sampled); + } + return result; + } + + /* ************************************************************************ */ + VectorValues GaussianBayesNet::sample() const { + return sample(&kRandomNumberGenerator); + } + + VectorValues GaussianBayesNet::sample(VectorValues given) const { + return sample(given, &kRandomNumberGenerator); + } + + /* ************************************************************************ */ VectorValues GaussianBayesNet::optimizeGradientSearch() const { gttic(GaussianBayesTree_optimizeGradientSearch); @@ -205,23 +230,5 @@ namespace gtsam { } /* ************************************************************************* */ - void GaussianBayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional : boost::adaptors::reverse(*this)) { - typename GaussianConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename GaussianConditional::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; - } - - of << "}"; - of.close(); - } - - /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index e55a89bcd..940ffd882 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -21,17 +21,22 @@ #pragma once #include +#include #include #include +#include namespace gtsam { - /** A Bayes net made from linear-Gaussian densities */ - class GTSAM_EXPORT GaussianBayesNet: public FactorGraph + /** + * GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals. + * @addtogroup linear + */ + class GTSAM_EXPORT GaussianBayesNet: public BayesNet { public: - typedef FactorGraph Base; + typedef BayesNet Base; typedef GaussianBayesNet This; typedef GaussianConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +49,21 @@ namespace gtsam { GaussianBayesNet() {} /** Construct from iterator over conditionals */ - template - GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit GaussianBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - GaussianBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit GaussianBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~GaussianBayesNet() {} @@ -66,16 +76,47 @@ namespace gtsam { /** Check equality */ bool equals(const This& bn, double tol = 1e-9) const; + /// print graph + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + /// @} /// @name Standard Interface /// @{ - /// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by back-substitution + /// Solve the GaussianBayesNet, i.e. return \f$ x = R^{-1}*d \f$, by + /// back-substitution VectorValues optimize() const; - /// Version of optimize for incomplete BayesNet, needs solution for missing variables - VectorValues optimize(const VectorValues& solutionForMissing) const; + /// Version of optimize for incomplete BayesNet, given missing variables + VectorValues optimize(const VectorValues given) const; + + /** + * Sample using ancestral sampling + * Example: + * std::mt19937_64 rng(42); + * auto sample = gbn.sample(&rng); + */ + VectorValues sample(std::mt19937_64* rng) const; + + /** + * Sample from an incomplete BayesNet, given missing variables + * Example: + * std::mt19937_64 rng(42); + * VectorValues given = ...; + * auto sample = gbn.sample(given, &rng); + */ + VectorValues sample(VectorValues given, std::mt19937_64* rng) const; + + /// Sample using ancestral sampling, use default rng + VectorValues sample() const; + + /// Sample from an incomplete BayesNet, use default rng + VectorValues sample(VectorValues given) const; /** * Return ordering corresponding to a topological sort. @@ -180,23 +221,6 @@ namespace gtsam { */ VectorValues backSubstituteTranspose(const VectorValues& gx) const; - /// print graph - void print( - const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); - } - - /** - * @brief Save the GaussianBayesNet as an image. Requires `dot` to be - * installed. - * - * @param s The name of the figure. - * @param keyFormatter Formatter to use for styling keys in the graph. - */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const; - /// @} private: diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 9297d6461..6199f91a7 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #ifdef __GNUC__ @@ -34,6 +35,9 @@ #include #include +// In Wrappers we have no access to this so have a default ready +static std::mt19937_64 kRandomNumberGenerator(42); + using namespace std; namespace gtsam { @@ -43,25 +47,62 @@ namespace gtsam { Key key, const Vector& d, const Matrix& R, const SharedDiagonal& sigmas) : BaseFactor(key, R, d, sigmas), BaseConditional(1) {} - /* ************************************************************************* */ - GaussianConditional::GaussianConditional( - Key key, const Vector& d, const Matrix& R, - Key name1, const Matrix& S, const SharedDiagonal& sigmas) : - BaseFactor(key, R, name1, S, d, sigmas), BaseConditional(1) {} + /* ************************************************************************ */ + GaussianConditional::GaussianConditional(Key key, const Vector& d, + const Matrix& R, Key parent1, + const Matrix& S, + const SharedDiagonal& sigmas) + : BaseFactor(key, R, parent1, S, d, sigmas), BaseConditional(1) {} - /* ************************************************************************* */ - GaussianConditional::GaussianConditional( - Key key, const Vector& d, const Matrix& R, - Key name1, const Matrix& S, Key name2, const Matrix& T, const SharedDiagonal& sigmas) : - BaseFactor(key, R, name1, S, name2, T, d, sigmas), BaseConditional(1) {} + /* ************************************************************************ */ + GaussianConditional::GaussianConditional(Key key, const Vector& d, + const Matrix& R, Key parent1, + const Matrix& S, Key parent2, + const Matrix& T, + const SharedDiagonal& sigmas) + : BaseFactor(key, R, parent1, S, parent2, T, d, sigmas), + BaseConditional(1) {} - /* ************************************************************************* */ + /* ************************************************************************ */ + GaussianConditional GaussianConditional::FromMeanAndStddev( + Key key, const Matrix& A, Key parent, const Vector& b, double sigma) { + // |Rx + Sy - d| = |x-(Ay + b)|/sigma + const Matrix R = Matrix::Identity(b.size(), b.size()); + const Matrix S = -A; + const Vector d = b; + return GaussianConditional(key, d, R, parent, S, + noiseModel::Isotropic::Sigma(b.size(), sigma)); + } + + /* ************************************************************************ */ + GaussianConditional GaussianConditional::FromMeanAndStddev( + Key key, const Matrix& A1, Key parent1, const Matrix& A2, Key parent2, + const Vector& b, double sigma) { + // |Rx + Sy + Tz - d| = |x-(A1 y + A2 z + b)|/sigma + const Matrix R = Matrix::Identity(b.size(), b.size()); + const Matrix S = -A1; + const Matrix T = -A2; + const Vector d = b; + return GaussianConditional(key, d, R, parent1, S, parent2, T, + noiseModel::Isotropic::Sigma(b.size(), sigma)); + } + + /* ************************************************************************ */ void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const { - cout << s << " Conditional density "; + cout << s << " p("; for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { - cout << (boost::format("[%1%]")%(formatter(*it))).str() << " "; + cout << (boost::format("%1%") % (formatter(*it))).str() + << (nrFrontals() > 1 ? " " : ""); } - cout << endl; + + if (nrParents()) { + cout << " |"; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << " " << (boost::format("%1%") % (formatter(*it))).str(); + } + } + cout << ")" << endl; + cout << formatMatrixIndented(" R = ", R()) << endl; for (const_iterator it = beginParents() ; it != endParents() ; ++it) { cout << formatMatrixIndented((boost::format(" S[%1%] = ")%(formatter(*it))).str(), getA(it)) @@ -192,13 +233,97 @@ namespace gtsam { } } - /* ************************************************************************* */ - void GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const { + /* ************************************************************************ */ + JacobianFactor::shared_ptr GaussianConditional::likelihood( + const VectorValues& frontalValues) const { + // Error is |Rx - (d - Sy - Tz - ...)|^2 + // so when we instantiate x (which has to be completely known) we beget: + // |Sy + Tz + ... - (d - Rx)|^2 + // The noise model just transfers over! + + // Get frontalValues as vector + const Vector x = + frontalValues.vector(KeyVector(beginFrontals(), endFrontals())); + + // Copy the augmented Jacobian matrix: + auto newAb = Ab_; + + // Restrict view to parent blocks + newAb.firstBlock() += nrFrontals_; + + // Update right-hand-side (last column) + auto last = newAb.matrix().cols() - 1; + const auto RR = R().triangularView(); + newAb.matrix().col(last) -= RR * x; + + // The keys now do not include the frontal keys: + KeyVector newKeys; + newKeys.reserve(nrParents()); + for (auto&& key : parents()) newKeys.push_back(key); + + // Hopefully second newAb copy below is optimized out... + return boost::make_shared(newKeys, newAb, model_); + } + + /* **************************************************************************/ + JacobianFactor::shared_ptr GaussianConditional::likelihood( + const Vector& frontal) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "GaussianConditional Single value likelihood can only be invoked on " + "single-variable conditional"); + VectorValues values; + values.insert(keys_[0], frontal); + return likelihood(values); + } + + /* ************************************************************************ */ + VectorValues GaussianConditional::sample(const VectorValues& parentsValues, + std::mt19937_64* rng) const { + if (nrFrontals() != 1) { + throw std::invalid_argument( + "GaussianConditional::sample can only be called on single variable " + "conditionals"); + } + if (!model_) { + throw std::invalid_argument( + "GaussianConditional::sample can only be called if a diagonal noise " + "model was specified at construction."); + } + VectorValues solution = solve(parentsValues); + Key key = firstFrontalKey(); + const Vector& sigmas = model_->sigmas(); + solution[key] += Sampler::sampleDiagonal(sigmas, rng); + return solution; + } + + VectorValues GaussianConditional::sample(std::mt19937_64* rng) const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + VectorValues values; + return sample(values); + } + + /* ************************************************************************ */ + VectorValues GaussianConditional::sample() const { + return sample(&kRandomNumberGenerator); + } + + VectorValues GaussianConditional::sample(const VectorValues& given) const { + return sample(given, &kRandomNumberGenerator); + } + + /* ************************************************************************ */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + void GTSAM_DEPRECATED + GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const { DenseIndex vectorPosition = 0; for (const_iterator frontal = beginFrontals(); frontal != endFrontals(); ++frontal) { gy[*frontal].array() *= model_->sigmas().segment(vectorPosition, getDim(frontal)).array(); vectorPosition += getDim(frontal); } } +#endif } // namespace gtsam diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 0ea597f99..b2b616dab 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -26,12 +26,15 @@ #include #include +#include // for std::mt19937_64 + namespace gtsam { /** - * A conditional Gaussian functions as the node in a Bayes network + * A GaussianConditional functions as the node in a Bayes network. * It has a set of parents y,z, etc. and implements a probability density on x. * The negative log-probability is given by \f$ \frac{1}{2} |Rx - (d - Sy - Tz - ...)|^2 \f$ + * @addtogroup linear */ class GTSAM_EXPORT GaussianConditional : public JacobianFactor, @@ -43,6 +46,9 @@ namespace gtsam { typedef JacobianFactor BaseFactor; ///< Typedef to our factor base class typedef Conditional BaseConditional; ///< Typedef to our conditional base class + /// @name Constructors + /// @{ + /** default constructor needed for serialization */ GaussianConditional() {} @@ -51,13 +57,14 @@ namespace gtsam { const SharedDiagonal& sigmas = SharedDiagonal()); /** constructor with only one parent |Rx+Sy-d| */ - GaussianConditional(Key key, const Vector& d, const Matrix& R, - Key name1, const Matrix& S, const SharedDiagonal& sigmas = SharedDiagonal()); + GaussianConditional(Key key, const Vector& d, const Matrix& R, Key parent1, + const Matrix& S, + const SharedDiagonal& sigmas = SharedDiagonal()); /** constructor with two parents |Rx+Sy+Tz-d| */ - GaussianConditional(Key key, const Vector& d, const Matrix& R, - Key name1, const Matrix& S, Key name2, const Matrix& T, - const SharedDiagonal& sigmas = SharedDiagonal()); + GaussianConditional(Key key, const Vector& d, const Matrix& R, Key parent1, + const Matrix& S, Key parent2, const Matrix& T, + const SharedDiagonal& sigmas = SharedDiagonal()); /** Constructor with arbitrary number of frontals and parents. * @tparam TERMS A container whose value type is std::pair, specifying the @@ -76,6 +83,17 @@ namespace gtsam { const KEYS& keys, size_t nrFrontals, const VerticalBlockMatrix& augmentedMatrix, const SharedDiagonal& sigmas = SharedDiagonal()); + /// Construct from mean A1 p1 + b and standard deviation. + static GaussianConditional FromMeanAndStddev(Key key, const Matrix& A, + Key parent, const Vector& b, + double sigma); + + /// Construct from mean A1 p1 + A2 p2 + b and standard deviation. + static GaussianConditional FromMeanAndStddev(Key key, // + const Matrix& A1, Key parent1, + const Matrix& A2, Key parent2, + const Vector& b, double sigma); + /** Combine several GaussianConditional into a single dense GC. The conditionals enumerated by * \c first and \c last must be in increasing order, meaning that the parents of any * conditional may not include a conditional coming before it. @@ -86,13 +104,22 @@ namespace gtsam { template static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional); + /// @} + /// @name Testable + /// @{ + /** print */ - void print(const std::string& = "GaussianConditional", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + void print( + const std::string& = "GaussianConditional", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; /** equals function */ bool equals(const GaussianFactor&cg, double tol = 1e-9) const override; + /// @} + /// @name Standard Interface + /// @{ + /** Return a view of the upper-triangular R block of the conditional */ constABlock R() const { return Ab_.range(0, nrFrontals()); } @@ -125,12 +152,47 @@ namespace gtsam { /** Performs transpose backsubstition in place on values */ void solveTransposeInPlace(VectorValues& gy) const; + /** Convert to a likelihood factor by providing value before bar. */ + JacobianFactor::shared_ptr likelihood( + const VectorValues& frontalValues) const; + + /** Single variable version of likelihood. */ + JacobianFactor::shared_ptr likelihood(const Vector& frontal) const; + + /** + * Sample from conditional, zero parent version + * Example: + * std::mt19937_64 rng(42); + * auto sample = gbn.sample(&rng); + */ + VectorValues sample(std::mt19937_64* rng) const; + + /** + * Sample from conditional, given missing variables + * Example: + * std::mt19937_64 rng(42); + * VectorValues given = ...; + * auto sample = gbn.sample(given, &rng); + */ + VectorValues sample(const VectorValues& parentsValues, + std::mt19937_64* rng) const; + + /// Sample, use default rng + VectorValues sample() const; + + /// Sample with given values, use default rng + VectorValues sample(const VectorValues& parentsValues) const; + + /// @} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ /** Scale the values in \c gy according to the sigmas for the frontal variables in this * conditional. */ - void scaleFrontalsBySigma(VectorValues& gy) const; - - // FIXME: deprecated flag doesn't appear to exist? - // __declspec(deprecated) void scaleFrontalsBySigma(VectorValues& gy) const; + void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const; + /// @} +#endif private: /** Serialization function */ diff --git a/gtsam/linear/GaussianDensity.cpp b/gtsam/linear/GaussianDensity.cpp index d9cde9b91..343396c0a 100644 --- a/gtsam/linear/GaussianDensity.cpp +++ b/gtsam/linear/GaussianDensity.cpp @@ -17,39 +17,41 @@ */ #include +#include +#include -using namespace std; +using std::cout; +using std::endl; +using std::string; namespace gtsam { - /* ************************************************************************* */ - GaussianDensity GaussianDensity::FromMeanAndStddev(Key key, const Vector& mean, const double& sigma) - { - return GaussianDensity(key, mean / sigma, Matrix::Identity(mean.size(), mean.size()) / sigma); - } +/* ************************************************************************* */ +GaussianDensity GaussianDensity::FromMeanAndStddev(Key key, const Vector& mean, + double sigma) { + return GaussianDensity(key, mean, Matrix::Identity(mean.size(), mean.size()), + noiseModel::Isotropic::Sigma(mean.size(), sigma)); +} - /* ************************************************************************* */ - void GaussianDensity::print(const string &s, const KeyFormatter& formatter) const - { - cout << s << ": density on "; - for(const_iterator it = beginFrontals(); it != endFrontals(); ++it) - cout << (boost::format("[%1%]")%(formatter(*it))).str() << " "; - cout << endl; - gtsam::print(Matrix(R()), "R: "); - gtsam::print(Vector(d()), "d: "); - if(model_) - model_->print("Noise model: "); - } +/* ************************************************************************* */ +void GaussianDensity::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << ": density on "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) + cout << (boost::format("[%1%]") % (formatter(*it))).str() << " "; + cout << endl; + gtsam::print(mean(), "mean: "); + gtsam::print(covariance(), "covariance: "); + if (model_) model_->print("Noise model: "); +} - /* ************************************************************************* */ - Vector GaussianDensity::mean() const { - VectorValues soln = this->solve(VectorValues()); - return soln[firstFrontalKey()]; - } +/* ************************************************************************* */ +Vector GaussianDensity::mean() const { + VectorValues soln = this->solve(VectorValues()); + return soln[firstFrontalKey()]; +} - /* ************************************************************************* */ - Matrix GaussianDensity::covariance() const { - return information().inverse(); - } +/* ************************************************************************* */ +Matrix GaussianDensity::covariance() const { return information().inverse(); } -} // gtsam +} // namespace gtsam diff --git a/gtsam/linear/GaussianDensity.h b/gtsam/linear/GaussianDensity.h index 71af704ab..f078d5db6 100644 --- a/gtsam/linear/GaussianDensity.h +++ b/gtsam/linear/GaussianDensity.h @@ -24,11 +24,10 @@ namespace gtsam { /** - * A Gaussian density. - * - * It is implemented as a GaussianConditional without parents. + * A GaussianDensity is a GaussianConditional without parents. * The negative log-probability is given by \f$ |Rx - d|^2 \f$ * with \f$ \Lambda = \Sigma^{-1} = R^T R \f$ and \f$ \mu = R^{-1} d \f$ + * @addtogroup linear */ class GTSAM_EXPORT GaussianDensity : public GaussianConditional { @@ -52,8 +51,9 @@ namespace gtsam { GaussianDensity(Key key, const Vector& d, const Matrix& R, const SharedDiagonal& noiseModel = SharedDiagonal()) : GaussianConditional(key, d, R, noiseModel) {} - /// Construct using a mean and variance - static GaussianDensity FromMeanAndStddev(Key key, const Vector& mean, const double& sigma); + /// Construct using a mean and standard deviation + static GaussianDensity FromMeanAndStddev(Key key, const Vector& mean, + double sigma); /// print void print(const std::string& = "GaussianDensity", diff --git a/gtsam/linear/GaussianFactor.h b/gtsam/linear/GaussianFactor.h index 334722868..672f5aa0d 100644 --- a/gtsam/linear/GaussianFactor.h +++ b/gtsam/linear/GaussianFactor.h @@ -117,9 +117,6 @@ namespace gtsam { /** Clone a factor (make a deep copy) */ virtual GaussianFactor::shared_ptr clone() const = 0; - /** Test whether the factor is empty */ - virtual bool empty() const = 0; - /** * Construct the corresponding anti-factor to negate information * stored stored in this factor. diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index 13eaee7e3..72eb107d0 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -19,7 +19,6 @@ */ #include -#include #include #include #include @@ -290,10 +289,11 @@ namespace gtsam { return blocks; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues GaussianFactorGraph::optimize(const Eliminate& function) const { gttic(GaussianFactorGraph_optimize); - return BaseEliminateable::eliminateMultifrontal(function)->optimize(); + return BaseEliminateable::eliminateMultifrontal(Ordering::COLAMD, function) + ->optimize(); } /* ************************************************************************* */ @@ -379,7 +379,7 @@ namespace gtsam { gttic(Compute_minimizing_step_size); // Compute minimizing step size - double step = -gradientSqNorm / dot(Rg, Rg); + double step = -gradientSqNorm / gtsam::dot(Rg, Rg); gttoc(Compute_minimizing_step_size); gttic(Compute_point); @@ -504,10 +504,32 @@ namespace gtsam { } /* ************************************************************************* */ - /** \deprecated */ - VectorValues GaussianFactorGraph::optimize(boost::none_t, - const Eliminate& function) const { - return optimize(function); + void GaussianFactorGraph::printErrors( + const VectorValues& values, const std::string& str, + const KeyFormatter& keyFormatter, + const std::function& + printCondition) const { + cout << str << "size: " << size() << endl << endl; + for (size_t i = 0; i < (*this).size(); i++) { + const sharedFactor& factor = (*this)[i]; + const double errorValue = + (factor != nullptr ? (*this)[i]->error(values) : .0); + if (!printCondition(factor.get(), errorValue, i)) + continue; // User-provided filter did not pass + + stringstream ss; + ss << "Factor " << i << ": "; + if (factor == nullptr) { + cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + cout << "error = " << errorValue << "\n"; + } + cout << endl; // only one "endl" at end might be faster, \n for each + // factor + } } } // namespace gtsam diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index e3304d5e8..0d5057aa8 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -21,12 +21,13 @@ #pragma once -#include #include +#include +#include // Included here instead of fw-declared so we can use Errors::iterator #include -#include #include -#include // Included here instead of fw-declared so we can use Errors::iterator +#include +#include namespace gtsam { @@ -98,6 +99,12 @@ namespace gtsam { /// @} + /// Check exact equality. + friend bool operator==(const GaussianFactorGraph& lhs, + const GaussianFactorGraph& rhs) { + return lhs.isEqual(rhs); + } + /** Add a factor by value - makes a copy */ void add(const GaussianFactor& factor) { push_back(factor.clone()); } @@ -153,7 +160,8 @@ namespace gtsam { /** Unnormalized probability. O(n) */ double probPrime(const VectorValues& c) const { - return exp(-0.5 * error(c)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(c)); } /** @@ -375,6 +383,14 @@ namespace gtsam { /** In-place version e <- A*x that takes an iterator. */ void multiplyInPlace(const VectorValues& x, const Errors::iterator& e) const; + void printErrors( + const VectorValues& x, + const std::string& str = "GaussianFactorGraph: ", + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const std::function& + printCondition = + [](const Factor*, double, size_t) { return true; }) const; /// @} private: @@ -387,9 +403,14 @@ namespace gtsam { public: - /** \deprecated */ - VectorValues optimize(boost::none_t, - const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated */ + VectorValues GTSAM_DEPRECATED + optimize(boost::none_t, const Eliminate& function = + EliminationTraitsType::DefaultEliminate) const { + return optimize(function); + } +#endif }; @@ -399,7 +420,7 @@ namespace gtsam { */ GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors); - /****** Linear Algebra Opeations ******/ + /****** Linear Algebra Operations ******/ ///* matrix-vector operations */ //GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r); diff --git a/gtsam/linear/GaussianJunctionTree.h b/gtsam/linear/GaussianJunctionTree.h index 67e5a374b..c7f13ea5c 100644 --- a/gtsam/linear/GaussianJunctionTree.h +++ b/gtsam/linear/GaussianJunctionTree.h @@ -16,6 +16,8 @@ * @author Richard Roberts */ +#pragma once + #include #include #include diff --git a/gtsam/linear/HessianFactor.h b/gtsam/linear/HessianFactor.h index 0f4c993fe..7020d6edd 100644 --- a/gtsam/linear/HessianFactor.h +++ b/gtsam/linear/HessianFactor.h @@ -221,9 +221,6 @@ namespace gtsam { */ GaussianFactor::shared_ptr negate() const override; - /** Check if the factor is empty. TODO: How should this be defined? */ - bool empty() const override { return size() == 0 /*|| rows() == 0*/; } - /** Return the constant term \f$ f \f$ as described above * @return The constant term \f$ f \f$ */ diff --git a/gtsam/linear/JacobianFactor.h b/gtsam/linear/JacobianFactor.h index 4d4480d32..ddf614910 100644 --- a/gtsam/linear/JacobianFactor.h +++ b/gtsam/linear/JacobianFactor.h @@ -260,9 +260,6 @@ namespace gtsam { */ GaussianFactor::shared_ptr negate() const override; - /** Check if the factor is empty. TODO: How should this be defined? */ - bool empty() const override { return size() == 0 /*|| rows() == 0*/; } - /** is noise model constrained ? */ bool isConstrained() const { return model_ && model_->isConstrained(); diff --git a/gtsam/linear/LossFunctions.cpp b/gtsam/linear/LossFunctions.cpp index bf799a2ba..7307c4a68 100644 --- a/gtsam/linear/LossFunctions.cpp +++ b/gtsam/linear/LossFunctions.cpp @@ -19,6 +19,7 @@ #include #include +#include using namespace std; diff --git a/gtsam/linear/LossFunctions.h b/gtsam/linear/LossFunctions.h index c3d7d64db..d9cfc1f3c 100644 --- a/gtsam/linear/LossFunctions.h +++ b/gtsam/linear/LossFunctions.h @@ -54,23 +54,31 @@ namespace noiseModel { // clang-format on namespace mEstimator { -//--------------------------------------------------------------------------------------- - +/** + * Pure virtual class for all robust error function classes. + * + * It provides the machinery for block vs scalar reweighting strategies, in + * addition to defining the interface of derived classes. + */ class GTSAM_EXPORT Base { public: + /** the rows can be weighted independently according to the error + * or uniformly with the norm of the right hand side */ enum ReweightScheme { Scalar, Block }; typedef boost::shared_ptr shared_ptr; protected: - /** the rows can be weighted independently according to the error - * or uniformly with the norm of the right hand side */ + /// Strategy for reweighting \sa ReweightScheme ReweightScheme reweight_; public: Base(const ReweightScheme reweight = Block) : reweight_(reweight) {} virtual ~Base() {} - /* + /// Returns the reweight scheme, as explained in ReweightScheme + ReweightScheme reweightScheme() const { return reweight_; } + + /** * This method is responsible for returning the total penalty for a given * amount of error. For example, this method is responsible for implementing * the quadratic function for an L2 penalty, the absolute value function for @@ -80,16 +88,20 @@ class GTSAM_EXPORT Base { * error vector, then it prevents implementations of asymmeric loss * functions. It would be better for this function to accept the vector and * internally call the norm if necessary. + * + * This returns \rho(x) in \ref mEstimator */ - virtual double loss(double distance) const { return 0; }; + virtual double loss(double distance) const { return 0; } - /* + /** * This method is responsible for returning the weight function for a given * amount of error. The weight function is related to the analytic derivative * of the loss function. See * https://members.loria.fr/MOBerger/Enseignement/Master2/Documents/ZhangIVC-97-01.pdf * for details. This method is required when optimizing cost functions with * robust penalties using iteratively re-weighted least squares. + * + * This returns w(x) in \ref mEstimator */ virtual double weight(double distance) const = 0; @@ -124,7 +136,15 @@ class GTSAM_EXPORT Base { } }; -/// Null class should behave as Gaussian +/** "Null" robust loss function, equivalent to a Gaussian pdf noise model, or + * plain least-squares (non-robust). + * + * This model has no additional parameters. + * + * - Loss \rho(x) = 0.5 x² + * - Derivative \phi(x) = x + * - Weight w(x) = \phi(x)/x = 1 + */ class GTSAM_EXPORT Null : public Base { public: typedef boost::shared_ptr shared_ptr; @@ -146,7 +166,14 @@ class GTSAM_EXPORT Null : public Base { } }; -/// Fair implements the "Fair" robust error model (Zhang97ivc) +/** Implementation of the "Fair" robust error model (Zhang97ivc) + * + * This model has a scalar parameter "c". + * + * - Loss \rho(x) = c² (|x|/c - log(1+|x|/c)) + * - Derivative \phi(x) = x/(1+|x|/c) + * - Weight w(x) = \phi(x)/x = 1/(1+|x|/c) + */ class GTSAM_EXPORT Fair : public Base { protected: double c_; @@ -160,6 +187,7 @@ class GTSAM_EXPORT Fair : public Base { void print(const std::string &s) const override; bool equals(const Base &expected, double tol = 1e-8) const override; static shared_ptr Create(double c, const ReweightScheme reweight = Block); + double modelParameter() const { return c_; } private: /** Serialization function */ @@ -171,7 +199,14 @@ class GTSAM_EXPORT Fair : public Base { } }; -/// Huber implements the "Huber" robust error model (Zhang97ivc) +/** The "Huber" robust error model (Zhang97ivc). + * + * This model has a scalar parameter "k". + * + * - Loss \rho(x) = 0.5 x² if |x| shared_ptr; @@ -293,6 +359,7 @@ class GTSAM_EXPORT GemanMcClure : public Base { void print(const std::string &s) const override; bool equals(const Base &expected, double tol = 1e-8) const override; static shared_ptr Create(double k, const ReweightScheme reweight = Block); + double modelParameter() const { return c_; } protected: double c_; @@ -307,11 +374,18 @@ class GTSAM_EXPORT GemanMcClure : public Base { } }; -/// DCS implements the Dynamic Covariance Scaling robust error model -/// from the paper Robust Map Optimization (Agarwal13icra). -/// -/// Under the special condition of the parameter c == 1.0 and not -/// forcing the output weight s <= 1.0, DCS is similar to Geman-McClure. +/** DCS implements the Dynamic Covariance Scaling robust error model + * from the paper Robust Map Optimization (Agarwal13icra). + * + * Under the special condition of the parameter c == 1.0 and not + * forcing the output weight s <= 1.0, DCS is similar to Geman-McClure. + * + * This model has a scalar parameter "c" (with "units" of squared error). + * + * - Loss \rho(x) = (c²x² + cx⁴)/(x²+c)² (for any "x") + * - Derivative \phi(x) = 2c²x/(x²+c)² + * - Weight w(x) = \phi(x)/x = 2c²/(x²+c)² if x²>c, 1 otherwise + */ class GTSAM_EXPORT DCS : public Base { public: typedef boost::shared_ptr shared_ptr; @@ -323,6 +397,7 @@ class GTSAM_EXPORT DCS : public Base { void print(const std::string &s) const override; bool equals(const Base &expected, double tol = 1e-8) const override; static shared_ptr Create(double k, const ReweightScheme reweight = Block); + double modelParameter() const { return c_; } protected: double c_; @@ -337,12 +412,19 @@ class GTSAM_EXPORT DCS : public Base { } }; -/// L2WithDeadZone implements a standard L2 penalty, but with a dead zone of -/// width 2*k, centered at the origin. The resulting penalty within the dead -/// zone is always zero, and grows quadratically outside the dead zone. In this -/// sense, the L2WithDeadZone penalty is "robust to inliers", rather than being -/// robust to outliers. This penalty can be used to create barrier functions in -/// a general way. +/** L2WithDeadZone implements a standard L2 penalty, but with a dead zone of + * width 2*k, centered at the origin. The resulting penalty within the dead + * zone is always zero, and grows quadratically outside the dead zone. In this + * sense, the L2WithDeadZone penalty is "robust to inliers", rather than being + * robust to outliers. This penalty can be used to create barrier functions in + * a general way. + * + * This model has a scalar parameter "k". + * + * - Loss \rho(x) = 0 if |x|k, (k+x) if x<-k + * - Weight w(x) = \phi(x)/x = 0 if |x|k, (k+x)/x if x<-k + */ class GTSAM_EXPORT L2WithDeadZone : public Base { protected: double k_; @@ -356,6 +438,7 @@ class GTSAM_EXPORT L2WithDeadZone : public Base { void print(const std::string &s) const override; bool equals(const Base &expected, double tol = 1e-8) const override; static shared_ptr Create(double k, const ReweightScheme reweight = Block); + double modelParameter() const { return k_; } private: /** Serialization function */ diff --git a/gtsam/linear/NoiseModel.cpp b/gtsam/linear/NoiseModel.cpp index cf10cf115..8bcef6fcc 100644 --- a/gtsam/linear/NoiseModel.cpp +++ b/gtsam/linear/NoiseModel.cpp @@ -134,7 +134,7 @@ Gaussian::shared_ptr Gaussian::Covariance(const Matrix& covariance, /* ************************************************************************* */ void Gaussian::print(const string& name) const { - gtsam::print(thisR(), name + "Gaussian"); + gtsam::print(thisR(), name + "Gaussian "); } /* ************************************************************************* */ @@ -285,7 +285,7 @@ Diagonal::shared_ptr Diagonal::Sigmas(const Vector& sigmas, bool smart) { /* ************************************************************************* */ void Diagonal::print(const string& name) const { - gtsam::print(sigmas_, name + "diagonal sigmas"); + gtsam::print(sigmas_, name + "diagonal sigmas "); } /* ************************************************************************* */ @@ -355,8 +355,8 @@ bool Constrained::constrained(size_t i) const { /* ************************************************************************* */ void Constrained::print(const std::string& name) const { - gtsam::print(sigmas_, name + "constrained sigmas"); - gtsam::print(mu_, name + "constrained mu"); + gtsam::print(sigmas_, name + "constrained sigmas "); + gtsam::print(mu_, name + "constrained mu "); } /* ************************************************************************* */ diff --git a/gtsam/linear/NoiseModel.h b/gtsam/linear/NoiseModel.h index 2fb54d329..5c379beb8 100644 --- a/gtsam/linear/NoiseModel.h +++ b/gtsam/linear/NoiseModel.h @@ -177,17 +177,16 @@ namespace gtsam { return *sqrt_information_; } - protected: - - /** protected constructor takes square root information matrix */ - Gaussian(size_t dim = 1, const boost::optional& sqrt_information = boost::none) : - Base(dim), sqrt_information_(sqrt_information) { - } public: typedef boost::shared_ptr shared_ptr; + /** constructor takes square root information matrix */ + Gaussian(size_t dim = 1, + const boost::optional& sqrt_information = boost::none) + : Base(dim), sqrt_information_(sqrt_information) {} + ~Gaussian() override {} /** @@ -290,13 +289,13 @@ namespace gtsam { Vector sigmas_, invsigmas_, precisions_; protected: - /** protected constructor - no initializations */ - Diagonal(); /** constructor to allow for disabling initialization of invsigmas */ Diagonal(const Vector& sigmas); public: + /** constructor - no initializations, for serialization */ + Diagonal(); typedef boost::shared_ptr shared_ptr; @@ -387,14 +386,6 @@ namespace gtsam { // Sigmas are contained in the base class Vector mu_; ///< Penalty function weight - needs to be large enough to dominate soft constraints - /** - * protected constructor takes sigmas. - * prevents any inf values - * from appearing in invsigmas or precisions. - * mu set to large default value (1000.0) - */ - Constrained(const Vector& sigmas = Z_1x1); - /** * Constructor that prevents any inf values * from appearing in invsigmas or precisions. @@ -406,6 +397,14 @@ namespace gtsam { typedef boost::shared_ptr shared_ptr; + /** + * protected constructor takes sigmas. + * prevents any inf values + * from appearing in invsigmas or precisions. + * mu set to large default value (1000.0) + */ + Constrained(const Vector& sigmas = Z_1x1); + ~Constrained() override {} /// true if a constrained noise mode, saves slow/clumsy dynamic casting @@ -461,6 +460,11 @@ namespace gtsam { return MixedVariances(precisions.array().inverse()); } + /** + * The squaredMahalanobisDistance function for a constrained noisemodel, + * for non-constrained versions, uses sigmas, otherwise + * uses the penalty function with mu + */ double squaredMahalanobisDistance(const Vector& v) const override; /** Fully constrained variations */ @@ -531,11 +535,11 @@ namespace gtsam { Isotropic(size_t dim, double sigma) : Diagonal(Vector::Constant(dim, sigma)),sigma_(sigma),invsigma_(1.0/sigma) {} + public: + /* dummy constructor to allow for serialization */ Isotropic() : Diagonal(Vector1::Constant(1.0)),sigma_(1.0),invsigma_(1.0) {} - public: - ~Isotropic() override {} typedef boost::shared_ptr shared_ptr; @@ -592,14 +596,13 @@ namespace gtsam { * Unit: i.i.d. unit-variance noise on all m dimensions. */ class GTSAM_EXPORT Unit : public Isotropic { - protected: - - Unit(size_t dim=1): Isotropic(dim,1.0) {} - public: typedef boost::shared_ptr shared_ptr; + /** constructor for serialization */ + Unit(size_t dim=1): Isotropic(dim,1.0) {} + ~Unit() override {} /** @@ -682,19 +685,19 @@ namespace gtsam { /// Return the contained noise model const NoiseModel::shared_ptr& noise() const { return noise_; } - // TODO: functions below are dummy but necessary for the noiseModel::Base + // Functions below are dummy but necessary for the noiseModel::Base inline Vector whiten(const Vector& v) const override { Vector r = v; this->WhitenSystem(r); return r; } inline Matrix Whiten(const Matrix& A) const override { Vector b; Matrix B=A; this->WhitenSystem(B,b); return B; } inline Vector unwhiten(const Vector& /*v*/) const override { throw std::invalid_argument("unwhiten is not currently supported for robust noise models."); } - + /// Compute loss from the m-estimator using the Mahalanobis distance. double loss(const double squared_distance) const override { return robust_->loss(std::sqrt(squared_distance)); } - // TODO: these are really robust iterated re-weighting support functions + // These are really robust iterated re-weighting support functions virtual void WhitenSystem(Vector& b) const; void WhitenSystem(std::vector& A, Vector& b) const override; void WhitenSystem(Matrix& A, Vector& b) const override; @@ -705,7 +708,6 @@ namespace gtsam { return noise_->unweightedWhiten(v); } double weight(const Vector& v) const override { - // Todo(mikebosse): make the robust weight function input a vector. return robust_->weight(v.norm()); } @@ -728,8 +730,8 @@ namespace gtsam { } // namespace noiseModel - /** Note, deliberately not in noiseModel namespace. - * Deprecated. Only for compatibility with previous version. + /** + * Aliases. Deliberately not in noiseModel namespace. */ typedef noiseModel::Base::shared_ptr SharedNoiseModel; typedef noiseModel::Gaussian::shared_ptr SharedGaussian; diff --git a/gtsam/linear/Sampler.cpp b/gtsam/linear/Sampler.cpp index 4957dfa14..20d4c955b 100644 --- a/gtsam/linear/Sampler.cpp +++ b/gtsam/linear/Sampler.cpp @@ -22,14 +22,18 @@ namespace gtsam { /* ************************************************************************* */ Sampler::Sampler(const noiseModel::Diagonal::shared_ptr& model, uint_fast64_t seed) - : model_(model), generator_(seed) {} + : model_(model), generator_(seed) { + if (!model) { + throw std::invalid_argument("Sampler::Sampler needs a non-null model."); + } +} /* ************************************************************************* */ Sampler::Sampler(const Vector& sigmas, uint_fast64_t seed) : model_(noiseModel::Diagonal::Sigmas(sigmas, true)), generator_(seed) {} /* ************************************************************************* */ -Vector Sampler::sampleDiagonal(const Vector& sigmas) const { +Vector Sampler::sampleDiagonal(const Vector& sigmas, std::mt19937_64* rng) { size_t d = sigmas.size(); Vector result(d); for (size_t i = 0; i < d; i++) { @@ -39,14 +43,18 @@ Vector Sampler::sampleDiagonal(const Vector& sigmas) const { if (sigma == 0.0) { result(i) = 0.0; } else { - typedef std::normal_distribution Normal; - Normal dist(0.0, sigma); - result(i) = dist(generator_); + std::normal_distribution dist(0.0, sigma); + result(i) = dist(*rng); } } return result; } +/* ************************************************************************* */ +Vector Sampler::sampleDiagonal(const Vector& sigmas) const { + return sampleDiagonal(sigmas, &generator_); +} + /* ************************************************************************* */ Vector Sampler::sample() const { assert(model_.get()); diff --git a/gtsam/linear/Sampler.h b/gtsam/linear/Sampler.h index bb5098f34..5be8b445d 100644 --- a/gtsam/linear/Sampler.h +++ b/gtsam/linear/Sampler.h @@ -63,15 +63,9 @@ class GTSAM_EXPORT Sampler { /// @name access functions /// @{ - size_t dim() const { - assert(model_.get()); - return model_->dim(); - } + size_t dim() const { return model_->dim(); } - Vector sigmas() const { - assert(model_.get()); - return model_->sigmas(); - } + Vector sigmas() const { return model_->sigmas(); } const noiseModel::Diagonal::shared_ptr& model() const { return model_; } @@ -82,6 +76,8 @@ class GTSAM_EXPORT Sampler { /// sample from distribution Vector sample() const; + /// sample with given random number generator + static Vector sampleDiagonal(const Vector& sigmas, std::mt19937_64* rng); /// @} protected: diff --git a/gtsam/linear/SubgraphBuilder.cpp b/gtsam/linear/SubgraphBuilder.cpp index 1919d38be..18e19cd20 100644 --- a/gtsam/linear/SubgraphBuilder.cpp +++ b/gtsam/linear/SubgraphBuilder.cpp @@ -446,30 +446,29 @@ SubgraphBuilder::Weights SubgraphBuilder::weights( } /*****************************************************************************/ -GaussianFactorGraph::shared_ptr buildFactorSubgraph( - const GaussianFactorGraph &gfg, const Subgraph &subgraph, - const bool clone) { - auto subgraphFactors = boost::make_shared(); - subgraphFactors->reserve(subgraph.size()); +GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg, + const Subgraph &subgraph, + const bool clone) { + GaussianFactorGraph subgraphFactors; + subgraphFactors.reserve(subgraph.size()); for (const auto &e : subgraph) { const auto factor = gfg[e.index]; - subgraphFactors->push_back(clone ? factor->clone() : factor); + subgraphFactors.push_back(clone ? factor->clone() : factor); } return subgraphFactors; } /**************************************************************************************************/ -std::pair // -splitFactorGraph(const GaussianFactorGraph &factorGraph, - const Subgraph &subgraph) { +std::pair splitFactorGraph( + const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) { // Get the subgraph by calling cheaper method auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false); // Now, copy all factors then set subGraph factors to zero - auto remaining = boost::make_shared(factorGraph); + GaussianFactorGraph remaining = factorGraph; for (const auto &e : subgraph) { - remaining->remove(e.index); + remaining.remove(e.index); } return std::make_pair(subgraphFactors, remaining); diff --git a/gtsam/linear/SubgraphBuilder.h b/gtsam/linear/SubgraphBuilder.h index 84a477a5e..a900c7531 100644 --- a/gtsam/linear/SubgraphBuilder.h +++ b/gtsam/linear/SubgraphBuilder.h @@ -172,12 +172,13 @@ class GTSAM_EXPORT SubgraphBuilder { }; /** Select the factors in a factor graph according to the subgraph. */ -boost::shared_ptr buildFactorSubgraph( - const GaussianFactorGraph &gfg, const Subgraph &subgraph, const bool clone); +GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg, + const Subgraph &subgraph, + const bool clone); /** Split the graph into a subgraph and the remaining edges. * Note that the remaining factorgraph has null factors. */ -std::pair, boost::shared_ptr > -splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph); +std::pair splitFactorGraph( + const GaussianFactorGraph &factorGraph, const Subgraph &subgraph); } // namespace gtsam diff --git a/gtsam/linear/SubgraphPreconditioner.cpp b/gtsam/linear/SubgraphPreconditioner.cpp index fdcb4f7ac..6689cdbed 100644 --- a/gtsam/linear/SubgraphPreconditioner.cpp +++ b/gtsam/linear/SubgraphPreconditioner.cpp @@ -77,16 +77,16 @@ static void setSubvector(const Vector &src, const KeyInfo &keyInfo, /* ************************************************************************* */ // Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with // Cholesky) -static GaussianFactorGraph::shared_ptr convertToJacobianFactors( +static GaussianFactorGraph convertToJacobianFactors( const GaussianFactorGraph &gfg) { - auto result = boost::make_shared(); + GaussianFactorGraph result; for (const auto &factor : gfg) if (factor) { auto jf = boost::dynamic_pointer_cast(factor); if (!jf) { jf = boost::make_shared(*factor); } - result->push_back(jf); + result.push_back(jf); } return result; } @@ -96,42 +96,42 @@ SubgraphPreconditioner::SubgraphPreconditioner(const SubgraphPreconditionerParam parameters_(p) {} /* ************************************************************************* */ -SubgraphPreconditioner::SubgraphPreconditioner(const sharedFG& Ab2, - const sharedBayesNet& Rc1, const sharedValues& xbar, const SubgraphPreconditionerParameters &p) : - Ab2_(convertToJacobianFactors(*Ab2)), Rc1_(Rc1), xbar_(xbar), - b2bar_(new Errors(-Ab2_->gaussianErrors(*xbar))), parameters_(p) { +SubgraphPreconditioner::SubgraphPreconditioner(const GaussianFactorGraph& Ab2, + const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p) : + Ab2_(convertToJacobianFactors(Ab2)), Rc1_(Rc1), xbar_(xbar), + b2bar_(-Ab2_.gaussianErrors(xbar)), parameters_(p) { } /* ************************************************************************* */ // x = xbar + inv(R1)*y VectorValues SubgraphPreconditioner::x(const VectorValues& y) const { - return *xbar_ + Rc1_->backSubstitute(y); + return xbar_ + Rc1_.backSubstitute(y); } /* ************************************************************************* */ double SubgraphPreconditioner::error(const VectorValues& y) const { Errors e(y); VectorValues x = this->x(y); - Errors e2 = Ab2()->gaussianErrors(x); + Errors e2 = Ab2_.gaussianErrors(x); return 0.5 * (dot(e, e) + dot(e2,e2)); } /* ************************************************************************* */ // gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar), VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const { - VectorValues x = Rc1()->backSubstitute(y); /* inv(R1)*y */ - Errors e = (*Ab2() * x - *b2bar()); /* (A2*inv(R1)*y-b2bar) */ + VectorValues x = Rc1_.backSubstitute(y); /* inv(R1)*y */ + Errors e = Ab2_ * x - b2bar_; /* (A2*inv(R1)*y-b2bar) */ VectorValues v = VectorValues::Zero(x); - Ab2()->transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */ - return y + Rc1()->backSubstituteTranspose(v); + Ab2_.transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */ + return y + Rc1_.backSubstituteTranspose(v); } /* ************************************************************************* */ // Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y] -Errors SubgraphPreconditioner::operator*(const VectorValues& y) const { +Errors SubgraphPreconditioner::operator*(const VectorValues &y) const { Errors e(y); - VectorValues x = Rc1()->backSubstitute(y); /* x=inv(R1)*y */ - Errors e2 = *Ab2() * x; /* A2*x */ + VectorValues x = Rc1_.backSubstitute(y); /* x=inv(R1)*y */ + Errors e2 = Ab2_ * x; /* A2*x */ e.splice(e.end(), e2); return e; } @@ -147,8 +147,8 @@ void SubgraphPreconditioner::multiplyInPlace(const VectorValues& y, Errors& e) c } // Add A2 contribution - VectorValues x = Rc1()->backSubstitute(y); // x=inv(R1)*y - Ab2()->multiplyInPlace(x, ei); // use iterator version + VectorValues x = Rc1_.backSubstitute(y); // x=inv(R1)*y + Ab2_.multiplyInPlace(x, ei); // use iterator version } /* ************************************************************************* */ @@ -173,7 +173,7 @@ void SubgraphPreconditioner::transposeMultiplyAdd Errors::const_iterator it = e.begin(); for(auto& key_value: y) { const Vector& ei = *it; - axpy(alpha, ei, key_value.second); + key_value.second += alpha * ei; ++it; } transposeMultiplyAdd2(alpha, it, e.end(), y); @@ -190,14 +190,14 @@ void SubgraphPreconditioner::transposeMultiplyAdd2 (double alpha, while (it != end) e2.push_back(*(it++)); VectorValues x = VectorValues::Zero(y); // x = 0 - Ab2_->transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2 - axpy(alpha, Rc1_->backSubstituteTranspose(x), y); // y += alpha*inv(R1')*x + Ab2_.transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2 + y += alpha * Rc1_.backSubstituteTranspose(x); // y += alpha*inv(R1')*x } /* ************************************************************************* */ void SubgraphPreconditioner::print(const std::string& s) const { cout << s << endl; - Ab2_->print(); + Ab2_.print(); } /*****************************************************************************/ @@ -205,7 +205,7 @@ void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const { assert(x.size() == y.size()); /* back substitute */ - for (const auto &cg : boost::adaptors::reverse(*Rc1_)) { + for (const auto &cg : boost::adaptors::reverse(Rc1_)) { /* collect a subvector of x that consists of the parents of cg (S) */ const KeyVector parentKeys(cg->beginParents(), cg->endParents()); const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); @@ -228,7 +228,7 @@ void SubgraphPreconditioner::transposeSolve(const Vector &y, Vector &x) const { std::copy(y.data(), y.data() + y.rows(), x.data()); /* in place back substitute */ - for (const auto &cg : *Rc1_) { + for (const auto &cg : Rc1_) { const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys); const Vector solFrontal = @@ -261,10 +261,10 @@ void SubgraphPreconditioner::build(const GaussianFactorGraph &gfg, const KeyInfo keyInfo_ = keyInfo; /* build factor subgraph */ - GaussianFactorGraph::shared_ptr gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true); + auto gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true); /* factorize and cache BayesNet */ - Rc1_ = gfg_subgraph->eliminateSequential(); + Rc1_ = *gfg_subgraph.eliminateSequential(); } /*****************************************************************************/ diff --git a/gtsam/linear/SubgraphPreconditioner.h b/gtsam/linear/SubgraphPreconditioner.h index 681c12e40..81c8968b1 100644 --- a/gtsam/linear/SubgraphPreconditioner.h +++ b/gtsam/linear/SubgraphPreconditioner.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -53,16 +55,12 @@ namespace gtsam { public: typedef boost::shared_ptr shared_ptr; - typedef boost::shared_ptr sharedBayesNet; - typedef boost::shared_ptr sharedFG; - typedef boost::shared_ptr sharedValues; - typedef boost::shared_ptr sharedErrors; private: - sharedFG Ab2_; - sharedBayesNet Rc1_; - sharedValues xbar_; ///< A1 \ b1 - sharedErrors b2bar_; ///< A2*xbar - b2 + GaussianFactorGraph Ab2_; + GaussianBayesNet Rc1_; + VectorValues xbar_; ///< A1 \ b1 + Errors b2bar_; ///< A2*xbar - b2 KeyInfo keyInfo_; SubgraphPreconditionerParameters parameters_; @@ -77,7 +75,7 @@ namespace gtsam { * @param Rc1: the Bayes Net R1*x=c1 * @param xbar: the solution to R1*x=c1 */ - SubgraphPreconditioner(const sharedFG& Ab2, const sharedBayesNet& Rc1, const sharedValues& xbar, + SubgraphPreconditioner(const GaussianFactorGraph& Ab2, const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters()); ~SubgraphPreconditioner() override {} @@ -86,13 +84,13 @@ namespace gtsam { void print(const std::string& s = "SubgraphPreconditioner") const; /** Access Ab2 */ - const sharedFG& Ab2() const { return Ab2_; } + const GaussianFactorGraph& Ab2() const { return Ab2_; } /** Access Rc1 */ - const sharedBayesNet& Rc1() const { return Rc1_; } + const GaussianBayesNet& Rc1() const { return Rc1_; } /** Access b2bar */ - const sharedErrors b2bar() const { return b2bar_; } + const Errors b2bar() const { return b2bar_; } /** * Add zero-mean i.i.d. Gaussian prior terms to each variable @@ -104,8 +102,7 @@ namespace gtsam { /* A zero VectorValues with the structure of xbar */ VectorValues zero() const { - assert(xbar_); - return VectorValues::Zero(*xbar_); + return VectorValues::Zero(xbar_); } /** diff --git a/gtsam/linear/SubgraphSolver.cpp b/gtsam/linear/SubgraphSolver.cpp index f49f9a135..0156c717e 100644 --- a/gtsam/linear/SubgraphSolver.cpp +++ b/gtsam/linear/SubgraphSolver.cpp @@ -34,24 +34,24 @@ namespace gtsam { SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab, const Parameters ¶meters, const Ordering& ordering) : parameters_(parameters) { - GaussianFactorGraph::shared_ptr Ab1,Ab2; + GaussianFactorGraph Ab1, Ab2; std::tie(Ab1, Ab2) = splitGraph(Ab); if (parameters_.verbosity()) - cout << "Split A into (A1) " << Ab1->size() << " and (A2) " << Ab2->size() + cout << "Split A into (A1) " << Ab1.size() << " and (A2) " << Ab2.size() << " factors" << endl; - auto Rc1 = Ab1->eliminateSequential(ordering, EliminateQR); - auto xbar = boost::make_shared(Rc1->optimize()); + auto Rc1 = *Ab1.eliminateSequential(ordering, EliminateQR); + auto xbar = Rc1.optimize(); pc_ = boost::make_shared(Ab2, Rc1, xbar); } /**************************************************************************************************/ // Taking eliminated tree [R1|c] and constraint graph [A2|b2] -SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1, - const GaussianFactorGraph::shared_ptr &Ab2, +SubgraphSolver::SubgraphSolver(const GaussianBayesNet &Rc1, + const GaussianFactorGraph &Ab2, const Parameters ¶meters) : parameters_(parameters) { - auto xbar = boost::make_shared(Rc1->optimize()); + auto xbar = Rc1.optimize(); pc_ = boost::make_shared(Ab2, Rc1, xbar); } @@ -59,10 +59,10 @@ SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1, // Taking subgraphs [A1|b1] and [A2|b2] // delegate up SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1, - const GaussianFactorGraph::shared_ptr &Ab2, + const GaussianFactorGraph &Ab2, const Parameters ¶meters, const Ordering &ordering) - : SubgraphSolver(Ab1.eliminateSequential(ordering, EliminateQR), Ab2, + : SubgraphSolver(*Ab1.eliminateSequential(ordering, EliminateQR), Ab2, parameters) {} /**************************************************************************************************/ @@ -78,7 +78,7 @@ VectorValues SubgraphSolver::optimize(const GaussianFactorGraph &gfg, return VectorValues(); } /**************************************************************************************************/ -pair // +pair // SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) { /* identify the subgraph structure */ diff --git a/gtsam/linear/SubgraphSolver.h b/gtsam/linear/SubgraphSolver.h index a41738321..0598b3321 100644 --- a/gtsam/linear/SubgraphSolver.h +++ b/gtsam/linear/SubgraphSolver.h @@ -99,15 +99,13 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver { * eliminate Ab1. We take Ab1 as a const reference, as it will be transformed * into Rc1, but take Ab2 as a shared pointer as we need to keep it around. */ - SubgraphSolver(const GaussianFactorGraph &Ab1, - const boost::shared_ptr &Ab2, + SubgraphSolver(const GaussianFactorGraph &Ab1, const GaussianFactorGraph &Ab2, const Parameters ¶meters, const Ordering &ordering); /** * The same as above, but we assume A1 was solved by caller. * We take two shared pointers as we keep both around. */ - SubgraphSolver(const boost::shared_ptr &Rc1, - const boost::shared_ptr &Ab2, + SubgraphSolver(const GaussianBayesNet &Rc1, const GaussianFactorGraph &Ab2, const Parameters ¶meters); /// Destructor @@ -131,9 +129,8 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver { /// @{ /// Split graph using Kruskal algorithm, treating binary factors as edges. - std::pair < boost::shared_ptr, - boost::shared_ptr > splitGraph( - const GaussianFactorGraph &gfg); + std::pair splitGraph( + const GaussianFactorGraph &gfg); /// @} }; diff --git a/gtsam/linear/VectorValues.cpp b/gtsam/linear/VectorValues.cpp index 6a2514b35..62996af27 100644 --- a/gtsam/linear/VectorValues.cpp +++ b/gtsam/linear/VectorValues.cpp @@ -33,7 +33,7 @@ namespace gtsam { using boost::adaptors::map_values; using boost::accumulate; - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues::VectorValues(const VectorValues& first, const VectorValues& second) { // Merge using predicate for comparing first of pair @@ -44,7 +44,7 @@ namespace gtsam { throw invalid_argument("Requested to merge two VectorValues that have one or more variables in common."); } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues::VectorValues(const Vector& x, const Dims& dims) { using Pair = pair; size_t j = 0; @@ -61,7 +61,7 @@ namespace gtsam { } } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues::VectorValues(const Vector& x, const Scatter& scatter) { size_t j = 0; for (const SlotEntry& v : scatter) { @@ -74,7 +74,7 @@ namespace gtsam { } } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::Zero(const VectorValues& other) { VectorValues result; @@ -87,7 +87,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues::iterator VectorValues::insert(const std::pair& key_value) { std::pair result = values_.insert(key_value); if(!result.second) @@ -97,7 +97,7 @@ namespace gtsam { return result.first; } - /* ************************************************************************* */ + /* ************************************************************************ */ void VectorValues::update(const VectorValues& values) { iterator hint = begin(); @@ -115,7 +115,7 @@ namespace gtsam { } } - /* ************************************************************************* */ + /* ************************************************************************ */ void VectorValues::insert(const VectorValues& values) { size_t originalSize = size(); @@ -124,14 +124,14 @@ namespace gtsam { throw invalid_argument("Requested to insert a VectorValues into another VectorValues that already contains one or more of its keys."); } - /* ************************************************************************* */ + /* ************************************************************************ */ void VectorValues::setZero() { for(Vector& v: values_ | map_values) v.setZero(); } - /* ************************************************************************* */ + /* ************************************************************************ */ GTSAM_EXPORT ostream& operator<<(ostream& os, const VectorValues& v) { // Change print depending on whether we are using TBB #ifdef GTSAM_USE_TBB @@ -150,7 +150,7 @@ namespace gtsam { return os; } - /* ************************************************************************* */ + /* ************************************************************************ */ void VectorValues::print(const string& str, const KeyFormatter& formatter) const { cout << str << ": " << size() << " elements\n"; @@ -158,7 +158,7 @@ namespace gtsam { cout.flush(); } - /* ************************************************************************* */ + /* ************************************************************************ */ bool VectorValues::equals(const VectorValues& x, double tol) const { if(this->size() != x.size()) return false; @@ -170,7 +170,7 @@ namespace gtsam { return true; } - /* ************************************************************************* */ + /* ************************************************************************ */ Vector VectorValues::vector() const { // Count dimensions DenseIndex totalDim = 0; @@ -187,7 +187,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ Vector VectorValues::vector(const Dims& keys) const { // Count dimensions @@ -203,12 +203,12 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ void VectorValues::swap(VectorValues& other) { this->values_.swap(other.values_); } - /* ************************************************************************* */ + /* ************************************************************************ */ namespace internal { bool structureCompareOp(const boost::tuple()); } - /* ************************************************************************* */ + /* ************************************************************************ */ double VectorValues::dot(const VectorValues& v) const { if(this->size() != v.size()) @@ -244,12 +244,12 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ double VectorValues::norm() const { return std::sqrt(this->squaredNorm()); } - /* ************************************************************************* */ + /* ************************************************************************ */ double VectorValues::squaredNorm() const { double sumSquares = 0.0; using boost::adaptors::map_values; @@ -258,7 +258,7 @@ namespace gtsam { return sumSquares; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::operator+(const VectorValues& c) const { if(this->size() != c.size()) @@ -278,13 +278,13 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::add(const VectorValues& c) const { return *this + c; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues& VectorValues::operator+=(const VectorValues& c) { if(this->size() != c.size()) @@ -301,13 +301,13 @@ namespace gtsam { return *this; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues& VectorValues::addInPlace(const VectorValues& c) { return *this += c; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues& VectorValues::addInPlace_(const VectorValues& c) { for(const_iterator j2 = c.begin(); j2 != c.end(); ++j2) { @@ -320,7 +320,7 @@ namespace gtsam { return *this; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::operator-(const VectorValues& c) const { if(this->size() != c.size()) @@ -340,13 +340,13 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::subtract(const VectorValues& c) const { return *this - c; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues operator*(const double a, const VectorValues &v) { VectorValues result; @@ -359,13 +359,13 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues VectorValues::scale(const double a) const { return a * *this; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues& VectorValues::operator*=(double alpha) { for(Vector& v: *this | map_values) @@ -373,12 +373,43 @@ namespace gtsam { return *this; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues& VectorValues::scaleInPlace(double alpha) { return *this *= alpha; } - /* ************************************************************************* */ + /* ************************************************************************ */ + string VectorValues::html(const KeyFormatter& keyFormatter) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " \n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. +#ifdef GTSAM_USE_TBB + // TBB uses un-ordered map, so inefficiently order them: + std::map ordered; + for (const auto& kv : *this) ordered.emplace(kv); + for (const auto& kv : ordered) { +#else + for (const auto& kv : *this) { +#endif + ss << " "; + ss << ""; + ss << "\n"; + } + ss << " \n
Variablevalue
" << keyFormatter(kv.first) << "" + << kv.second.transpose() << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ } // \namespace gtsam diff --git a/gtsam/linear/VectorValues.h b/gtsam/linear/VectorValues.h index 9e60ff2aa..1ff8c5c16 100644 --- a/gtsam/linear/VectorValues.h +++ b/gtsam/linear/VectorValues.h @@ -34,7 +34,7 @@ namespace gtsam { /** - * This class represents a collection of vector-valued variables associated + * VectorValues represents a collection of vector-valued variables associated * each with a unique integer index. It is typically used to store the variables * of a GaussianFactorGraph. Optimizing a GaussianFactorGraph or GaussianBayesNet * returns this class. @@ -69,7 +69,7 @@ namespace gtsam { * which is a view on the underlying data structure. * * This class is additionally used in gradient descent and dog leg to store the gradient. - * \nosubgrouping + * @addtogroup linear */ class GTSAM_EXPORT VectorValues { protected: @@ -344,11 +344,16 @@ namespace gtsam { /// @} - /// @} - /// @name Matlab syntactic sugar for linear algebra operations + /// @name Wrapper support /// @{ - //inline VectorValues scale(const double a, const VectorValues& c) const { return a * (*this); } + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + */ + std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; /// @} diff --git a/gtsam/linear/iterative-inl.h b/gtsam/linear/iterative-inl.h index 58ef7d733..906ee80fd 100644 --- a/gtsam/linear/iterative-inl.h +++ b/gtsam/linear/iterative-inl.h @@ -72,7 +72,7 @@ namespace gtsam { double takeOptimalStep(V& x) { // TODO: can we use gamma instead of dot(d,g) ????? Answer not trivial double alpha = -dot(d, g) / dot(Ad, Ad); // calculate optimal step-size - axpy(alpha, d, x); // // do step in new search direction, x += alpha*d + x += alpha * d; // do step in new search direction, x += alpha*d return alpha; } @@ -106,7 +106,7 @@ namespace gtsam { double beta = new_gamma / gamma; // d = g + d*beta; d *= beta; - axpy(1.0, g, d); + d += 1.0 * g; } gamma = new_gamma; diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index 8635c55f8..f1bc92f69 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -221,6 +221,7 @@ class VectorValues { //Constructors VectorValues(); VectorValues(const gtsam::VectorValues& other); + VectorValues(const gtsam::VectorValues& first, const gtsam::VectorValues& second); //Named Constructors static gtsam::VectorValues Zero(const gtsam::VectorValues& model); @@ -254,9 +255,7 @@ class VectorValues { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; + string html() const; }; #include @@ -301,6 +300,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor { void print(string s = "", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void printKeys(string s) const; + gtsam::KeyVector& keys() const; bool equals(const gtsam::GaussianFactor& lf, double tol) const; size_t size() const; Vector unweighted_error(const gtsam::VectorValues& c) const; @@ -327,9 +327,6 @@ virtual class JacobianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -362,9 +359,6 @@ virtual class HessianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -400,6 +394,7 @@ class GaussianFactorGraph { // error and probability double error(const gtsam::VectorValues& c) const; double probPrime(const gtsam::VectorValues& c) const; + void printErrors(const gtsam::VectorValues& c, string str = "GaussianFactorGraph: ", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; gtsam::GaussianFactorGraph clone() const; gtsam::GaussianFactorGraph negate() const; @@ -413,8 +408,10 @@ class GaussianFactorGraph { // Elimination and marginals gtsam::GaussianBayesNet* eliminateSequential(); + gtsam::GaussianBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); gtsam::GaussianBayesNet* eliminateSequential(const gtsam::Ordering& ordering); gtsam::GaussianBayesTree* eliminateMultifrontal(); + gtsam::GaussianBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); gtsam::GaussianBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); pair eliminatePartialSequential( const gtsam::Ordering& ordering); @@ -443,58 +440,93 @@ class GaussianFactorGraph { pair hessian() const; pair hessian(const gtsam::Ordering& ordering) const; + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include virtual class GaussianConditional : gtsam::JacobianFactor { - //Constructors - GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); + // Constructors + GaussianConditional(size_t key, Vector d, Matrix R, + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - const gtsam::noiseModel::Diagonal* sigmas); + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas); + size_t name2, Matrix T, + const gtsam::noiseModel::Diagonal* sigmas); - //Constructors with no noise model + // Constructors with no noise model GaussianConditional(size_t key, Vector d, Matrix R); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, + size_t name2, Matrix T); - //Standard Interface - void print(string s = "GaussianConditional", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::GaussianConditional& cg, double tol) const; + // Named constructors + static gtsam::GaussianConditional FromMeanAndStddev(gtsam::Key key, + const Matrix& A, + gtsam::Key parent, + const Vector& b, + double sigma); - // Advanced Interface - gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; - gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, - const gtsam::VectorValues& rhs) const; - void solveTransposeInPlace(gtsam::VectorValues& gy) const; - void scaleFrontalsBySigma(gtsam::VectorValues& gy) const; - Matrix R() const; - Matrix S() const; - Vector d() const; + static gtsam::GaussianConditional FromMeanAndStddev(gtsam::Key key, + const Matrix& A1, + gtsam::Key parent1, + const Matrix& A2, + gtsam::Key parent2, + const Vector& b, + double sigma); + // Testable + void print(string s = "GaussianConditional", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::GaussianConditional& cg, double tol) const; + + // Standard Interface + gtsam::Key firstFrontalKey() const; + gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; + gtsam::JacobianFactor* likelihood( + const gtsam::VectorValues& frontalValues) const; + gtsam::JacobianFactor* likelihood(Vector frontal) const; + gtsam::VectorValues sample(const gtsam::VectorValues& parents) const; + gtsam::VectorValues sample() const; + + // Advanced Interface + gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, + const gtsam::VectorValues& rhs) const; + void solveTransposeInPlace(gtsam::VectorValues& gy) const; + Matrix R() const; + Matrix S() const; + Vector d() const; - // enabling serialization functionality - void serialize() const; + // enabling serialization functionality + void serialize() const; }; #include virtual class GaussianDensity : gtsam::GaussianConditional { - //Constructors - GaussianDensity(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); + // Constructors + GaussianDensity(gtsam::Key key, Vector d, Matrix R, + const gtsam::noiseModel::Diagonal* sigmas); - //Standard Interface + static gtsam::GaussianDensity FromMeanAndStddev(gtsam::Key key, + const Vector& mean, + double sigma); + + // Testable void print(string s = "GaussianDensity", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::GaussianDensity &cg, double tol) const; + bool equals(const gtsam::GaussianDensity& cg, double tol) const; + + // Standard Interface Vector mean() const; Matrix covariance() const; }; @@ -511,29 +543,43 @@ virtual class GaussianBayesNet { bool equals(const gtsam::GaussianBayesNet& other, double tol) const; size_t size() const; + // Standard interface + void push_back(gtsam::GaussianConditional* conditional); + void push_back(const gtsam::GaussianBayesNet& bayesNet); + gtsam::GaussianConditional* front() const; + gtsam::GaussianConditional* back() const; + + gtsam::VectorValues optimize() const; + gtsam::VectorValues optimize(gtsam::VectorValues given) const; + gtsam::VectorValues optimizeGradientSearch() const; + + gtsam::VectorValues sample(gtsam::VectorValues given) const; + gtsam::VectorValues sample() const; + gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; + gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; + // FactorGraph derived interface - // size_t size() const; gtsam::GaussianConditional* at(size_t idx) const; gtsam::KeySet keys() const; + gtsam::KeyVector keyVector() const; bool exists(size_t idx) const; void saveGraph(const string& s) const; - gtsam::GaussianConditional* front() const; - gtsam::GaussianConditional* back() const; - void push_back(gtsam::GaussianConditional* conditional); - void push_back(const gtsam::GaussianBayesNet& bayesNet); - - gtsam::VectorValues optimize() const; - gtsam::VectorValues optimize(gtsam::VectorValues& solutionForMissing) const; - gtsam::VectorValues optimizeGradientSearch() const; + std::pair matrix() const; gtsam::VectorValues gradient(const gtsam::VectorValues& x0) const; gtsam::VectorValues gradientAtZero() const; double error(const gtsam::VectorValues& x) const; double determinant() const; double logDeterminant() const; - gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; - gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -547,7 +593,12 @@ virtual class GaussianBayesTree { size_t size() const; bool empty() const; size_t numCachedSeparatorMarginals() const; - void saveGraph(string s) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; gtsam::VectorValues optimize() const; gtsam::VectorValues optimizeGradientSearch() const; @@ -634,7 +685,7 @@ virtual class SubgraphSolverParameters : gtsam::ConjugateGradientParameters { virtual class SubgraphSolver { SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); - SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph* Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); + SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph& Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); gtsam::VectorValues optimize() const; }; diff --git a/gtsam/linear/linearAlgorithms-inst.h b/gtsam/linear/linearAlgorithms-inst.h index 811e07121..d19ac6de5 100644 --- a/gtsam/linear/linearAlgorithms-inst.h +++ b/gtsam/linear/linearAlgorithms-inst.h @@ -15,6 +15,8 @@ * @author Richard Roberts */ +#pragma once + #include #include #include diff --git a/gtsam/linear/tests/powerMethodExample.h b/gtsam/linear/tests/powerMethodExample.h index f80299386..994fcc640 100644 --- a/gtsam/linear/tests/powerMethodExample.h +++ b/gtsam/linear/tests/powerMethodExample.h @@ -19,6 +19,8 @@ * PowerMethod/AcceleratedPowerMethod */ +#pragma once + #include #include diff --git a/gtsam/linear/tests/testErrors.cpp b/gtsam/linear/tests/testErrors.cpp index 74eef9a2c..f11fb90b9 100644 --- a/gtsam/linear/tests/testErrors.cpp +++ b/gtsam/linear/tests/testErrors.cpp @@ -32,7 +32,7 @@ TEST( Errors, arithmetic ) e += Vector2(1.0,2.0), Vector3(3.0,4.0,5.0); DOUBLES_EQUAL(1+4+9+16+25,dot(e,e),1e-9); - axpy(2.0,e,e); + axpy(2.0, e, e); Errors expected; expected += Vector2(3.0,6.0), Vector3(9.0,12.0,15.0); CHECK(assert_equal(expected,e)); diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index 00a338e54..2b125265f 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -16,10 +16,12 @@ */ #include +#include #include #include #include #include +#include #include #include @@ -35,6 +37,7 @@ using namespace boost::assign; using namespace std::placeholders; using namespace std; using namespace gtsam; +using symbol_shorthand::X; static const Key _x_ = 11, _y_ = 22, _z_ = 33; @@ -138,6 +141,30 @@ TEST( GaussianBayesNet, optimize3 ) EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianBayesNet, sample) { + GaussianBayesNet gbn; + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + const Vector2 mean(20, 40), b(10, 10); + const double sigma = 0.01; + + gbn.add(GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma)); + gbn.add(GaussianDensity::FromMeanAndStddev(X(1), mean, sigma)); + + auto actual = gbn.sample(); + EXPECT_LONGS_EQUAL(2, actual.size()); + EXPECT(assert_equal(mean, actual[X(1)], 50 * sigma)); + EXPECT(assert_equal(A1 * mean + b, actual[X(0)], 50 * sigma)); + + // Use a specific random generator + std::mt19937_64 rng(4242); + auto actual3 = gbn.sample(&rng); + EXPECT_LONGS_EQUAL(2, actual.size()); + // regression is not repeatable across platforms/versions :-( + // EXPECT(assert_equal(Vector2(20.0129382, 40.0039798), actual[X(1)], 1e-5)); + // EXPECT(assert_equal(Vector2(110.032083, 230.039811), actual[X(0)], 1e-5)); +} + /* ************************************************************************* */ TEST(GaussianBayesNet, ordering) { @@ -301,5 +328,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) { } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +TEST(GaussianBayesNet, Dot) { + GaussianBayesNet fragment; + DotWriter writer; + writer.variablePositions.emplace(_x_, Vector2(10, 20)); + writer.variablePositions.emplace(_y_, Vector2(50, 20)); + + auto position = writer.variablePos(_x_); + CHECK(position); + EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5)); + + string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var11[label=\"11\", pos=\"10,20!\"];\n" + " var22[label=\"22\", pos=\"50,20!\"];\n" + "\n" + " var22->var11\n" + "}"); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index fae00e1e4..6ec06a0ce 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -20,9 +20,10 @@ #include #include -#include +#include #include #include +#include #include #include @@ -38,6 +39,8 @@ using namespace gtsam; using namespace std; using namespace boost::assign; +using symbol_shorthand::X; +using symbol_shorthand::Y; static const double tol = 1e-5; @@ -316,5 +319,136 @@ TEST( GaussianConditional, isGaussianFactor ) { } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +// Test FromMeanAndStddev named constructors +TEST(GaussianConditional, FromMeanAndStddev) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + Matrix A2 = (Matrix(2, 2) << 5., 6., 7., 8.).finished(); + const Vector2 b(20, 40), x0(1, 2), x1(3, 4), x2(5, 6); + const double sigma = 3; + + VectorValues values = map_list_of(X(0), x0)(X(1), x1)(X(2), x2); + + auto conditional1 = + GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); + Vector2 e1 = (x0 - (A1 * x1 + b)) / sigma; + double expected1 = 0.5 * e1.dot(e1); + EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9); + + auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2, + X(2), b, sigma); + Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma; + double expected2 = 0.5 * e2.dot(e2); + EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9); +} + +/* ************************************************************************* */ +// Test likelihood method (conversion to JacobianFactor) +TEST(GaussianConditional, likelihood) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + const Vector2 b(20, 40), x0(1, 2); + const double sigma = 0.01; + + // |x0 - A1 x1 - b|^2 + auto conditional = + GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); + + VectorValues frontalValues; + frontalValues.insert(X(0), x0); + auto actual1 = conditional.likelihood(frontalValues); + CHECK(actual1); + + // |(-A1) x1 - (b - x0)|^2 + JacobianFactor expected(X(1), -A1, b - x0, + noiseModel::Isotropic::Sigma(2, sigma)); + EXPECT(assert_equal(expected, *actual1, tol)); + + // Check single vector version + auto actual2 = conditional.likelihood(x0); + CHECK(actual2); + EXPECT(assert_equal(expected, *actual2, tol)); +} + +/* ************************************************************************* */ +// Test sampling +TEST(GaussianConditional, sample) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + const Vector2 b(20, 40), x1(3, 4); + const double sigma = 0.01; + + auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma); + auto actual1 = density.sample(); + EXPECT_LONGS_EQUAL(1, actual1.size()); + EXPECT(assert_equal(b, actual1[X(0)], 50 * sigma)); + + VectorValues given; + given.insert(X(1), x1); + + auto conditional = + GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); + auto actual2 = conditional.sample(given); + EXPECT_LONGS_EQUAL(1, actual2.size()); + EXPECT(assert_equal(A1 * x1 + b, actual2[X(0)], 50 * sigma)); + + // Use a specific random generator + std::mt19937_64 rng(4242); + auto actual3 = conditional.sample(given, &rng); + EXPECT_LONGS_EQUAL(1, actual2.size()); + // regression is not repeatable across platforms/versions :-( + // EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5)); +} + +/* ************************************************************************* */ +TEST(GaussianConditional, Print) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + Matrix A2 = (Matrix(2, 2) << 5., 6., 7., 8.).finished(); + const Vector2 b(20, 40); + const double sigma = 3; + + GaussianConditional conditional(X(0), b, Matrix2::Identity(), + noiseModel::Isotropic::Sigma(2, sigma)); + + // Test printing for no parents. + std::string expected = + "GaussianConditional p(x0)\n" + " R = [ 1 0 ]\n" + " [ 0 1 ]\n" + " d = [ 20 40 ]\n" + "isotropic dim=2 sigma=3\n"; + EXPECT(assert_print_equal(expected, conditional, "GaussianConditional")); + + auto conditional1 = + GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma); + + // Test printing for single parent. + std::string expected1 = + "GaussianConditional p(x0 | x1)\n" + " R = [ 1 0 ]\n" + " [ 0 1 ]\n" + " S[x1] = [ -1 -2 ]\n" + " [ -3 -4 ]\n" + " d = [ 20 40 ]\n" + "isotropic dim=2 sigma=3\n"; + EXPECT(assert_print_equal(expected1, conditional1, "GaussianConditional")); + + // Test printing for multiple parents. + auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, Y(0), A2, + Y(1), b, sigma); + std::string expected2 = + "GaussianConditional p(x0 | y0 y1)\n" + " R = [ 1 0 ]\n" + " [ 0 1 ]\n" + " S[y0] = [ -1 -2 ]\n" + " [ -3 -4 ]\n" + " S[y1] = [ -5 -6 ]\n" + " [ -7 -8 ]\n" + " d = [ 20 40 ]\n" + "isotropic dim=2 sigma=3\n"; + EXPECT(assert_print_equal(expected2, conditional2, "GaussianConditional")); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ diff --git a/gtsam/linear/tests/testGaussianDensity.cpp b/gtsam/linear/tests/testGaussianDensity.cpp index 29dc49591..14608e794 100644 --- a/gtsam/linear/tests/testGaussianDensity.cpp +++ b/gtsam/linear/tests/testGaussianDensity.cpp @@ -17,10 +17,13 @@ **/ #include +#include + #include using namespace gtsam; using namespace std; +using symbol_shorthand::X; /* ************************************************************************* */ TEST(GaussianDensity, constructor) @@ -37,6 +40,22 @@ TEST(GaussianDensity, constructor) EXPECT(assert_equal(s, copied.get_model()->sigmas())); } +/* ************************************************************************* */ +// Test FromMeanAndStddev named constructor +TEST(GaussianDensity, FromMeanAndStddev) { + Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); + const Vector2 b(20, 40), x0(1, 2); + const double sigma = 3; + + VectorValues values; + values.insert(X(0), x0); + + auto density = GaussianDensity::FromMeanAndStddev(X(0), b, sigma); + Vector2 e = (x0 - b) / sigma; + double expected = 0.5 * e.dot(e); + EXPECT_DOUBLES_EQUAL(expected, density.error(values), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */ diff --git a/gtsam/linear/tests/testGaussianFactorGraph.cpp b/gtsam/linear/tests/testGaussianFactorGraph.cpp index bb07a36aa..41464a110 100644 --- a/gtsam/linear/tests/testGaussianFactorGraph.cpp +++ b/gtsam/linear/tests/testGaussianFactorGraph.cpp @@ -426,6 +426,7 @@ TEST(GaussianFactorGraph, hessianDiagonal) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ TEST(GaussianFactorGraph, DenseSolve) { GaussianFactorGraph fg = createSimpleGaussianFactorGraph(); VectorValues expected = fg.optimize(); @@ -433,6 +434,28 @@ TEST(GaussianFactorGraph, DenseSolve) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianFactorGraph, ProbPrime) { + GaussianFactorGraph gfg; + gfg.emplace_shared(1, I_1x1, Z_1x1, + noiseModel::Isotropic::Sigma(1, 1.0)); + + VectorValues values; + values.insert(1, I_1x1); + + // We are testing the normal distribution PDF where info matrix Σ = 1, + // mean mu = 0 and x = 1. + // Therefore factor squared error: y = 0.5 * (Σ*x - mu)^2 = + // 0.5 * (1.0 - 0)^2 = 0.5 + // NOTE the 0.5 constant is a part of the factor error. + EXPECT_DOUBLES_EQUAL(0.5, gfg.error(values), 1e-12); + + // The gaussian PDF value is: exp^(-0.5 * (Σ*x - mu)^2) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-0.5 * (1.0)^2) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/linear/tests/testNoiseModel.cpp b/gtsam/linear/tests/testNoiseModel.cpp index 42d68a603..b974b6cd5 100644 --- a/gtsam/linear/tests/testNoiseModel.cpp +++ b/gtsam/linear/tests/testNoiseModel.cpp @@ -662,25 +662,14 @@ TEST(NoiseModel, robustNoiseL2WithDeadZone) { double dead_zone_size = 1.0; SharedNoiseModel robust = noiseModel::Robust::Create( - noiseModel::mEstimator::L2WithDeadZone::Create(dead_zone_size), - Unit::Create(3)); - -/* - * TODO(mike): There is currently a bug in GTSAM, where none of the mEstimator classes - * implement a loss function, and GTSAM calls the weight function to evaluate the - * total penalty, rather than calling the loss function. The weight function should be - * used during iteratively reweighted least squares optimization, but should not be used to - * evaluate the total penalty. The long-term solution is for all mEstimators to implement - * both a weight and a loss function, and for GTSAM to call the loss function when - * evaluating the total penalty. This bug causes the test below to fail, so I'm leaving it - * commented out until the underlying bug in GTSAM is fixed. - * - * for (int i = 0; i < 5; i++) { - * Vector3 error = Vector3(i, 0, 0); - * DOUBLES_EQUAL(0.5*max(0,i-1)*max(0,i-1), robust->distance(error), 1e-8); - * } - */ + noiseModel::mEstimator::L2WithDeadZone::Create(dead_zone_size), + Unit::Create(3)); + for (int i = 0; i < 5; i++) { + Vector3 error = Vector3(i, 0, 0); + DOUBLES_EQUAL(std::fmax(0, i - dead_zone_size) * i, + robust->squaredMahalanobisDistance(error), 1e-8); + } } TEST(NoiseModel, lossFunctionAtZero) @@ -707,9 +696,9 @@ TEST(NoiseModel, lossFunctionAtZero) auto dcs = mEstimator::DCS::Create(k); DOUBLES_EQUAL(dcs->loss(0), 0, 1e-8); DOUBLES_EQUAL(dcs->weight(0), 1, 1e-8); - // auto lsdz = mEstimator::L2WithDeadZone::Create(k); - // DOUBLES_EQUAL(lsdz->loss(0), 0, 1e-8); - // DOUBLES_EQUAL(lsdz->weight(0), 1, 1e-8); + auto lsdz = mEstimator::L2WithDeadZone::Create(k); + DOUBLES_EQUAL(lsdz->loss(0), 0, 1e-8); + DOUBLES_EQUAL(lsdz->weight(0), 0, 1e-8); } diff --git a/gtsam/linear/tests/testSerializationLinear.cpp b/gtsam/linear/tests/testSerializationLinear.cpp index c5b3dab37..881b2830e 100644 --- a/gtsam/linear/tests/testSerializationLinear.cpp +++ b/gtsam/linear/tests/testSerializationLinear.cpp @@ -39,14 +39,14 @@ using namespace gtsam::serializationTestHelpers; /* ************************************************************************* */ // Export Noisemodels // See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic") -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); +BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel") +BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal") /* ************************************************************************* */ // example noise models @@ -129,9 +129,9 @@ TEST (Serialization, SharedDiagonal_noiseModels) { /* Create GUIDs for factors */ /* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianConditional , "gtsam::GaussianConditional"); +BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::GaussianConditional , "gtsam::GaussianConditional") /* ************************************************************************* */ TEST (Serialization, linear_factors) { diff --git a/gtsam/linear/tests/testVectorValues.cpp b/gtsam/linear/tests/testVectorValues.cpp index f97f96aaf..521cc2289 100644 --- a/gtsam/linear/tests/testVectorValues.cpp +++ b/gtsam/linear/tests/testVectorValues.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include #include @@ -248,6 +248,33 @@ TEST(VectorValues, print) EXPECT(expected == actual.str()); } +/* ************************************************************************* */ +// Check html representation. +TEST(VectorValues, html) { + VectorValues vv; + using symbol_shorthand::X; + vv.insert(X(1), Vector2(2, 3.1)); + vv.insert(X(2), Vector2(4, 5.2)); + vv.insert(X(5), Vector2(6, 7.3)); + vv.insert(X(7), Vector2(8, 9.4)); + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
Variablevalue
x1 2 3.1
x2 4 5.2
x5 6 7.3
x7 8 9.4
\n" + "
"; + string actual = vv.html(); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/gtsam/navigation/AHRSFactor.cpp b/gtsam/navigation/AHRSFactor.cpp index 4604a55dd..f4db42d0f 100644 --- a/gtsam/navigation/AHRSFactor.cpp +++ b/gtsam/navigation/AHRSFactor.cpp @@ -168,13 +168,12 @@ Vector AHRSFactor::evaluateError(const Rot3& Ri, const Rot3& Rj, } //------------------------------------------------------------------------------ -Rot3 AHRSFactor::Predict( - const Rot3& rot_i, const Vector3& bias, - const PreintegratedAhrsMeasurements preintegratedMeasurements) { - const Vector3 biascorrectedOmega = preintegratedMeasurements.predict(bias); +Rot3 AHRSFactor::Predict(const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim) { + const Vector3 biascorrectedOmega = pim.predict(bias); // Coriolis term - const Vector3 coriolis = preintegratedMeasurements.integrateCoriolis(rot_i); + const Vector3 coriolis = pim.integrateCoriolis(rot_i); const Vector3 correctedOmega = biascorrectedOmega - coriolis; const Rot3 correctedDeltaRij = Rot3::Expmap(correctedOmega); @@ -184,27 +183,26 @@ Rot3 AHRSFactor::Predict( //------------------------------------------------------------------------------ AHRSFactor::AHRSFactor(Key rot_i, Key rot_j, Key bias, - const PreintegratedMeasurements& pim, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, const boost::optional& body_P_sensor) - : Base(noiseModel::Gaussian::Covariance(pim.preintMeasCov_), rot_i, rot_j, bias), + : Base(noiseModel::Gaussian::Covariance(pim.preintMeasCov_), rot_i, rot_j, + bias), _PIM_(pim) { - boost::shared_ptr p = - boost::make_shared(pim.p()); + auto p = boost::make_shared(pim.p()); p->body_P_sensor = body_P_sensor; _PIM_.p_ = p; } //------------------------------------------------------------------------------ Rot3 AHRSFactor::predict(const Rot3& rot_i, const Vector3& bias, - const PreintegratedMeasurements pim, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, const boost::optional& body_P_sensor) { - boost::shared_ptr p = - boost::make_shared(pim.p()); + auto p = boost::make_shared(pim.p()); p->omegaCoriolis = omegaCoriolis; p->body_P_sensor = body_P_sensor; - PreintegratedMeasurements newPim = pim; + PreintegratedAhrsMeasurements newPim = pim; newPim.p_ = p; return Predict(rot_i, bias, newPim); } diff --git a/gtsam/navigation/AHRSFactor.h b/gtsam/navigation/AHRSFactor.h index 1ab2d7cdc..c7d82975a 100644 --- a/gtsam/navigation/AHRSFactor.h +++ b/gtsam/navigation/AHRSFactor.h @@ -90,7 +90,11 @@ class GTSAM_EXPORT PreintegratedAhrsMeasurements : public PreintegratedRotation /** * Add a single Gyroscope measurement to the preintegration. - * @param measureOmedga Measured angular velocity (in body frame) + * Measurements are taken to be in the sensor + * frame and conversion to the body frame is handled by `body_P_sensor` in + * `PreintegratedRotationParams` (if provided). + * + * @param measuredOmega Measured angular velocity (as given by the sensor) * @param deltaT Time step */ void integrateMeasurement(const Vector3& measuredOmega, double deltaT); @@ -104,11 +108,10 @@ class GTSAM_EXPORT PreintegratedAhrsMeasurements : public PreintegratedRotation static Vector DeltaAngles(const Vector& msr_gyro_t, const double msr_dt, const Vector3& delta_angles); - /// @deprecated constructor + /// @deprecated constructor, but used in tests. PreintegratedAhrsMeasurements(const Vector3& biasHat, const Matrix3& measuredOmegaCovariance) - : PreintegratedRotation(boost::make_shared()), - biasHat_(biasHat) { + : PreintegratedRotation(boost::make_shared()), biasHat_(biasHat) { p_->gyroscopeCovariance = measuredOmegaCovariance; resetIntegration(); } @@ -182,24 +185,26 @@ public: /// predicted states from IMU /// TODO(frank): relationship with PIM predict ?? - static Rot3 Predict( - const Rot3& rot_i, const Vector3& bias, - const PreintegratedAhrsMeasurements preintegratedMeasurements); + static Rot3 Predict(const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim); + /// @deprecated constructor, but used in tests. + AHRSFactor(Key rot_i, Key rot_j, Key bias, + const PreintegratedAhrsMeasurements& pim, + const Vector3& omegaCoriolis, + const boost::optional& body_P_sensor = boost::none); + + /// @deprecated static function, but used in tests. + static Rot3 predict( + const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, + const boost::optional& body_P_sensor = boost::none); + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @deprecated name typedef PreintegratedAhrsMeasurements PreintegratedMeasurements; - /// @deprecated constructor - AHRSFactor(Key rot_i, Key rot_j, Key bias, - const PreintegratedMeasurements& preintegratedMeasurements, - const Vector3& omegaCoriolis, - const boost::optional& body_P_sensor = boost::none); - - /// @deprecated static function - static Rot3 predict(const Rot3& rot_i, const Vector3& bias, - const PreintegratedMeasurements preintegratedMeasurements, - const Vector3& omegaCoriolis, - const boost::optional& body_P_sensor = boost::none); +#endif private: diff --git a/gtsam/navigation/BarometricFactor.cpp b/gtsam/navigation/BarometricFactor.cpp new file mode 100644 index 000000000..2f0ff7436 --- /dev/null +++ b/gtsam/navigation/BarometricFactor.cpp @@ -0,0 +1,55 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.cpp + * @author Peter Milani + * @brief Implementation file for Barometric factor + * @date December 16, 2021 + **/ + +#include "BarometricFactor.h" + +using namespace std; + +namespace gtsam { + +//*************************************************************************** +void BarometricFactor::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << (s.empty() ? "" : s + " ") << "Barometric Factor on " + << keyFormatter(key1()) << "Barometric Bias on " + << keyFormatter(key2()) << "\n"; + + cout << " Baro measurement: " << nT_ << "\n"; + noiseModel_->print(" noise model: "); +} + +//*************************************************************************** +bool BarometricFactor::equals(const NonlinearFactor& expected, + double tol) const { + const This* e = dynamic_cast(&expected); + return e != nullptr && Base::equals(*e, tol) && + traits::Equals(nT_, e->nT_, tol); +} + +//*************************************************************************** +Vector BarometricFactor::evaluateError(const Pose3& p, const double& bias, + boost::optional H, + boost::optional H2) const { + Matrix tH; + Vector ret = (Vector(1) << (p.translation(tH).z() + bias - nT_)).finished(); + if (H) (*H) = tH.block<1, 6>(2, 0); + if (H2) (*H2) = (Matrix(1, 1) << 1.0).finished(); + return ret; +} + +} // namespace gtsam diff --git a/gtsam/navigation/BarometricFactor.h b/gtsam/navigation/BarometricFactor.h new file mode 100644 index 000000000..e7bf6f998 --- /dev/null +++ b/gtsam/navigation/BarometricFactor.h @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.h + * @author Peter Milani + * @brief Header file for Barometric factor + * @date December 16, 2021 + **/ +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * Prior on height in a cartesian frame. + * Receive barometric pressure in kilopascals + * Model with a slowly moving bias to capture differences + * between the height and the standard atmosphere + * https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + * @addtogroup Navigation + */ +class GTSAM_EXPORT BarometricFactor : public NoiseModelFactor2 { + private: + typedef NoiseModelFactor2 Base; + + double nT_; ///< Height Measurement based on a standard atmosphere + + public: + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + + /// Typedef to this class + typedef BarometricFactor This; + + /** default constructor - only use for serialization */ + BarometricFactor() : nT_(0) {} + + ~BarometricFactor() override {} + + /** + * @brief Constructor from a measurement of pressure in KPa. + * @param key of the Pose3 variable that will be constrained + * @param key of the barometric bias that will be constrained + * @param baroIn measurement in KPa + * @param model Gaussian noise model 1 dimension + */ + BarometricFactor(Key key, Key baroKey, const double& baroIn, + const SharedNoiseModel& model) + : Base(model, key, baroKey), nT_(heightOut(baroIn)) {} + + /// @return a deep copy of this factor + gtsam::NonlinearFactor::shared_ptr clone() const override { + return boost::static_pointer_cast( + gtsam::NonlinearFactor::shared_ptr(new This(*this))); + } + + /// print + void print( + const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override; + + /// vector of errors + Vector evaluateError( + const Pose3& p, const double& b, + boost::optional H = boost::none, + boost::optional H2 = boost::none) const override; + + inline const double& measurementIn() const { return nT_; } + + inline double heightOut(double n) const { + // From https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + return (std::pow(n / 101.29, 1. / 5.256) * 288.08 - 273.1 - 15.04) / + -0.00649; + }; + + inline double baroOut(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); + }; + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "NoiseModelFactor1", + boost::serialization::base_object(*this)); + ar& BOOST_SERIALIZATION_NVP(nT_); + } +}; + +} // namespace gtsam diff --git a/gtsam/navigation/CombinedImuFactor.cpp b/gtsam/navigation/CombinedImuFactor.cpp index 64d555e2e..f2349a96e 100644 --- a/gtsam/navigation/CombinedImuFactor.cpp +++ b/gtsam/navigation/CombinedImuFactor.cpp @@ -294,6 +294,3 @@ std::ostream& operator<<(std::ostream& os, const CombinedImuFactor& f) { } /// namespace gtsam -/// Boost serialization export definition for derived class -BOOST_CLASS_EXPORT_IMPLEMENT(gtsam::PreintegrationCombinedParams); - diff --git a/gtsam/navigation/CombinedImuFactor.h b/gtsam/navigation/CombinedImuFactor.h index 5849f6f24..28a314247 100644 --- a/gtsam/navigation/CombinedImuFactor.h +++ b/gtsam/navigation/CombinedImuFactor.h @@ -209,8 +209,11 @@ public: /** * Add a single IMU measurement to the preintegration. - * @param measuredAcc Measured acceleration (in body frame, as given by the - * sensor) + * Both accelerometer and gyroscope measurements are taken to be in the sensor + * frame and conversion to the body frame is handled by `body_P_sensor` in + * `PreintegrationParams`. + * + * @param measuredAcc Measured acceleration (as given by the sensor) * @param measuredOmega Measured angular velocity (as given by the sensor) * @param dt Time interval between two consecutive IMU measurements */ @@ -352,6 +355,3 @@ template <> struct traits : public Testable {}; } // namespace gtsam - -/// Add Boost serialization export key (declaration) for derived class -BOOST_CLASS_EXPORT_KEY(gtsam::PreintegrationCombinedParams); diff --git a/gtsam/navigation/ConstantVelocityFactor.h b/gtsam/navigation/ConstantVelocityFactor.h index ed68ac077..6ab4c2f02 100644 --- a/gtsam/navigation/ConstantVelocityFactor.h +++ b/gtsam/navigation/ConstantVelocityFactor.h @@ -15,6 +15,8 @@ * @author Asa Hammond */ +#pragma once + #include #include diff --git a/gtsam/navigation/ImuBias.cpp b/gtsam/navigation/ImuBias.cpp index 0dbc5736f..bc11f95f8 100644 --- a/gtsam/navigation/ImuBias.cpp +++ b/gtsam/navigation/ImuBias.cpp @@ -66,8 +66,8 @@ namespace imuBias { // } /// ostream operator std::ostream& operator<<(std::ostream& os, const ConstantBias& bias) { - os << "acc = " << Point3(bias.accelerometer()); - os << " gyro = " << Point3(bias.gyroscope()); + os << "acc = " << bias.accelerometer().transpose(); + os << " gyro = " << bias.gyroscope().transpose(); return os; } diff --git a/gtsam/navigation/ImuBias.h b/gtsam/navigation/ImuBias.h index fad952232..9346a4a77 100644 --- a/gtsam/navigation/ImuBias.h +++ b/gtsam/navigation/ImuBias.h @@ -131,30 +131,30 @@ public: /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @name Deprecated /// @{ - ConstantBias inverse() { - return -(*this); - } - ConstantBias compose(const ConstantBias& q) { + ConstantBias GTSAM_DEPRECATED inverse() { return -(*this); } + ConstantBias GTSAM_DEPRECATED compose(const ConstantBias& q) { return (*this) + q; } - ConstantBias between(const ConstantBias& q) { + ConstantBias GTSAM_DEPRECATED between(const ConstantBias& q) { return q - (*this); } - Vector6 localCoordinates(const ConstantBias& q) { - return between(q).vector(); + Vector6 GTSAM_DEPRECATED localCoordinates(const ConstantBias& q) { + return (q - (*this)).vector(); } - ConstantBias retract(const Vector6& v) { - return compose(ConstantBias(v)); + ConstantBias GTSAM_DEPRECATED retract(const Vector6& v) { + return (*this) + ConstantBias(v); } - static Vector6 Logmap(const ConstantBias& p) { + static Vector6 GTSAM_DEPRECATED Logmap(const ConstantBias& p) { return p.vector(); } - static ConstantBias Expmap(const Vector6& v) { + static ConstantBias GTSAM_DEPRECATED Expmap(const Vector6& v) { return ConstantBias(v); } /// @} +#endif private: diff --git a/gtsam/navigation/ImuFactor.h b/gtsam/navigation/ImuFactor.h index 266f2a547..e89c2afb5 100644 --- a/gtsam/navigation/ImuFactor.h +++ b/gtsam/navigation/ImuFactor.h @@ -122,7 +122,11 @@ public: /** * Add a single IMU measurement to the preintegration. - * @param measuredAcc Measured acceleration (in body frame, as given by the sensor) + * Both accelerometer and gyroscope measurements are taken to be in the sensor + * frame and conversion to the body frame is handled by `body_P_sensor` in + * `PreintegrationParams`. + * + * @param measuredAcc Measured acceleration (as given by the sensor) * @param measuredOmega Measured angular velocity (as given by the sensor) * @param dt Time interval between this and the last IMU measurement */ diff --git a/gtsam/navigation/MagFactor.h b/gtsam/navigation/MagFactor.h index 74e9177d5..895ac6512 100644 --- a/gtsam/navigation/MagFactor.h +++ b/gtsam/navigation/MagFactor.h @@ -16,6 +16,8 @@ * @date January 29, 2014 */ +#pragma once + #include #include #include diff --git a/gtsam/navigation/PreintegrationParams.cpp b/gtsam/navigation/PreintegrationParams.cpp index 2298bb696..2548f95a6 100644 --- a/gtsam/navigation/PreintegrationParams.cpp +++ b/gtsam/navigation/PreintegrationParams.cpp @@ -34,7 +34,6 @@ void PreintegrationParams::print(const string& s) const { << endl; if (omegaCoriolis && use2ndOrderCoriolis) cout << "Using 2nd-order Coriolis" << endl; - if (body_P_sensor) body_P_sensor->print(" "); cout << "n_gravity = (" << n_gravity.transpose() << ")" << endl; } diff --git a/gtsam/navigation/navigation.i b/gtsam/navigation/navigation.i index 48a5a35de..6ede1645f 100644 --- a/gtsam/navigation/navigation.i +++ b/gtsam/navigation/navigation.i @@ -18,29 +18,21 @@ class ConstantBias { // Group static gtsam::imuBias::ConstantBias identity(); - gtsam::imuBias::ConstantBias inverse() const; - gtsam::imuBias::ConstantBias compose(const gtsam::imuBias::ConstantBias& b) const; - gtsam::imuBias::ConstantBias between(const gtsam::imuBias::ConstantBias& b) const; // Operator Overloads gtsam::imuBias::ConstantBias operator-() const; gtsam::imuBias::ConstantBias operator+(const gtsam::imuBias::ConstantBias& b) const; gtsam::imuBias::ConstantBias operator-(const gtsam::imuBias::ConstantBias& b) const; - // Manifold - gtsam::imuBias::ConstantBias retract(Vector v) const; - Vector localCoordinates(const gtsam::imuBias::ConstantBias& b) const; - - // Lie Group - static gtsam::imuBias::ConstantBias Expmap(Vector v); - static Vector Logmap(const gtsam::imuBias::ConstantBias& b); - // Standard Interface Vector vector() const; Vector accelerometer() const; Vector gyroscope() const; Vector correctAccelerometer(Vector measurement) const; Vector correctGyroscope(Vector measurement) const; + + // enabling serialization functionality + void serialize() const; }; }///\namespace imuBias @@ -64,6 +56,9 @@ class NavState { gtsam::NavState retract(const Vector& x) const; Vector localCoordinates(const gtsam::NavState& g) const; + + // enabling serialization functionality + void serialize() const; }; #include @@ -88,6 +83,8 @@ virtual class PreintegratedRotationParams { virtual class PreintegrationParams : gtsam::PreintegratedRotationParams { PreintegrationParams(Vector n_gravity); + gtsam::Vector n_gravity; + static gtsam::PreintegrationParams* MakeSharedD(double g); static gtsam::PreintegrationParams* MakeSharedU(double g); static gtsam::PreintegrationParams* MakeSharedD(); // default g = 9.81 @@ -104,6 +101,9 @@ virtual class PreintegrationParams : gtsam::PreintegratedRotationParams { Matrix getAccelerometerCovariance() const; Matrix getIntegrationCovariance() const; bool getUse2ndOrderCoriolis() const; + + // enabling serialization functionality + void serialize() const; }; #include @@ -133,6 +133,9 @@ class PreintegratedImuMeasurements { Vector biasHatVector() const; gtsam::NavState predict(const gtsam::NavState& state_i, const gtsam::imuBias::ConstantBias& bias) const; + + // enabling serialization functionality + void serialize() const; }; virtual class ImuFactor: gtsam::NonlinearFactor { diff --git a/gtsam/navigation/tests/imuFactorTesting.h b/gtsam/navigation/tests/imuFactorTesting.h index 5aa83e87e..6160db2a1 100644 --- a/gtsam/navigation/tests/imuFactorTesting.h +++ b/gtsam/navigation/tests/imuFactorTesting.h @@ -28,6 +28,7 @@ using symbol_shorthand::X; using symbol_shorthand::V; using symbol_shorthand::B; +namespace { static const Vector3 kZero = Z_3x1; typedef imuBias::ConstantBias Bias; static const Bias kZeroBiasHat, kZeroBias; @@ -43,6 +44,7 @@ static const Vector3 kGravityAlongNavZDown(0, 0, kGravity); auto radians = [](double t) { return t * M_PI / 180; }; static const double kGyroSigma = radians(0.5) / 60; // 0.5 degree ARW static const double kAccelSigma = 0.1 / 60; // 10 cm VRW +} namespace testing { diff --git a/gtsam/navigation/tests/testAHRSFactor.cpp b/gtsam/navigation/tests/testAHRSFactor.cpp index a4d06d01a..779f6abcc 100644 --- a/gtsam/navigation/tests/testAHRSFactor.cpp +++ b/gtsam/navigation/tests/testAHRSFactor.cpp @@ -54,11 +54,11 @@ Rot3 evaluateRotationError(const AHRSFactor& factor, const Rot3 rot_i, return Rot3::Expmap(factor.evaluateError(rot_i, rot_j, bias).tail(3)); } -AHRSFactor::PreintegratedMeasurements evaluatePreintegratedMeasurements( +PreintegratedAhrsMeasurements evaluatePreintegratedMeasurements( const Vector3& bias, const list& measuredOmegas, const list& deltaTs, const Vector3& initialRotationRate = Vector3::Zero()) { - AHRSFactor::PreintegratedMeasurements result(bias, I_3x3); + PreintegratedAhrsMeasurements result(bias, I_3x3); list::const_iterator itOmega = measuredOmegas.begin(); list::const_iterator itDeltaT = deltaTs.begin(); @@ -86,10 +86,10 @@ Rot3 evaluateRotation(const Vector3 measuredOmega, const Vector3 biasOmega, Vector3 evaluateLogRotation(const Vector3 thetahat, const Vector3 deltatheta) { return Rot3::Logmap(Rot3::Expmap(thetahat).compose(Rot3::Expmap(deltatheta))); } - } + //****************************************************************************** -TEST( AHRSFactor, PreintegratedMeasurements ) { +TEST( AHRSFactor, PreintegratedAhrsMeasurements ) { // Linearization point Vector3 bias(0,0,0); ///< Current estimate of angular rate bias @@ -102,7 +102,7 @@ TEST( AHRSFactor, PreintegratedMeasurements ) { double expectedDeltaT1(0.5); // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements actual1(bias, Z_3x3); + PreintegratedAhrsMeasurements actual1(bias, Z_3x3); actual1.integrateMeasurement(measuredOmega, deltaT); EXPECT(assert_equal(expectedDeltaR1, Rot3(actual1.deltaRij()), 1e-6)); @@ -113,7 +113,7 @@ TEST( AHRSFactor, PreintegratedMeasurements ) { double expectedDeltaT2(1); // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements actual2 = actual1; + PreintegratedAhrsMeasurements actual2 = actual1; actual2.integrateMeasurement(measuredOmega, deltaT); EXPECT(assert_equal(expectedDeltaR2, Rot3(actual2.deltaRij()), 1e-6)); @@ -159,7 +159,7 @@ TEST(AHRSFactor, Error) { Vector3 measuredOmega; measuredOmega << M_PI / 100, 0, 0; double deltaT = 1.0; - AHRSFactor::PreintegratedMeasurements pim(bias, Z_3x3); + PreintegratedAhrsMeasurements pim(bias, Z_3x3); pim.integrateMeasurement(measuredOmega, deltaT); // Create factor @@ -217,7 +217,7 @@ TEST(AHRSFactor, ErrorWithBiases) { measuredOmega << 0, 0, M_PI / 10.0 + 0.3; double deltaT = 1.0; - AHRSFactor::PreintegratedMeasurements pim(Vector3(0,0,0), + PreintegratedAhrsMeasurements pim(Vector3(0,0,0), Z_3x3); pim.integrateMeasurement(measuredOmega, deltaT); @@ -360,7 +360,7 @@ TEST( AHRSFactor, FirstOrderPreIntegratedMeasurements ) { } // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements preintegrated = + PreintegratedAhrsMeasurements preintegrated = evaluatePreintegratedMeasurements(bias, measuredOmegas, deltaTs, Vector3(M_PI / 100.0, 0.0, 0.0)); @@ -397,7 +397,7 @@ TEST( AHRSFactor, ErrorWithBiasesAndSensorBodyDisplacement ) { const Pose3 body_P_sensor(Rot3::Expmap(Vector3(0, 0.10, 0.10)), Point3(1, 0, 0)); - AHRSFactor::PreintegratedMeasurements pim(Vector3::Zero(), kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(Vector3::Zero(), kMeasuredAccCovariance); pim.integrateMeasurement(measuredOmega, deltaT); @@ -439,7 +439,7 @@ TEST (AHRSFactor, predictTest) { Vector3 measuredOmega; measuredOmega << 0, 0, M_PI / 10.0; double deltaT = 0.2; - AHRSFactor::PreintegratedMeasurements pim(bias, kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(bias, kMeasuredAccCovariance); for (int i = 0; i < 1000; ++i) { pim.integrateMeasurement(measuredOmega, deltaT); } @@ -456,9 +456,9 @@ TEST (AHRSFactor, predictTest) { Rot3 actualRot = factor.predict(x, bias, pim, kZeroOmegaCoriolis); EXPECT(assert_equal(expectedRot, actualRot, 1e-6)); - // AHRSFactor::PreintegratedMeasurements::predict + // PreintegratedAhrsMeasurements::predict Matrix expectedH = numericalDerivative11( - std::bind(&AHRSFactor::PreintegratedMeasurements::predict, + std::bind(&PreintegratedAhrsMeasurements::predict, &pim, std::placeholders::_1, boost::none), bias); // Actual Jacobians @@ -478,7 +478,7 @@ TEST (AHRSFactor, graphTest) { // PreIntegrator Vector3 biasHat(0, 0, 0); - AHRSFactor::PreintegratedMeasurements pim(biasHat, kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(biasHat, kMeasuredAccCovariance); // Pre-integrate measurements Vector3 measuredOmega(0, M_PI / 20, 0); diff --git a/gtsam/navigation/tests/testAttitudeFactor.cpp b/gtsam/navigation/tests/testAttitudeFactor.cpp index d49907cbf..26d793528 100644 --- a/gtsam/navigation/tests/testAttitudeFactor.cpp +++ b/gtsam/navigation/tests/testAttitudeFactor.cpp @@ -19,8 +19,6 @@ #include #include #include -#include -#include #include #include @@ -63,22 +61,6 @@ TEST( Rot3AttitudeFactor, Constructor ) { EXPECT(assert_equal(expectedH, actualH, 1e-8)); } -/* ************************************************************************* */ -// Export Noisemodels -// See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html -BOOST_CLASS_EXPORT(gtsam::noiseModel::Isotropic); - -/* ************************************************************************* */ -TEST(Rot3AttitudeFactor, Serialization) { - Unit3 nDown(0, 0, -1); - SharedNoiseModel model = noiseModel::Isotropic::Sigma(2, 0.25); - Rot3AttitudeFactor factor(0, nDown, model); - - EXPECT(serializationTestHelpers::equalsObj(factor)); - EXPECT(serializationTestHelpers::equalsXML(factor)); - EXPECT(serializationTestHelpers::equalsBinary(factor)); -} - /* ************************************************************************* */ TEST(Rot3AttitudeFactor, CopyAndMove) { Unit3 nDown(0, 0, -1); @@ -129,17 +111,6 @@ TEST( Pose3AttitudeFactor, Constructor ) { EXPECT(assert_equal(expectedH, actualH, 1e-8)); } -/* ************************************************************************* */ -TEST(Pose3AttitudeFactor, Serialization) { - Unit3 nDown(0, 0, -1); - SharedNoiseModel model = noiseModel::Isotropic::Sigma(2, 0.25); - Pose3AttitudeFactor factor(0, nDown, model); - - EXPECT(serializationTestHelpers::equalsObj(factor)); - EXPECT(serializationTestHelpers::equalsXML(factor)); - EXPECT(serializationTestHelpers::equalsBinary(factor)); -} - /* ************************************************************************* */ TEST(Pose3AttitudeFactor, CopyAndMove) { Unit3 nDown(0, 0, -1); diff --git a/gtsam/navigation/tests/testBarometricFactor.cpp b/gtsam/navigation/tests/testBarometricFactor.cpp new file mode 100644 index 000000000..47f4824c1 --- /dev/null +++ b/gtsam/navigation/tests/testBarometricFactor.cpp @@ -0,0 +1,129 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testBarometricFactor.cpp + * @brief Unit test for BarometricFactor + * @author Peter Milani + * @date 16 Dec, 2021 + */ + +#include +#include +#include +#include + +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +// ************************************************************************* +namespace example {} + +double metersToBaro(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); +} + +// ************************************************************************* +TEST(BarometricFactor, Constructor) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + // Create a linearization point at zero error + Pose3 T(Rot3::RzRyRx(0., 0., 0.), Point3(0., 0., 10.)); + double baroBias = 0.; + Vector1 zero; + zero << 0.; + EXPECT(assert_equal(zero, factor.evaluateError(T, baroBias), 1e-5)); + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative + Matrix actualH, actualH2; + factor.evaluateError(T, baroBias, actualH, actualH2); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); +} + +// ************************************************************************* + +//*************************************************************************** +TEST(BarometricFactor, nonZero) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + Pose3 T(Rot3::RzRyRx(0.5, 1., 1.), Point3(20., 30., 1.)); + double baroBias = 5.; + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative and the error + Matrix actualH, actualH2; + Vector error = factor.evaluateError(T, baroBias, actualH, actualH2); + Vector actual = (Vector(1) << -4.0).finished(); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); + EXPECT(assert_equal(error, actual, 1e-8)); +} + +// ************************************************************************* +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +// ************************************************************************* diff --git a/gtsam/navigation/tests/testGPSFactor.cpp b/gtsam/navigation/tests/testGPSFactor.cpp index b784c0c94..5607add16 100644 --- a/gtsam/navigation/tests/testGPSFactor.cpp +++ b/gtsam/navigation/tests/testGPSFactor.cpp @@ -27,7 +27,6 @@ #include #include -using namespace std::placeholders; using namespace std; using namespace gtsam; using namespace GeographicLib; @@ -71,8 +70,8 @@ TEST( GPSFactor, Constructor ) { EXPECT(assert_equal(Z_3x1,factor.evaluateError(T),1e-5)); // Calculate numerical derivatives - Matrix expectedH = numericalDerivative11( - std::bind(&GPSFactor::evaluateError, &factor, _1, boost::none), T); + Matrix expectedH = numericalDerivative11( + std::bind(&GPSFactor::evaluateError, &factor, std::placeholders::_1, boost::none), T); // Use the factor to calculate the derivative Matrix actualH; @@ -100,8 +99,8 @@ TEST( GPSFactor2, Constructor ) { EXPECT(assert_equal(Z_3x1,factor.evaluateError(T),1e-5)); // Calculate numerical derivatives - Matrix expectedH = numericalDerivative11( - std::bind(&GPSFactor2::evaluateError, &factor, _1, boost::none), T); + Matrix expectedH = numericalDerivative11( + std::bind(&GPSFactor2::evaluateError, &factor, std::placeholders::_1, boost::none), T); // Use the factor to calculate the derivative Matrix actualH; diff --git a/gtsam/navigation/tests/testImuBias.cpp b/gtsam/navigation/tests/testImuBias.cpp index b486a4a98..81a1a2ceb 100644 --- a/gtsam/navigation/tests/testImuBias.cpp +++ b/gtsam/navigation/tests/testImuBias.cpp @@ -47,20 +47,19 @@ TEST(ImuBias, Constructor) { } /* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 TEST(ImuBias, inverse) { Bias biasActual = bias1.inverse(); Bias biasExpected = Bias(-biasAcc1, -biasGyro1); EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, compose) { Bias biasActual = bias2.compose(bias1); Bias biasExpected = Bias(biasAcc1 + biasAcc2, biasGyro1 + biasGyro2); EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, between) { // p.between(q) == q - p Bias biasActual = bias2.between(bias1); @@ -68,7 +67,6 @@ TEST(ImuBias, between) { EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, localCoordinates) { Vector deltaActual = Vector(bias2.localCoordinates(bias1)); Vector deltaExpected = @@ -76,7 +74,6 @@ TEST(ImuBias, localCoordinates) { EXPECT(assert_equal(deltaExpected, deltaActual)); } -/* ************************************************************************* */ TEST(ImuBias, retract) { Vector6 delta; delta << 0.1, 0.2, -0.3, 0.1, -0.1, 0.2; @@ -86,14 +83,12 @@ TEST(ImuBias, retract) { EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, Logmap) { Vector deltaActual = bias2.Logmap(bias1); Vector deltaExpected = bias1.vector(); EXPECT(assert_equal(deltaExpected, deltaActual)); } -/* ************************************************************************* */ TEST(ImuBias, Expmap) { Vector6 delta; delta << 0.1, 0.2, -0.3, 0.1, -0.1, 0.2; @@ -101,6 +96,7 @@ TEST(ImuBias, Expmap) { Bias biasExpected = Bias(delta); EXPECT(assert_equal(biasExpected, biasActual)); } +#endif /* ************************************************************************* */ TEST(ImuBias, operatorSub) { diff --git a/gtsam/navigation/tests/testMagFactor.cpp b/gtsam/navigation/tests/testMagFactor.cpp index 5107b3b6b..971803dbf 100644 --- a/gtsam/navigation/tests/testMagFactor.cpp +++ b/gtsam/navigation/tests/testMagFactor.cpp @@ -26,12 +26,11 @@ #include -using namespace std::placeholders; using namespace std; using namespace gtsam; using namespace GeographicLib; -// ************************************************************************* +namespace { // Convert from Mag to ENU // ENU Origin is where the plane was in hold next to runway // const double lat0 = 33.86998, lon0 = -84.30626, h0 = 274; @@ -51,10 +50,11 @@ Point3 bias(10, -10, 50); Point3 scaled = scale * nM; Point3 measured = nRb.inverse() * (scale * nM) + bias; -double s(scale * nM.norm()); +double s(scale* nM.norm()); Unit3 dir(nM); SharedNoiseModel model = noiseModel::Isotropic::Sigma(3, 0.25); +} // namespace using boost::none; @@ -63,8 +63,8 @@ TEST( MagFactor, unrotate ) { Matrix H; Point3 expected(22735.5, 314.502, 44202.5); EXPECT( assert_equal(expected, MagFactor::unrotate(theta,nM,H),1e-1)); - EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor::unrotate, _1, nM, none), theta), H, 1e-6)); + EXPECT(assert_equal(numericalDerivative11 // + (std::bind(&MagFactor::unrotate, std::placeholders::_1, nM, none), theta), H, 1e-6)); } // ************************************************************************* @@ -74,37 +74,37 @@ TEST( MagFactor, Factors ) { // MagFactor MagFactor f(1, measured, s, dir, bias, model); - EXPECT( assert_equal(Z_3x1,f.evaluateError(theta,H1),1e-5)); - EXPECT( assert_equal((Matrix)numericalDerivative11 // - (std::bind(&MagFactor::evaluateError, &f, _1, none), theta), H1, 1e-7)); + EXPECT(assert_equal(Z_3x1,f.evaluateError(theta,H1),1e-5)); + EXPECT(assert_equal((Matrix)numericalDerivative11 // + (std::bind(&MagFactor::evaluateError, &f, std::placeholders::_1, none), theta), H1, 1e-7)); -// MagFactor1 + // MagFactor1 MagFactor1 f1(1, measured, s, dir, bias, model); - EXPECT( assert_equal(Z_3x1,f1.evaluateError(nRb,H1),1e-5)); - EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor1::evaluateError, &f1, _1, none), nRb), H1, 1e-7)); + EXPECT(assert_equal(Z_3x1,f1.evaluateError(nRb,H1),1e-5)); + EXPECT(assert_equal(numericalDerivative11 // + (std::bind(&MagFactor1::evaluateError, &f1, std::placeholders::_1, none), nRb), H1, 1e-7)); -// MagFactor2 + // MagFactor2 MagFactor2 f2(1, 2, measured, nRb, model); - EXPECT( assert_equal(Z_3x1,f2.evaluateError(scaled,bias,H1,H2),1e-5)); - EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor2::evaluateError, &f2, _1, bias, none, none), scaled),// + EXPECT(assert_equal(Z_3x1,f2.evaluateError(scaled,bias,H1,H2),1e-5)); + EXPECT(assert_equal(numericalDerivative11 // + (std::bind(&MagFactor2::evaluateError, &f2, std::placeholders::_1, bias, none, none), scaled),// H1, 1e-7)); - EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor2::evaluateError, &f2, scaled, _1, none, none), bias),// + EXPECT(assert_equal(numericalDerivative11 // + (std::bind(&MagFactor2::evaluateError, &f2, scaled, std::placeholders::_1, none, none), bias),// H2, 1e-7)); -// MagFactor2 + // MagFactor3 MagFactor3 f3(1, 2, 3, measured, nRb, model); EXPECT(assert_equal(Z_3x1,f3.evaluateError(s,dir,bias,H1,H2,H3),1e-5)); EXPECT(assert_equal((Matrix)numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, _1, dir, bias, none, none, none), s),// + (std::bind(&MagFactor3::evaluateError, &f3, std::placeholders::_1, dir, bias, none, none, none), s),// H1, 1e-7)); EXPECT(assert_equal(numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, s, _1, bias, none, none, none), dir),// + (std::bind(&MagFactor3::evaluateError, &f3, s, std::placeholders::_1, bias, none, none, none), dir),// H2, 1e-7)); EXPECT(assert_equal(numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, s, dir, _1, none, none, none), bias),// + (std::bind(&MagFactor3::evaluateError, &f3, s, dir, std::placeholders::_1, none, none, none), bias),// H3, 1e-7)); } diff --git a/gtsam/navigation/tests/testMagPoseFactor.cpp b/gtsam/navigation/tests/testMagPoseFactor.cpp index 204c1d38f..e10409f4c 100644 --- a/gtsam/navigation/tests/testMagPoseFactor.cpp +++ b/gtsam/navigation/tests/testMagPoseFactor.cpp @@ -20,7 +20,7 @@ using namespace std::placeholders; using namespace gtsam; -// ***************************************************************************** +namespace { // Magnetic field in the nav frame (NED), with units of nT. Point3 nM(22653.29982, -1956.83010, 44202.47862); @@ -51,8 +51,9 @@ SharedNoiseModel model3 = noiseModel::Isotropic::Sigma(3, 0.25); // Make up a rotation and offset of the sensor in the body frame. Pose2 body_P2_sensor(Rot2(-0.30), Point2(1.0, -2.0)); -Pose3 body_P3_sensor(Rot3::RzRyRx(Vector3(1.5, 0.9, -1.15)), Point3(-0.1, 0.2, 0.3)); -// ***************************************************************************** +Pose3 body_P3_sensor(Rot3::RzRyRx(Vector3(1.5, 0.9, -1.15)), + Point3(-0.1, 0.2, 0.3)); +} // namespace // ***************************************************************************** TEST(MagPoseFactor, Constructors) { diff --git a/gtsam/navigation/tests/testImuFactorSerialization.cpp b/gtsam/navigation/tests/testSerializationNavigation.cpp similarity index 60% rename from gtsam/navigation/tests/testImuFactorSerialization.cpp rename to gtsam/navigation/tests/testSerializationNavigation.cpp index ed72e18e9..6a2155875 100644 --- a/gtsam/navigation/tests/testImuFactorSerialization.cpp +++ b/gtsam/navigation/tests/testSerializationNavigation.cpp @@ -10,17 +10,19 @@ * -------------------------------------------------------------------------- */ /** - * @file testImuFactor.cpp - * @brief Unit test for ImuFactor + * @file testSerializationNavigation.cpp + * @brief serialization tests for navigation * @author Luca Carlone * @author Frank Dellaert * @author Richard Roberts * @author Stephen Williams * @author Varun Agrawal + * @date February 2022 */ #include #include +#include #include #include @@ -30,17 +32,16 @@ using namespace std; using namespace gtsam; using namespace gtsam::serializationTestHelpers; -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, - "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, - "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, - "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, - "gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); +BOOST_CLASS_EXPORT_GUID(noiseModel::Constrained, "gtsam_noiseModel_Constrained") +BOOST_CLASS_EXPORT_GUID(noiseModel::Diagonal, "gtsam_noiseModel_Diagonal") +BOOST_CLASS_EXPORT_GUID(noiseModel::Gaussian, "gtsam_noiseModel_Gaussian") +BOOST_CLASS_EXPORT_GUID(noiseModel::Unit, "gtsam_noiseModel_Unit") +BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic") +BOOST_CLASS_EXPORT_GUID(SharedNoiseModel, "gtsam_SharedNoiseModel") +BOOST_CLASS_EXPORT_GUID(SharedDiagonal, "gtsam_SharedDiagonal") +BOOST_CLASS_EXPORT_GUID(PreintegratedImuMeasurements, "gtsam_PreintegratedImuMeasurements") +BOOST_CLASS_EXPORT_GUID(PreintegrationCombinedParams, "gtsam_PreintegrationCombinedParams") +BOOST_CLASS_EXPORT_GUID(PreintegratedCombinedMeasurements, "gtsam_PreintegratedCombinedMeasurements") template P getPreintegratedMeasurements() { @@ -69,6 +70,7 @@ P getPreintegratedMeasurements() { return pim; } +/* ************************************************************************* */ TEST(ImuFactor, serialization) { auto pim = getPreintegratedMeasurements(); @@ -83,6 +85,7 @@ TEST(ImuFactor, serialization) { EXPECT(equalsBinary(factor)); } +/* ************************************************************************* */ TEST(ImuFactor2, serialization) { auto pim = getPreintegratedMeasurements(); @@ -93,6 +96,7 @@ TEST(ImuFactor2, serialization) { EXPECT(equalsBinary(factor)); } +/* ************************************************************************* */ TEST(CombinedImuFactor, Serialization) { auto pim = getPreintegratedMeasurements(); @@ -107,6 +111,28 @@ TEST(CombinedImuFactor, Serialization) { EXPECT(equalsBinary(factor)); } +/* ************************************************************************* */ +TEST(Rot3AttitudeFactor, Serialization) { + Unit3 nDown(0, 0, -1); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(2, 0.25); + Rot3AttitudeFactor factor(0, nDown, model); + + EXPECT(serializationTestHelpers::equalsObj(factor)); + EXPECT(serializationTestHelpers::equalsXML(factor)); + EXPECT(serializationTestHelpers::equalsBinary(factor)); +} + +/* ************************************************************************* */ +TEST(Pose3AttitudeFactor, Serialization) { + Unit3 nDown(0, 0, -1); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(2, 0.25); + Pose3AttitudeFactor factor(0, nDown, model); + + EXPECT(serializationTestHelpers::equalsObj(factor)); + EXPECT(serializationTestHelpers::equalsXML(factor)); + EXPECT(serializationTestHelpers::equalsBinary(factor)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/nonlinear/CustomFactor.h b/gtsam/nonlinear/CustomFactor.h index dbaf31898..615b5418e 100644 --- a/gtsam/nonlinear/CustomFactor.h +++ b/gtsam/nonlinear/CustomFactor.h @@ -101,4 +101,4 @@ private: } }; -}; +} diff --git a/gtsam/nonlinear/DoglegOptimizerImpl.cpp b/gtsam/nonlinear/DoglegOptimizerImpl.cpp index c319f26e6..7e9db6b64 100644 --- a/gtsam/nonlinear/DoglegOptimizerImpl.cpp +++ b/gtsam/nonlinear/DoglegOptimizerImpl.cpp @@ -78,7 +78,8 @@ VectorValues DoglegOptimizerImpl::ComputeBlend(double delta, const VectorValues& // Compute blended point if(verbose) cout << "In blend region with fraction " << tau << " of Newton's method point" << endl; - VectorValues blend = (1. - tau) * x_u; axpy(tau, x_n, blend); + VectorValues blend = (1. - tau) * x_u; + blend += tau * x_n; return blend; } diff --git a/gtsam/nonlinear/Expression-inl.h b/gtsam/nonlinear/Expression-inl.h index cf2462dfc..b2ef6f055 100644 --- a/gtsam/nonlinear/Expression-inl.h +++ b/gtsam/nonlinear/Expression-inl.h @@ -246,6 +246,18 @@ struct apply_compose { return x.compose(y, H1, H2); } }; + +template <> +struct apply_compose { + double operator()(const double& x, const double& y, + OptionalJacobian<1, 1> H1 = boost::none, + OptionalJacobian<1, 1> H2 = boost::none) const { + if (H1) H1->setConstant(y); + if (H2) H2->setConstant(x); + return x * y; + } +}; + } // Global methods: diff --git a/gtsam/nonlinear/ExpressionFactor.h b/gtsam/nonlinear/ExpressionFactor.h index b55d643aa..11bf873e7 100644 --- a/gtsam/nonlinear/ExpressionFactor.h +++ b/gtsam/nonlinear/ExpressionFactor.h @@ -295,17 +295,17 @@ struct traits> // ExpressionFactorN -#if defined(GTSAM_ALLOW_DEPRECATED_SINCE_V41) +#if defined(GTSAM_ALLOW_DEPRECATED_SINCE_V42) /** * Binary specialization of ExpressionFactor meant as a base class for binary * factors. Enforces an 'expression' method with two keys, and provides * 'evaluateError'. Derived class (a binary factor!) needs to call 'initialize'. * * \sa ExpressionFactorN - * \deprecated Prefer the more general ExpressionFactorN<>. + * @deprecated Prefer the more general ExpressionFactorN<>. */ template -class ExpressionFactor2 : public ExpressionFactorN { +class GTSAM_DEPRECATED ExpressionFactor2 : public ExpressionFactorN { public: /// Destructor ~ExpressionFactor2() override {} diff --git a/gtsam/nonlinear/ExtendedKalmanFilter.h b/gtsam/nonlinear/ExtendedKalmanFilter.h index 77bb1ca6c..df27d16ff 100644 --- a/gtsam/nonlinear/ExtendedKalmanFilter.h +++ b/gtsam/nonlinear/ExtendedKalmanFilter.h @@ -51,9 +51,11 @@ class ExtendedKalmanFilter { typedef boost::shared_ptr > shared_ptr; typedef VALUE T; +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 //@deprecated: any NoiseModelFactor will do, as long as they have the right keys typedef NoiseModelFactor2 MotionFactor; typedef NoiseModelFactor1 MeasurementFactor; +#endif protected: T x_; // linearization point diff --git a/gtsam/nonlinear/FunctorizedFactor.h b/gtsam/nonlinear/FunctorizedFactor.h index e1f8ece8d..394b22b6b 100644 --- a/gtsam/nonlinear/FunctorizedFactor.h +++ b/gtsam/nonlinear/FunctorizedFactor.h @@ -56,7 +56,7 @@ namespace gtsam { * MultiplyFunctor(multiplier)); */ template -class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1 { +class FunctorizedFactor : public NoiseModelFactor1 { private: using Base = NoiseModelFactor1; @@ -155,7 +155,7 @@ FunctorizedFactor MakeFunctorizedFactor(Key key, const R &z, * @param T2: The second argument type for the functor. */ template -class GTSAM_EXPORT FunctorizedFactor2 : public NoiseModelFactor2 { +class FunctorizedFactor2 : public NoiseModelFactor2 { private: using Base = NoiseModelFactor2; diff --git a/gtsam/nonlinear/GncOptimizer.h b/gtsam/nonlinear/GncOptimizer.h index 3ddaf4820..cc3fdaf34 100644 --- a/gtsam/nonlinear/GncOptimizer.h +++ b/gtsam/nonlinear/GncOptimizer.h @@ -142,8 +142,9 @@ class GTSAM_EXPORT GncOptimizer { * provides an extra interface for the user to initialize the weightst * */ void setWeights(const Vector w) { - if(w.size() != nfg_.size()){ - throw std::runtime_error("GncOptimizer::setWeights: the number of specified weights" + if (size_t(w.size()) != nfg_.size()) { + throw std::runtime_error( + "GncOptimizer::setWeights: the number of specified weights" " does not match the size of the factor graph."); } weights_ = w; @@ -183,7 +184,8 @@ class GTSAM_EXPORT GncOptimizer { /// Compute optimal solution using graduated non-convexity. Values optimize() { NonlinearFactorGraph graph_initial = this->makeWeightedGraph(weights_); - BaseOptimizer baseOptimizer(graph_initial, state_); + BaseOptimizer baseOptimizer( + graph_initial, state_, params_.baseOptimizerParams); Values result = baseOptimizer.optimize(); double mu = initializeMu(); double prev_cost = graph_initial.error(result); @@ -227,7 +229,8 @@ class GTSAM_EXPORT GncOptimizer { // variable/values update NonlinearFactorGraph graph_iter = this->makeWeightedGraph(weights_); - BaseOptimizer baseOptimizer_iter(graph_iter, state_); + BaseOptimizer baseOptimizer_iter( + graph_iter, state_, params_.baseOptimizerParams); result = baseOptimizer_iter.optimize(); // stopping condition diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp new file mode 100644 index 000000000..ca3466b6a --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -0,0 +1,145 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.cpp + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +// TODO(frank): nonlinear should not depend on geometry: +#include +#include + +#include + +namespace gtsam { + +Vector2 GraphvizFormatting::findBounds(const Values& values, + const KeySet& keys) const { + Vector2 min; + min.x() = std::numeric_limits::infinity(); + min.y() = std::numeric_limits::infinity(); + for (const Key& key : keys) { + if (values.exists(key)) { + boost::optional xy = extractPosition(values.at(key)); + if (xy) { + if (xy->x() < min.x()) min.x() = xy->x(); + if (xy->y() < min.y()) min.y() = xy->y(); + } + } + } + return min; +} + +boost::optional GraphvizFormatting::extractPosition( + const Value& value) const { + Vector3 t; + if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + if (p->dim() == 2) { + const Eigen::Ref p_2d(p->value()); + t << p_2d.x(), p_2d.y(), 0; + } else if (p->dim() == 3) { + const Eigen::Ref p_3d(p->value()); + t = p_3d; + } else { + return boost::none; + } + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value().translation(); + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value(); + } else { + return boost::none; + } + double x, y; + switch (paperHorizontalAxis) { + case X: + x = t.x(); + break; + case Y: + x = t.y(); + break; + case Z: + x = t.z(); + break; + case NEGX: + x = -t.x(); + break; + case NEGY: + x = -t.y(); + break; + case NEGZ: + x = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + switch (paperVerticalAxis) { + case X: + y = t.x(); + break; + case Y: + y = t.y(); + break; + case Z: + y = t.z(); + break; + case NEGX: + y = -t.x(); + break; + case NEGY: + y = -t.y(); + break; + case NEGZ: + y = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + return Vector2(x, y); +} + +boost::optional GraphvizFormatting::variablePos(const Values& values, + const Vector2& min, + Key key) const { + if (!values.exists(key)) return DotWriter::variablePos(key); + boost::optional xy = extractPosition(values.at(key)); + if (xy) { + xy->x() = scale * (xy->x() - min.x()); + xy->y() = scale * (xy->y() - min.y()); + } + return xy; +} + +boost::optional GraphvizFormatting::factorPos(const Vector2& min, + size_t i) const { + if (factorPositions.size() == 0) return boost::none; + auto it = factorPositions.find(i); + if (it == factorPositions.end()) return boost::none; + auto pos = it->second; + return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y())); +} + +} // namespace gtsam diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h new file mode 100644 index 000000000..03cdb3469 --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -0,0 +1,66 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.h + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include + +namespace gtsam { + +class Values; +class Value; + +/** + * Formatting options and functions for saving a NonlinearFactorGraph instance + * in GraphViz format. + */ +struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { + /// World axes to be assigned to paper axes + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis + double scale; ///< Scale all positions to reduce / increase density + bool mergeSimilarFactors; ///< Merge multiple factors that have the same + ///< connectivity + + /// Default constructor sets up robot coordinates. Paper horizontal is robot + /// Y, paper vertical is robot X. Default figure size of 5x5 in. + GraphvizFormatting() + : paperHorizontalAxis(Y), + paperVerticalAxis(X), + scale(1), + mergeSimilarFactors(false) {} + + // Find bounds + Vector2 findBounds(const Values& values, const KeySet& keys) const; + + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional extractPosition(const Value& value) const; + + /// Return affinely transformed variable position if it exists. + boost::optional variablePos(const Values& values, const Vector2& min, + Key key) const; + + /// Return affinely transformed factor position if it exists. + boost::optional factorPos(const Vector2& min, size_t i) const; +}; + +} // namespace gtsam diff --git a/gtsam/nonlinear/ISAM2.h b/gtsam/nonlinear/ISAM2.h index 92c2142a7..37feee837 100644 --- a/gtsam/nonlinear/ISAM2.h +++ b/gtsam/nonlinear/ISAM2.h @@ -295,6 +295,17 @@ class GTSAM_EXPORT ISAM2 : public BayesTree { const ISAM2UpdateParams& updateParams, const FastList& affectedKeys, const KeySet& relinKeys); + /** + * @brief Perform an incremental update of the factor graph to return a new + * Bayes Tree with affected keys. + * + * @param updateParams Parameters for the ISAM2 update. + * @param relinKeys Keys of variables to relinearize. + * @param affectedKeys The set of keys which are affected in the update. + * @param affectedKeysSet [output] Affected and contaminated keys. + * @param orphans [output] List of orphanes cliques after elimination. + * @param result [output] The result of the incremental update step. + */ void recalculateIncremental(const ISAM2UpdateParams& updateParams, const KeySet& relinKeys, const FastList& affectedKeys, diff --git a/gtsam/nonlinear/ISAM2Params.h b/gtsam/nonlinear/ISAM2Params.h index c6e1001c4..d88afd505 100644 --- a/gtsam/nonlinear/ISAM2Params.h +++ b/gtsam/nonlinear/ISAM2Params.h @@ -300,18 +300,10 @@ struct GTSAM_EXPORT ISAM2Params { RelinearizationThreshold getRelinearizeThreshold() const { return relinearizeThreshold; } - int getRelinearizeSkip() const { return relinearizeSkip; } - bool isEnableRelinearization() const { return enableRelinearization; } - bool isEvaluateNonlinearError() const { return evaluateNonlinearError; } std::string getFactorization() const { return factorizationTranslator(factorization); } - bool isCacheLinearizedFactors() const { return cacheLinearizedFactors; } KeyFormatter getKeyFormatter() const { return keyFormatter; } - bool isEnableDetailedResults() const { return enableDetailedResults; } - bool isEnablePartialRelinearizationCheck() const { - return enablePartialRelinearizationCheck; - } void setOptimizationParams(OptimizationParams optimizationParams) { this->optimizationParams = optimizationParams; @@ -319,31 +311,12 @@ struct GTSAM_EXPORT ISAM2Params { void setRelinearizeThreshold(RelinearizationThreshold relinearizeThreshold) { this->relinearizeThreshold = relinearizeThreshold; } - void setRelinearizeSkip(int relinearizeSkip) { - this->relinearizeSkip = relinearizeSkip; - } - void setEnableRelinearization(bool enableRelinearization) { - this->enableRelinearization = enableRelinearization; - } - void setEvaluateNonlinearError(bool evaluateNonlinearError) { - this->evaluateNonlinearError = evaluateNonlinearError; - } void setFactorization(const std::string& factorization) { this->factorization = factorizationTranslator(factorization); } - void setCacheLinearizedFactors(bool cacheLinearizedFactors) { - this->cacheLinearizedFactors = cacheLinearizedFactors; - } void setKeyFormatter(KeyFormatter keyFormatter) { this->keyFormatter = keyFormatter; } - void setEnableDetailedResults(bool enableDetailedResults) { - this->enableDetailedResults = enableDetailedResults; - } - void setEnablePartialRelinearizationCheck( - bool enablePartialRelinearizationCheck) { - this->enablePartialRelinearizationCheck = enablePartialRelinearizationCheck; - } GaussianFactorGraph::Eliminate getEliminationFunction() const { return factorization == CHOLESKY diff --git a/gtsam/nonlinear/ISAM2Result.h b/gtsam/nonlinear/ISAM2Result.h index b249af5c5..be91f17e2 100644 --- a/gtsam/nonlinear/ISAM2Result.h +++ b/gtsam/nonlinear/ISAM2Result.h @@ -175,6 +175,7 @@ struct ISAM2Result { /** Getters and Setters */ size_t getVariablesRelinearized() const { return variablesRelinearized; } size_t getVariablesReeliminated() const { return variablesReeliminated; } + FactorIndices getNewFactorsIndices() const { return newFactorsIndices; } size_t getCliques() const { return cliques; } double getErrorBefore() const { return errorBefore ? *errorBefore : std::nan(""); } double getErrorAfter() const { return errorAfter ? *errorAfter : std::nan(""); } diff --git a/gtsam/nonlinear/LevenbergMarquardtParams.h b/gtsam/nonlinear/LevenbergMarquardtParams.h index 1e2c6e395..f40443457 100644 --- a/gtsam/nonlinear/LevenbergMarquardtParams.h +++ b/gtsam/nonlinear/LevenbergMarquardtParams.h @@ -35,7 +35,7 @@ class LevenbergMarquardtOptimizer; class GTSAM_EXPORT LevenbergMarquardtParams: public NonlinearOptimizerParams { public: - /** See LevenbergMarquardtParams::lmVerbosity */ + /** See LevenbergMarquardtParams::verbosityLM */ enum VerbosityLM { SILENT = 0, SUMMARY, TERMINATION, LAMBDA, TRYLAMBDA, TRYCONFIG, DAMPED, TRYDELTA }; diff --git a/gtsam/nonlinear/LinearContainerFactor.h b/gtsam/nonlinear/LinearContainerFactor.h index 8c5b34f01..16094b67a 100644 --- a/gtsam/nonlinear/LinearContainerFactor.h +++ b/gtsam/nonlinear/LinearContainerFactor.h @@ -23,17 +23,14 @@ namespace gtsam { * This factor does have the ability to perform relinearization under small-angle and * linearity assumptions if a linearization point is added. */ -class LinearContainerFactor : public NonlinearFactor { +class GTSAM_EXPORT LinearContainerFactor : public NonlinearFactor { protected: GaussianFactor::shared_ptr factor_; boost::optional linearizationPoint_; - /** Default constructor - necessary for serialization */ - LinearContainerFactor() {} - /** direct copy constructor */ - GTSAM_EXPORT LinearContainerFactor(const GaussianFactor::shared_ptr& factor, const boost::optional& linearizationPoint); + LinearContainerFactor(const GaussianFactor::shared_ptr& factor, const boost::optional& linearizationPoint); // Some handy typedefs typedef NonlinearFactor Base; @@ -43,14 +40,17 @@ public: typedef boost::shared_ptr shared_ptr; - /** Primary constructor: store a linear factor with optional linearization point */ - GTSAM_EXPORT LinearContainerFactor(const JacobianFactor& factor, const Values& linearizationPoint = Values()); + /** Default constructor - necessary for serialization */ + LinearContainerFactor() {} /** Primary constructor: store a linear factor with optional linearization point */ - GTSAM_EXPORT LinearContainerFactor(const HessianFactor& factor, const Values& linearizationPoint = Values()); + LinearContainerFactor(const JacobianFactor& factor, const Values& linearizationPoint = Values()); + + /** Primary constructor: store a linear factor with optional linearization point */ + LinearContainerFactor(const HessianFactor& factor, const Values& linearizationPoint = Values()); /** Constructor from shared_ptr */ - GTSAM_EXPORT LinearContainerFactor(const GaussianFactor::shared_ptr& factor, const Values& linearizationPoint = Values()); + LinearContainerFactor(const GaussianFactor::shared_ptr& factor, const Values& linearizationPoint = Values()); // Access @@ -59,10 +59,10 @@ public: // Testable /** print */ - GTSAM_EXPORT void print(const std::string& s = "", const KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const override; + void print(const std::string& s = "", const KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const override; /** Check if two factors are equal */ - GTSAM_EXPORT bool equals(const NonlinearFactor& f, double tol = 1e-9) const override; + bool equals(const NonlinearFactor& f, double tol = 1e-9) const override; // NonlinearFactor @@ -74,10 +74,10 @@ public: * * @return nonlinear error if linearizationPoint present, zero otherwise */ - GTSAM_EXPORT double error(const Values& c) const override; + double error(const Values& c) const override; /** get the dimension of the factor: rows of linear factor */ - GTSAM_EXPORT size_t dim() const override; + size_t dim() const override; /** Extract the linearization point used in recalculating error */ const boost::optional& linearizationPoint() const { return linearizationPoint_; } @@ -98,17 +98,17 @@ public: * TODO: better approximation of relinearization * TODO: switchable modes for approximation technique */ - GTSAM_EXPORT GaussianFactor::shared_ptr linearize(const Values& c) const override; + GaussianFactor::shared_ptr linearize(const Values& c) const override; /** * Creates an anti-factor directly */ - GTSAM_EXPORT GaussianFactor::shared_ptr negateToGaussian() const; + GaussianFactor::shared_ptr negateToGaussian() const; /** * Creates the equivalent anti-factor as another LinearContainerFactor. */ - GTSAM_EXPORT NonlinearFactor::shared_ptr negateToNonlinear() const; + NonlinearFactor::shared_ptr negateToNonlinear() const; /** * Creates a shared_ptr clone of the factor - needs to be specialized to allow @@ -140,25 +140,24 @@ public: /** * Simple checks whether this is a Jacobian or Hessian factor */ - GTSAM_EXPORT bool isJacobian() const; - GTSAM_EXPORT bool isHessian() const; + bool isJacobian() const; + bool isHessian() const; /** Casts to JacobianFactor */ - GTSAM_EXPORT boost::shared_ptr toJacobian() const; + boost::shared_ptr toJacobian() const; /** Casts to HessianFactor */ - GTSAM_EXPORT boost::shared_ptr toHessian() const; + boost::shared_ptr toHessian() const; /** * Utility function for converting linear graphs to nonlinear graphs * consisting of LinearContainerFactors. */ - GTSAM_EXPORT static NonlinearFactorGraph ConvertLinearGraph(const GaussianFactorGraph& linear_graph, const Values& linearizationPoint = Values()); protected: - GTSAM_EXPORT void initializeLinearizationPoint(const Values& linearizationPoint); + void initializeLinearizationPoint(const Values& linearizationPoint); private: /** Serialization function */ diff --git a/gtsam/nonlinear/Marginals.cpp b/gtsam/nonlinear/Marginals.cpp index c29a79623..41212ed76 100644 --- a/gtsam/nonlinear/Marginals.cpp +++ b/gtsam/nonlinear/Marginals.cpp @@ -80,11 +80,15 @@ Marginals::Marginals(const GaussianFactorGraph& graph, const VectorValues& solut /* ************************************************************************* */ void Marginals::computeBayesTree() { + // The default ordering to use. + const Ordering::OrderingType defaultOrderingType = Ordering::COLAMD; // Compute BayesTree - if(factorization_ == CHOLESKY) - bayesTree_ = *graph_.eliminateMultifrontal(EliminatePreferCholesky); - else if(factorization_ == QR) - bayesTree_ = *graph_.eliminateMultifrontal(EliminateQR); + if (factorization_ == CHOLESKY) + bayesTree_ = *graph_.eliminateMultifrontal(defaultOrderingType, + EliminatePreferCholesky); + else if (factorization_ == QR) + bayesTree_ = + *graph_.eliminateMultifrontal(defaultOrderingType, EliminateQR); } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/Marginals.h b/gtsam/nonlinear/Marginals.h index 9935bafdd..028545d01 100644 --- a/gtsam/nonlinear/Marginals.h +++ b/gtsam/nonlinear/Marginals.h @@ -131,17 +131,19 @@ protected: void computeBayesTree(const Ordering& ordering); public: - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization, +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization, + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization, + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} +#endif }; diff --git a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h index fd9e49a62..a7a0d724b 100644 --- a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h +++ b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h @@ -149,7 +149,7 @@ boost::tuple nonlinearConjugateGradient(const S &system, const V &initial, const NonlinearOptimizerParams ¶ms, const bool singleIteration, const bool gradientDescent = false) { - // GTSAM_CONCEPT_MANIFOLD_TYPE(V); + // GTSAM_CONCEPT_MANIFOLD_TYPE(V) size_t iteration = 0; diff --git a/gtsam/nonlinear/NonlinearEquality.h b/gtsam/nonlinear/NonlinearEquality.h index 47083d5d7..43d30254e 100644 --- a/gtsam/nonlinear/NonlinearEquality.h +++ b/gtsam/nonlinear/NonlinearEquality.h @@ -219,7 +219,6 @@ protected: X value_; /// fixed value for variable GTSAM_CONCEPT_MANIFOLD_TYPE(X) - GTSAM_CONCEPT_TESTABLE_TYPE(X) public: diff --git a/gtsam/nonlinear/NonlinearFactor.cpp b/gtsam/nonlinear/NonlinearFactor.cpp index 8b8d2da6c..3d572e970 100644 --- a/gtsam/nonlinear/NonlinearFactor.cpp +++ b/gtsam/nonlinear/NonlinearFactor.cpp @@ -114,7 +114,7 @@ double NoiseModelFactor::weight(const Values& c) const { if (noiseModel_) { const Vector b = unwhitenedError(c); check(noiseModel_, b.size()); - return 0.5 * noiseModel_->weight(b); + return noiseModel_->weight(b); } else return 1.0; diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 8e4cf277c..dfa54f26f 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -33,9 +33,10 @@ # include #endif +#include #include #include -#include +#include using namespace std; @@ -46,7 +47,8 @@ template class FactorGraph; /* ************************************************************************* */ double NonlinearFactorGraph::probPrime(const Values& values) const { - return exp(-0.5 * error(values)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(values)); } /* ************************************************************************* */ @@ -55,9 +57,14 @@ void NonlinearFactorGraph::print(const std::string& str, const KeyFormatter& key for (size_t i = 0; i < factors_.size(); i++) { stringstream ss; ss << "Factor " << i << ": "; - if (factors_[i] != nullptr) factors_[i]->print(ss.str(), keyFormatter); - cout << endl; + if (factors_[i] != nullptr) { + factors_[i]->print(ss.str(), keyFormatter); + cout << "\n"; + } else { + cout << ss.str() << "nullptr\n"; + } } + std::cout.flush(); } /* ************************************************************************* */ @@ -81,8 +88,9 @@ void NonlinearFactorGraph::printErrors(const Values& values, const std::string& factor->print(ss.str(), keyFormatter); cout << "error = " << errorValue << "\n"; } - cout << endl; // only one "endl" at end might be faster, \n for each factor + cout << "\n"; } + std::cout.flush(); } /* ************************************************************************* */ @@ -91,89 +99,25 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, - const GraphvizFormatting& formatting, - const KeyFormatter& keyFormatter) const -{ - stm << "graph {\n"; - stm << " size=\"" << formatting.figureWidthInches << "," << - formatting.figureHeightInches << "\";\n\n"; +void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + writer.graphPreamble(&os); + // Find bounds (imperative) KeySet keys = this->keys(); - - // Local utility function to extract x and y coordinates - struct { boost::optional operator()( - const Value& value, const GraphvizFormatting& graphvizFormatting) - { - Vector3 t; - if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value().translation(); - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value(); - } else { - return boost::none; - } - double x, y; - switch (graphvizFormatting.paperHorizontalAxis) { - case GraphvizFormatting::X: x = t.x(); break; - case GraphvizFormatting::Y: x = t.y(); break; - case GraphvizFormatting::Z: x = t.z(); break; - case GraphvizFormatting::NEGX: x = -t.x(); break; - case GraphvizFormatting::NEGY: x = -t.y(); break; - case GraphvizFormatting::NEGZ: x = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - switch (graphvizFormatting.paperVerticalAxis) { - case GraphvizFormatting::X: y = t.x(); break; - case GraphvizFormatting::Y: y = t.y(); break; - case GraphvizFormatting::Z: y = t.z(); break; - case GraphvizFormatting::NEGX: y = -t.x(); break; - case GraphvizFormatting::NEGY: y = -t.y(); break; - case GraphvizFormatting::NEGZ: y = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - return Point2(x,y); - }} getXY; - - // Find bounds - double minX = numeric_limits::infinity(), maxX = -numeric_limits::infinity(); - double minY = numeric_limits::infinity(), maxY = -numeric_limits::infinity(); - for (const Key& key : keys) { - if (values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) { - if(xy->x() < minX) - minX = xy->x(); - if(xy->x() > maxX) - maxX = xy->x(); - if(xy->y() < minY) - minY = xy->y(); - if(xy->y() > maxY) - maxY = xy->y(); - } - } - } + Vector2 min = writer.findBounds(values, keys); // Create nodes for each variable in the graph - for(Key key: keys){ - // Label the node with the label from the KeyFormatter - stm << " var" << key << "[label=\"" << keyFormatter(key) << "\""; - if(values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) - stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\""; - } - stm << "];\n"; + for (Key key : keys) { + auto position = writer.variablePos(values, min, key); + writer.drawVariable(key, keyFormatter, position, &os); } - stm << "\n"; + os << "\n"; - if (formatting.mergeSimilarFactors) { + if (writer.mergeSimilarFactors) { // Remove duplicate factors - std::set structure; + std::set structure; for (const sharedFactor& factor : factors_) { if (factor) { KeyVector factorKeys = factor->keys(); @@ -184,86 +128,41 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, // Create factors and variable connections size_t i = 0; - for(const KeyVector& factorKeys: structure){ - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = formatting.factorPositions.find(i); - if(pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << "," - << formatting.scale*(pos->second.y() - minY) << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - for(Key key: factorKeys) { - stm << " var" << key << "--" << "factor" << i << ";\n"; - } - - ++ i; + for (const KeyVector& factorKeys : structure) { + writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os); } } else { // Create factors and variable connections - for(size_t i = 0; i < size(); ++i) { + for (size_t i = 0; i < size(); ++i) { const NonlinearFactor::shared_ptr& factor = at(i); - // If null pointer, move on to the next - if (!factor) { - continue; - } - - if (formatting.plotFactorPoints) { - const KeyVector& keys = factor->keys(); - if (formatting.binaryEdges && keys.size() == 2) { - stm << " var" << keys[0] << "--" - << "var" << keys[1] << ";\n"; - } else { - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = - formatting.factorPositions.find(i); - if (pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX) - << "," << formatting.scale * (pos->second.y() - minY) - << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - if (formatting.connectKeysToFactor && factor) { - for (Key key : *factor) { - stm << " var" << key << "--" - << "factor" << i << ";\n"; - } - } - } - } else { - Key k; - bool firstTime = true; - for (Key key : *this->at(i)) { - if (firstTime) { - k = key; - firstTime = false; - continue; - } - stm << " var" << key << "--" - << "var" << k << ";\n"; - k = key; - } + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, keyFormatter, + writer.factorPos(min, i), &os); } } } - stm << "}\n"; + os << "}\n"; + std::flush(os); } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph( - const std::string& file, const Values& values, - const GraphvizFormatting& graphvizFormatting, - const KeyFormatter& keyFormatter) const { - std::ofstream of(file); - saveGraph(of, values, graphvizFormatting, keyFormatter); +std::string NonlinearFactorGraph::dot(const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::stringstream ss; + dot(ss, values, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +void NonlinearFactorGraph::saveGraph(const std::string& filename, + const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::ofstream of(filename); + dot(of, values, keyFormatter, writer); of.close(); } diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 4d321f8ab..3237d7c1e 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -42,38 +43,14 @@ namespace gtsam { class ExpressionFactor; /** - * Formatting options when saving in GraphViz format using - * NonlinearFactorGraph::saveGraph. - */ - struct GTSAM_EXPORT GraphvizFormatting { - enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis - double figureWidthInches; ///< The figure width on paper in inches - double figureHeightInches; ///< The figure height on paper in inches - double scale; ///< Scale all positions to reduce / increase density - bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity - bool plotFactorPoints; ///< Plots each factor as a dot between the variables - bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor - bool binaryEdges; ///< just use non-dotted edges for binary factors - std::map factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions. - /// Default constructor sets up robot coordinates. Paper horizontal is robot Y, - /// paper vertical is robot X. Default figure size of 5x5 in. - GraphvizFormatting() : - paperHorizontalAxis(Y), paperVerticalAxis(X), - figureWidthInches(5), figureHeightInches(5), scale(1), - mergeSimilarFactors(false), plotFactorPoints(true), - connectKeysToFactor(true), binaryEdges(true) {} - }; - - - /** - * A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, - * which derive from NonlinearFactor. The values structures are typically (in SAM) more general - * than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds. - * Linearizing the non-linear factor graph creates a linear factor graph on the - * tangent vector space at the linearization point. Because the tangent space is a true - * vector space, the config type will be an VectorValues in that linearized factor graph. + * A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors, + * which derive from NonlinearFactor. The values structures are typically (in + * SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects + * in non-linear manifolds. Linearizing the non-linear factor graph creates a + * linear factor graph on the tangent vector space at the linearization point. + * Because the tangent space is a true vector space, the config type will be + * an VectorValues in that linearized factor graph. + * @addtogroup nonlinear */ class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph { @@ -83,6 +60,9 @@ namespace gtsam { typedef NonlinearFactorGraph This; typedef boost::shared_ptr shared_ptr; + /// @name Standard Constructors + /// @{ + /** Default constructor */ NonlinearFactorGraph() {} @@ -101,6 +81,10 @@ namespace gtsam { /// Destructor virtual ~NonlinearFactorGraph() {} + /// @} + /// @name Testable + /// @{ + /** print */ void print( const std::string& str = "NonlinearFactorGraph: ", @@ -115,22 +99,11 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /// Write the graph in GraphViz format for visualization - void saveGraph(std::ostream& stm, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} + /// @name Standard Interface + /// @{ - /** - * Write the graph in GraphViz format to file for visualization. - * - * This is a wrapper friendly version since wrapped languages don't have - * access to C++ streams. - */ - void saveGraph(const std::string& file, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ + /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ double error(const Values& values) const; /** Unnormalized probability. O(n) */ @@ -246,7 +219,32 @@ namespace gtsam { emplace_shared>(key, prior, covariance); } - private: + /// @} + /// @name Graph Display + /// @{ + + using FactorGraph::dot; + using FactorGraph::saveGraph; + + /// Output to graphviz format, stream version, with Values/extra options. + void dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + + /// Output to graphviz format string, with Values/extra options. + std::string dot( + const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + + /// output to file with graphviz format, with Values/extra options. + void saveGraph( + const std::string& filename, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + /// @} + + private: /** * Linearize from Scatter rather than from Ordering. Made private because @@ -265,16 +263,36 @@ namespace gtsam { public: - /** \deprecated */ - boost::shared_ptr linearizeToHessianFactor( +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED linearizeToHessianFactor( const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return linearizeToHessianFactor(values, dampen);} - /** \deprecated */ - Values updateCholesky(const Values& values, boost::none_t, + /** @deprecated */ + Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return updateCholesky(values, dampen);} + /** @deprecated */ + void GTSAM_DEPRECATED saveGraph( + std::ostream& os, const Values& values = Values(), + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + dot(os, values, keyFormatter, graphvizFormatting); + } + /** @deprecated */ + void GTSAM_DEPRECATED + saveGraph(const std::string& filename, const Values& values, + const GraphvizFormatting& graphvizFormatting, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + saveGraph(filename, values, keyFormatter, graphvizFormatting); + } + /// @} +#endif + }; /// traits diff --git a/gtsam/nonlinear/NonlinearOptimizer.cpp b/gtsam/nonlinear/NonlinearOptimizer.cpp index 0d7e9e17f..3ce6db4af 100644 --- a/gtsam/nonlinear/NonlinearOptimizer.cpp +++ b/gtsam/nonlinear/NonlinearOptimizer.cpp @@ -147,11 +147,13 @@ VectorValues NonlinearOptimizer::solve(const GaussianFactorGraph& gfg, } else if (params.isSequential()) { // Sequential QR or Cholesky (decided by params.getEliminationFunction()) if (params.ordering) - delta = gfg.eliminateSequential(*params.ordering, params.getEliminationFunction(), - boost::none, params.orderingType)->optimize(); + delta = gfg.eliminateSequential(*params.ordering, + params.getEliminationFunction()) + ->optimize(); else - delta = gfg.eliminateSequential(params.getEliminationFunction(), boost::none, - params.orderingType)->optimize(); + delta = gfg.eliminateSequential(params.orderingType, + params.getEliminationFunction()) + ->optimize(); } else if (params.isIterative()) { // Conjugate Gradient -> needs params.iterativeParams if (!params.iterativeParams) diff --git a/gtsam/nonlinear/Values-inl.h b/gtsam/nonlinear/Values-inl.h index 8ebdcab17..0370c5cee 100644 --- a/gtsam/nonlinear/Values-inl.h +++ b/gtsam/nonlinear/Values-inl.h @@ -279,10 +279,11 @@ namespace gtsam { template struct handle { ValueType operator()(Key j, const Value* const pointer) { - try { + auto ptr = dynamic_cast*>(pointer); + if (ptr) { // value returns a const ValueType&, and the return makes a copy !!!!! - return dynamic_cast&>(*pointer).value(); - } catch (std::bad_cast&) { + return ptr->value(); + } else { throw ValuesIncorrectType(j, typeid(*pointer), typeid(ValueType)); } } @@ -294,11 +295,12 @@ namespace gtsam { // Handle dynamic matrices template struct handle_matrix, true> { - Eigen::Matrix operator()(Key j, const Value* const pointer) { - try { + inline Eigen::Matrix operator()(Key j, const Value* const pointer) { + auto ptr = dynamic_cast>*>(pointer); + if (ptr) { // value returns a const Matrix&, and the return makes a copy !!!!! - return dynamic_cast>&>(*pointer).value(); - } catch (std::bad_cast&) { + return ptr->value(); + } else { // If a fixed matrix was stored, we end up here as well. throw ValuesIncorrectType(j, typeid(*pointer), typeid(Eigen::Matrix)); } @@ -308,16 +310,18 @@ namespace gtsam { // Handle fixed matrices template struct handle_matrix, false> { - Eigen::Matrix operator()(Key j, const Value* const pointer) { - try { + inline Eigen::Matrix operator()(Key j, const Value* const pointer) { + auto ptr = dynamic_cast>*>(pointer); + if (ptr) { // value returns a const MatrixMN&, and the return makes a copy !!!!! - return dynamic_cast>&>(*pointer).value(); - } catch (std::bad_cast&) { + return ptr->value(); + } else { Matrix A; - try { - // Check if a dynamic matrix was stored - A = handle_matrix()(j, pointer); // will throw if not.... - } catch (const ValuesIncorrectType&) { + // Check if a dynamic matrix was stored + auto ptr = dynamic_cast*>(pointer); + if (ptr) { + A = ptr->value(); + } else { // Or a dynamic vector A = handle_matrix()(j, pointer); // will throw if not.... } @@ -364,10 +368,10 @@ namespace gtsam { if(item != values_.end()) { // dynamic cast the type and throw exception if incorrect - const Value& value = *item->second; - try { - return dynamic_cast&>(value).value(); - } catch (std::bad_cast &) { + auto ptr = dynamic_cast*>(item->second); + if (ptr) { + return ptr->value(); + } else { // NOTE(abe): clang warns about potential side effects if done in typeid const Value* value = item->second; throw ValuesIncorrectType(j, typeid(*value), typeid(ValueType)); @@ -391,4 +395,10 @@ namespace gtsam { update(j, static_cast(GenericValue(val))); } + // insert_or_assign with templated value + template + void Values::insert_or_assign(Key j, const ValueType& val) { + insert_or_assign(j, static_cast(GenericValue(val))); + } + } diff --git a/gtsam/nonlinear/Values.cpp b/gtsam/nonlinear/Values.cpp index ebc9c51f6..adadc99c0 100644 --- a/gtsam/nonlinear/Values.cpp +++ b/gtsam/nonlinear/Values.cpp @@ -171,6 +171,25 @@ namespace gtsam { } } + /* ************************************************************************ */ + void Values::insert_or_assign(Key j, const Value& val) { + if (this->exists(j)) { + // If key already exists, perform an update. + this->update(j, val); + } else { + // If key does not exist, perform an insert. + this->insert(j, val); + } + } + + /* ************************************************************************ */ + void Values::insert_or_assign(const Values& values) { + for (const_iterator key_value = values.begin(); key_value != values.end(); + ++key_value) { + this->insert_or_assign(key_value->key, key_value->value); + } + } + /* ************************************************************************* */ void Values::erase(Key j) { KeyValueMap::iterator item = values_.find(j); diff --git a/gtsam/nonlinear/Values.h b/gtsam/nonlinear/Values.h index 33e9e7d82..cfe6347b5 100644 --- a/gtsam/nonlinear/Values.h +++ b/gtsam/nonlinear/Values.h @@ -24,6 +24,7 @@ #pragma once +#include #include #include #include @@ -62,17 +63,18 @@ namespace gtsam { class GTSAM_EXPORT Values { private: - // Internally we store a boost ptr_map, with a ValueCloneAllocator (defined - // below) to clone and deallocate the Value objects, and a boost - // fast_pool_allocator to allocate map nodes. In this way, all memory is - // allocated in a boost memory pool. + // below) to clone and deallocate the Value objects, and our compile-flag- + // dependent FastDefaultAllocator to allocate map nodes. In this way, the + // user defines the allocation details (i.e. optimize for memory pool/arenas + // concurrency). + typedef internal::FastDefaultAllocator>::type KeyValuePtrPairAllocator; typedef boost::ptr_map< Key, Value, std::less, ValueCloneAllocator, - boost::fast_pool_allocator > > KeyValueMap; + KeyValuePtrPairAllocator > KeyValueMap; // The member to store the values, see just above KeyValueMap values_; @@ -283,6 +285,19 @@ namespace gtsam { /** update the current available values without adding new ones */ void update(const Values& values); + /// If key j exists, update value, else perform an insert. + void insert_or_assign(Key j, const Value& val); + + /** + * Update a set of variables. + * If any variable key doe not exist, then perform an insert. + */ + void insert_or_assign(const Values& values); + + /// Templated version to insert_or_assign a variable with the given j. + template + void insert_or_assign(Key j, const ValueType& val); + /** Remove a variable from the config, throws KeyDoesNotExist if j is not present */ void erase(Key j); diff --git a/gtsam/nonlinear/WhiteNoiseFactor.h b/gtsam/nonlinear/WhiteNoiseFactor.h index 95f46ab6c..1cd117437 100644 --- a/gtsam/nonlinear/WhiteNoiseFactor.h +++ b/gtsam/nonlinear/WhiteNoiseFactor.h @@ -17,6 +17,8 @@ * @date September 2011 */ +#pragma once + #include #include #include diff --git a/gtsam/nonlinear/factorTesting.h b/gtsam/nonlinear/factorTesting.h index 74ef87737..266aa841c 100644 --- a/gtsam/nonlinear/factorTesting.h +++ b/gtsam/nonlinear/factorTesting.h @@ -21,6 +21,8 @@ #include #include +#include +#include namespace gtsam { @@ -34,36 +36,36 @@ namespace gtsam { * This is fixable but expensive, and does not matter in practice as most factors will sit near * zero errors anyway. However, it means that below will only be exact for the correct measurement. */ -JacobianFactor linearizeNumerically(const NoiseModelFactor& factor, - const Values& values, double delta = 1e-5) { - +inline JacobianFactor linearizeNumerically(const NoiseModelFactor& factor, + const Values& values, + double delta = 1e-5) { // We will fill a vector of key/Jacobians pairs (a map would sort) std::vector > jacobians; // Get size - const Eigen::VectorXd e = factor.whitenedError(values); + const Vector e = factor.whitenedError(values); const size_t rows = e.size(); // Loop over all variables const double one_over_2delta = 1.0 / (2.0 * delta); - for(Key key: factor) { + for (Key key : factor) { // Compute central differences using the values struct. VectorValues dX = values.zeroVectors(); const size_t cols = dX.dim(key); Matrix J = Matrix::Zero(rows, cols); for (size_t col = 0; col < cols; ++col) { - Eigen::VectorXd dx = Eigen::VectorXd::Zero(cols); - dx[col] = delta; + Vector dx = Vector::Zero(cols); + dx(col) = delta; dX[key] = dx; Values eval_values = values.retract(dX); - const Eigen::VectorXd left = factor.whitenedError(eval_values); - dx[col] = -delta; + const Vector left = factor.whitenedError(eval_values); + dx(col) = -delta; dX[key] = dx; eval_values = values.retract(dX); - const Eigen::VectorXd right = factor.whitenedError(eval_values); + const Vector right = factor.whitenedError(eval_values); J.col(col) = (left - right) * one_over_2delta; } - jacobians.push_back(std::make_pair(key,J)); + jacobians.emplace_back(key, J); } // Next step...return JacobianFactor @@ -72,15 +74,15 @@ JacobianFactor linearizeNumerically(const NoiseModelFactor& factor, namespace internal { // CPPUnitLite-style test for linearization of a factor -bool testFactorJacobians(const std::string& name_, - const NoiseModelFactor& factor, const gtsam::Values& values, double delta, - double tolerance) { - +inline bool testFactorJacobians(const std::string& name_, + const NoiseModelFactor& factor, + const gtsam::Values& values, double delta, + double tolerance) { // Create expected value by numerical differentiation JacobianFactor expected = linearizeNumerically(factor, values, delta); // Create actual value by linearize - boost::shared_ptr actual = // + auto actual = boost::dynamic_pointer_cast(factor.linearize(values)); if (!actual) return false; @@ -90,17 +92,19 @@ bool testFactorJacobians(const std::string& name_, // if not equal, test individual jacobians: if (!equal) { for (size_t i = 0; i < actual->size(); i++) { - bool i_good = assert_equal((Matrix) (expected.getA(expected.begin() + i)), - (Matrix) (actual->getA(actual->begin() + i)), tolerance); + bool i_good = + assert_equal((Matrix)(expected.getA(expected.begin() + i)), + (Matrix)(actual->getA(actual->begin() + i)), tolerance); if (!i_good) { - std::cout << "Mismatch in Jacobian " << i+1 << " (base 1), as shown above" << std::endl; + std::cout << "Mismatch in Jacobian " << i + 1 + << " (base 1), as shown above" << std::endl; } } } return equal; } -} +} // namespace internal /// \brief Check the Jacobians produced by a factor against finite differences. /// \param factor The factor to test. @@ -110,4 +114,4 @@ bool testFactorJacobians(const std::string& name_, #define EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, numerical_derivative_step, tolerance) \ { EXPECT(gtsam::internal::testFactorJacobians(name_, factor, values, numerical_derivative_step, tolerance)); } -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/nonlinear/internal/LevenbergMarquardtState.h b/gtsam/nonlinear/internal/LevenbergMarquardtState.h index cee839540..75e5a5135 100644 --- a/gtsam/nonlinear/internal/LevenbergMarquardtState.h +++ b/gtsam/nonlinear/internal/LevenbergMarquardtState.h @@ -16,6 +16,8 @@ * @date April 2016 */ +#pragma once + #include "NonlinearOptimizerState.h" #include diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index d068bd7ee..3fff71978 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -23,113 +23,20 @@ namespace gtsam { #include #include #include -#include #include #include +#include -class Symbol { - Symbol(); - Symbol(char c, uint64_t j); - Symbol(size_t key); +#include +class GraphvizFormatting : gtsam::DotWriter { + GraphvizFormatting(); - size_t key() const; - void print(const string& s = "") const; - bool equals(const gtsam::Symbol& expected, double tol) const; + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + Axis paperHorizontalAxis; + Axis paperVerticalAxis; - char chr() const; - uint64_t index() const; - string string() const; -}; - -size_t symbol(char chr, size_t index); -char symbolChr(size_t key); -size_t symbolIndex(size_t key); - -namespace symbol_shorthand { -size_t A(size_t j); -size_t B(size_t j); -size_t C(size_t j); -size_t D(size_t j); -size_t E(size_t j); -size_t F(size_t j); -size_t G(size_t j); -size_t H(size_t j); -size_t I(size_t j); -size_t J(size_t j); -size_t K(size_t j); -size_t L(size_t j); -size_t M(size_t j); -size_t N(size_t j); -size_t O(size_t j); -size_t P(size_t j); -size_t Q(size_t j); -size_t R(size_t j); -size_t S(size_t j); -size_t T(size_t j); -size_t U(size_t j); -size_t V(size_t j); -size_t W(size_t j); -size_t X(size_t j); -size_t Y(size_t j); -size_t Z(size_t j); -} // namespace symbol_shorthand - -// Default keyformatter -void PrintKeyList( - const gtsam::KeyList& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void PrintKeyVector( - const gtsam::KeyVector& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void PrintKeySet( - const gtsam::KeySet& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); - -#include -class LabeledSymbol { - LabeledSymbol(size_t full_key); - LabeledSymbol(const gtsam::LabeledSymbol& key); - LabeledSymbol(unsigned char valType, unsigned char label, size_t j); - - size_t key() const; - unsigned char label() const; - unsigned char chr() const; - size_t index() const; - - gtsam::LabeledSymbol upper() const; - gtsam::LabeledSymbol lower() const; - gtsam::LabeledSymbol newChr(unsigned char c) const; - gtsam::LabeledSymbol newLabel(unsigned char label) const; - - void print(string s = "") const; -}; - -size_t mrsymbol(unsigned char c, unsigned char label, size_t j); -unsigned char mrsymbolChr(size_t key); -unsigned char mrsymbolLabel(size_t key); -size_t mrsymbolIndex(size_t key); - -#include -class Ordering { - // Standard Constructors and Named Constructors - Ordering(); - Ordering(const gtsam::Ordering& other); - - // Testable - void print(string s = "", const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::Ordering& ord, double tol) const; - - // Standard interface - size_t size() const; - size_t at(size_t key) const; - void push_back(size_t key); - - // enabling serialization functionality - void serialize() const; - - // enable pickling in python - void pickle() const; + double scale; + bool mergeSimilarFactors; }; #include @@ -189,13 +96,17 @@ class NonlinearFactorGraph { gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const; gtsam::NonlinearFactorGraph clone() const; + string dot( + const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()); + void saveGraph( + const string& s, const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; - - void saveGraph(const string& s) const; }; #include @@ -271,6 +182,7 @@ class Values { void insert(const gtsam::Values& values); void update(const gtsam::Values& values); + void insert_or_assign(const gtsam::Values& values); void erase(size_t j); void swap(gtsam::Values& values); @@ -285,9 +197,6 @@ class Values { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // New in 4.0, we have to specialize every insert/update/at to generate // wrappers Instead of the old: void insert(size_t j, const gtsam::Value& // value); void update(size_t j, const gtsam::Value& val); gtsam::Value @@ -320,6 +229,25 @@ class Values { void insert(size_t j, const gtsam::imuBias::ConstantBias& constant_bias); void insert(size_t j, const gtsam::NavState& nav_state); void insert(size_t j, double c); + void insert(size_t j, const gtsam::ParameterMatrix<1>& X); + void insert(size_t j, const gtsam::ParameterMatrix<2>& X); + void insert(size_t j, const gtsam::ParameterMatrix<3>& X); + void insert(size_t j, const gtsam::ParameterMatrix<4>& X); + void insert(size_t j, const gtsam::ParameterMatrix<5>& X); + void insert(size_t j, const gtsam::ParameterMatrix<6>& X); + void insert(size_t j, const gtsam::ParameterMatrix<7>& X); + void insert(size_t j, const gtsam::ParameterMatrix<8>& X); + void insert(size_t j, const gtsam::ParameterMatrix<9>& X); + void insert(size_t j, const gtsam::ParameterMatrix<10>& X); + void insert(size_t j, const gtsam::ParameterMatrix<11>& X); + void insert(size_t j, const gtsam::ParameterMatrix<12>& X); + void insert(size_t j, const gtsam::ParameterMatrix<13>& X); + void insert(size_t j, const gtsam::ParameterMatrix<14>& X); + void insert(size_t j, const gtsam::ParameterMatrix<15>& X); + + template + void insert(size_t j, const T& val); void update(size_t j, const gtsam::Point2& point2); void update(size_t j, const gtsam::Point3& point3); @@ -346,6 +274,62 @@ class Values { void update(size_t j, Vector vector); void update(size_t j, Matrix matrix); void update(size_t j, double c); + void update(size_t j, const gtsam::ParameterMatrix<1>& X); + void update(size_t j, const gtsam::ParameterMatrix<2>& X); + void update(size_t j, const gtsam::ParameterMatrix<3>& X); + void update(size_t j, const gtsam::ParameterMatrix<4>& X); + void update(size_t j, const gtsam::ParameterMatrix<5>& X); + void update(size_t j, const gtsam::ParameterMatrix<6>& X); + void update(size_t j, const gtsam::ParameterMatrix<7>& X); + void update(size_t j, const gtsam::ParameterMatrix<8>& X); + void update(size_t j, const gtsam::ParameterMatrix<9>& X); + void update(size_t j, const gtsam::ParameterMatrix<10>& X); + void update(size_t j, const gtsam::ParameterMatrix<11>& X); + void update(size_t j, const gtsam::ParameterMatrix<12>& X); + void update(size_t j, const gtsam::ParameterMatrix<13>& X); + void update(size_t j, const gtsam::ParameterMatrix<14>& X); + void update(size_t j, const gtsam::ParameterMatrix<15>& X); + + void insert_or_assign(size_t j, const gtsam::Point2& point2); + void insert_or_assign(size_t j, const gtsam::Point3& point3); + void insert_or_assign(size_t j, const gtsam::Rot2& rot2); + void insert_or_assign(size_t j, const gtsam::Pose2& pose2); + void insert_or_assign(size_t j, const gtsam::SO3& R); + void insert_or_assign(size_t j, const gtsam::SO4& Q); + void insert_or_assign(size_t j, const gtsam::SOn& P); + void insert_or_assign(size_t j, const gtsam::Rot3& rot3); + void insert_or_assign(size_t j, const gtsam::Pose3& pose3); + void insert_or_assign(size_t j, const gtsam::Unit3& unit3); + void insert_or_assign(size_t j, const gtsam::Cal3_S2& cal3_s2); + void insert_or_assign(size_t j, const gtsam::Cal3DS2& cal3ds2); + void insert_or_assign(size_t j, const gtsam::Cal3Bundler& cal3bundler); + void insert_or_assign(size_t j, const gtsam::Cal3Fisheye& cal3fisheye); + void insert_or_assign(size_t j, const gtsam::Cal3Unified& cal3unified); + void insert_or_assign(size_t j, const gtsam::EssentialMatrix& essential_matrix); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::imuBias::ConstantBias& constant_bias); + void insert_or_assign(size_t j, const gtsam::NavState& nav_state); + void insert_or_assign(size_t j, Vector vector); + void insert_or_assign(size_t j, Matrix matrix); + void insert_or_assign(size_t j, double c); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<1>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<2>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<3>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<4>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<5>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<6>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<7>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<8>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<9>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<10>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<11>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<12>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<13>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<14>& X); + void insert_or_assign(size_t j, const gtsam::ParameterMatrix<15>& X); template + double, + gtsam::ParameterMatrix<1>, + gtsam::ParameterMatrix<2>, + gtsam::ParameterMatrix<3>, + gtsam::ParameterMatrix<4>, + gtsam::ParameterMatrix<5>, + gtsam::ParameterMatrix<6>, + gtsam::ParameterMatrix<7>, + gtsam::ParameterMatrix<8>, + gtsam::ParameterMatrix<9>, + gtsam::ParameterMatrix<10>, + gtsam::ParameterMatrix<11>, + gtsam::ParameterMatrix<12>, + gtsam::ParameterMatrix<13>, + gtsam::ParameterMatrix<14>, + gtsam::ParameterMatrix<15>}> T at(size_t j); }; @@ -654,21 +653,19 @@ class ISAM2Params { void setOptimizationParams(const gtsam::ISAM2DoglegParams& dogleg_params); void setRelinearizeThreshold(double threshold); void setRelinearizeThreshold(const gtsam::ISAM2ThresholdMap& threshold_map); - int getRelinearizeSkip() const; - void setRelinearizeSkip(int relinearizeSkip); - bool isEnableRelinearization() const; - void setEnableRelinearization(bool enableRelinearization); - bool isEvaluateNonlinearError() const; - void setEvaluateNonlinearError(bool evaluateNonlinearError); string getFactorization() const; void setFactorization(string factorization); - bool isCacheLinearizedFactors() const; - void setCacheLinearizedFactors(bool cacheLinearizedFactors); - bool isEnableDetailedResults() const; - void setEnableDetailedResults(bool enableDetailedResults); - bool isEnablePartialRelinearizationCheck() const; - void setEnablePartialRelinearizationCheck( - bool enablePartialRelinearizationCheck); + + int relinearizeSkip; + bool enableRelinearization; + bool evaluateNonlinearError; + bool cacheLinearizedFactors; + bool enableDetailedResults; + bool enablePartialRelinearizationCheck; + bool findUnusedFactorSlots; + + enum Factorization { CHOLESKY, QR }; + Factorization factorization; }; class ISAM2Clique { @@ -689,6 +686,7 @@ class ISAM2Result { /** Getters and Setters for all properties */ size_t getVariablesRelinearized() const; size_t getVariablesReeliminated() const; + FactorIndices getNewFactorsIndices() const; size_t getCliques() const; double getErrorBefore() const; double getErrorAfter() const; @@ -734,7 +732,12 @@ class ISAM2 { const gtsam::KeyList& extraReelimKeys, bool force_relinearize); + gtsam::ISAM2Result update(const gtsam::NonlinearFactorGraph& newFactors, + const gtsam::Values& newTheta, + const gtsam::ISAM2UpdateParams& updateParams); + gtsam::Values getLinearizationPoint() const; + bool valueExists(gtsam::Key key) const; gtsam::Values calculateEstimate() const; template , gtsam::PinholeCamera, Vector, Matrix}> VALUE calculateEstimate(size_t key) const; - gtsam::Values calculateBestEstimate() const; Matrix marginalCovariance(size_t key) const; + gtsam::Values calculateBestEstimate() const; gtsam::VectorValues getDelta() const; + double error(const gtsam::VectorValues& x) const; gtsam::NonlinearFactorGraph getFactorsUnsafe() const; gtsam::VariableIndex getVariableIndex() const; + const gtsam::KeySet& getFixedVariables() const; gtsam::ISAM2Params params() const; + + void printStats() const; + gtsam::VectorValues gradientAtZero() const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; #include @@ -810,9 +824,6 @@ virtual class PriorFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include diff --git a/gtsam/nonlinear/tests/testCallRecord.cpp b/gtsam/nonlinear/tests/testCallRecord.cpp index 5d0d5d5f2..419172f74 100644 --- a/gtsam/nonlinear/tests/testCallRecord.cpp +++ b/gtsam/nonlinear/tests/testCallRecord.cpp @@ -153,7 +153,7 @@ TEST(CallRecord, virtualReverseAdDispatching) { } { const int Rows = 6; - record.CallRecord::reverseAD2(Eigen::Matrix(), NJM); + record.CallRecord::reverseAD2(Eigen::Matrix::Zero(), NJM); EXPECT((assert_equal(record.cc, CallConfig(Rows, Cols)))); record.CallRecord::reverseAD2(DynRowMat(Rows, Cols), NJM); EXPECT((assert_equal(record.cc, CallConfig(Eigen::Dynamic, Cols, Rows, Cols)))); @@ -168,4 +168,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/nonlinear/tests/testExpression.cpp b/gtsam/nonlinear/tests/testExpression.cpp index 80262ae3f..92f4998a2 100644 --- a/gtsam/nonlinear/tests/testExpression.cpp +++ b/gtsam/nonlinear/tests/testExpression.cpp @@ -293,6 +293,19 @@ TEST(Expression, compose3) { EXPECT(expected == R3.keys()); } +/* ************************************************************************* */ +// Test compose with double type (should be multiplication). +TEST(Expression, compose4) { + // Create expression + gtsam::Key key = 1; + Double_ R1(key), R2(key); + Double_ R3 = R1 * R2; + + // Check keys + set expected = list_of(1); + EXPECT(expected == R3.keys()); +} + /* ************************************************************************* */ // Test with ternary function. Rot3 composeThree(const Rot3& R1, const Rot3& R2, const Rot3& R3, OptionalJacobian<3, 3> H1, diff --git a/gtsam/nonlinear/tests/testFunctorizedFactor.cpp b/gtsam/nonlinear/tests/testFunctorizedFactor.cpp index 14a14fc19..214c5efa7 100644 --- a/gtsam/nonlinear/tests/testFunctorizedFactor.cpp +++ b/gtsam/nonlinear/tests/testFunctorizedFactor.cpp @@ -17,16 +17,14 @@ * @brief unit tests for FunctorizedFactor class */ -#include -#include -#include -#include -#include -#include -#include #include #include #include +#include +#include +#include + +#include using namespace std; using namespace gtsam; @@ -272,135 +270,6 @@ TEST(FunctorizedFactor, Lambda2) { EXPECT(assert_equal(Vector::Zero(3), error, 1e-9)); } -const size_t N = 2; - -//****************************************************************************** -TEST(FunctorizedFactor, Print2) { - const size_t M = 1; - - Vector measured = Vector::Ones(M) * 42; - - auto model = noiseModel::Isotropic::Sigma(M, 1.0); - VectorEvaluationFactor priorFactor(key, measured, model, N, 0); - - string expected = - " keys = { X0 }\n" - " noise model: unit (1) \n" - "FunctorizedFactor(X0)\n" - " measurement: [\n" - " 42\n" - "]\n" - " noise model sigmas: 1\n"; - - EXPECT(assert_print_equal(expected, priorFactor)); -} - -//****************************************************************************** -TEST(FunctorizedFactor, VectorEvaluationFactor) { - const size_t M = 4; - - Vector measured = Vector::Zero(M); - - auto model = noiseModel::Isotropic::Sigma(M, 1.0); - VectorEvaluationFactor priorFactor(key, measured, model, N, 0); - - NonlinearFactorGraph graph; - graph.add(priorFactor); - - ParameterMatrix stateMatrix(N); - - Values initial; - initial.insert>(key, stateMatrix); - - LevenbergMarquardtParams parameters; - parameters.verbosity = NonlinearOptimizerParams::SILENT; - parameters.verbosityLM = LevenbergMarquardtParams::SILENT; - parameters.setMaxIterations(20); - Values result = - LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); - - EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); -} - -//****************************************************************************** -TEST(FunctorizedFactor, VectorComponentFactor) { - const int P = 4; - const size_t i = 2; - const double measured = 0.0, t = 3.0, a = 2.0, b = 4.0; - auto model = noiseModel::Isotropic::Sigma(1, 1.0); - VectorComponentFactor controlPrior(key, measured, model, N, i, - t, a, b); - - NonlinearFactorGraph graph; - graph.add(controlPrior); - - ParameterMatrix

stateMatrix(N); - - Values initial; - initial.insert>(key, stateMatrix); - - LevenbergMarquardtParams parameters; - parameters.verbosity = NonlinearOptimizerParams::SILENT; - parameters.verbosityLM = LevenbergMarquardtParams::SILENT; - parameters.setMaxIterations(20); - Values result = - LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); - - EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); -} - -//****************************************************************************** -TEST(FunctorizedFactor, VecDerivativePrior) { - const size_t M = 4; - - Vector measured = Vector::Zero(M); - auto model = noiseModel::Isotropic::Sigma(M, 1.0); - VectorDerivativeFactor vecDPrior(key, measured, model, N, 0); - - NonlinearFactorGraph graph; - graph.add(vecDPrior); - - ParameterMatrix stateMatrix(N); - - Values initial; - initial.insert>(key, stateMatrix); - - LevenbergMarquardtParams parameters; - parameters.verbosity = NonlinearOptimizerParams::SILENT; - parameters.verbosityLM = LevenbergMarquardtParams::SILENT; - parameters.setMaxIterations(20); - Values result = - LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); - - EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); -} - -//****************************************************************************** -TEST(FunctorizedFactor, ComponentDerivativeFactor) { - const size_t M = 4; - - double measured = 0; - auto model = noiseModel::Isotropic::Sigma(1, 1.0); - ComponentDerivativeFactor controlDPrior(key, measured, model, - N, 0, 0); - - NonlinearFactorGraph graph; - graph.add(controlDPrior); - - Values initial; - ParameterMatrix stateMatrix(N); - initial.insert>(key, stateMatrix); - - LevenbergMarquardtParams parameters; - parameters.verbosity = NonlinearOptimizerParams::SILENT; - parameters.verbosityLM = LevenbergMarquardtParams::SILENT; - parameters.setMaxIterations(20); - Values result = - LevenbergMarquardtOptimizer(graph, initial, parameters).optimize(); - - EXPECT_DOUBLES_EQUAL(0, graph.error(result), 1e-9); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/nonlinear/tests/testSerializationNonlinear.cpp b/gtsam/nonlinear/tests/testSerializationNonlinear.cpp index f4bb5f4f6..4a73cbb0b 100644 --- a/gtsam/nonlinear/tests/testSerializationNonlinear.cpp +++ b/gtsam/nonlinear/tests/testSerializationNonlinear.cpp @@ -35,37 +35,37 @@ using namespace gtsam::serializationTestHelpers; /* ************************************************************************* */ // Create GUIDs for Noisemodels -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Base , "gtsam_noiseModel_mEstimator_Base"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Null , "gtsam_noiseModel_mEstimator_Null"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Fair , "gtsam_noiseModel_mEstimator_Fair"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Huber, "gtsam_noiseModel_mEstimator_Huber"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Tukey, "gtsam_noiseModel_mEstimator_Tukey"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic,"gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Robust, "gtsam_noiseModel_Robust"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Base , "gtsam_noiseModel_mEstimator_Base") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Null , "gtsam_noiseModel_mEstimator_Null") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Fair , "gtsam_noiseModel_mEstimator_Fair") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Huber, "gtsam_noiseModel_mEstimator_Huber") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Tukey, "gtsam_noiseModel_mEstimator_Tukey") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic,"gtsam_noiseModel_Isotropic") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Robust, "gtsam_noiseModel_Robust") +BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel") +BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal") /* ************************************************************************* */ // Create GUIDs for factors -BOOST_CLASS_EXPORT_GUID(gtsam::PriorFactor, "gtsam::PriorFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianConditional , "gtsam::GaussianConditional"); +BOOST_CLASS_EXPORT_GUID(gtsam::PriorFactor, "gtsam::PriorFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::GaussianConditional , "gtsam::GaussianConditional") /* ************************************************************************* */ // Export all classes derived from Value -GTSAM_VALUE_EXPORT(gtsam::Cal3_S2); -GTSAM_VALUE_EXPORT(gtsam::Cal3Bundler); -GTSAM_VALUE_EXPORT(gtsam::Point3); -GTSAM_VALUE_EXPORT(gtsam::Pose3); -GTSAM_VALUE_EXPORT(gtsam::Rot3); -GTSAM_VALUE_EXPORT(gtsam::PinholeCamera); -GTSAM_VALUE_EXPORT(gtsam::PinholeCamera); -GTSAM_VALUE_EXPORT(gtsam::PinholeCamera); +GTSAM_VALUE_EXPORT(gtsam::Cal3_S2) +GTSAM_VALUE_EXPORT(gtsam::Cal3Bundler) +GTSAM_VALUE_EXPORT(gtsam::Point3) +GTSAM_VALUE_EXPORT(gtsam::Pose3) +GTSAM_VALUE_EXPORT(gtsam::Rot3) +GTSAM_VALUE_EXPORT(gtsam::PinholeCamera) +GTSAM_VALUE_EXPORT(gtsam::PinholeCamera) +GTSAM_VALUE_EXPORT(gtsam::PinholeCamera) namespace detail { template struct pack { diff --git a/gtsam/geometry/tests/testUtilities.cpp b/gtsam/nonlinear/tests/testUtilities.cpp similarity index 68% rename from gtsam/geometry/tests/testUtilities.cpp rename to gtsam/nonlinear/tests/testUtilities.cpp index 25ac3acc8..55a7fdb13 100644 --- a/gtsam/geometry/tests/testUtilities.cpp +++ b/gtsam/nonlinear/tests/testUtilities.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include using namespace gtsam; @@ -55,6 +54,26 @@ TEST(Utilities, ExtractPoint3) { EXPECT_LONGS_EQUAL(2, all_points.rows()); } +/* ************************************************************************* */ +TEST(Utilities, ExtractVector) { + // Test normal case with 3 vectors and 1 non-vector (ignore non-vector) + auto values = Values(); + values.insert(X(0), (Vector(4) << 1, 2, 3, 4).finished()); + values.insert(X(2), (Vector(4) << 13, 14, 15, 16).finished()); + values.insert(X(1), (Vector(4) << 6, 7, 8, 9).finished()); + values.insert(X(3), Pose3()); + auto actual = utilities::extractVectors(values, 'x'); + auto expected = + (Matrix(3, 4) << 1, 2, 3, 4, 6, 7, 8, 9, 13, 14, 15, 16).finished(); + EXPECT(assert_equal(expected, actual)); + + // Check that mis-sized vectors fail + values.insert(X(4), (Vector(2) << 1, 2).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); + values.update(X(4), (Vector(6) << 1, 2, 3, 4, 5, 6).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); +} + /* ************************************************************************* */ int main() { srand(time(nullptr)); diff --git a/gtsam/nonlinear/tests/testValues.cpp b/gtsam/nonlinear/tests/testValues.cpp index b894f4816..bed2a8af9 100644 --- a/gtsam/nonlinear/tests/testValues.cpp +++ b/gtsam/nonlinear/tests/testValues.cpp @@ -172,6 +172,22 @@ TEST( Values, update_element ) CHECK(assert_equal((Vector)v2, cfg.at(key1))); } +TEST(Values, InsertOrAssign) { + Values values; + Key X(0); + double x = 1; + + CHECK(values.size() == 0); + // This should perform an insert. + values.insert_or_assign(X, x); + EXPECT(assert_equal(values.at(X), x)); + + // This should perform an update. + double y = 2; + values.insert_or_assign(X, y); + EXPECT(assert_equal(values.at(X), y)); +} + /* ************************************************************************* */ TEST(Values, basic_functions) { diff --git a/gtsam/nonlinear/utilities.h b/gtsam/nonlinear/utilities.h index fdc1da2c4..d2b38d374 100644 --- a/gtsam/nonlinear/utilities.h +++ b/gtsam/nonlinear/utilities.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -162,6 +163,34 @@ Matrix extractPose3(const Values& values) { return result; } +/// Extract all Vector values with a given symbol character into an mxn matrix, +/// where m is the number of symbols that match the character and n is the +/// dimension of the variables. If not all variables have dimension n, then a +/// runtime error will be thrown. The order of returned values are sorted by +/// the symbol. +/// For example, calling extractVector(values, 'x'), where values contains 200 +/// variables x1, x2, ..., x200 of type Vector each 5-dimensional, will return a +/// 200x5 matrix with row i containing xi. +Matrix extractVectors(const Values& values, char c) { + Values::ConstFiltered vectors = + values.filter(Symbol::ChrTest(c)); + if (vectors.size() == 0) { + return Matrix(); + } + auto dim = vectors.begin()->value.size(); + Matrix result(vectors.size(), dim); + Eigen::Index rowi = 0; + for (const auto& kv : vectors) { + if (kv.value.size() != dim) { + throw std::runtime_error( + "Tried to extract different-sized vectors into a single matrix"); + } + result.row(rowi) = kv.value; + ++rowi; + } + return result; +} + /// Perturb all Point2 values using normally distributed noise void perturbPoint2(Values& values, double sigma, int32_t seed = 42u) { noiseModel::Isotropic::shared_ptr model = diff --git a/gtsam/sam/sam.i b/gtsam/sam/sam.i index 370e1c3ea..90c319ede 100644 --- a/gtsam/sam/sam.i +++ b/gtsam/sam/sam.i @@ -20,12 +20,23 @@ virtual class RangeFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + const double measured() const; }; +// between points: +typedef gtsam::RangeFactor RangeFactor2; +typedef gtsam::RangeFactor RangeFactor3; + +// between pose and point: typedef gtsam::RangeFactor RangeFactor2D; -typedef gtsam::RangeFactor RangeFactor3D; typedef gtsam::RangeFactor RangeFactorPose2; + +// between poses: +typedef gtsam::RangeFactor RangeFactor3D; typedef gtsam::RangeFactor RangeFactorPose3; + +// more specialized types: typedef gtsam::RangeFactor RangeFactorCalibratedCameraPoint; typedef gtsam::RangeFactor, gtsam::Point3> @@ -45,6 +56,9 @@ virtual class RangeFactorWithTransform : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + // Use `double` instead of template since that is all we need. + const double measured() const; }; typedef gtsam::RangeFactorWithTransform @@ -64,6 +78,8 @@ virtual class BearingFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + const BEARING& measured() const; }; typedef gtsam::BearingFactor diff --git a/gtsam/sam/tests/testBearingFactor.cpp b/gtsam/sam/tests/testBearingFactor.cpp index 17a049a1d..904bdba31 100644 --- a/gtsam/sam/tests/testBearingFactor.cpp +++ b/gtsam/sam/tests/testBearingFactor.cpp @@ -21,14 +21,13 @@ #include #include #include -#include -#include #include using namespace std; using namespace gtsam; +namespace { Key poseKey(1); Key pointKey(2); @@ -41,43 +40,18 @@ typedef BearingFactor BearingFactor3D; Unit3 measurement3D = Pose3().bearing(Point3(1, 0, 0)); // has to match values! static SharedNoiseModel model3D(noiseModel::Isotropic::Sigma(2, 0.5)); BearingFactor3D factor3D(poseKey, pointKey, measurement3D, model3D); - -/* ************************************************************************* */ -// Export Noisemodels -// See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html -BOOST_CLASS_EXPORT(gtsam::noiseModel::Isotropic); - -/* ************************************************************************* */ -TEST(BearingFactor, Serialization2D) { - EXPECT(serializationTestHelpers::equalsObj(factor2D)); - EXPECT(serializationTestHelpers::equalsXML(factor2D)); - EXPECT(serializationTestHelpers::equalsBinary(factor2D)); } /* ************************************************************************* */ TEST(BearingFactor, 2D) { - // Serialize the factor - std::string serialized = serializeXML(factor2D); - - // And de-serialize it - BearingFactor2D factor; - deserializeXML(serialized, factor); - // Set the linearization point Values values; values.insert(poseKey, Pose2(1.0, 2.0, 0.57)); values.insert(pointKey, Point2(-4.0, 11.0)); - EXPECT_CORRECT_EXPRESSION_JACOBIANS(factor.expression({poseKey, pointKey}), + EXPECT_CORRECT_EXPRESSION_JACOBIANS(factor2D.expression({poseKey, pointKey}), values, 1e-7, 1e-5); - EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-7, 1e-5); -} - -/* ************************************************************************* */ -TEST(BearingFactor, Serialization3D) { - EXPECT(serializationTestHelpers::equalsObj(factor3D)); - EXPECT(serializationTestHelpers::equalsXML(factor3D)); - EXPECT(serializationTestHelpers::equalsBinary(factor3D)); + EXPECT_CORRECT_FACTOR_JACOBIANS(factor2D, values, 1e-7, 1e-5); } /* ************************************************************************* */ diff --git a/gtsam/sam/tests/testBearingRangeFactor.cpp b/gtsam/sam/tests/testBearingRangeFactor.cpp index 735358d89..0dcc227c7 100644 --- a/gtsam/sam/tests/testBearingRangeFactor.cpp +++ b/gtsam/sam/tests/testBearingRangeFactor.cpp @@ -21,14 +21,13 @@ #include #include #include -#include -#include #include using namespace std; using namespace gtsam; +namespace { Key poseKey(1); Key pointKey(2); @@ -40,43 +39,18 @@ typedef BearingRangeFactor BearingRangeFactor3D; static SharedNoiseModel model3D(noiseModel::Isotropic::Sigma(3, 0.5)); BearingRangeFactor3D factor3D(poseKey, pointKey, Pose3().bearing(Point3(1, 0, 0)), 1, model3D); - -/* ************************************************************************* */ -// Export Noisemodels -// See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html -BOOST_CLASS_EXPORT(gtsam::noiseModel::Isotropic); - -/* ************************************************************************* */ -TEST(BearingRangeFactor, Serialization2D) { - EXPECT(serializationTestHelpers::equalsObj(factor2D)); - EXPECT(serializationTestHelpers::equalsXML(factor2D)); - EXPECT(serializationTestHelpers::equalsBinary(factor2D)); } /* ************************************************************************* */ TEST(BearingRangeFactor, 2D) { - // Serialize the factor - std::string serialized = serializeXML(factor2D); - - // And de-serialize it - BearingRangeFactor2D factor; - deserializeXML(serialized, factor); - // Set the linearization point Values values; values.insert(poseKey, Pose2(1.0, 2.0, 0.57)); values.insert(pointKey, Point2(-4.0, 11.0)); - EXPECT_CORRECT_EXPRESSION_JACOBIANS(factor.expression({poseKey, pointKey}), + EXPECT_CORRECT_EXPRESSION_JACOBIANS(factor2D.expression({poseKey, pointKey}), values, 1e-7, 1e-5); - EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-7, 1e-5); -} - -/* ************************************************************************* */ -TEST(BearingRangeFactor, Serialization3D) { - EXPECT(serializationTestHelpers::equalsObj(factor3D)); - EXPECT(serializationTestHelpers::equalsXML(factor3D)); - EXPECT(serializationTestHelpers::equalsBinary(factor3D)); + EXPECT_CORRECT_FACTOR_JACOBIANS(factor2D, values, 1e-7, 1e-5); } /* ************************************************************************* */ diff --git a/gtsam/sam/tests/testRangeFactor.cpp b/gtsam/sam/tests/testRangeFactor.cpp index 5f5d4f4dd..200e1236a 100644 --- a/gtsam/sam/tests/testRangeFactor.cpp +++ b/gtsam/sam/tests/testRangeFactor.cpp @@ -22,7 +22,6 @@ #include #include #include -#include #include #include @@ -32,42 +31,40 @@ using namespace std::placeholders; using namespace std; using namespace gtsam; -// Create a noise model for the pixel error -static SharedNoiseModel model(noiseModel::Unit::Create(1)); - typedef RangeFactor RangeFactor2D; typedef RangeFactor RangeFactor3D; typedef RangeFactorWithTransform RangeFactorWithTransform2D; typedef RangeFactorWithTransform RangeFactorWithTransform3D; // Keys are deliberately *not* in sorted order to test that case. +namespace { +// Create a noise model for the pixel error +static SharedNoiseModel model(noiseModel::Unit::Create(1)); + constexpr Key poseKey(2); constexpr Key pointKey(1); constexpr double measurement(10.0); -/* ************************************************************************* */ Vector factorError2D(const Pose2& pose, const Point2& point, - const RangeFactor2D& factor) { + const RangeFactor2D& factor) { return factor.evaluateError(pose, point); } -/* ************************************************************************* */ Vector factorError3D(const Pose3& pose, const Point3& point, - const RangeFactor3D& factor) { + const RangeFactor3D& factor) { return factor.evaluateError(pose, point); } -/* ************************************************************************* */ Vector factorErrorWithTransform2D(const Pose2& pose, const Point2& point, - const RangeFactorWithTransform2D& factor) { + const RangeFactorWithTransform2D& factor) { return factor.evaluateError(pose, point); } -/* ************************************************************************* */ Vector factorErrorWithTransform3D(const Pose3& pose, const Point3& point, - const RangeFactorWithTransform3D& factor) { + const RangeFactorWithTransform3D& factor) { return factor.evaluateError(pose, point); } +} // namespace /* ************************************************************************* */ TEST( RangeFactor, Constructor) { @@ -75,27 +72,6 @@ TEST( RangeFactor, Constructor) { RangeFactor3D factor3D(poseKey, pointKey, measurement, model); } -/* ************************************************************************* */ -// Export Noisemodels -// See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html -BOOST_CLASS_EXPORT(gtsam::noiseModel::Unit); - -/* ************************************************************************* */ -TEST(RangeFactor, Serialization2D) { - RangeFactor2D factor2D(poseKey, pointKey, measurement, model); - EXPECT(serializationTestHelpers::equalsObj(factor2D)); - EXPECT(serializationTestHelpers::equalsXML(factor2D)); - EXPECT(serializationTestHelpers::equalsBinary(factor2D)); -} - -/* ************************************************************************* */ -TEST(RangeFactor, Serialization3D) { - RangeFactor3D factor3D(poseKey, pointKey, measurement, model); - EXPECT(serializationTestHelpers::equalsObj(factor3D)); - EXPECT(serializationTestHelpers::equalsXML(factor3D)); - EXPECT(serializationTestHelpers::equalsBinary(factor3D)); -} - /* ************************************************************************* */ TEST( RangeFactor, ConstructorWithTransform) { Pose2 body_P_sensor_2D(0.25, -0.10, -M_PI_2); @@ -142,28 +118,6 @@ TEST( RangeFactor, EqualsWithTransform ) { body_P_sensor_3D); CHECK(assert_equal(factor3D_1, factor3D_2)); } -/* ************************************************************************* */ -TEST( RangeFactor, EqualsAfterDeserializing) { - // Check that the same results are obtained after deserializing: - Pose3 body_P_sensor_3D(Rot3::RzRyRx(-M_PI_2, 0.0, -M_PI_2), - Point3(0.25, -0.10, 1.0)); - - RangeFactorWithTransform3D factor3D_1(poseKey, pointKey, measurement, model, - body_P_sensor_3D), factor3D_2; - - // check with Equal() trait: - gtsam::serializationTestHelpers::roundtripXML(factor3D_1, factor3D_2); - CHECK(assert_equal(factor3D_1, factor3D_2)); - - const Pose3 pose(Rot3::RzRyRx(0.2, -0.3, 1.75), Point3(1.0, 2.0, -3.0)); - const Point3 point(-2.0, 11.0, 1.0); - const Values values = {{poseKey, genericValue(pose)}, {pointKey, genericValue(point)}}; - - const Vector error_1 = factor3D_1.unwhitenedError(values); - const Vector error_2 = factor3D_2.unwhitenedError(values); - CHECK(assert_equal(error_1, error_2)); -} - /* ************************************************************************* */ TEST( RangeFactor, Error2D ) { // Create a factor @@ -411,7 +365,7 @@ TEST( RangeFactor, Camera) { /* ************************************************************************* */ // Do a test with non GTSAM types -namespace gtsam{ +namespace gtsam { template <> struct Range { typedef double result_type; @@ -421,7 +375,7 @@ struct Range { // derivatives not implemented } }; -} +} // namespace gtsam TEST(RangeFactor, NonGTSAM) { // Create a factor diff --git a/gtsam/sam/tests/testSerializationSam.cpp b/gtsam/sam/tests/testSerializationSam.cpp new file mode 100644 index 000000000..8fdd8f37e --- /dev/null +++ b/gtsam/sam/tests/testSerializationSam.cpp @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSerializationSam.cpp + * @brief All serialization test in this directory + * @author Frank Dellaert + * @date February 2022 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +namespace { +Key poseKey(1); +Key pointKey(2); +constexpr double rangeMmeasurement(10.0); +} // namespace + +/* ************************************************************************* */ +// Export Noisemodels +// See http://www.boost.org/doc/libs/1_32_0/libs/serialization/doc/special.html +BOOST_CLASS_EXPORT(gtsam::noiseModel::Isotropic); +BOOST_CLASS_EXPORT(gtsam::noiseModel::Unit); + +/* ************************************************************************* */ +TEST(SerializationSam, BearingFactor2D) { + using BearingFactor2D = BearingFactor; + double measurement2D(10.0); + static SharedNoiseModel model2D(noiseModel::Isotropic::Sigma(1, 0.5)); + BearingFactor2D factor2D(poseKey, pointKey, measurement2D, model2D); + EXPECT(serializationTestHelpers::equalsObj(factor2D)); + EXPECT(serializationTestHelpers::equalsXML(factor2D)); + EXPECT(serializationTestHelpers::equalsBinary(factor2D)); +} + +/* ************************************************************************* */ +TEST(SerializationSam, BearingFactor3D) { + using BearingFactor3D = BearingFactor; + Unit3 measurement3D = + Pose3().bearing(Point3(1, 0, 0)); // has to match values! + static SharedNoiseModel model3D(noiseModel::Isotropic::Sigma(2, 0.5)); + BearingFactor3D factor3D(poseKey, pointKey, measurement3D, model3D); + EXPECT(serializationTestHelpers::equalsObj(factor3D)); + EXPECT(serializationTestHelpers::equalsXML(factor3D)); + EXPECT(serializationTestHelpers::equalsBinary(factor3D)); +} + +/* ************************************************************************* */ +namespace { +static SharedNoiseModel rangeNoiseModel(noiseModel::Unit::Create(1)); +} + +TEST(SerializationSam, RangeFactor2D) { + using RangeFactor2D = RangeFactor; + RangeFactor2D factor2D(poseKey, pointKey, rangeMmeasurement, rangeNoiseModel); + EXPECT(serializationTestHelpers::equalsObj(factor2D)); + EXPECT(serializationTestHelpers::equalsXML(factor2D)); + EXPECT(serializationTestHelpers::equalsBinary(factor2D)); +} + +/* ************************************************************************* */ +TEST(SerializationSam, RangeFactor3D) { + using RangeFactor3D = RangeFactor; + RangeFactor3D factor3D(poseKey, pointKey, rangeMmeasurement, rangeNoiseModel); + EXPECT(serializationTestHelpers::equalsObj(factor3D)); + EXPECT(serializationTestHelpers::equalsXML(factor3D)); + EXPECT(serializationTestHelpers::equalsBinary(factor3D)); +} + +/* ************************************************************************* */ +TEST(RangeFactor, EqualsAfterDeserializing) { + // Check that the same results are obtained after deserializing: + Pose3 body_P_sensor_3D(Rot3::RzRyRx(-M_PI_2, 0.0, -M_PI_2), + Point3(0.25, -0.10, 1.0)); + RangeFactorWithTransform factor3D_1( + poseKey, pointKey, rangeMmeasurement, rangeNoiseModel, body_P_sensor_3D), + factor3D_2; + + // check with Equal() trait: + gtsam::serializationTestHelpers::roundtripXML(factor3D_1, factor3D_2); + CHECK(assert_equal(factor3D_1, factor3D_2)); + + const Pose3 pose(Rot3::RzRyRx(0.2, -0.3, 1.75), Point3(1.0, 2.0, -3.0)); + const Point3 point(-2.0, 11.0, 1.0); + const Values values = {{poseKey, genericValue(pose)}, + {pointKey, genericValue(point)}}; + + const Vector error_1 = factor3D_1.unwhitenedError(values); + const Vector error_2 = factor3D_2.unwhitenedError(values); + CHECK(assert_equal(error_1, error_2)); +} + +/* ************************************************************************* */ +TEST(BearingRangeFactor, Serialization2D) { + using BearingRangeFactor2D = BearingRangeFactor; + static SharedNoiseModel model2D(noiseModel::Isotropic::Sigma(2, 0.5)); + BearingRangeFactor2D factor2D(poseKey, pointKey, 1, 2, model2D); + + EXPECT(serializationTestHelpers::equalsObj(factor2D)); + EXPECT(serializationTestHelpers::equalsXML(factor2D)); + EXPECT(serializationTestHelpers::equalsBinary(factor2D)); +} + +/* ************************************************************************* */ +TEST(BearingRangeFactor, Serialization3D) { + using BearingRangeFactor3D = BearingRangeFactor; + static SharedNoiseModel model3D(noiseModel::Isotropic::Sigma(3, 0.5)); + BearingRangeFactor3D factor3D(poseKey, pointKey, + Pose3().bearing(Point3(1, 0, 0)), 1, model3D); + EXPECT(serializationTestHelpers::equalsObj(factor3D)); + EXPECT(serializationTestHelpers::equalsXML(factor3D)); + EXPECT(serializationTestHelpers::equalsBinary(factor3D)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/sfm/MFAS.h b/gtsam/sfm/MFAS.h index decfbed0f..151b318ad 100644 --- a/gtsam/sfm/MFAS.h +++ b/gtsam/sfm/MFAS.h @@ -48,7 +48,7 @@ namespace gtsam { unit translations in a projection direction. @addtogroup SFM */ -class MFAS { +class GTSAM_EXPORT MFAS { public: // used to represent edges between two nodes in the graph. When used in // translation averaging for global SfM diff --git a/gtsam/sfm/SfmData.cpp b/gtsam/sfm/SfmData.cpp new file mode 100644 index 000000000..6c2676e48 --- /dev/null +++ b/gtsam/sfm/SfmData.cpp @@ -0,0 +1,459 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SfmData.cpp + * @date January 2022 + * @author Frank dellaert + * @brief Data structure for dealing with Structure from Motion data + */ + +#include +#include +#include + +#include +#include + +namespace gtsam { + +using std::cout; +using std::endl; + +using gtsam::symbol_shorthand::P; + +/* ************************************************************************** */ +void SfmData::print(const std::string &s) const { + std::cout << "Number of cameras = " << cameras.size() << std::endl; + std::cout << "Number of tracks = " << tracks.size() << std::endl; +} + +/* ************************************************************************** */ +bool SfmData::equals(const SfmData &sfmData, double tol) const { + // check number of cameras and tracks + if (cameras.size() != sfmData.cameras.size() || + tracks.size() != sfmData.tracks.size()) { + return false; + } + + // check each camera + for (size_t i = 0; i < cameras.size(); ++i) { + if (!camera(i).equals(sfmData.camera(i), tol)) { + return false; + } + } + + // check each track + for (size_t j = 0; j < tracks.size(); ++j) { + if (!track(j).equals(sfmData.track(j), tol)) { + return false; + } + } + + return true; +} + +/* ************************************************************************* */ +Rot3 openGLFixedRotation() { // this is due to different convention for + // cameras in gtsam and openGL + /* R = [ 1 0 0 + * 0 -1 0 + * 0 0 -1] + */ + Matrix3 R_mat = Matrix3::Zero(3, 3); + R_mat(0, 0) = 1.0; + R_mat(1, 1) = -1.0; + R_mat(2, 2) = -1.0; + return Rot3(R_mat); +} + +/* ************************************************************************* */ +Pose3 openGL2gtsam(const Rot3 &R, double tx, double ty, double tz) { + Rot3 R90 = openGLFixedRotation(); + Rot3 wRc = (R.inverse()).compose(R90); + + // Our camera-to-world translation wTc = -R'*t + return Pose3(wRc, R.unrotate(Point3(-tx, -ty, -tz))); +} + +/* ************************************************************************* */ +Pose3 gtsam2openGL(const Rot3 &R, double tx, double ty, double tz) { + Rot3 R90 = openGLFixedRotation(); + Rot3 cRw_openGL = R90.compose(R.inverse()); + Point3 t_openGL = cRw_openGL.rotate(Point3(-tx, -ty, -tz)); + return Pose3(cRw_openGL, t_openGL); +} + +/* ************************************************************************* */ +Pose3 gtsam2openGL(const Pose3 &PoseGTSAM) { + return gtsam2openGL(PoseGTSAM.rotation(), PoseGTSAM.x(), PoseGTSAM.y(), + PoseGTSAM.z()); +} + +/* ************************************************************************** */ +SfmData SfmData::FromBundlerFile(const std::string &filename) { + // Load the data file + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is) { + throw std::runtime_error( + "Error in FromBundlerFile: can not find the file!!"); + } + + SfmData sfmData; + + // Ignore the first line + char aux[500]; + is.getline(aux, 500); + + // Get the number of camera poses and 3D points + size_t nrPoses, nrPoints; + is >> nrPoses >> nrPoints; + + // Get the information for the camera poses + for (size_t i = 0; i < nrPoses; i++) { + // Get the focal length and the radial distortion parameters + float f, k1, k2; + is >> f >> k1 >> k2; + Cal3Bundler K(f, k1, k2); + + // Get the rotation matrix + float r11, r12, r13; + float r21, r22, r23; + float r31, r32, r33; + is >> r11 >> r12 >> r13 >> r21 >> r22 >> r23 >> r31 >> r32 >> r33; + + // Bundler-OpenGL rotation matrix + Rot3 R(r11, r12, r13, r21, r22, r23, r31, r32, r33); + + // Check for all-zero R, in which case quit + if (r11 == 0 && r12 == 0 && r13 == 0) { + throw std::runtime_error( + "Error in FromBundlerFile: zero rotation matrix"); + } + + // Get the translation vector + float tx, ty, tz; + is >> tx >> ty >> tz; + + Pose3 pose = openGL2gtsam(R, tx, ty, tz); + + sfmData.cameras.emplace_back(pose, K); + } + + // Get the information for the 3D points + sfmData.tracks.reserve(nrPoints); + for (size_t j = 0; j < nrPoints; j++) { + SfmTrack track; + + // Get the 3D position + float x, y, z; + is >> x >> y >> z; + track.p = Point3(x, y, z); + + // Get the color information + float r, g, b; + is >> r >> g >> b; + track.r = r / 255.f; + track.g = g / 255.f; + track.b = b / 255.f; + + // Now get the visibility information + size_t nvisible = 0; + is >> nvisible; + + track.measurements.reserve(nvisible); + track.siftIndices.reserve(nvisible); + for (size_t k = 0; k < nvisible; k++) { + size_t cam_idx = 0, point_idx = 0; + float u, v; + is >> cam_idx >> point_idx >> u >> v; + track.measurements.emplace_back(cam_idx, Point2(u, -v)); + track.siftIndices.emplace_back(cam_idx, point_idx); + } + + sfmData.tracks.push_back(track); + } + + return sfmData; +} + +/* ************************************************************************** */ +SfmData SfmData::FromBalFile(const std::string &filename) { + // Load the data file + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is) { + throw std::runtime_error("Error in FromBalFile: can not find the file!!"); + } + + SfmData sfmData; + + // Get the number of camera poses and 3D points + size_t nrPoses, nrPoints, nrObservations; + is >> nrPoses >> nrPoints >> nrObservations; + + sfmData.tracks.resize(nrPoints); + + // Get the information for the observations + for (size_t k = 0; k < nrObservations; k++) { + size_t i = 0, j = 0; + float u, v; + is >> i >> j >> u >> v; + sfmData.tracks[j].measurements.emplace_back(i, Point2(u, -v)); + } + + // Get the information for the camera poses + for (size_t i = 0; i < nrPoses; i++) { + // Get the Rodrigues vector + float wx, wy, wz; + is >> wx >> wy >> wz; + Rot3 R = Rot3::Rodrigues(wx, wy, wz); // BAL-OpenGL rotation matrix + + // Get the translation vector + float tx, ty, tz; + is >> tx >> ty >> tz; + + Pose3 pose = openGL2gtsam(R, tx, ty, tz); + + // Get the focal length and the radial distortion parameters + float f, k1, k2; + is >> f >> k1 >> k2; + Cal3Bundler K(f, k1, k2); + + sfmData.cameras.emplace_back(pose, K); + } + + // Get the information for the 3D points + for (size_t j = 0; j < nrPoints; j++) { + // Get the 3D position + float x, y, z; + is >> x >> y >> z; + SfmTrack &track = sfmData.tracks[j]; + track.p = Point3(x, y, z); + track.r = 0.4f; + track.g = 0.4f; + track.b = 0.4f; + } + + return sfmData; +} + +/* ************************************************************************** */ +bool writeBAL(const std::string &filename, const SfmData &data) { + // Open the output file + std::ofstream os; + os.open(filename.c_str()); + os.precision(20); + if (!os.is_open()) { + cout << "Error in writeBAL: can not open the file!!" << endl; + return false; + } + + // Write the number of camera poses and 3D points + size_t nrObservations = 0; + for (size_t j = 0; j < data.tracks.size(); j++) { + nrObservations += data.tracks[j].numberMeasurements(); + } + + // Write observations + os << data.cameras.size() << " " << data.tracks.size() << " " + << nrObservations << endl; + os << endl; + + for (size_t j = 0; j < data.tracks.size(); j++) { // for each 3D point j + const SfmTrack &track = data.tracks[j]; + + for (size_t k = 0; k < track.numberMeasurements(); + k++) { // for each observation of the 3D point j + size_t i = track.measurements[k].first; // camera id + double u0 = data.cameras[i].calibration().px(); + double v0 = data.cameras[i].calibration().py(); + + if (u0 != 0 || v0 != 0) { + cout << "writeBAL has not been tested for calibration with nonzero " + "(u0,v0)" + << endl; + } + + double pixelBALx = track.measurements[k].second.x() - + u0; // center of image is the origin + double pixelBALy = -(track.measurements[k].second.y() - + v0); // center of image is the origin + Point2 pixelMeasurement(pixelBALx, pixelBALy); + os << i /*camera id*/ << " " << j /*point id*/ << " " + << pixelMeasurement.x() /*u of the pixel*/ << " " + << pixelMeasurement.y() /*v of the pixel*/ << endl; + } + } + os << endl; + + // Write cameras + for (size_t i = 0; i < data.cameras.size(); i++) { // for each camera + Pose3 poseGTSAM = data.cameras[i].pose(); + Cal3Bundler cameraCalibration = data.cameras[i].calibration(); + Pose3 poseOpenGL = gtsam2openGL(poseGTSAM); + os << Rot3::Logmap(poseOpenGL.rotation()) << endl; + os << poseOpenGL.translation().x() << endl; + os << poseOpenGL.translation().y() << endl; + os << poseOpenGL.translation().z() << endl; + os << cameraCalibration.fx() << endl; + os << cameraCalibration.k1() << endl; + os << cameraCalibration.k2() << endl; + os << endl; + } + + // Write the points + for (size_t j = 0; j < data.tracks.size(); j++) { // for each 3D point j + Point3 point = data.tracks[j].p; + os << point.x() << endl; + os << point.y() << endl; + os << point.z() << endl; + os << endl; + } + + os.close(); + return true; +} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +bool readBundler(const std::string &filename, SfmData &data) { + try { + data = SfmData::FromBundlerFile(filename); + return true; + } catch (const std::exception & /* e */) { + return false; + } +} +bool readBAL(const std::string &filename, SfmData &data) { + try { + data = SfmData::FromBalFile(filename); + return true; + } catch (const std::exception & /* e */) { + return false; + } +} +#endif + +SfmData readBal(const std::string &filename) { + return SfmData::FromBalFile(filename); +} + +/* ************************************************************************** */ +bool writeBALfromValues(const std::string &filename, const SfmData &data, + const Values &values) { + using Camera = PinholeCamera; + SfmData dataValues = data; + + // Store poses or cameras in SfmData + size_t nrPoses = values.count(); + if (nrPoses == dataValues.cameras.size()) { // we only estimated camera poses + for (size_t i = 0; i < dataValues.cameras.size(); i++) { // for each camera + Pose3 pose = values.at(i); + Cal3Bundler K = dataValues.cameras[i].calibration(); + Camera camera(pose, K); + dataValues.cameras[i] = camera; + } + } else { + size_t nrCameras = values.count(); + if (nrCameras == dataValues.cameras.size()) { // we only estimated camera + // poses and calibration + for (size_t i = 0; i < nrCameras; i++) { // for each camera + Key cameraKey = i; // symbol('c',i); + Camera camera = values.at(cameraKey); + dataValues.cameras[i] = camera; + } + } else { + cout << "writeBALfromValues: different number of cameras in " + "SfM_dataValues (#cameras " + << dataValues.cameras.size() << ") and values (#cameras " << nrPoses + << ", #poses " << nrCameras << ")!!" << endl; + return false; + } + } + + // Store 3D points in SfmData + size_t nrPoints = values.count(), nrTracks = dataValues.tracks.size(); + if (nrPoints != nrTracks) { + cout << "writeBALfromValues: different number of points in " + "SfM_dataValues (#points= " + << nrTracks << ") and values (#points " << nrPoints << ")!!" << endl; + } + + for (size_t j = 0; j < nrTracks; j++) { // for each point + Key pointKey = P(j); + if (values.exists(pointKey)) { + Point3 point = values.at(pointKey); + dataValues.tracks[j].p = point; + } else { + dataValues.tracks[j].r = 1.0; + dataValues.tracks[j].g = 0.0; + dataValues.tracks[j].b = 0.0; + dataValues.tracks[j].p = Point3(0, 0, 0); + } + } + + // Write SfmData to file + return writeBAL(filename, dataValues); +} + +/* ************************************************************************** */ +NonlinearFactorGraph SfmData::generalSfmFactors( + const SharedNoiseModel &model) const { + using ProjectionFactor = GeneralSFMFactor; + NonlinearFactorGraph factors; + + size_t j = 0; + for (const SfmTrack &track : tracks) { + for (const SfmMeasurement &m : track.measurements) { + size_t i = m.first; + Point2 uv = m.second; + factors.emplace_shared(uv, model, i, P(j)); + } + j += 1; + } + + return factors; +} + +/* ************************************************************************** */ +NonlinearFactorGraph SfmData::sfmFactorGraph( + const SharedNoiseModel &model, boost::optional fixedCamera, + boost::optional fixedPoint) const { + NonlinearFactorGraph graph = generalSfmFactors(model); + using noiseModel::Constrained; + if (fixedCamera) { + graph.addPrior(*fixedCamera, cameras[0], Constrained::All(9)); + } + if (fixedPoint) { + graph.addPrior(P(*fixedPoint), tracks[0].p, Constrained::All(3)); + } + return graph; +} + +/* ************************************************************************** */ +Values initialCamerasEstimate(const SfmData &db) { + Values initial; + size_t i = 0; // NO POINTS: j = 0; + for (const SfmCamera &camera : db.cameras) initial.insert(i++, camera); + return initial; +} + +/* ************************************************************************** */ +Values initialCamerasAndPointsEstimate(const SfmData &db) { + Values initial; + size_t i = 0, j = 0; + for (const SfmCamera &camera : db.cameras) initial.insert(i++, camera); + for (const SfmTrack &track : db.tracks) initial.insert(P(j++), track.p); + return initial; +} + +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/sfm/SfmData.h b/gtsam/sfm/SfmData.h new file mode 100644 index 000000000..afce12205 --- /dev/null +++ b/gtsam/sfm/SfmData.h @@ -0,0 +1,236 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SfmData.h + * @date January 2022 + * @author Frank dellaert + * @brief Data structure for dealing with Structure from Motion data + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace gtsam { + +/// Define the structure for the camera poses +typedef PinholeCamera SfmCamera; + +/** + * @brief SfmData stores a bunch of SfmTracks + * @addtogroup sfm + */ +struct GTSAM_EXPORT SfmData { + std::vector cameras; ///< Set of cameras + + std::vector tracks; ///< Sparse set of points + + /// @name Create from file + /// @{ + + /** + * @brief Parses a bundler output file and return result as SfmData instance. + * @param filename The name of the bundler file + * @param data SfM structure where the data is stored + * @return true if the parsing was successful, false otherwise + */ + static SfmData FromBundlerFile(const std::string& filename); + + /** + * @brief Parse a "Bundle Adjustment in the Large" (BAL) file and return + * result as SfmData instance. + * @param filename The name of the BAL file. + * @return SfM structure where the data is stored. + */ + static SfmData FromBalFile(const std::string& filename); + + /// @} + /// @name Standard Interface + /// @{ + + /// Add a track to SfmData + void addTrack(const SfmTrack& t) { tracks.push_back(t); } + + /// Add a camera to SfmData + void addCamera(const SfmCamera& cam) { cameras.push_back(cam); } + + /// The number of reconstructed 3D points + size_t numberTracks() const { return tracks.size(); } + + /// The number of cameras + size_t numberCameras() const { return cameras.size(); } + + /// The track formed by series of landmark measurements + SfmTrack track(size_t idx) const { return tracks[idx]; } + + /// The camera pose at frame index `idx` + SfmCamera camera(size_t idx) const { return cameras[idx]; } + + /** + * @brief Create projection factors using keys i and P(j) + * + * @param model a noise model for projection errors + * @return NonlinearFactorGraph + */ + NonlinearFactorGraph generalSfmFactors( + const SharedNoiseModel& model = noiseModel::Isotropic::Sigma(2, + 1.0)) const; + + /** + * @brief Create factor graph with projection factors and gauge fix. + * + * Note: pose keys are simply integer indices, points use Symbol('p', j). + * + * @param model a noise model for projection errors + * @param fixedCamera which camera to fix, if any (use boost::none if none) + * @param fixedPoint which point to fix, if any (use boost::none if none) + * @return NonlinearFactorGraph + */ + NonlinearFactorGraph sfmFactorGraph( + const SharedNoiseModel& model = noiseModel::Isotropic::Sigma(2, 1.0), + boost::optional fixedCamera = 0, + boost::optional fixedPoint = 0) const; + + /// @} + /// @name Testable + /// @{ + + /// print + void print(const std::string& s = "") const; + + /// assert equality up to a tolerance + bool equals(const SfmData& sfmData, double tol = 1e-9) const; + + /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ + void GTSAM_DEPRECATED add_track(const SfmTrack& t) { tracks.push_back(t); } + void GTSAM_DEPRECATED add_camera(const SfmCamera& cam) { + cameras.push_back(cam); + } + size_t GTSAM_DEPRECATED number_tracks() const { return tracks.size(); } + size_t GTSAM_DEPRECATED number_cameras() const { return cameras.size(); } + /// @} +#endif + /// @name Serialization + /// @{ + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(cameras); + ar& BOOST_SERIALIZATION_NVP(tracks); + } + + /// @} +}; + +/// traits +template <> +struct traits : public Testable {}; + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +GTSAM_EXPORT bool GTSAM_DEPRECATED readBundler(const std::string& filename, + SfmData& data); +GTSAM_EXPORT bool GTSAM_DEPRECATED readBAL(const std::string& filename, + SfmData& data); +#endif + +/** + * @brief This function parses a "Bundle Adjustment in the Large" (BAL) file and + * returns the data as a SfmData structure. Mainly used by wrapped code. + * @param filename The name of the BAL file. + * @return SfM structure where the data is stored. + */ +GTSAM_EXPORT SfmData readBal(const std::string& filename); + +/** + * @brief This function writes a "Bundle Adjustment in the Large" (BAL) file + * from a SfmData structure + * @param filename The name of the BAL file to write + * @param data SfM structure where the data is stored + * @return true if the parsing was successful, false otherwise + */ +GTSAM_EXPORT bool writeBAL(const std::string& filename, const SfmData& data); + +/** + * @brief This function writes a "Bundle Adjustment in the Large" (BAL) file + * from a SfmData structure and a value structure (measurements are the same as + * the SfM input data, while camera poses and values are read from Values) + * @param filename The name of the BAL file to write + * @param data SfM structure where the data is stored + * @param values structure where the graph values are stored (values can be + * either Pose3 or PinholeCamera for the cameras, and should be + * Point3 for the 3D points). Note: assumes that the keys are "i" for pose i + * and "Symbol::('p',j)" for landmark j. + * @return true if the parsing was successful, false otherwise + */ +GTSAM_EXPORT bool writeBALfromValues(const std::string& filename, + const SfmData& data, const Values& values); + +/** + * @brief This function converts an openGL camera pose to an GTSAM camera pose + * @param R rotation in openGL + * @param tx x component of the translation in openGL + * @param ty y component of the translation in openGL + * @param tz z component of the translation in openGL + * @return Pose3 in GTSAM format + */ +GTSAM_EXPORT Pose3 openGL2gtsam(const Rot3& R, double tx, double ty, double tz); + +/** + * @brief This function converts a GTSAM camera pose to an openGL camera pose + * @param R rotation in GTSAM + * @param tx x component of the translation in GTSAM + * @param ty y component of the translation in GTSAM + * @param tz z component of the translation in GTSAM + * @return Pose3 in openGL format + */ +GTSAM_EXPORT Pose3 gtsam2openGL(const Rot3& R, double tx, double ty, double tz); + +/** + * @brief This function converts a GTSAM camera pose to an openGL camera pose + * @param PoseGTSAM pose in GTSAM format + * @return Pose3 in openGL format + */ +GTSAM_EXPORT Pose3 gtsam2openGL(const Pose3& PoseGTSAM); + +/** + * @brief This function creates initial values for cameras from db. + * + * No symbol is used, so camera keys are simply integer indices. + * + * @param SfmData + * @return Values + */ +GTSAM_EXPORT Values initialCamerasEstimate(const SfmData& db); + +/** + * @brief This function creates initial values for cameras and points from db + * + * Note: Pose keys are simply integer indices, points use Symbol('p', j). + * + * @param SfmData + * @return Values + */ +GTSAM_EXPORT Values initialCamerasAndPointsEstimate(const SfmData& db); + +} // namespace gtsam diff --git a/gtsam/sfm/SfmTrack.cpp b/gtsam/sfm/SfmTrack.cpp new file mode 100644 index 000000000..d571e7c35 --- /dev/null +++ b/gtsam/sfm/SfmTrack.cpp @@ -0,0 +1,71 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SfmTrack.cpp + * @date January 2022 + * @author Frank Dellaert + * @brief A simple data structure for a track in Structure from Motion + */ + +#include + +#include + +namespace gtsam { + +void SfmTrack::print(const std::string& s) const { + std::cout << "Track with " << measurements.size(); + std::cout << " measurements of point " << p << std::endl; +} + +bool SfmTrack::equals(const SfmTrack& sfmTrack, double tol) const { + // check the 3D point + if (!p.isApprox(sfmTrack.p)) { + return false; + } + + // check the RGB values + if (r != sfmTrack.r || g != sfmTrack.g || b != sfmTrack.b) { + return false; + } + + // compare size of vectors for measurements and siftIndices + if (numberMeasurements() != sfmTrack.numberMeasurements() || + siftIndices.size() != sfmTrack.siftIndices.size()) { + return false; + } + + // compare measurements (order sensitive) + for (size_t idx = 0; idx < numberMeasurements(); ++idx) { + SfmMeasurement measurement = measurements[idx]; + SfmMeasurement otherMeasurement = sfmTrack.measurements[idx]; + + if (measurement.first != otherMeasurement.first || + !measurement.second.isApprox(otherMeasurement.second)) { + return false; + } + } + + // compare sift indices (order sensitive) + for (size_t idx = 0; idx < siftIndices.size(); ++idx) { + SiftIndex index = siftIndices[idx]; + SiftIndex otherIndex = sfmTrack.siftIndices[idx]; + + if (index.first != otherIndex.first || index.second != otherIndex.second) { + return false; + } + } + + return true; +} + +} // namespace gtsam diff --git a/gtsam/sfm/SfmTrack.h b/gtsam/sfm/SfmTrack.h new file mode 100644 index 000000000..37731b32a --- /dev/null +++ b/gtsam/sfm/SfmTrack.h @@ -0,0 +1,133 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SfmTrack.h + * @date January 2022 + * @author Frank Dellaert + * @brief A simple data structure for a track in Structure from Motion + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/// A measurement with its camera index +typedef std::pair SfmMeasurement; + +/// Sift index for SfmTrack +typedef std::pair SiftIndex; + +/** + * @brief An SfmTrack stores SfM measurements grouped in a track + * @addtogroup sfm + */ +struct GTSAM_EXPORT SfmTrack { + Point3 p; ///< 3D position of the point + float r, g, b; ///< RGB color of the 3D point + + /// The 2D image projections (id,(u,v)) + std::vector measurements; + + /// The feature descriptors + std::vector siftIndices; + + /// @name Constructors + /// @{ + + explicit SfmTrack(float r = 0, float g = 0, float b = 0) + : p(0, 0, 0), r(r), g(g), b(b) {} + + explicit SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0, + float b = 0) + : p(pt), r(r), g(g), b(b) {} + + /// @} + /// @name Standard Interface + /// @{ + + /// Add measurement (camera_idx, Point2) to track + void addMeasurement(size_t idx, const gtsam::Point2& m) { + measurements.emplace_back(idx, m); + } + + /// Total number of measurements in this track + size_t numberMeasurements() const { return measurements.size(); } + + /// Get the measurement (camera index, Point2) at pose index `idx` + const SfmMeasurement& measurement(size_t idx) const { + return measurements[idx]; + } + + /// Get the SIFT feature index corresponding to the measurement at `idx` + const SiftIndex& siftIndex(size_t idx) const { return siftIndices[idx]; } + + /// Get 3D point + const Point3& point3() const { return p; } + + /// Get RGB values describing 3d point + Point3 rgb() const { return Point3(r, g, b); } + + /// @} + /// @name Testable + /// @{ + + /// print + void print(const std::string& s = "") const; + + /// assert equality up to a tolerance + bool equals(const SfmTrack& sfmTrack, double tol = 1e-9) const; + + /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ + void GTSAM_DEPRECATED add_measurement(size_t idx, const gtsam::Point2& m) { + measurements.emplace_back(idx, m); + } + + size_t GTSAM_DEPRECATED number_measurements() const { + return measurements.size(); + } + /// @} +#endif + /// @name Serialization + /// @{ + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(p); + ar& BOOST_SERIALIZATION_NVP(r); + ar& BOOST_SERIALIZATION_NVP(g); + ar& BOOST_SERIALIZATION_NVP(b); + ar& BOOST_SERIALIZATION_NVP(measurements); + ar& BOOST_SERIALIZATION_NVP(siftIndices); + } + /// @} +}; + +template +struct traits; + +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/sfm/ShonanAveraging.h b/gtsam/sfm/ShonanAveraging.h index de12de478..e035da4c7 100644 --- a/gtsam/sfm/ShonanAveraging.h +++ b/gtsam/sfm/ShonanAveraging.h @@ -165,7 +165,7 @@ class GTSAM_EXPORT ShonanAveraging { size_t nrUnknowns() const { return nrUnknowns_; } /// Return number of measurements - size_t nrMeasurements() const { return measurements_.size(); } + size_t numberMeasurements() const { return measurements_.size(); } /// k^th binary measurement const BinaryMeasurement &measurement(size_t k) const { @@ -300,6 +300,7 @@ class GTSAM_EXPORT ShonanAveraging { /** * Create initial Values of type SO(p) * @param p the dimensionality of the rotation manifold + * @param rng random number generator */ Values initializeRandomlyAt(size_t p, std::mt19937 &rng) const; diff --git a/gtsam/sfm/TranslationRecovery.cpp b/gtsam/sfm/TranslationRecovery.cpp index f38c14ba7..2e81c2d56 100644 --- a/gtsam/sfm/TranslationRecovery.cpp +++ b/gtsam/sfm/TranslationRecovery.cpp @@ -35,6 +35,9 @@ using namespace gtsam; using namespace std; +// In Wrappers we have no access to this so have a default ready. +static std::mt19937 kRandomNumberGenerator(42); + TranslationRecovery::TranslationRecovery( const TranslationRecovery::TranslationEdges &relativeTranslations, const LevenbergMarquardtParams &lmParams) @@ -88,13 +91,15 @@ void TranslationRecovery::addPrior( edge->key2(), scale * edge->measured().point3(), edge->noiseModel()); } -Values TranslationRecovery::initalizeRandomly() const { +Values TranslationRecovery::initializeRandomly(std::mt19937 *rng) const { + uniform_real_distribution randomVal(-1, 1); // Create a lambda expression that checks whether value exists and randomly // initializes if not. Values initial; - auto insert = [&initial](Key j) { + auto insert = [&](Key j) { if (!initial.exists(j)) { - initial.insert(j, Vector3::Random()); + initial.insert( + j, Point3(randomVal(*rng), randomVal(*rng), randomVal(*rng))); } }; @@ -115,10 +120,14 @@ Values TranslationRecovery::initalizeRandomly() const { return initial; } +Values TranslationRecovery::initializeRandomly() const { + return initializeRandomly(&kRandomNumberGenerator); +} + Values TranslationRecovery::run(const double scale) const { auto graph = buildGraph(); addPrior(scale, &graph); - const Values initial = initalizeRandomly(); + const Values initial = initializeRandomly(); LevenbergMarquardtOptimizer lm(graph, initial, params_); Values result = lm.optimize(); diff --git a/gtsam/sfm/TranslationRecovery.h b/gtsam/sfm/TranslationRecovery.h index c99836853..30c9a14e3 100644 --- a/gtsam/sfm/TranslationRecovery.h +++ b/gtsam/sfm/TranslationRecovery.h @@ -16,16 +16,16 @@ * @brief Recovering translations in an epipolar graph when rotations are given. */ -#include -#include -#include -#include - #include #include #include #include +#include +#include +#include +#include + namespace gtsam { // Set up an optimization problem for the unknown translations Ti in the world @@ -100,9 +100,17 @@ class TranslationRecovery { /** * @brief Create random initial translations. * + * @param rng random number generator * @return Values */ - Values initalizeRandomly() const; + Values initializeRandomly(std::mt19937 *rng) const; + + /** + * @brief Version of initializeRandomly with a fixed seed. + * + * @return Values + */ + Values initializeRandomly() const; /** * @brief Build and optimize factor graph. diff --git a/gtsam/sfm/sfm.i b/gtsam/sfm/sfm.i index 705892e60..bf9a73ac5 100644 --- a/gtsam/sfm/sfm.i +++ b/gtsam/sfm/sfm.i @@ -4,7 +4,62 @@ namespace gtsam { -// ##### +#include +class SfmTrack { + SfmTrack(); + SfmTrack(const gtsam::Point3& pt); + const Point3& point3() const; + + double r; + double g; + double b; + + std::vector> measurements; + + size_t numberMeasurements() const; + pair measurement(size_t idx) const; + pair siftIndex(size_t idx) const; + void addMeasurement(size_t idx, const gtsam::Point2& m); + + // enabling serialization functionality + void serialize() const; + + // enabling function to compare objects + bool equals(const gtsam::SfmTrack& expected, double tol) const; +}; + +#include +class SfmData { + SfmData(); + static gtsam::SfmData FromBundlerFile(string filename); + static gtsam::SfmData FromBalFile(string filename); + + void addTrack(const gtsam::SfmTrack& t); + void addCamera(const gtsam::SfmCamera& cam); + size_t numberTracks() const; + size_t numberCameras() const; + gtsam::SfmTrack track(size_t idx) const; + gtsam::PinholeCamera camera(size_t idx) const; + + gtsam::NonlinearFactorGraph generalSfmFactors( + const gtsam::SharedNoiseModel& model = + gtsam::noiseModel::Isotropic::Sigma(2, 1.0)) const; + gtsam::NonlinearFactorGraph sfmFactorGraph( + const gtsam::SharedNoiseModel& model = + gtsam::noiseModel::Isotropic::Sigma(2, 1.0), + size_t fixedCamera = 0, size_t fixedPoint = 0) const; + + // enabling serialization functionality + void serialize() const; + + // enabling function to compare objects + bool equals(const gtsam::SfmData& expected, double tol) const; +}; + +gtsam::SfmData readBal(string filename); +bool writeBAL(string filename, gtsam::SfmData& data); +gtsam::Values initialCamerasEstimate(const gtsam::SfmData& db); +gtsam::Values initialCamerasAndPointsEstimate(const gtsam::SfmData& db); #include @@ -92,7 +147,7 @@ class ShonanAveraging2 { // Query properties size_t nrUnknowns() const; - size_t nrMeasurements() const; + size_t numberMeasurements() const; gtsam::Rot2 measured(size_t i); gtsam::KeyVector keys(size_t i); @@ -140,7 +195,7 @@ class ShonanAveraging3 { // Query properties size_t nrUnknowns() const; - size_t nrMeasurements() const; + size_t numberMeasurements() const; gtsam::Rot3 measured(size_t i); gtsam::KeyVector keys(size_t i); diff --git a/gtsam/sfm/tests/testSfmData.cpp b/gtsam/sfm/tests/testSfmData.cpp new file mode 100644 index 000000000..7bd5d27e7 --- /dev/null +++ b/gtsam/sfm/tests/testSfmData.cpp @@ -0,0 +1,214 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TestSfmData.cpp + * @date January 2022 + * @author Frank dellaert + * @brief tests for SfmData class and associated utilites + */ + +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +using gtsam::symbol_shorthand::P; + +namespace gtsam { +GTSAM_EXPORT std::string createRewrittenFileName(const std::string& name); +GTSAM_EXPORT std::string findExampleDataFile(const std::string& name); +} // namespace gtsam + +/* ************************************************************************* */ +TEST(dataSet, Balbianello) { + // The structure where we will save the SfM data + const string filename = findExampleDataFile("Balbianello"); + SfmData sfmData = SfmData::FromBundlerFile(filename); + + // Check number of things + EXPECT_LONGS_EQUAL(5, sfmData.numberCameras()); + EXPECT_LONGS_EQUAL(544, sfmData.numberTracks()); + const SfmTrack& track0 = sfmData.tracks[0]; + EXPECT_LONGS_EQUAL(3, track0.numberMeasurements()); + + // Check projection of a given point + EXPECT_LONGS_EQUAL(0, track0.measurements[0].first); + const SfmCamera& camera0 = sfmData.cameras[0]; + Point2 expected = camera0.project(track0.p), + actual = track0.measurements[0].second; + EXPECT(assert_equal(expected, actual, 1)); + + // We share *one* noiseModel between all projection factors + auto model = noiseModel::Isotropic::Sigma(2, 1.0); // one pixel in u and v + + // Convert to NonlinearFactorGraph + NonlinearFactorGraph graph = sfmData.sfmFactorGraph(model); + EXPECT_LONGS_EQUAL(1419, graph.size()); // regression + + // Get initial estimate + Values values = initialCamerasAndPointsEstimate(sfmData); + EXPECT_LONGS_EQUAL(549, values.size()); // regression +} + +/* ************************************************************************* */ +TEST(dataSet, readBAL_Dubrovnik) { + // The structure where we will save the SfM data + const string filename = findExampleDataFile("dubrovnik-3-7-pre"); + SfmData sfmData = SfmData::FromBalFile(filename); + + // Check number of things + EXPECT_LONGS_EQUAL(3, sfmData.numberCameras()); + EXPECT_LONGS_EQUAL(7, sfmData.numberTracks()); + const SfmTrack& track0 = sfmData.tracks[0]; + EXPECT_LONGS_EQUAL(3, track0.numberMeasurements()); + + // Check projection of a given point + EXPECT_LONGS_EQUAL(0, track0.measurements[0].first); + const SfmCamera& camera0 = sfmData.cameras[0]; + Point2 expected = camera0.project(track0.p), + actual = track0.measurements[0].second; + EXPECT(assert_equal(expected, actual, 12)); +} + +/* ************************************************************************* */ +TEST(dataSet, openGL2gtsam) { + Vector3 rotVec(0.2, 0.7, 1.1); + Rot3 R = Rot3::Expmap(rotVec); + Point3 t = Point3(0.0, 0.0, 0.0); + Pose3 poseGTSAM = Pose3(R, t); + + Pose3 expected = openGL2gtsam(R, t.x(), t.y(), t.z()); + + Point3 r1 = R.r1(), r2 = R.r2(), r3 = R.r3(); // columns! + Rot3 cRw(r1.x(), r2.x(), r3.x(), -r1.y(), -r2.y(), -r3.y(), -r1.z(), -r2.z(), + -r3.z()); + Rot3 wRc = cRw.inverse(); + Pose3 actual = Pose3(wRc, t); + + EXPECT(assert_equal(expected, actual)); +} + +/* ************************************************************************* */ +TEST(dataSet, gtsam2openGL) { + Vector3 rotVec(0.2, 0.7, 1.1); + Rot3 R = Rot3::Expmap(rotVec); + Point3 t = Point3(1.0, 20.0, 10.0); + Pose3 actual = Pose3(R, t); + Pose3 poseGTSAM = openGL2gtsam(R, t.x(), t.y(), t.z()); + + Pose3 expected = gtsam2openGL(poseGTSAM); + EXPECT(assert_equal(expected, actual)); +} + +/* ************************************************************************* */ +TEST(dataSet, writeBAL_Dubrovnik) { + const string filenameToRead = findExampleDataFile("dubrovnik-3-7-pre"); + SfmData readData = SfmData::FromBalFile(filenameToRead); + + // Write readData to file filenameToWrite + const string filenameToWrite = createRewrittenFileName(filenameToRead); + CHECK(writeBAL(filenameToWrite, readData)); + + // Read what we wrote + SfmData writtenData = SfmData::FromBalFile(filenameToWrite); + + // Check that what we read is the same as what we wrote + EXPECT_LONGS_EQUAL(readData.numberCameras(), writtenData.numberCameras()); + EXPECT_LONGS_EQUAL(readData.numberTracks(), writtenData.numberTracks()); + + for (size_t i = 0; i < readData.numberCameras(); i++) { + PinholeCamera expectedCamera = writtenData.cameras[i]; + PinholeCamera actualCamera = readData.cameras[i]; + EXPECT(assert_equal(expectedCamera, actualCamera)); + } + + for (size_t j = 0; j < readData.numberTracks(); j++) { + // check point + SfmTrack expectedTrack = writtenData.tracks[j]; + SfmTrack actualTrack = readData.tracks[j]; + Point3 expectedPoint = expectedTrack.p; + Point3 actualPoint = actualTrack.p; + EXPECT(assert_equal(expectedPoint, actualPoint)); + + // check rgb + Point3 expectedRGB = + Point3(expectedTrack.r, expectedTrack.g, expectedTrack.b); + Point3 actualRGB = Point3(actualTrack.r, actualTrack.g, actualTrack.b); + EXPECT(assert_equal(expectedRGB, actualRGB)); + + // check measurements + for (size_t k = 0; k < actualTrack.numberMeasurements(); k++) { + EXPECT_LONGS_EQUAL(expectedTrack.measurements[k].first, + actualTrack.measurements[k].first); + EXPECT(assert_equal(expectedTrack.measurements[k].second, + actualTrack.measurements[k].second)); + } + } +} + +/* ************************************************************************* */ +TEST(dataSet, writeBALfromValues_Dubrovnik) { + const string filenameToRead = findExampleDataFile("dubrovnik-3-7-pre"); + SfmData readData = SfmData::FromBalFile(filenameToRead); + + Pose3 poseChange = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.3, 0.1, 0.3)); + + Values values; + for (size_t i = 0; i < readData.numberCameras(); i++) { // for each camera + Pose3 pose = poseChange.compose(readData.cameras[i].pose()); + values.insert(i, pose); + } + for (size_t j = 0; j < readData.numberTracks(); j++) { // for each point + Point3 point = poseChange.transformFrom(readData.tracks[j].p); + values.insert(P(j), point); + } + + // Write values and readData to a file + const string filenameToWrite = createRewrittenFileName(filenameToRead); + writeBALfromValues(filenameToWrite, readData, values); + + // Read the file we wrote + SfmData writtenData = SfmData::FromBalFile(filenameToWrite); + + // Check that the reprojection errors are the same and the poses are correct + // Check number of things + EXPECT_LONGS_EQUAL(3, writtenData.numberCameras()); + EXPECT_LONGS_EQUAL(7, writtenData.numberTracks()); + const SfmTrack& track0 = writtenData.tracks[0]; + EXPECT_LONGS_EQUAL(3, track0.numberMeasurements()); + + // Check projection of a given point + EXPECT_LONGS_EQUAL(0, track0.measurements[0].first); + const SfmCamera& camera0 = writtenData.cameras[0]; + Point2 expected = camera0.project(track0.p), + actual = track0.measurements[0].second; + EXPECT(assert_equal(expected, actual, 12)); + + Pose3 expectedPose = camera0.pose(); + Pose3 actualPose = values.at(0); + EXPECT(assert_equal(expectedPose, actualPose, 1e-7)); + + Point3 expectedPoint = track0.p; + Point3 actualPoint = values.at(P(0)); + EXPECT(assert_equal(expectedPoint, actualPoint, 1e-6)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/slam/BetweenFactor.h b/gtsam/slam/BetweenFactor.h index aef41d5fd..f80462847 100644 --- a/gtsam/slam/BetweenFactor.h +++ b/gtsam/slam/BetweenFactor.h @@ -80,7 +80,9 @@ namespace gtsam { /// @{ /// print with optional string - void print(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + void print( + const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { std::cout << s << "BetweenFactor(" << keyFormatter(this->key1()) << "," << keyFormatter(this->key2()) << ")\n"; @@ -103,7 +105,7 @@ namespace gtsam { boost::none, boost::optional H2 = boost::none) const override { T hx = traits::Between(p1, p2, H1, H2); // h(x) // manifold equivalent of h(x)-z -> log(z,h(x)) -#ifdef SLOW_BUT_CORRECT_BETWEENFACTOR +#ifdef GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR typename traits::ChartJacobian::Jacobian Hlocal; Vector rval = traits::Local(measured_, hx, boost::none, (H1 || H2) ? &Hlocal : 0); if (H1) *H1 = Hlocal * (*H1); diff --git a/gtsam/slam/EssentialMatrixFactor.h b/gtsam/slam/EssentialMatrixFactor.h index 787efac51..5997ad224 100644 --- a/gtsam/slam/EssentialMatrixFactor.h +++ b/gtsam/slam/EssentialMatrixFactor.h @@ -1,7 +1,20 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2014, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + /* - * @file EssentialMatrixFactor.cpp + * @file EssentialMatrixFactor.h * @brief EssentialMatrixFactor class * @author Frank Dellaert + * @author Ayush Baid + * @author Akshay Krishnan * @date December 17, 2013 */ diff --git a/gtsam/slam/FrobeniusFactor.h b/gtsam/slam/FrobeniusFactor.h index f17a9e421..05e23ce6d 100644 --- a/gtsam/slam/FrobeniusFactor.h +++ b/gtsam/slam/FrobeniusFactor.h @@ -48,12 +48,14 @@ ConvertNoiseModel(const SharedNoiseModel &model, size_t n, * element of SO(3) or SO(4). */ template -class GTSAM_EXPORT FrobeniusPrior : public NoiseModelFactor1 { +class FrobeniusPrior : public NoiseModelFactor1 { enum { Dim = Rot::VectorN2::RowsAtCompileTime }; using MatrixNN = typename Rot::MatrixNN; Eigen::Matrix vecM_; ///< vectorized matrix to approximate public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW + /// Constructor FrobeniusPrior(Key j, const MatrixNN& M, const SharedNoiseModel& model = nullptr) @@ -73,7 +75,7 @@ class GTSAM_EXPORT FrobeniusPrior : public NoiseModelFactor1 { * The template argument can be any fixed-size SO. */ template -class GTSAM_EXPORT FrobeniusFactor : public NoiseModelFactor2 { +class FrobeniusFactor : public NoiseModelFactor2 { enum { Dim = Rot::VectorN2::RowsAtCompileTime }; public: @@ -99,13 +101,15 @@ class GTSAM_EXPORT FrobeniusFactor : public NoiseModelFactor2 { * and in fact only SO3 and SO4 really work, as we need SO::AdjointMap. */ template -class GTSAM_EXPORT FrobeniusBetweenFactor : public NoiseModelFactor2 { +class FrobeniusBetweenFactor : public NoiseModelFactor2 { Rot R12_; ///< measured rotation between R1 and R2 Eigen::Matrix R2hat_H_R1_; ///< fixed derivative of R2hat wrpt R1 enum { Dim = Rot::VectorN2::RowsAtCompileTime }; public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW + /// @name Constructor /// @{ diff --git a/gtsam/slam/GeneralSFMFactor.h b/gtsam/slam/GeneralSFMFactor.h index 2e4543177..bfc3a0f78 100644 --- a/gtsam/slam/GeneralSFMFactor.h +++ b/gtsam/slam/GeneralSFMFactor.h @@ -59,8 +59,8 @@ namespace gtsam { template class GeneralSFMFactor: public NoiseModelFactor2 { - GTSAM_CONCEPT_MANIFOLD_TYPE(CAMERA); - GTSAM_CONCEPT_MANIFOLD_TYPE(LANDMARK); + GTSAM_CONCEPT_MANIFOLD_TYPE(CAMERA) + GTSAM_CONCEPT_MANIFOLD_TYPE(LANDMARK) static const int DimC = FixedDimension::value; static const int DimL = FixedDimension::value; @@ -202,7 +202,7 @@ struct traits > : Testable< template class GeneralSFMFactor2: public NoiseModelFactor3 { - GTSAM_CONCEPT_MANIFOLD_TYPE(CALIBRATION); + GTSAM_CONCEPT_MANIFOLD_TYPE(CALIBRATION) static const int DimK = FixedDimension::value; protected: diff --git a/gtsam/slam/JacobianFactorSVD.h b/gtsam/slam/JacobianFactorSVD.h index bc906d24e..f6bc1dd8c 100644 --- a/gtsam/slam/JacobianFactorSVD.h +++ b/gtsam/slam/JacobianFactorSVD.h @@ -9,20 +9,21 @@ namespace gtsam { /** - * JacobianFactor for Schur complement that uses the "Nullspace Trick" by Mourikis + * JacobianFactor for Schur complement that uses the "Nullspace Trick" by + * Mourikis et al. * * This trick is equivalent to the Schur complement, but can be faster. - * In essence, the linear factor |E*dp + F*dX - b|, where p is point and X are poses, - * is multiplied by Enull, a matrix that spans the left nullspace of E, i.e., - * The mx3 matrix is analyzed with SVD as E = [Erange Enull]*S*V (mxm * mx3 * 3x3) - * where Enull is an m x (m-3) matrix - * Then Enull'*E*dp = 0, and + * In essence, the linear factor |E*dp + F*dX - b|, where p is point and X are + * poses, is multiplied by Enull, a matrix that spans the left nullspace of E, + * i.e., The mx3 matrix is analyzed with SVD as E = [Erange Enull]*S*V (mxm * + * mx3 * 3x3) where Enull is an m x (m-3) matrix Then Enull'*E*dp = 0, and * |Enull'*E*dp + Enull'*F*dX - Enull'*b| == |Enull'*F*dX - Enull'*b| * Normally F is m x 6*numKeys, and Enull'*F yields an (m-3) x 6*numKeys matrix. * - * The code below assumes that F is block diagonal and is given as a vector of ZDim*D blocks. - * Example: m = 4 (2 measurements), Enull = 4*1, F = 4*12 (for D=6) - * Then Enull'*F = 1*4 * 4*12 = 1*12, but each 1*6 piece can be computed as a 1x2 * 2x6 mult + * The code below assumes that F is block diagonal and is given as a vector of + * ZDim*D blocks. Example: m = 4 (2 measurements), Enull = 4*1, F = 4*12 (for + * D=6) Then Enull'*F = 1*4 * 4*12 = 1*12, but each 1*6 piece can be computed as + * a 1x2 * 2x6 multiplication. */ template class JacobianFactorSVD: public RegularJacobianFactor { @@ -37,10 +38,10 @@ public: JacobianFactorSVD() { } - /// Empty constructor with keys - JacobianFactorSVD(const KeyVector& keys, // - const SharedDiagonal& model = SharedDiagonal()) : - Base() { + /// Empty constructor with keys. + JacobianFactorSVD(const KeyVector& keys, + const SharedDiagonal& model = SharedDiagonal()) + : Base() { Matrix zeroMatrix = Matrix::Zero(0, D); Vector zeroVector = Vector::Zero(0); std::vector QF; @@ -51,18 +52,21 @@ public: } /** - * @brief Constructor - * Takes the CameraSet derivatives (as ZDim*D blocks of block-diagonal F) - * and a reduced point derivative, Enull - * and creates a reduced-rank Jacobian factor on the CameraSet + * @brief Construct a new JacobianFactorSVD object, createing a reduced-rank + * Jacobian factor on the CameraSet. * - * @Fblocks: + * @param keys keys associated with F blocks. + * @param Fblocks CameraSet derivatives, ZDim*D blocks of block-diagonal F + * @param Enull a reduced point derivative + * @param b right-hand side + * @param model noise model */ - JacobianFactorSVD(const KeyVector& keys, - const std::vector >& Fblocks, const Matrix& Enull, - const Vector& b, // - const SharedDiagonal& model = SharedDiagonal()) : - Base() { + JacobianFactorSVD( + const KeyVector& keys, + const std::vector >& Fblocks, + const Matrix& Enull, const Vector& b, + const SharedDiagonal& model = SharedDiagonal()) + : Base() { size_t numKeys = Enull.rows() / ZDim; size_t m2 = ZDim * numKeys - 3; // TODO: is this not just Enull.rows()? // PLAIN nullptr SPACE TRICK @@ -74,9 +78,8 @@ public: QF.reserve(numKeys); for (size_t k = 0; k < Fblocks.size(); ++k) { Key key = keys[k]; - QF.push_back( - KeyMatrix(key, - (Enull.transpose()).block(0, ZDim * k, m2, ZDim) * Fblocks[k])); + QF.emplace_back( + key, (Enull.transpose()).block(0, ZDim * k, m2, ZDim) * Fblocks[k]); } JacobianFactor::fillTerms(QF, Enull.transpose() * b, model); } diff --git a/gtsam/slam/KarcherMeanFactor-inl.h b/gtsam/slam/KarcherMeanFactor-inl.h index c81a9adc5..00f741705 100644 --- a/gtsam/slam/KarcherMeanFactor-inl.h +++ b/gtsam/slam/KarcherMeanFactor-inl.h @@ -40,8 +40,7 @@ T FindKarcherMeanImpl(const vector& rotations) { return result.at(kKey); } -template ::value >::type > +template T FindKarcherMean(const std::vector& rotations) { return FindKarcherMeanImpl(rotations); } diff --git a/gtsam/slam/OrientedPlane3Factor.h b/gtsam/slam/OrientedPlane3Factor.h index d7b836dec..81bb790de 100644 --- a/gtsam/slam/OrientedPlane3Factor.h +++ b/gtsam/slam/OrientedPlane3Factor.h @@ -15,7 +15,7 @@ namespace gtsam { /** * Factor to measure a planar landmark from a given pose */ -class OrientedPlane3Factor: public NoiseModelFactor2 { +class GTSAM_EXPORT OrientedPlane3Factor: public NoiseModelFactor2 { protected: OrientedPlane3 measured_p_; typedef NoiseModelFactor2 Base; @@ -49,7 +49,7 @@ class OrientedPlane3Factor: public NoiseModelFactor2 { }; // TODO: Convert this factor to dimension two, three dimensions is redundant for direction prior -class OrientedPlane3DirectionPrior : public NoiseModelFactor1 { +class GTSAM_EXPORT OrientedPlane3DirectionPrior : public NoiseModelFactor1 { protected: OrientedPlane3 measured_p_; /// measured plane parameters typedef NoiseModelFactor1 Base; diff --git a/gtsam/slam/PoseRotationPrior.h b/gtsam/slam/PoseRotationPrior.h index ba4d12a25..20f12dbce 100644 --- a/gtsam/slam/PoseRotationPrior.h +++ b/gtsam/slam/PoseRotationPrior.h @@ -39,6 +39,9 @@ protected: public: + /** default constructor - only use for serialization */ + PoseRotationPrior() {} + /** standard constructor */ PoseRotationPrior(Key key, const Rotation& rot_z, const SharedNoiseModel& model) : Base(model, key), measured_(rot_z) {} diff --git a/gtsam/slam/ProjectionFactor.h b/gtsam/slam/ProjectionFactor.h index ada304f27..42dba8bd0 100644 --- a/gtsam/slam/ProjectionFactor.h +++ b/gtsam/slam/ProjectionFactor.h @@ -11,7 +11,7 @@ /** * @file ProjectionFactor.h - * @brief Basic bearing factor from 2D measurement + * @brief Reprojection of a LANDMARK to a 2D point. * @author Chris Beall * @author Richard Roberts * @author Frank Dellaert @@ -22,17 +22,21 @@ #include #include +#include +#include #include #include namespace gtsam { /** - * Non-linear factor for a constraint derived from a 2D measurement. The calibration is known here. - * i.e. the main building block for visual SLAM. + * Non-linear factor for a constraint derived from a 2D measurement. + * The calibration is known here. + * The main building block for visual SLAM. * @addtogroup SLAM */ - template + template class GenericProjectionFactor: public NoiseModelFactor2 { protected: @@ -57,9 +61,9 @@ namespace gtsam { typedef boost::shared_ptr shared_ptr; /// Default constructor - GenericProjectionFactor() : - measured_(0, 0), throwCheirality_(false), verboseCheirality_(false) { - } + GenericProjectionFactor() : + measured_(0, 0), throwCheirality_(false), verboseCheirality_(false) { + } /** * Constructor diff --git a/gtsam/slam/README.md b/gtsam/slam/README.md new file mode 100644 index 000000000..ae5edfdac --- /dev/null +++ b/gtsam/slam/README.md @@ -0,0 +1,68 @@ +# SLAM Factors + +## GenericProjectionFactor (defined in ProjectionFactor.h) + +Non-linear factor that minimizes the re-projection error with respect to a 2D measurement. +The calibration is assumed known and passed in the constructor. +The main building block for visual SLAM. + +Templated on +- `POSE`, default `Pose3` +- `LANDMARK`, default `Point3` +- `CALIBRATION`, default `Cal3_S2` + +## SmartFactors + +These are "structure-less" factors, i.e., rather than introducing a new variable for an observed 3D point or landmark, a single factor is created that provides a multi-view constraint on several poses and/or cameras. +While one typically adds multiple GenericProjectionFactors (one for each observation of a landmark), a SmartFactor collects all measurements for a landmark, i.e., the factor graph contains 1 smart factor per landmark. + +### SmartFactorBase + +This is the base class for smart factors, templated on a `CAMERA` type. +It has no internal point, but it saves the measurements, keeps a noise model, and an optional sensor pose. + +### SmartProjectionFactor + +Also templated on `CAMERA`. Triangulates a 3D point and keeps an estimate of it around. +This factor operates with monocular cameras, and is used to optimize the camera pose +*and* calibration for each camera, and requires variables of type `CAMERA` in values. + +If the calibration is fixed use `SmartProjectionPoseFactor` instead! + + +### SmartProjectionPoseFactor + +Derives from `SmartProjectionFactor` but is templated on a `CALIBRATION` type, setting `CAMERA = PinholePose`. +This factor assumes that the camera calibration is fixed and the same for all cameras involved in this factor. +The factor only constrains poses. + +If the calibration should be optimized, as well, use `SmartProjectionFactor` instead! + +### SmartProjectionRigFactor + +Same as `SmartProjectionPoseFactor`, except: +- it is templated on `CAMERA`, i.e., it allows cameras beyond pinhole; +- it allows measurements from multiple cameras, each camera with fixed but potentially different intrinsics and extrinsics; +- it allows multiple observations from the same pose/key, again, to model a multi-camera system. + +## Linearized Smart Factors + +The factors below are less likely to be relevant to the user, but result from using the non-linear smart factors above. + + +### RegularImplicitSchurFactor + +A specialization of a GaussianFactor to structure-less SFM, which is very fast in a conjugate gradient (CG) solver. +It is produced by calling `createRegularImplicitSchurFactor` in `SmartFactorBase` or `SmartProjectionFactor`. + +### JacobianFactorQ + +A RegularJacobianFactor that uses some badly documented reduction on the Jacobians. + +### JacobianFactorQR + +A RegularJacobianFactor that eliminates a point using sequential elimination. + +### JacobianFactorQR + +A RegularJacobianFactor that uses the "Nullspace Trick" by Mourikis et al. See the documentation in the file, which *is* well documented. \ No newline at end of file diff --git a/gtsam/slam/RegularImplicitSchurFactor.h b/gtsam/slam/RegularImplicitSchurFactor.h index 2ed6aa491..340f84018 100644 --- a/gtsam/slam/RegularImplicitSchurFactor.h +++ b/gtsam/slam/RegularImplicitSchurFactor.h @@ -1,6 +1,6 @@ /** * @file RegularImplicitSchurFactor.h - * @brief A new type of linear factor (GaussianFactor), which is subclass of GaussianFactor + * @brief A subclass of GaussianFactor specialized to structureless SFM. * @author Frank Dellaert * @author Luca Carlone */ @@ -20,6 +20,20 @@ namespace gtsam { /** * RegularImplicitSchurFactor + * + * A specialization of a GaussianFactor to structure-less SFM, which is very + * fast in a conjugate gradient (CG) solver. Specifically, as measured in + * timeSchurFactors.cpp, it stays very fast for an increasing number of cameras. + * The magic is in multiplyHessianAdd, which does the Hessian-vector multiply at + * the core of CG, and implements + * y += F'*alpha*(I - E*P*E')*F*x + * where + * - F is the 2mx6m Jacobian of the m 2D measurements wrpt m 6DOF poses + * - E is the 2mx3 Jacabian of the m 2D measurements wrpt a 3D point + * - P is the covariance on the point + * The equation above implicitly executes the Schur complement by removing the + * information E*P*E' from the Hessian. It is also very fast as we do not use + * the full 6m*6m F matrix, but rather only it's m 6x6 diagonal blocks. */ template class RegularImplicitSchurFactor: public GaussianFactor { @@ -38,9 +52,10 @@ protected: static const int ZDim = traits::dimension; ///< Measurement dimension typedef Eigen::Matrix MatrixZD; ///< type of an F block - typedef Eigen::Matrix MatrixDD; ///< camera hessian + typedef Eigen::Matrix MatrixDD; ///< camera Hessian + typedef std::vector > FBlocks; - const std::vector > FBlocks_; ///< All ZDim*D F blocks (one for each camera) + FBlocks FBlocks_; ///< All ZDim*D F blocks (one for each camera) const Matrix PointCovariance_; ///< the 3*3 matrix P = inv(E'E) (2*2 if degenerate) const Matrix E_; ///< The 2m*3 E Jacobian with respect to the point const Vector b_; ///< 2m-dimensional RHS vector @@ -52,17 +67,25 @@ public: } /// Construct from blocks of F, E, inv(E'*E), and RHS vector b - RegularImplicitSchurFactor(const KeyVector& keys, - const std::vector >& FBlocks, const Matrix& E, const Matrix& P, - const Vector& b) : - GaussianFactor(keys), FBlocks_(FBlocks), PointCovariance_(P), E_(E), b_(b) { - } + + /** + * @brief Construct a new RegularImplicitSchurFactor object. + * + * @param keys keys corresponding to cameras + * @param Fs All ZDim*D F blocks (one for each camera) + * @param E Jacobian of measurements wrpt point. + * @param P point covariance matrix + * @param b RHS vector + */ + RegularImplicitSchurFactor(const KeyVector& keys, const FBlocks& Fs, + const Matrix& E, const Matrix& P, const Vector& b) + : GaussianFactor(keys), FBlocks_(Fs), PointCovariance_(P), E_(E), b_(b) {} /// Destructor ~RegularImplicitSchurFactor() override { } - std::vector >& FBlocks() const { + const FBlocks& Fs() const { return FBlocks_; } @@ -237,10 +260,6 @@ public: "RegularImplicitSchurFactor::clone non implemented"); } - bool empty() const override { - return false; - } - GaussianFactor::shared_ptr negate() const override { return boost::make_shared >(keys_, FBlocks_, PointCovariance_, E_, b_); diff --git a/gtsam/slam/SmartFactorBase.h b/gtsam/slam/SmartFactorBase.h index 380283141..ca158cc1d 100644 --- a/gtsam/slam/SmartFactorBase.h +++ b/gtsam/slam/SmartFactorBase.h @@ -37,12 +37,14 @@ namespace gtsam { /** - * @brief Base class for smart factors + * @brief Base class for smart factors. * This base class has no internal point, but it has a measurement, noise model * and an optional sensor pose. - * This class mainly computes the derivatives and returns them as a variety of factors. - * The methods take a Cameras argument, which should behave like PinholeCamera, and - * the value of a point, which is kept in the base class. + * This class mainly computes the derivatives and returns them as a variety of + * factors. The methods take a CameraSet argument and the value of a + * point, which is kept in the derived class. + * + * @tparam CAMERA should behave like a PinholeCamera. */ template class SmartFactorBase: public NonlinearFactor { @@ -64,19 +66,20 @@ protected: /** * As of Feb 22, 2015, the noise model is the same for all measurements and * is isotropic. This allows for moving most calculations of Schur complement - * etc to be moved to CameraSet very easily, and also agrees pragmatically + * etc. to be easily moved to CameraSet, and also agrees pragmatically * with what is normally done. */ SharedIsotropic noiseModel_; /** - * 2D measurement and noise model for each of the m views - * We keep a copy of measurements for I/O and computing the error. + * Measurements for each of the m views. + * We keep a copy of the measurements for I/O and computing the error. * The order is kept the same as the keys that we use to create the factor. */ ZVector measured_; - boost::optional body_P_sensor_; ///< Pose of the camera in the body frame + boost::optional + body_P_sensor_; ///< Pose of the camera in the body frame // Cache for Fblocks, to avoid a malloc ever time we re-linearize mutable FBlocks Fs; @@ -84,16 +87,16 @@ protected: public: GTSAM_MAKE_ALIGNED_OPERATOR_NEW - /// shorthand for a smart pointer to a factor + /// shorthand for a smart pointer to a factor. typedef boost::shared_ptr shared_ptr; - /// We use the new CameraSte data structure to refer to a set of cameras + /// The CameraSet data structure is used to refer to a set of cameras. typedef CameraSet Cameras; - /// Default Constructor, for serialization + /// Default Constructor, for serialization. SmartFactorBase() {} - /// Constructor + /// Construct with given noise model and optional arguments. SmartFactorBase(const SharedNoiseModel& sharedNoiseModel, boost::optional body_P_sensor = boost::none, size_t expectedNumberCameras = 10) @@ -111,12 +114,12 @@ protected: noiseModel_ = sharedIsotropic; } - /// Virtual destructor, subclasses from NonlinearFactor + /// Virtual destructor, subclasses from NonlinearFactor. ~SmartFactorBase() override { } /** - * Add a new measurement and pose/camera key + * Add a new measurement and pose/camera key. * @param measured is the 2m dimensional projection of a single landmark * @param key is the index corresponding to the camera observing the landmark */ @@ -129,9 +132,7 @@ protected: this->keys_.push_back(key); } - /** - * Add a bunch of measurements, together with the camera keys - */ + /// Add a bunch of measurements, together with the camera keys. void add(const ZVector& measurements, const KeyVector& cameraKeys) { assert(measurements.size() == cameraKeys.size()); for (size_t i = 0; i < measurements.size(); i++) { @@ -140,28 +141,24 @@ protected: } /** - * Adds an entire SfM_track (collection of cameras observing a single point). - * The noise is assumed to be the same for all measurements + * Add an entire SfM_track (collection of cameras observing a single point). + * The noise is assumed to be the same for all measurements. */ template void add(const SFM_TRACK& trackToAdd) { - for (size_t k = 0; k < trackToAdd.number_measurements(); k++) { + for (size_t k = 0; k < trackToAdd.numberMeasurements(); k++) { this->measured_.push_back(trackToAdd.measurements[k].second); this->keys_.push_back(trackToAdd.measurements[k].first); } } - /// get the dimension (number of rows!) of the factor - size_t dim() const override { - return ZDim * this->measured_.size(); - } + /// Return the dimension (number of rows!) of the factor. + size_t dim() const override { return ZDim * this->measured_.size(); } - /** return the measurements */ - const ZVector& measured() const { - return measured_; - } + /// Return the 2D measurements (ZDim, in general). + const ZVector& measured() const { return measured_; } - /// Collect all cameras: important that in key order + /// Collect all cameras: important that in key order. virtual Cameras cameras(const Values& values) const { Cameras cameras; for(const Key& k: this->keys_) @@ -188,25 +185,30 @@ protected: /// equals bool equals(const NonlinearFactor& p, double tol = 1e-9) const override { - const This *e = dynamic_cast(&p); - - bool areMeasurementsEqual = true; - for (size_t i = 0; i < measured_.size(); i++) { - if (traits::Equals(this->measured_.at(i), e->measured_.at(i), tol) == false) - areMeasurementsEqual = false; - break; + if (const This* e = dynamic_cast(&p)) { + // Check that all measurements are the same. + for (size_t i = 0; i < measured_.size(); i++) { + if (!traits::Equals(this->measured_.at(i), e->measured_.at(i), tol)) + return false; + } + // If so, check base class. + return Base::equals(p, tol); + } else { + return false; } - return e && Base::equals(p, tol) && areMeasurementsEqual; } /// Compute reprojection errors [h(x)-z] = [cameras.project(p)-z] and - /// derivatives + /// derivatives. This is the error before the noise model is applied. template Vector unwhitenedError( const Cameras& cameras, const POINT& point, boost::optional Fs = boost::none, // boost::optional E = boost::none) const { - Vector ue = cameras.reprojectionError(point, measured_, Fs, E); + // Reproject, with optional derivatives. + Vector error = cameras.reprojectionError(point, measured_, Fs, E); + + // Apply chain rule if body_P_sensor_ is given. if (body_P_sensor_ && Fs) { const Pose3 sensor_P_body = body_P_sensor_->inverse(); constexpr int camera_dim = traits::dimension; @@ -224,52 +226,60 @@ protected: Fs->at(i) = Fs->at(i) * J; } } - correctForMissingMeasurements(cameras, ue, Fs, E); - return ue; + + // Correct the Jacobians in case some measurements are missing. + correctForMissingMeasurements(cameras, error, Fs, E); + + return error; } /** - * This corrects the Jacobians for the case in which some pixel measurement is missing (nan) - * In practice, this does not do anything in the monocular case, but it is implemented in the stereo version + * This corrects the Jacobians for the case in which some 2D measurement is + * missing (nan). In practice, this does not do anything in the monocular + * case, but it is implemented in the stereo version. */ - virtual void correctForMissingMeasurements(const Cameras& cameras, Vector& ue, boost::optional Fs = boost::none, - boost::optional E = boost::none) const {} + virtual void correctForMissingMeasurements( + const Cameras& cameras, Vector& ue, + boost::optional Fs = boost::none, + boost::optional E = boost::none) const {} /** - * Calculate vector of re-projection errors [h(x)-z] = [cameras.project(p) - z] - * Noise model applied + * Calculate vector of re-projection errors [h(x)-z] = [cameras.project(p) - + * z], with the noise model applied. */ template Vector whitenedError(const Cameras& cameras, const POINT& point) const { - Vector e = cameras.reprojectionError(point, measured_); + Vector error = cameras.reprojectionError(point, measured_); if (noiseModel_) - noiseModel_->whitenInPlace(e); - return e; + noiseModel_->whitenInPlace(error); + return error; } - /** Calculate the error of the factor. - * This is the log-likelihood, e.g. \f$ 0.5(h(x)-z)^2/\sigma^2 \f$ in case of Gaussian. - * In this class, we take the raw prediction error \f$ h(x)-z \f$, ask the noise model - * to transform it to \f$ (h(x)-z)^2/\sigma^2 \f$, and then multiply by 0.5. - * Will be used in "error(Values)" function required by NonlinearFactor base class + /** + * Calculate the error of the factor. + * This is the log-likelihood, e.g. \f$ 0.5(h(x)-z)^2/\sigma^2 \f$ in case of + * Gaussian. In this class, we take the raw prediction error \f$ h(x)-z \f$, + * ask the noise model to transform it to \f$ (h(x)-z)^2/\sigma^2 \f$, and + * then multiply by 0.5. Will be used in "error(Values)" function required by + * NonlinearFactor base class */ template double totalReprojectionError(const Cameras& cameras, const POINT& point) const { - Vector e = whitenedError(cameras, point); - return 0.5 * e.dot(e); + Vector error = whitenedError(cameras, point); + return 0.5 * error.dot(error); } - /// Computes Point Covariance P from E - static Matrix PointCov(Matrix& E) { + /// Computes Point Covariance P from the "point Jacobian" E. + static Matrix PointCov(const Matrix& E) { return (E.transpose() * E).inverse(); } /** - * Compute F, E, and b (called below in both vanilla and SVD versions), where - * F is a vector of derivatives wrpt the cameras, and E the stacked derivatives - * with respect to the point. The value of cameras/point are passed as parameters. - * TODO: Kill this obsolete method + * Compute F, E, and b (called below in both vanilla and SVD versions), where + * F is a vector of derivatives wrpt the cameras, and E the stacked + * derivatives with respect to the point. The value of cameras/point are + * passed as parameters. */ template void computeJacobians(FBlocks& Fs, Matrix& E, Vector& b, @@ -281,7 +291,11 @@ protected: b = -unwhitenedError(cameras, point, Fs, E); } - /// SVD version + /** + * SVD version that produces smaller Jacobian matrices by doing an SVD + * decomposition on E, and returning the left nulkl-space of E. + * See JacobianFactorSVD for more documentation. + * */ template void computeJacobiansSVD(FBlocks& Fs, Matrix& Enull, Vector& b, const Cameras& cameras, const POINT& point) const { @@ -291,14 +305,14 @@ protected: static const int N = FixedDimension::value; // 2 (Unit3) or 3 (Point3) - // Do SVD on A + // Do SVD on A. Eigen::JacobiSVD svd(E, Eigen::ComputeFullU); - Vector s = svd.singularValues(); size_t m = this->keys_.size(); Enull = svd.matrixU().block(0, N, ZDim * m, ZDim * m - N); // last ZDim*m-N columns } - /// Linearize to a Hessianfactor + /// Linearize to a Hessianfactor. + // TODO(dellaert): Not used/tested anywhere and not properly whitened. boost::shared_ptr > createHessianFactor( const Cameras& cameras, const Point3& point, const double lambda = 0.0, bool diagonalDamping = false) const { @@ -351,9 +365,7 @@ protected: P, b); } - /** - * Return Jacobians as JacobianFactorQ - */ + /// Return Jacobians as JacobianFactorQ. boost::shared_ptr > createJacobianQFactor( const Cameras& cameras, const Point3& point, double lambda = 0.0, bool diagonalDamping = false) const { @@ -368,8 +380,8 @@ protected: } /** - * Return Jacobians as JacobianFactorSVD - * TODO lambda is currently ignored + * Return Jacobians as JacobianFactorSVD. + * TODO(dellaert): lambda is currently ignored */ boost::shared_ptr createJacobianSVDFactor( const Cameras& cameras, const Point3& point, double lambda = 0.0) const { @@ -393,7 +405,7 @@ protected: F.block(ZDim * i, Dim * i) = Fs.at(i); } - + // Return sensor pose. Pose3 body_P_sensor() const{ if(body_P_sensor_) return *body_P_sensor_; diff --git a/gtsam/slam/SmartProjectionFactor.h b/gtsam/slam/SmartProjectionFactor.h index f67ca0740..f9c101cb8 100644 --- a/gtsam/slam/SmartProjectionFactor.h +++ b/gtsam/slam/SmartProjectionFactor.h @@ -61,15 +61,17 @@ protected: /// @name Caching triangulation /// @{ mutable TriangulationResult result_; ///< result from triangulateSafe - mutable std::vector > cameraPosesTriangulation_; ///< current triangulation poses + mutable std::vector > + cameraPosesTriangulation_; ///< current triangulation poses /// @} -public: + public: /// shorthand for a smart pointer to a factor typedef boost::shared_ptr shared_ptr; /// shorthand for a set of cameras + typedef CAMERA Camera; typedef CameraSet Cameras; /** @@ -116,21 +118,31 @@ public: && Base::equals(p, tol); } - /// Check if the new linearization point is the same as the one used for previous triangulation + /** + * @brief Check if the new linearization point is the same as the one used for + * previous triangulation. + * + * @param cameras + * @return true if we need to re-triangulate. + */ bool decideIfTriangulate(const Cameras& cameras) const { - // several calls to linearize will be done from the same linearization point, hence it is not needed to re-triangulate - // Note that this is not yet "selecting linearization", that will come later, and we only check if the - // current linearization is the "same" (up to tolerance) w.r.t. the last time we triangulated the point + // Several calls to linearize will be done from the same linearization + // point, hence it is not needed to re-triangulate. Note that this is not + // yet "selecting linearization", that will come later, and we only check if + // the current linearization is the "same" (up to tolerance) w.r.t. the last + // time we triangulated the point. size_t m = cameras.size(); bool retriangulate = false; - // if we do not have a previous linearization point or the new linearization point includes more poses + // Definitely true if we do not have a previous linearization point or the + // new linearization point includes more poses. if (cameraPosesTriangulation_.empty() || cameras.size() != cameraPosesTriangulation_.size()) retriangulate = true; + // Otherwise, check poses against cache. if (!retriangulate) { for (size_t i = 0; i < cameras.size(); i++) { if (!cameras[i].pose().equals(cameraPosesTriangulation_[i], @@ -141,7 +153,8 @@ public: } } - if (retriangulate) { // we store the current poses used for triangulation + // Store the current poses used for triangulation if we will re-triangulate. + if (retriangulate) { cameraPosesTriangulation_.clear(); cameraPosesTriangulation_.reserve(m); for (size_t i = 0; i < m; i++) @@ -149,10 +162,15 @@ public: cameraPosesTriangulation_.push_back(cameras[i].pose()); } - return retriangulate; // if we arrive to this point all poses are the same and we don't need re-triangulation + return retriangulate; } - /// triangulateSafe + /** + * @brief Call gtsam::triangulateSafe iff we need to re-triangulate. + * + * @param cameras + * @return TriangulationResult + */ TriangulationResult triangulateSafe(const Cameras& cameras) const { size_t m = cameras.size(); @@ -166,17 +184,21 @@ public: return result_; } - /// triangulate + /** + * @brief Possibly re-triangulate before calculating Jacobians. + * + * @param cameras + * @return true if we could safely triangulate + */ bool triangulateForLinearize(const Cameras& cameras) const { triangulateSafe(cameras); // imperative, might reset result_ return bool(result_); } - /// linearize returns a Hessianfactor that is an approximation of error(p) + /// Create a Hessianfactor that is an approximation of error(p). boost::shared_ptr > createHessianFactor( - const Cameras& cameras, const double lambda = 0.0, bool diagonalDamping = - false) const { - + const Cameras& cameras, const double lambda = 0.0, + bool diagonalDamping = false) const { size_t numKeys = this->keys_.size(); // Create structures for Hessian Factors KeyVector js; @@ -184,39 +206,38 @@ public: std::vector gs(numKeys); if (this->measured_.size() != cameras.size()) - throw std::runtime_error("SmartProjectionHessianFactor: this->measured_" - ".size() inconsistent with input"); + throw std::runtime_error( + "SmartProjectionHessianFactor: this->measured_" + ".size() inconsistent with input"); triangulateSafe(cameras); if (params_.degeneracyMode == ZERO_ON_DEGENERACY && !result_) { // failed: return"empty" Hessian - for(Matrix& m: Gs) - m = Matrix::Zero(Base::Dim, Base::Dim); - for(Vector& v: gs) - v = Vector::Zero(Base::Dim); + for (Matrix& m : Gs) m = Matrix::Zero(Base::Dim, Base::Dim); + for (Vector& v : gs) v = Vector::Zero(Base::Dim); return boost::make_shared >(this->keys_, - Gs, gs, 0.0); + Gs, gs, 0.0); } // Jacobian could be 3D Point3 OR 2D Unit3, difference is E.cols(). - std::vector > Fblocks; + typename Base::FBlocks Fs; Matrix E; Vector b; - computeJacobiansWithTriangulatedPoint(Fblocks, E, b, cameras); + computeJacobiansWithTriangulatedPoint(Fs, E, b, cameras); // Whiten using noise model - Base::whitenJacobians(Fblocks, E, b); + Base::whitenJacobians(Fs, E, b); // build augmented hessian - SymmetricBlockMatrix augmentedHessian = // - Cameras::SchurComplement(Fblocks, E, b, lambda, diagonalDamping); + SymmetricBlockMatrix augmentedHessian = // + Cameras::SchurComplement(Fs, E, b, lambda, diagonalDamping); - return boost::make_shared >(this->keys_, - augmentedHessian); + return boost::make_shared >( + this->keys_, augmentedHessian); } - // create factor + // Create RegularImplicitSchurFactor factor. boost::shared_ptr > createRegularImplicitSchurFactor( const Cameras& cameras, double lambda) const { if (triangulateForLinearize(cameras)) @@ -226,7 +247,7 @@ public: return boost::shared_ptr >(); } - /// create factor + /// Create JacobianFactorQ factor. boost::shared_ptr > createJacobianQFactor( const Cameras& cameras, double lambda) const { if (triangulateForLinearize(cameras)) @@ -236,13 +257,13 @@ public: return boost::make_shared >(this->keys_); } - /// Create a factor, takes values + /// Create JacobianFactorQ factor, takes values. boost::shared_ptr > createJacobianQFactor( const Values& values, double lambda) const { return createJacobianQFactor(this->cameras(values), lambda); } - /// different (faster) way to compute Jacobian factor + /// Different (faster) way to compute a JacobianFactorSVD factor. boost::shared_ptr createJacobianSVDFactor( const Cameras& cameras, double lambda) const { if (triangulateForLinearize(cameras)) @@ -252,19 +273,19 @@ public: return boost::make_shared >(this->keys_); } - /// linearize to a Hessianfactor + /// Linearize to a Hessianfactor. virtual boost::shared_ptr > linearizeToHessian( const Values& values, double lambda = 0.0) const { return createHessianFactor(this->cameras(values), lambda); } - /// linearize to an Implicit Schur factor + /// Linearize to an Implicit Schur factor. virtual boost::shared_ptr > linearizeToImplicit( const Values& values, double lambda = 0.0) const { return createRegularImplicitSchurFactor(this->cameras(values), lambda); } - /// linearize to a JacobianfactorQ + /// Linearize to a JacobianfactorQ. virtual boost::shared_ptr > linearizeToJacobian( const Values& values, double lambda = 0.0) const { return createJacobianQFactor(this->cameras(values), lambda); @@ -334,7 +355,7 @@ public: /// Assumes the point has been computed /// Note E can be 2m*3 or 2m*2, in case point is degenerate void computeJacobiansWithTriangulatedPoint( - std::vector >& Fblocks, Matrix& E, Vector& b, + typename Base::FBlocks& Fs, Matrix& E, Vector& b, const Cameras& cameras) const { if (!result_) { @@ -342,32 +363,32 @@ public: // TODO check flag whether we should do this Unit3 backProjected = cameras[0].backprojectPointAtInfinity( this->measured_.at(0)); - Base::computeJacobians(Fblocks, E, b, cameras, backProjected); + Base::computeJacobians(Fs, E, b, cameras, backProjected); } else { // valid result: just return Base version - Base::computeJacobians(Fblocks, E, b, cameras, *result_); + Base::computeJacobians(Fs, E, b, cameras, *result_); } } /// Version that takes values, and creates the point bool triangulateAndComputeJacobians( - std::vector >& Fblocks, Matrix& E, Vector& b, + typename Base::FBlocks& Fs, Matrix& E, Vector& b, const Values& values) const { Cameras cameras = this->cameras(values); bool nonDegenerate = triangulateForLinearize(cameras); if (nonDegenerate) - computeJacobiansWithTriangulatedPoint(Fblocks, E, b, cameras); + computeJacobiansWithTriangulatedPoint(Fs, E, b, cameras); return nonDegenerate; } /// takes values bool triangulateAndComputeJacobiansSVD( - std::vector >& Fblocks, Matrix& Enull, Vector& b, + typename Base::FBlocks& Fs, Matrix& Enull, Vector& b, const Values& values) const { Cameras cameras = this->cameras(values); bool nonDegenerate = triangulateForLinearize(cameras); if (nonDegenerate) - Base::computeJacobiansSVD(Fblocks, Enull, b, cameras, *result_); + Base::computeJacobiansSVD(Fs, Enull, b, cameras, *result_); return nonDegenerate; } diff --git a/gtsam/slam/SmartProjectionPoseFactor.h b/gtsam/slam/SmartProjectionPoseFactor.h index c7b1d5424..f4c0c73aa 100644 --- a/gtsam/slam/SmartProjectionPoseFactor.h +++ b/gtsam/slam/SmartProjectionPoseFactor.h @@ -41,11 +41,10 @@ namespace gtsam { * If the calibration should be optimized, as well, use SmartProjectionFactor instead! * @addtogroup SLAM */ -template -class SmartProjectionPoseFactor: public SmartProjectionFactor< - PinholePose > { - -private: +template +class SmartProjectionPoseFactor + : public SmartProjectionFactor > { + private: typedef PinholePose Camera; typedef SmartProjectionFactor Base; typedef SmartProjectionPoseFactor This; @@ -156,7 +155,6 @@ public: ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_); } - }; // end of class declaration diff --git a/gtsam/slam/SmartProjectionRigFactor.h b/gtsam/slam/SmartProjectionRigFactor.h new file mode 100644 index 000000000..149c12928 --- /dev/null +++ b/gtsam/slam/SmartProjectionRigFactor.h @@ -0,0 +1,372 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SmartProjectionRigFactor.h + * @brief Smart factor on poses, assuming camera calibration is fixed. + * Same as SmartProjectionPoseFactor, except: + * - it is templated on CAMERA (i.e., it allows cameras beyond pinhole) + * - it admits a different calibration for each measurement (i.e., it + * can model a multi-camera rig system) + * - it allows multiple observations from the same pose/key (again, to + * model a multi-camera system) + * @author Luca Carlone + * @author Frank Dellaert + */ + +#pragma once + +#include + +namespace gtsam { +/** + * + * @addtogroup SLAM + * + * If you are using the factor, please cite: + * L. Carlone, Z. Kira, C. Beall, V. Indelman, F. Dellaert, Eliminating + * conditionally independent sets in factor graphs: a unifying perspective based + * on smart factors, Int. Conf. on Robotics and Automation (ICRA), 2014. + */ + +/** + * This factor assumes that camera calibration is fixed (but each measurement + * can be taken by a different camera in the rig, hence can have a different + * extrinsic and intrinsic calibration). The factor only constrains poses + * (variable dimension is 6 for each pose). This factor requires that values + * contains the involved poses (Pose3). If all measurements share the same + * calibration (i.e., are from the same camera), use SmartProjectionPoseFactor + * instead! If the calibration should be optimized, as well, use + * SmartProjectionFactor instead! + * @addtogroup SLAM + */ +template +class SmartProjectionRigFactor : public SmartProjectionFactor { + private: + typedef SmartProjectionFactor Base; + typedef SmartProjectionRigFactor This; + typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; + + static const int DimPose = 6; ///< Pose3 dimension + static const int ZDim = 2; ///< Measurement dimension + + protected: + /// vector of keys (one for each observation) with potentially repeated keys + KeyVector nonUniqueKeys_; + + /// cameras in the rig (fixed poses wrt body and intrinsics, for each camera) + boost::shared_ptr cameraRig_; + + /// vector of camera Ids (one for each observation, in the same order), + /// identifying which camera took the measurement + FastVector cameraIds_; + + public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW + + typedef CAMERA Camera; + typedef CameraSet Cameras; + + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + + /// Default constructor, only for serialization + SmartProjectionRigFactor() {} + + /** + * Constructor + * @param sharedNoiseModel isotropic noise model for the 2D feature + * measurements + * @param cameraRig set of cameras (fixed poses wrt body and intrinsics) in + * the camera rig + * @param params parameters for the smart projection factors + */ + SmartProjectionRigFactor( + const SharedNoiseModel& sharedNoiseModel, + const boost::shared_ptr& cameraRig, + const SmartProjectionParams& params = SmartProjectionParams()) + : Base(sharedNoiseModel, params), cameraRig_(cameraRig) { + // throw exception if configuration is not supported by this factor + if (Base::params_.degeneracyMode != gtsam::ZERO_ON_DEGENERACY) + throw std::runtime_error( + "SmartProjectionRigFactor: " + "degeneracyMode must be set to ZERO_ON_DEGENERACY"); + if (Base::params_.linearizationMode != gtsam::HESSIAN) + throw std::runtime_error( + "SmartProjectionRigFactor: " + "linearizationMode must be set to HESSIAN"); + } + + /** Virtual destructor */ + ~SmartProjectionRigFactor() override = default; + + /** + * add a new measurement, corresponding to an observation from pose "poseKey" + * and taken from the camera in the rig identified by "cameraId" + * @param measured 2-dimensional location of the projection of a + * single landmark in a single view (the measurement) + * @param poseKey key corresponding to the body pose of the camera taking the + * measurement + * @param cameraId ID of the camera in the rig taking the measurement (default + * 0) + */ + void add(const MEASUREMENT& measured, const Key& poseKey, + const size_t& cameraId = 0) { + // store measurement and key + this->measured_.push_back(measured); + this->nonUniqueKeys_.push_back(poseKey); + + // also store keys in the keys_ vector: these keys are assumed to be + // unique, so we avoid duplicates here + if (std::find(this->keys_.begin(), this->keys_.end(), poseKey) == + this->keys_.end()) + this->keys_.push_back(poseKey); // add only unique keys + + // store id of the camera taking the measurement + cameraIds_.push_back(cameraId); + } + + /** + * Variant of the previous "add" function in which we include multiple + * measurements + * @param measurements vector of the 2m dimensional location of the projection + * of a single landmark in the m views (the measurements) + * @param poseKeys keys corresponding to the body poses of the cameras taking + * the measurements + * @param cameraIds IDs of the cameras in the rig taking each measurement + * (same order as the measurements) + */ + void add(const MEASUREMENTS& measurements, const KeyVector& poseKeys, + const FastVector& cameraIds = FastVector()) { + if (poseKeys.size() != measurements.size() || + (poseKeys.size() != cameraIds.size() && cameraIds.size() != 0)) { + throw std::runtime_error( + "SmartProjectionRigFactor: " + "trying to add inconsistent inputs"); + } + if (cameraIds.size() == 0 && cameraRig_->size() > 1) { + throw std::runtime_error( + "SmartProjectionRigFactor: " + "camera rig includes multiple camera " + "but add did not input cameraIds"); + } + for (size_t i = 0; i < measurements.size(); i++) { + add(measurements[i], poseKeys[i], + cameraIds.size() == 0 ? 0 : cameraIds[i]); + } + } + + /// return (for each observation) the (possibly non unique) keys involved in + /// the measurements + const KeyVector& nonUniqueKeys() const { return nonUniqueKeys_; } + + /// return the calibration object + const boost::shared_ptr& cameraRig() const { return cameraRig_; } + + /// return the calibration object + const FastVector& cameraIds() const { return cameraIds_; } + + /** + * print + * @param s optional string naming the factor + * @param keyFormatter optional formatter useful for printing Symbols + */ + void print( + const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + std::cout << s << "SmartProjectionRigFactor: \n "; + for (size_t i = 0; i < nonUniqueKeys_.size(); i++) { + std::cout << "-- Measurement nr " << i << std::endl; + std::cout << "key: " << keyFormatter(nonUniqueKeys_[i]) << std::endl; + std::cout << "cameraId: " << cameraIds_[i] << std::endl; + (*cameraRig_)[cameraIds_[i]].print("camera in rig:\n"); + } + Base::print("", keyFormatter); + } + + /// equals + bool equals(const NonlinearFactor& p, double tol = 1e-9) const override { + const This* e = dynamic_cast(&p); + return e && Base::equals(p, tol) && nonUniqueKeys_ == e->nonUniqueKeys() && + cameraRig_->equals(*(e->cameraRig())) && + std::equal(cameraIds_.begin(), cameraIds_.end(), + e->cameraIds().begin()); + } + + /** + * Collect all cameras involved in this factor + * @param values Values structure which must contain body poses corresponding + * to keys involved in this factor + * @return vector of cameras + */ + typename Base::Cameras cameras(const Values& values) const override { + typename Base::Cameras cameras; + cameras.reserve(nonUniqueKeys_.size()); // preallocate + for (size_t i = 0; i < nonUniqueKeys_.size(); i++) { + const typename Base::Camera& camera_i = (*cameraRig_)[cameraIds_[i]]; + const Pose3 world_P_sensor_i = + values.at(nonUniqueKeys_[i]) // = world_P_body + * camera_i.pose(); // = body_P_cam_i + cameras.emplace_back(world_P_sensor_i, + make_shared( + camera_i.calibration())); + } + return cameras; + } + + /** + * error calculates the error of the factor. + */ + double error(const Values& values) const override { + if (this->active(values)) { + return this->totalReprojectionError(this->cameras(values)); + } else { // else of active flag + return 0.0; + } + } + + /** + * Compute jacobian F, E and error vector at a given linearization point + * @param values Values structure which must contain camera poses + * corresponding to keys involved in this factor + * @return Return arguments are the camera jacobians Fs (including the + * jacobian with respect to both body poses we interpolate from), the point + * Jacobian E, and the error vector b. Note that the jacobians are computed + * for a given point. + */ + void computeJacobiansWithTriangulatedPoint(typename Base::FBlocks& Fs, + Matrix& E, Vector& b, + const Cameras& cameras) const { + if (!this->result_) { + throw("computeJacobiansWithTriangulatedPoint"); + } else { // valid result: compute jacobians + b = -cameras.reprojectionError(*this->result_, this->measured_, Fs, E); + for (size_t i = 0; i < Fs.size(); i++) { + const Pose3& body_P_sensor = (*cameraRig_)[cameraIds_[i]].pose(); + const Pose3 world_P_body = cameras[i].pose() * body_P_sensor.inverse(); + Eigen::Matrix H; + world_P_body.compose(body_P_sensor, H); + Fs.at(i) = Fs.at(i) * H; + } + } + } + + /// linearize and return a Hessianfactor that is an approximation of error(p) + boost::shared_ptr > createHessianFactor( + const Values& values, const double& lambda = 0.0, + bool diagonalDamping = false) const { + // we may have multiple observation sharing the same keys (e.g., 2 cameras + // measuring from the same body pose), hence the number of unique keys may + // be smaller than nrMeasurements + size_t nrUniqueKeys = + this->keys_ + .size(); // note: by construction, keys_ only contains unique keys + + Cameras cameras = this->cameras(values); + + // Create structures for Hessian Factors + std::vector js; + std::vector Gs(nrUniqueKeys * (nrUniqueKeys + 1) / 2); + std::vector gs(nrUniqueKeys); + + if (this->measured_.size() != cameras.size()) // 1 observation per camera + throw std::runtime_error( + "SmartProjectionRigFactor: " + "measured_.size() inconsistent with input"); + + // triangulate 3D point at given linearization point + this->triangulateSafe(cameras); + + if (!this->result_) { // failed: return "empty/zero" Hessian + if (this->params_.degeneracyMode == ZERO_ON_DEGENERACY) { + for (Matrix& m : Gs) m = Matrix::Zero(DimPose, DimPose); + for (Vector& v : gs) v = Vector::Zero(DimPose); + return boost::make_shared >(this->keys_, + Gs, gs, 0.0); + } else { + throw std::runtime_error( + "SmartProjectionRigFactor: " + "only supported degeneracy mode is ZERO_ON_DEGENERACY"); + } + } + + // compute Jacobian given triangulated 3D Point + typename Base::FBlocks Fs; + Matrix E; + Vector b; + this->computeJacobiansWithTriangulatedPoint(Fs, E, b, cameras); + + // Whiten using noise model + this->noiseModel_->WhitenSystem(E, b); + for (size_t i = 0; i < Fs.size(); i++) { + Fs[i] = this->noiseModel_->Whiten(Fs[i]); + } + + const Matrix3 P = Base::Cameras::PointCov(E, lambda, diagonalDamping); + + // Build augmented Hessian (with last row/column being the information + // vector) Note: we need to get the augumented hessian wrt the unique keys + // in key_ + SymmetricBlockMatrix augmentedHessianUniqueKeys = + Base::Cameras::template SchurComplementAndRearrangeBlocks<3, 6, 6>( + Fs, E, P, b, nonUniqueKeys_, this->keys_); + + return boost::make_shared >( + this->keys_, augmentedHessianUniqueKeys); + } + + /** + * Linearize to Gaussian Factor (possibly adding a damping factor Lambda for + * LM) + * @param values Values structure which must contain camera poses and + * extrinsic pose for this factor + * @return a Gaussian factor + */ + boost::shared_ptr linearizeDamped( + const Values& values, const double& lambda = 0.0) const { + // depending on flag set on construction we may linearize to different + // linear factors + switch (this->params_.linearizationMode) { + case HESSIAN: + return this->createHessianFactor(values, lambda); + default: + throw std::runtime_error( + "SmartProjectionRigFactor: unknown linearization mode"); + } + } + + /// linearize + boost::shared_ptr linearize( + const Values& values) const override { + return this->linearizeDamped(values); + } + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + // ar& BOOST_SERIALIZATION_NVP(nonUniqueKeys_); + // ar& BOOST_SERIALIZATION_NVP(cameraRig_); + // ar& BOOST_SERIALIZATION_NVP(cameraIds_); + } +}; +// end of class declaration + +/// traits +template +struct traits > + : public Testable > {}; + +} // namespace gtsam diff --git a/gtsam/slam/TriangulationFactor.h b/gtsam/slam/TriangulationFactor.h index f12053d29..b6da02d55 100644 --- a/gtsam/slam/TriangulationFactor.h +++ b/gtsam/slam/TriangulationFactor.h @@ -15,6 +15,8 @@ * @author Frank Dellaert */ +#pragma once + #include #include #include @@ -33,18 +35,18 @@ class TriangulationFactor: public NoiseModelFactor1 { public: /// CAMERA type - typedef CAMERA Camera; + using Camera = CAMERA; protected: /// shorthand for base class type - typedef NoiseModelFactor1 Base; + using Base = NoiseModelFactor1; /// shorthand for this class - typedef TriangulationFactor This; + using This = TriangulationFactor; /// shorthand for measurement type, e.g. Point2 or StereoPoint2 - typedef typename CAMERA::Measurement Measurement; + using Measurement = typename CAMERA::Measurement; // Keep a copy of measurement and calibration for I/O const CAMERA camera_; ///< CAMERA in which this landmark was seen @@ -55,9 +57,10 @@ protected: const bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; + using shared_ptr = boost::shared_ptr; /// Default constructor TriangulationFactor() : @@ -129,7 +132,7 @@ public: << std::endl; if (throwCheirality_) throw e; - return Eigen::Matrix::dimension,1>::Constant(2.0 * camera_.calibration().fx()); + return camera_.defaultErrorWhenTriangulatingBehindCamera(); } } diff --git a/gtsam/slam/dataset.cpp b/gtsam/slam/dataset.cpp index c8a8b15c5..71dd64dbb 100644 --- a/gtsam/slam/dataset.cpp +++ b/gtsam/slam/dataset.cpp @@ -54,8 +54,6 @@ using namespace std; namespace fs = boost::filesystem; using gtsam::symbol_shorthand::L; -using gtsam::symbol_shorthand::P; -using gtsam::symbol_shorthand::X; #define LINESIZE 81920 @@ -179,8 +177,8 @@ boost::optional parseVertexPose(istream &is, const string &tag) { } template <> -std::map parseVariables(const std::string &filename, - size_t maxIndex) { +GTSAM_EXPORT std::map parseVariables( + const std::string &filename, size_t maxIndex) { return parseToMap(filename, parseVertexPose, maxIndex); } @@ -201,22 +199,22 @@ boost::optional parseVertexLandmark(istream &is, } template <> -std::map parseVariables(const std::string &filename, - size_t maxIndex) { +GTSAM_EXPORT std::map parseVariables( + const std::string &filename, size_t maxIndex) { return parseToMap(filename, parseVertexLandmark, maxIndex); } /* ************************************************************************* */ // Interpret noise parameters according to flags -static SharedNoiseModel -createNoiseModel(const Vector6 v, bool smart, NoiseFormat noiseFormat, - KernelFunctionType kernelFunctionType) { +static SharedNoiseModel createNoiseModel( + const Vector6 &v, bool smart, NoiseFormat noiseFormat, + KernelFunctionType kernelFunctionType) { if (noiseFormat == NoiseFormatAUTO) { // Try to guess covariance matrix layout - if (v(0) != 0.0 && v(1) == 0.0 && v(2) != 0.0 && // + if (v(0) != 0.0 && v(1) == 0.0 && v(2) != 0.0 && // v(3) != 0.0 && v(4) == 0.0 && v(5) == 0.0) { noiseFormat = NoiseFormatGRAPH; - } else if (v(0) != 0.0 && v(1) == 0.0 && v(2) == 0.0 && // + } else if (v(0) != 0.0 && v(1) == 0.0 && v(2) == 0.0 && // v(3) != 0.0 && v(4) == 0.0 && v(5) != 0.0) { noiseFormat = NoiseFormatCOV; } else { @@ -386,13 +384,14 @@ boost::shared_ptr createSampler(const SharedNoiseModel &model) { /* ************************************************************************* */ // Implementation of parseMeasurements for Pose2 template <> +GTSAM_EXPORT std::vector> parseMeasurements(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, size_t maxIndex) { ParseMeasurement parse{model ? createSampler(model) : nullptr, maxIndex, true, NoiseFormatAUTO, - KernelFunctionTypeNONE}; + KernelFunctionTypeNONE, nullptr}; return parseToVector>(filename, parse); } @@ -413,6 +412,7 @@ static BinaryMeasurement convert(const BinaryMeasurement &p) { } template <> +GTSAM_EXPORT std::vector> parseMeasurements(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, @@ -428,6 +428,7 @@ parseMeasurements(const std::string &filename, /* ************************************************************************* */ // Implementation of parseFactors for Pose2 template <> +GTSAM_EXPORT std::vector::shared_ptr> parseFactors(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, @@ -777,8 +778,8 @@ boost::optional> parseVertexPose3(istream &is, } template <> -std::map parseVariables(const std::string &filename, - size_t maxIndex) { +GTSAM_EXPORT std::map parseVariables( + const std::string &filename, size_t maxIndex) { return parseToMap(filename, parseVertexPose3, maxIndex); } @@ -795,8 +796,8 @@ boost::optional> parseVertexPoint3(istream &is, } template <> -std::map parseVariables(const std::string &filename, - size_t maxIndex) { +GTSAM_EXPORT std::map parseVariables( + const std::string &filename, size_t maxIndex) { return parseToMap(filename, parseVertexPoint3, maxIndex); } @@ -870,6 +871,7 @@ template <> struct ParseMeasurement { /* ************************************************************************* */ // Implementation of parseMeasurements for Pose3 template <> +GTSAM_EXPORT std::vector> parseMeasurements(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, @@ -897,6 +899,7 @@ static BinaryMeasurement convert(const BinaryMeasurement &p) { } template <> +GTSAM_EXPORT std::vector> parseMeasurements(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, @@ -912,6 +915,7 @@ parseMeasurements(const std::string &filename, /* ************************************************************************* */ // Implementation of parseFactors for Pose3 template <> +GTSAM_EXPORT std::vector::shared_ptr> parseFactors(const std::string &filename, const noiseModel::Diagonal::shared_ptr &model, @@ -945,352 +949,6 @@ GraphAndValues load3D(const string &filename) { return make_pair(graph, initial); } -/* ************************************************************************* */ -Rot3 openGLFixedRotation() { // this is due to different convention for - // cameras in gtsam and openGL - /* R = [ 1 0 0 - * 0 -1 0 - * 0 0 -1] - */ - Matrix3 R_mat = Matrix3::Zero(3, 3); - R_mat(0, 0) = 1.0; - R_mat(1, 1) = -1.0; - R_mat(2, 2) = -1.0; - return Rot3(R_mat); -} - -/* ************************************************************************* */ -Pose3 openGL2gtsam(const Rot3 &R, double tx, double ty, double tz) { - Rot3 R90 = openGLFixedRotation(); - Rot3 wRc = (R.inverse()).compose(R90); - - // Our camera-to-world translation wTc = -R'*t - return Pose3(wRc, R.unrotate(Point3(-tx, -ty, -tz))); -} - -/* ************************************************************************* */ -Pose3 gtsam2openGL(const Rot3 &R, double tx, double ty, double tz) { - Rot3 R90 = openGLFixedRotation(); - Rot3 cRw_openGL = R90.compose(R.inverse()); - Point3 t_openGL = cRw_openGL.rotate(Point3(-tx, -ty, -tz)); - return Pose3(cRw_openGL, t_openGL); -} - -/* ************************************************************************* */ -Pose3 gtsam2openGL(const Pose3 &PoseGTSAM) { - return gtsam2openGL(PoseGTSAM.rotation(), PoseGTSAM.x(), PoseGTSAM.y(), - PoseGTSAM.z()); -} - -/* ************************************************************************* */ -bool readBundler(const string &filename, SfmData &data) { - // Load the data file - ifstream is(filename.c_str(), ifstream::in); - if (!is) { - cout << "Error in readBundler: can not find the file!!" << endl; - return false; - } - - // Ignore the first line - char aux[500]; - is.getline(aux, 500); - - // Get the number of camera poses and 3D points - size_t nrPoses, nrPoints; - is >> nrPoses >> nrPoints; - - // Get the information for the camera poses - for (size_t i = 0; i < nrPoses; i++) { - // Get the focal length and the radial distortion parameters - float f, k1, k2; - is >> f >> k1 >> k2; - Cal3Bundler K(f, k1, k2); - - // Get the rotation matrix - float r11, r12, r13; - float r21, r22, r23; - float r31, r32, r33; - is >> r11 >> r12 >> r13 >> r21 >> r22 >> r23 >> r31 >> r32 >> r33; - - // Bundler-OpenGL rotation matrix - Rot3 R(r11, r12, r13, r21, r22, r23, r31, r32, r33); - - // Check for all-zero R, in which case quit - if (r11 == 0 && r12 == 0 && r13 == 0) { - cout << "Error in readBundler: zero rotation matrix for pose " << i - << endl; - return false; - } - - // Get the translation vector - float tx, ty, tz; - is >> tx >> ty >> tz; - - Pose3 pose = openGL2gtsam(R, tx, ty, tz); - - data.cameras.emplace_back(pose, K); - } - - // Get the information for the 3D points - data.tracks.reserve(nrPoints); - for (size_t j = 0; j < nrPoints; j++) { - SfmTrack track; - - // Get the 3D position - float x, y, z; - is >> x >> y >> z; - track.p = Point3(x, y, z); - - // Get the color information - float r, g, b; - is >> r >> g >> b; - track.r = r / 255.f; - track.g = g / 255.f; - track.b = b / 255.f; - - // Now get the visibility information - size_t nvisible = 0; - is >> nvisible; - - track.measurements.reserve(nvisible); - track.siftIndices.reserve(nvisible); - for (size_t k = 0; k < nvisible; k++) { - size_t cam_idx = 0, point_idx = 0; - float u, v; - is >> cam_idx >> point_idx >> u >> v; - track.measurements.emplace_back(cam_idx, Point2(u, -v)); - track.siftIndices.emplace_back(cam_idx, point_idx); - } - - data.tracks.push_back(track); - } - - is.close(); - return true; -} - -/* ************************************************************************* */ -bool readBAL(const string &filename, SfmData &data) { - // Load the data file - ifstream is(filename.c_str(), ifstream::in); - if (!is) { - cout << "Error in readBAL: can not find the file!!" << endl; - return false; - } - - // Get the number of camera poses and 3D points - size_t nrPoses, nrPoints, nrObservations; - is >> nrPoses >> nrPoints >> nrObservations; - - data.tracks.resize(nrPoints); - - // Get the information for the observations - for (size_t k = 0; k < nrObservations; k++) { - size_t i = 0, j = 0; - float u, v; - is >> i >> j >> u >> v; - data.tracks[j].measurements.emplace_back(i, Point2(u, -v)); - } - - // Get the information for the camera poses - for (size_t i = 0; i < nrPoses; i++) { - // Get the Rodrigues vector - float wx, wy, wz; - is >> wx >> wy >> wz; - Rot3 R = Rot3::Rodrigues(wx, wy, wz); // BAL-OpenGL rotation matrix - - // Get the translation vector - float tx, ty, tz; - is >> tx >> ty >> tz; - - Pose3 pose = openGL2gtsam(R, tx, ty, tz); - - // Get the focal length and the radial distortion parameters - float f, k1, k2; - is >> f >> k1 >> k2; - Cal3Bundler K(f, k1, k2); - - data.cameras.emplace_back(pose, K); - } - - // Get the information for the 3D points - for (size_t j = 0; j < nrPoints; j++) { - // Get the 3D position - float x, y, z; - is >> x >> y >> z; - SfmTrack &track = data.tracks[j]; - track.p = Point3(x, y, z); - track.r = 0.4f; - track.g = 0.4f; - track.b = 0.4f; - } - - is.close(); - return true; -} - -/* ************************************************************************* */ -SfmData readBal(const string &filename) { - SfmData data; - readBAL(filename, data); - return data; -} - -/* ************************************************************************* */ -bool writeBAL(const string &filename, SfmData &data) { - // Open the output file - ofstream os; - os.open(filename.c_str()); - os.precision(20); - if (!os.is_open()) { - cout << "Error in writeBAL: can not open the file!!" << endl; - return false; - } - - // Write the number of camera poses and 3D points - size_t nrObservations = 0; - for (size_t j = 0; j < data.number_tracks(); j++) { - nrObservations += data.tracks[j].number_measurements(); - } - - // Write observations - os << data.number_cameras() << " " << data.number_tracks() << " " - << nrObservations << endl; - os << endl; - - for (size_t j = 0; j < data.number_tracks(); j++) { // for each 3D point j - const SfmTrack &track = data.tracks[j]; - - for (size_t k = 0; k < track.number_measurements(); - k++) { // for each observation of the 3D point j - size_t i = track.measurements[k].first; // camera id - double u0 = data.cameras[i].calibration().px(); - double v0 = data.cameras[i].calibration().py(); - - if (u0 != 0 || v0 != 0) { - cout << "writeBAL has not been tested for calibration with nonzero " - "(u0,v0)" - << endl; - } - - double pixelBALx = track.measurements[k].second.x() - - u0; // center of image is the origin - double pixelBALy = -(track.measurements[k].second.y() - - v0); // center of image is the origin - Point2 pixelMeasurement(pixelBALx, pixelBALy); - os << i /*camera id*/ << " " << j /*point id*/ << " " - << pixelMeasurement.x() /*u of the pixel*/ << " " - << pixelMeasurement.y() /*v of the pixel*/ << endl; - } - } - os << endl; - - // Write cameras - for (size_t i = 0; i < data.number_cameras(); i++) { // for each camera - Pose3 poseGTSAM = data.cameras[i].pose(); - Cal3Bundler cameraCalibration = data.cameras[i].calibration(); - Pose3 poseOpenGL = gtsam2openGL(poseGTSAM); - os << Rot3::Logmap(poseOpenGL.rotation()) << endl; - os << poseOpenGL.translation().x() << endl; - os << poseOpenGL.translation().y() << endl; - os << poseOpenGL.translation().z() << endl; - os << cameraCalibration.fx() << endl; - os << cameraCalibration.k1() << endl; - os << cameraCalibration.k2() << endl; - os << endl; - } - - // Write the points - for (size_t j = 0; j < data.number_tracks(); j++) { // for each 3D point j - Point3 point = data.tracks[j].p; - os << point.x() << endl; - os << point.y() << endl; - os << point.z() << endl; - os << endl; - } - - os.close(); - return true; -} - -bool writeBALfromValues(const string &filename, const SfmData &data, - Values &values) { - using Camera = PinholeCamera; - SfmData dataValues = data; - - // Store poses or cameras in SfmData - size_t nrPoses = values.count(); - if (nrPoses == - dataValues.number_cameras()) { // we only estimated camera poses - for (size_t i = 0; i < dataValues.number_cameras(); - i++) { // for each camera - Pose3 pose = values.at(X(i)); - Cal3Bundler K = dataValues.cameras[i].calibration(); - Camera camera(pose, K); - dataValues.cameras[i] = camera; - } - } else { - size_t nrCameras = values.count(); - if (nrCameras == dataValues.number_cameras()) { // we only estimated camera - // poses and calibration - for (size_t i = 0; i < nrCameras; i++) { // for each camera - Key cameraKey = i; // symbol('c',i); - Camera camera = values.at(cameraKey); - dataValues.cameras[i] = camera; - } - } else { - cout << "writeBALfromValues: different number of cameras in " - "SfM_dataValues (#cameras " - << dataValues.number_cameras() << ") and values (#cameras " - << nrPoses << ", #poses " << nrCameras << ")!!" << endl; - return false; - } - } - - // Store 3D points in SfmData - size_t nrPoints = values.count(), - nrTracks = dataValues.number_tracks(); - if (nrPoints != nrTracks) { - cout << "writeBALfromValues: different number of points in " - "SfM_dataValues (#points= " - << nrTracks << ") and values (#points " << nrPoints << ")!!" << endl; - } - - for (size_t j = 0; j < nrTracks; j++) { // for each point - Key pointKey = P(j); - if (values.exists(pointKey)) { - Point3 point = values.at(pointKey); - dataValues.tracks[j].p = point; - } else { - dataValues.tracks[j].r = 1.0; - dataValues.tracks[j].g = 0.0; - dataValues.tracks[j].b = 0.0; - dataValues.tracks[j].p = Point3(0, 0, 0); - } - } - - // Write SfmData to file - return writeBAL(filename, dataValues); -} - -Values initialCamerasEstimate(const SfmData &db) { - Values initial; - size_t i = 0; // NO POINTS: j = 0; - for (const SfmCamera &camera : db.cameras) - initial.insert(i++, camera); - return initial; -} - -Values initialCamerasAndPointsEstimate(const SfmData &db) { - Values initial; - size_t i = 0, j = 0; - for (const SfmCamera &camera : db.cameras) - initial.insert((i++), camera); - for (const SfmTrack &track : db.tracks) - initial.insert(P(j++), track.p); - return initial; -} - // Wrapper-friendly versions of parseFactors and parseFactors BetweenFactorPose2s parse2DFactors(const std::string &filename, @@ -1304,14 +962,14 @@ parse3DFactors(const std::string &filename, return parseFactors(filename, model, maxIndex); } -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 -std::map parse3DPoses(const std::string &filename, - size_t maxIndex) { +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +std::map GTSAM_DEPRECATED +parse3DPoses(const std::string &filename, size_t maxIndex) { return parseVariables(filename, maxIndex); } -std::map parse3DLandmarks(const std::string &filename, - size_t maxIndex) { +std::map GTSAM_DEPRECATED +parse3DLandmarks(const std::string &filename, size_t maxIndex) { return parseVariables(filename, maxIndex); } #endif diff --git a/gtsam/slam/dataset.h b/gtsam/slam/dataset.h index ec5d6dce9..dc450a9f7 100644 --- a/gtsam/slam/dataset.h +++ b/gtsam/slam/dataset.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -167,14 +168,11 @@ GTSAM_EXPORT GraphAndValues load2D( * @param kernelFunctionType whether to wrap the noise model in a robust kernel * @return graph and initial values */ -GTSAM_EXPORT GraphAndValues load2D(const std::string& filename, - SharedNoiseModel model = SharedNoiseModel(), size_t maxIndex = 0, bool addNoise = - false, bool smart = true, NoiseFormat noiseFormat = NoiseFormatAUTO, // - KernelFunctionType kernelFunctionType = KernelFunctionTypeNONE); - -/// @deprecated load2D now allows for arbitrary models and wrapping a robust kernel -GTSAM_EXPORT GraphAndValues load2D_robust(const std::string& filename, - const noiseModel::Base::shared_ptr& model, size_t maxIndex = 0); +GTSAM_EXPORT GraphAndValues +load2D(const std::string& filename, SharedNoiseModel model = SharedNoiseModel(), + size_t maxIndex = 0, bool addNoise = false, bool smart = true, + NoiseFormat noiseFormat = NoiseFormatAUTO, // + KernelFunctionType kernelFunctionType = KernelFunctionTypeNONE); /** save 2d graph */ GTSAM_EXPORT void save2D(const NonlinearFactorGraph& graph, @@ -189,8 +187,9 @@ GTSAM_EXPORT void save2D(const NonlinearFactorGraph& graph, * @param kernelFunctionType whether to wrap the noise model in a robust kernel * @return graph and initial values */ -GTSAM_EXPORT GraphAndValues readG2o(const std::string& g2oFile, const bool is3D = false, - KernelFunctionType kernelFunctionType = KernelFunctionTypeNONE); +GTSAM_EXPORT GraphAndValues +readG2o(const std::string& g2oFile, const bool is3D = false, + KernelFunctionType kernelFunctionType = KernelFunctionTypeNONE); /** * @brief This function writes a g2o file from @@ -210,286 +209,6 @@ GTSAM_EXPORT void writeG2o(const NonlinearFactorGraph& graph, /// Load TORO 3D Graph GTSAM_EXPORT GraphAndValues load3D(const std::string& filename); -/// A measurement with its camera index -typedef std::pair SfmMeasurement; - -/// Sift index for SfmTrack -typedef std::pair SiftIndex; - -/// Define the structure for the 3D points -struct SfmTrack { - SfmTrack(float r = 0, float g = 0, float b = 0): p(0,0,0), r(r), g(g), b(b) {} - SfmTrack(const gtsam::Point3& pt, float r = 0, float g = 0, float b = 0) : p(pt), r(r), g(g), b(b) {} - - Point3 p; ///< 3D position of the point - float r, g, b; ///< RGB color of the 3D point - std::vector measurements; ///< The 2D image projections (id,(u,v)) - std::vector siftIndices; - - /// Get RGB values describing 3d point - const Point3 rgb() const { return Point3(r, g, b); } - - /// Total number of measurements in this track - size_t number_measurements() const { - return measurements.size(); - } - /// Get the measurement (camera index, Point2) at pose index `idx` - SfmMeasurement measurement(size_t idx) const { - return measurements[idx]; - } - /// Get the SIFT feature index corresponding to the measurement at `idx` - SiftIndex siftIndex(size_t idx) const { - return siftIndices[idx]; - } - /// Get 3D point - const Point3& point3() const { - return p; - } - /// Add measurement (camera_idx, Point2) to track - void add_measurement(size_t idx, const gtsam::Point2& m) { - measurements.emplace_back(idx, m); - } - - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(ARCHIVE & ar, const unsigned int /*version*/) { - ar & BOOST_SERIALIZATION_NVP(p); - ar & BOOST_SERIALIZATION_NVP(r); - ar & BOOST_SERIALIZATION_NVP(g); - ar & BOOST_SERIALIZATION_NVP(b); - ar & BOOST_SERIALIZATION_NVP(measurements); - ar & BOOST_SERIALIZATION_NVP(siftIndices); - } - - /// assert equality up to a tolerance - bool equals(const SfmTrack &sfmTrack, double tol = 1e-9) const { - // check the 3D point - if (!p.isApprox(sfmTrack.p)) { - return false; - } - - // check the RGB values - if (r!=sfmTrack.r || g!=sfmTrack.g || b!=sfmTrack.b) { - return false; - } - - // compare size of vectors for measurements and siftIndices - if (number_measurements() != sfmTrack.number_measurements() || - siftIndices.size() != sfmTrack.siftIndices.size()) { - return false; - } - - // compare measurements (order sensitive) - for (size_t idx = 0; idx < number_measurements(); ++idx) { - SfmMeasurement measurement = measurements[idx]; - SfmMeasurement otherMeasurement = sfmTrack.measurements[idx]; - - if (measurement.first != otherMeasurement.first || - !measurement.second.isApprox(otherMeasurement.second)) { - return false; - } - } - - // compare sift indices (order sensitive) - for (size_t idx = 0; idx < siftIndices.size(); ++idx) { - SiftIndex index = siftIndices[idx]; - SiftIndex otherIndex = sfmTrack.siftIndices[idx]; - - if (index.first != otherIndex.first || - index.second != otherIndex.second) { - return false; - } - } - - return true; - } - - /// print - void print(const std::string& s = "") const { - std::cout << "Track with " << measurements.size(); - std::cout << " measurements of point " << p << std::endl; - } -}; - -/* ************************************************************************* */ -/// traits -template<> -struct traits : public Testable { -}; - - -/// Define the structure for the camera poses -typedef PinholeCamera SfmCamera; - -/// Define the structure for SfM data -struct SfmData { - std::vector cameras; ///< Set of cameras - std::vector tracks; ///< Sparse set of points - size_t number_cameras() const { - return cameras.size(); - } - /// The number of reconstructed 3D points - size_t number_tracks() const { - return tracks.size(); - } - /// The camera pose at frame index `idx` - SfmCamera camera(size_t idx) const { - return cameras[idx]; - } - /// The track formed by series of landmark measurements - SfmTrack track(size_t idx) const { - return tracks[idx]; - } - /// Add a track to SfmData - void add_track(const SfmTrack& t) { - tracks.push_back(t); - } - /// Add a camera to SfmData - void add_camera(const SfmCamera& cam) { - cameras.push_back(cam); - } - - /** Serialization function */ - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & BOOST_SERIALIZATION_NVP(cameras); - ar & BOOST_SERIALIZATION_NVP(tracks); - } - - /// @} - /// @name Testable - /// @{ - - /// assert equality up to a tolerance - bool equals(const SfmData &sfmData, double tol = 1e-9) const { - // check number of cameras and tracks - if (number_cameras() != sfmData.number_cameras() || - number_tracks() != sfmData.number_tracks()) { - return false; - } - - // check each camera - for (size_t i = 0; i < number_cameras(); ++i) { - if (!camera(i).equals(sfmData.camera(i), tol)) { - return false; - } - } - - // check each track - for (size_t j = 0; j < number_tracks(); ++j) { - if (!track(j).equals(sfmData.track(j), tol)) { - return false; - } - } - - return true; - } - - /// print - void print(const std::string& s = "") const { - std::cout << "Number of cameras = " << number_cameras() << std::endl; - std::cout << "Number of tracks = " << number_tracks() << std::endl; - } -}; - -/* ************************************************************************* */ -/// traits -template<> -struct traits : public Testable { -}; - -/** - * @brief This function parses a bundler output file and stores the data into a - * SfmData structure - * @param filename The name of the bundler file - * @param data SfM structure where the data is stored - * @return true if the parsing was successful, false otherwise - */ -GTSAM_EXPORT bool readBundler(const std::string& filename, SfmData &data); - -/** - * @brief This function parses a "Bundle Adjustment in the Large" (BAL) file and stores the data into a - * SfmData structure - * @param filename The name of the BAL file - * @param data SfM structure where the data is stored - * @return true if the parsing was successful, false otherwise - */ -GTSAM_EXPORT bool readBAL(const std::string& filename, SfmData &data); - -/** - * @brief This function parses a "Bundle Adjustment in the Large" (BAL) file and returns the data - * as a SfmData structure. Mainly used by wrapped code. - * @param filename The name of the BAL file. - * @return SfM structure where the data is stored. - */ -GTSAM_EXPORT SfmData readBal(const std::string& filename); - -/** - * @brief This function writes a "Bundle Adjustment in the Large" (BAL) file from a - * SfmData structure - * @param filename The name of the BAL file to write - * @param data SfM structure where the data is stored - * @return true if the parsing was successful, false otherwise - */ -GTSAM_EXPORT bool writeBAL(const std::string& filename, SfmData &data); - -/** - * @brief This function writes a "Bundle Adjustment in the Large" (BAL) file from a - * SfmData structure and a value structure (measurements are the same as the SfM input data, - * while camera poses and values are read from Values) - * @param filename The name of the BAL file to write - * @param data SfM structure where the data is stored - * @param values structure where the graph values are stored (values can be either Pose3 or PinholeCamera for the - * cameras, and should be Point3 for the 3D points). Note that the current version - * assumes that the keys are "x1" for pose 1 (or "c1" for camera 1) and "l1" for landmark 1 - * @return true if the parsing was successful, false otherwise - */ -GTSAM_EXPORT bool writeBALfromValues(const std::string& filename, - const SfmData &data, Values& values); - -/** - * @brief This function converts an openGL camera pose to an GTSAM camera pose - * @param R rotation in openGL - * @param tx x component of the translation in openGL - * @param ty y component of the translation in openGL - * @param tz z component of the translation in openGL - * @return Pose3 in GTSAM format - */ -GTSAM_EXPORT Pose3 openGL2gtsam(const Rot3& R, double tx, double ty, double tz); - -/** - * @brief This function converts a GTSAM camera pose to an openGL camera pose - * @param R rotation in GTSAM - * @param tx x component of the translation in GTSAM - * @param ty y component of the translation in GTSAM - * @param tz z component of the translation in GTSAM - * @return Pose3 in openGL format - */ -GTSAM_EXPORT Pose3 gtsam2openGL(const Rot3& R, double tx, double ty, double tz); - -/** - * @brief This function converts a GTSAM camera pose to an openGL camera pose - * @param PoseGTSAM pose in GTSAM format - * @return Pose3 in openGL format - */ -GTSAM_EXPORT Pose3 gtsam2openGL(const Pose3& PoseGTSAM); - -/** - * @brief This function creates initial values for cameras from db - * @param SfmData - * @return Values - */ -GTSAM_EXPORT Values initialCamerasEstimate(const SfmData& db); - -/** - * @brief This function creates initial values for cameras and points from db - * @param SfmData - * @return Values - */ -GTSAM_EXPORT Values initialCamerasAndPointsEstimate(const SfmData& db); - // Wrapper-friendly versions of parseFactors and parseFactors using BetweenFactorPose2s = std::vector::shared_ptr>; GTSAM_EXPORT BetweenFactorPose2s @@ -504,17 +223,21 @@ parse3DFactors(const std::string &filename, size_t maxIndex = 0); using BinaryMeasurementsUnit3 = std::vector>; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 -inline boost::optional parseVertex(std::istream &is, - const std::string &tag) { + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +inline boost::optional GTSAM_DEPRECATED +parseVertex(std::istream& is, const std::string& tag) { return parseVertexPose(is, tag); } -GTSAM_EXPORT std::map parse3DPoses(const std::string &filename, - size_t maxIndex = 0); +GTSAM_EXPORT std::map GTSAM_DEPRECATED +parse3DPoses(const std::string& filename, size_t maxIndex = 0); -GTSAM_EXPORT std::map -parse3DLandmarks(const std::string &filename, size_t maxIndex = 0); +GTSAM_EXPORT std::map GTSAM_DEPRECATED +parse3DLandmarks(const std::string& filename, size_t maxIndex = 0); +GTSAM_EXPORT GraphAndValues GTSAM_DEPRECATED +load2D_robust(const std::string& filename, + const noiseModel::Base::shared_ptr& model, size_t maxIndex = 0); #endif } // namespace gtsam diff --git a/gtsam/slam/expressions.h b/gtsam/slam/expressions.h index c6aa02774..3b8ea86d3 100644 --- a/gtsam/slam/expressions.h +++ b/gtsam/slam/expressions.h @@ -138,4 +138,21 @@ Point2_ uncalibrate(const Expression& K, const Point2_& xy_hat) { return Point2_(K, &CALIBRATION::uncalibrate, xy_hat); } + +/// logmap +// TODO(dellaert): Should work but fails because of a type deduction conflict. +// template +// gtsam::Expression::TangentVector> logmap( +// const gtsam::Expression &x1, const gtsam::Expression &x2) { +// return gtsam::Expression::TangentVector>( +// x1, &T::logmap, x2); +// } + +template +gtsam::Expression::TangentVector> logmap( + const gtsam::Expression &x1, const gtsam::Expression &x2) { + return Expression::TangentVector>( + gtsam::traits::Logmap, between(x1, x2)); +} + } // \namespace gtsam diff --git a/gtsam/slam/lago.cpp b/gtsam/slam/lago.cpp index 70caa424f..f8b092f86 100644 --- a/gtsam/slam/lago.cpp +++ b/gtsam/slam/lago.cpp @@ -36,7 +36,7 @@ static const Matrix I = I_1x1; static const Matrix I3 = I_3x3; static const noiseModel::Diagonal::shared_ptr priorOrientationNoise = - noiseModel::Diagonal::Sigmas((Vector(1) << 0).finished()); + noiseModel::Diagonal::Sigmas(Vector1(0)); static const noiseModel::Diagonal::shared_ptr priorPose2Noise = noiseModel::Diagonal::Variances(Vector3(1e-6, 1e-6, 1e-8)); diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index 1c04fd14c..4e943253e 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -11,7 +11,7 @@ namespace gtsam { // ###### #include -template virtual class BetweenFactor : gtsam::NoiseModelFactor { @@ -21,9 +21,6 @@ virtual class BetweenFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -168,6 +165,10 @@ template virtual class PoseTranslationPrior : gtsam::NoiseModelFactor { PoseTranslationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Translation measured() const; + + // enabling serialization functionality + void serialize() const; }; typedef gtsam::PoseTranslationPrior PoseTranslationPrior2D; @@ -178,6 +179,7 @@ template virtual class PoseRotationPrior : gtsam::NoiseModelFactor { PoseRotationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Rotation measured() const; }; typedef gtsam::PoseRotationPrior PoseRotationPrior2D; @@ -188,73 +190,46 @@ virtual class EssentialMatrixFactor : gtsam::NoiseModelFactor { EssentialMatrixFactor(size_t key, const gtsam::Point2& pA, const gtsam::Point2& pB, const gtsam::noiseModel::Base* noiseModel); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixFactor& other, double tol) const; + Vector evaluateError(const gtsam::EssentialMatrix& E) const; +}; + +#include +virtual class EssentialMatrixConstraint : gtsam::NoiseModelFactor { + EssentialMatrixConstraint(size_t key1, size_t key2, const gtsam::EssentialMatrix &measuredE, + const gtsam::noiseModel::Base *model); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixConstraint& other, double tol) const; + Vector evaluateError(const gtsam::Pose3& p1, const gtsam::Pose3& p2) const; + const gtsam::EssentialMatrix& measured() const; }; #include -class SfmTrack { - SfmTrack(); - SfmTrack(const gtsam::Point3& pt); - const Point3& point3() const; - - double r; - double g; - double b; - - std::vector> measurements; - - size_t number_measurements() const; - pair measurement(size_t idx) const; - pair siftIndex(size_t idx) const; - void add_measurement(size_t idx, const gtsam::Point2& m); - - // enabling serialization functionality - void serialize() const; - - // enable pickling in python - void pickle() const; - - // enabling function to compare objects - bool equals(const gtsam::SfmTrack& expected, double tol) const; +enum NoiseFormat { + NoiseFormatG2O, + NoiseFormatTORO, + NoiseFormatGRAPH, + NoiseFormatCOV, + NoiseFormatAUTO }; -class SfmData { - SfmData(); - size_t number_cameras() const; - size_t number_tracks() const; - gtsam::PinholeCamera camera(size_t idx) const; - gtsam::SfmTrack track(size_t idx) const; - void add_track(const gtsam::SfmTrack& t); - void add_camera(const gtsam::SfmCamera& cam); - - // enabling serialization functionality - void serialize() const; - - // enable pickling in python - void pickle() const; - - // enabling function to compare objects - bool equals(const gtsam::SfmData& expected, double tol) const; +enum KernelFunctionType { + KernelFunctionTypeNONE, + KernelFunctionTypeHUBER, + KernelFunctionTypeTUKEY }; -gtsam::SfmData readBal(string filename); -bool writeBAL(string filename, gtsam::SfmData& data); -gtsam::Values initialCamerasEstimate(const gtsam::SfmData& db); -gtsam::Values initialCamerasAndPointsEstimate(const gtsam::SfmData& db); +pair load2D( + string filename, gtsam::noiseModel::Diagonal* model = nullptr, + size_t maxIndex = 0, bool addNoise = false, bool smart = true, + gtsam::NoiseFormat noiseFormat = gtsam::NoiseFormat::NoiseFormatAUTO, + gtsam::KernelFunctionType kernelFunctionType = + gtsam::KernelFunctionType::KernelFunctionTypeNONE); -pair load2D( - string filename, gtsam::noiseModel::Diagonal* model, int maxIndex, - bool addNoise, bool smart); -pair load2D( - string filename, gtsam::noiseModel::Diagonal* model, int maxIndex, - bool addNoise); -pair load2D( - string filename, gtsam::noiseModel::Diagonal* model, int maxIndex); -pair load2D( - string filename, gtsam::noiseModel::Diagonal* model); -pair load2D(string filename); -pair load2D_robust( - string filename, gtsam::noiseModel::Base* model, int maxIndex); void save2D(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config, gtsam::noiseModel::Diagonal* model, string filename); @@ -281,9 +256,10 @@ gtsam::BetweenFactorPose3s parse3DFactors(string filename); pair load3D(string filename); -pair readG2o(string filename); -pair readG2o(string filename, - bool is3D); +pair readG2o( + string filename, const bool is3D = false, + gtsam::KernelFunctionType kernelFunctionType = + gtsam::KernelFunctionType::KernelFunctionTypeNONE); void writeG2o(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& estimate, string filename); @@ -314,6 +290,8 @@ virtual class KarcherMeanFactor : gtsam::NonlinearFactor { KarcherMeanFactor(const gtsam::KeyVector& keys); }; +gtsam::Rot3 FindKarcherMean(const gtsam::Rot3Vector& rotations); + #include gtsam::noiseModel::Isotropic* ConvertNoiseModel(gtsam::noiseModel::Base* model, size_t d); @@ -334,5 +312,11 @@ virtual class FrobeniusBetweenFactor : gtsam::NoiseModelFactor { Vector evaluateError(const T& R1, const T& R2); }; - + +#include +namespace lago { + gtsam::Values initialize(const gtsam::NonlinearFactorGraph& graph, bool useOdometricPath = true); + gtsam::Values initialize(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& initialGuess); +} + } // namespace gtsam diff --git a/gtsam/slam/tests/PinholeFactor.h b/gtsam/slam/tests/PinholeFactor.h new file mode 100644 index 000000000..35500ca35 --- /dev/null +++ b/gtsam/slam/tests/PinholeFactor.h @@ -0,0 +1,52 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file PinholeFactor.h + * @brief helper class for tests + * @author Frank Dellaert + * @date February 2022 + */ + +#pragma once + +namespace gtsam { +template +struct traits; +} + +#include +#include +#include +#include + +namespace gtsam { + +class PinholeFactor : public SmartFactorBase > { + public: + typedef SmartFactorBase > Base; + PinholeFactor() {} + PinholeFactor(const SharedNoiseModel& sharedNoiseModel, + boost::optional body_P_sensor = boost::none, + size_t expectedNumberCameras = 10) + : Base(sharedNoiseModel, body_P_sensor, expectedNumberCameras) {} + double error(const Values& values) const override { return 0.0; } + boost::shared_ptr linearize( + const Values& values) const override { + return boost::shared_ptr(new JacobianFactor()); + } +}; + +/// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/slam/tests/smartFactorScenarios.h b/gtsam/slam/tests/smartFactorScenarios.h index 4abc59305..eff942799 100644 --- a/gtsam/slam/tests/smartFactorScenarios.h +++ b/gtsam/slam/tests/smartFactorScenarios.h @@ -17,15 +17,19 @@ */ #pragma once -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + +#include "../SmartProjectionRigFactor.h" using namespace std; using namespace gtsam; +namespace { // three landmarks ~5 meters infront of camera Point3 landmark1(5, 0.5, 1.2); Point3 landmark2(5, -0.5, 1.2); @@ -43,103 +47,127 @@ Pose3 pose_above = level_pose * Pose3(Rot3(), Point3(0, -1, 0)); // Create a noise unit2 for the pixel error static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); -static double fov = 60; // degrees +static double fov = 60; // degrees static size_t w = 640, h = 480; +} /* ************************************************************************* */ // default Cal3_S2 cameras namespace vanilla { typedef PinholeCamera Camera; typedef SmartProjectionFactor SmartFactor; -static Cal3_S2 K(fov, w, h); -static Cal3_S2 K2(1500, 1200, 0, w, h); -Camera level_camera(level_pose, K2); -Camera level_camera_right(pose_right, K2); -Point2 level_uv = level_camera.project(landmark1); -Point2 level_uv_right = level_camera_right.project(landmark1); -Camera cam1(level_pose, K2); -Camera cam2(pose_right, K2); -Camera cam3(pose_above, K2); +static const Cal3_S2 K(fov, w, h); +static const Cal3_S2 K2(1500, 1200, 0, w, h); +static const Camera level_camera(level_pose, K2); +static const Camera level_camera_right(pose_right, K2); +static const Point2 level_uv = level_camera.project(landmark1); +static const Point2 level_uv_right = level_camera_right.project(landmark1); +static const Camera cam1(level_pose, K2); +static const Camera cam2(pose_right, K2); +static const Camera cam3(pose_above, K2); typedef GeneralSFMFactor SFMFactor; -SmartProjectionParams params; -} +static const SmartProjectionParams params; +} // namespace vanilla /* ************************************************************************* */ // default Cal3_S2 poses namespace vanillaPose { typedef PinholePose Camera; +typedef CameraSet Cameras; typedef SmartProjectionPoseFactor SmartFactor; -static Cal3_S2::shared_ptr sharedK(new Cal3_S2(fov, w, h)); -Camera level_camera(level_pose, sharedK); -Camera level_camera_right(pose_right, sharedK); -Camera cam1(level_pose, sharedK); -Camera cam2(pose_right, sharedK); -Camera cam3(pose_above, sharedK); -} +typedef SmartProjectionRigFactor SmartRigFactor; +static const Cal3_S2::shared_ptr sharedK(new Cal3_S2(fov, w, h)); +static const Camera level_camera(level_pose, sharedK); +static const Camera level_camera_right(pose_right, sharedK); +static const Camera cam1(level_pose, sharedK); +static const Camera cam2(pose_right, sharedK); +static const Camera cam3(pose_above, sharedK); +} // namespace vanillaPose /* ************************************************************************* */ // default Cal3_S2 poses namespace vanillaPose2 { typedef PinholePose Camera; +typedef CameraSet Cameras; typedef SmartProjectionPoseFactor SmartFactor; -static Cal3_S2::shared_ptr sharedK2(new Cal3_S2(1500, 1200, 0, 640, 480)); -Camera level_camera(level_pose, sharedK2); -Camera level_camera_right(pose_right, sharedK2); -Camera cam1(level_pose, sharedK2); -Camera cam2(pose_right, sharedK2); -Camera cam3(pose_above, sharedK2); -} +typedef SmartProjectionRigFactor SmartRigFactor; +static const Cal3_S2::shared_ptr sharedK2(new Cal3_S2(1500, 1200, 0, 640, 480)); +static const Camera level_camera(level_pose, sharedK2); +static const Camera level_camera_right(pose_right, sharedK2); +static const Camera cam1(level_pose, sharedK2); +static const Camera cam2(pose_right, sharedK2); +static const Camera cam3(pose_above, sharedK2); +} // namespace vanillaPose2 /* *************************************************************************/ // Cal3Bundler cameras namespace bundler { typedef PinholeCamera Camera; +typedef CameraSet Cameras; typedef SmartProjectionFactor SmartFactor; -static Cal3Bundler K(500, 1e-3, 1e-3, 0, 0); -Camera level_camera(level_pose, K); -Camera level_camera_right(pose_right, K); -Point2 level_uv = level_camera.project(landmark1); -Point2 level_uv_right = level_camera_right.project(landmark1); -Pose3 pose1 = level_pose; -Camera cam1(level_pose, K); -Camera cam2(pose_right, K); -Camera cam3(pose_above, K); +static const Cal3Bundler K(500, 1e-3, 1e-3, 0, 0); +static const Camera level_camera(level_pose, K); +static const Camera level_camera_right(pose_right, K); +static const Point2 level_uv = level_camera.project(landmark1); +static const Point2 level_uv_right = level_camera_right.project(landmark1); +static const Pose3 pose1 = level_pose; +static const Camera cam1(level_pose, K); +static const Camera cam2(pose_right, K); +static const Camera cam3(pose_above, K); typedef GeneralSFMFactor SFMFactor; -} +} // namespace bundler + /* *************************************************************************/ // Cal3Bundler poses namespace bundlerPose { typedef PinholePose Camera; +typedef CameraSet Cameras; typedef SmartProjectionPoseFactor SmartFactor; -static boost::shared_ptr sharedBundlerK( - new Cal3Bundler(500, 1e-3, 1e-3, 1000, 2000)); -Camera level_camera(level_pose, sharedBundlerK); -Camera level_camera_right(pose_right, sharedBundlerK); -Camera cam1(level_pose, sharedBundlerK); -Camera cam2(pose_right, sharedBundlerK); -Camera cam3(pose_above, sharedBundlerK); -} +typedef SmartProjectionRigFactor SmartRigFactor; +static const boost::shared_ptr sharedBundlerK(new Cal3Bundler(500, 1e-3, + 1e-3, 1000, + 2000)); +static const Camera level_camera(level_pose, sharedBundlerK); +static const Camera level_camera_right(pose_right, sharedBundlerK); +static const Camera cam1(level_pose, sharedBundlerK); +static const Camera cam2(pose_right, sharedBundlerK); +static const Camera cam3(pose_above, sharedBundlerK); +} // namespace bundlerPose + +/* ************************************************************************* */ +// sphericalCamera +namespace sphericalCamera { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionRigFactor SmartFactorP; +static const EmptyCal::shared_ptr emptyK(new EmptyCal()); +static const Camera level_camera(level_pose); +static const Camera level_camera_right(pose_right); +static const Camera cam1(level_pose); +static const Camera cam2(pose_right); +static const Camera cam3(pose_above); +} // namespace sphericalCamera /* *************************************************************************/ -template +template CAMERA perturbCameraPose(const CAMERA& camera) { - Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), - Point3(0.5, 0.1, 0.3)); + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.5, 0.1, 0.3)); Pose3 cameraPose = camera.pose(); Pose3 perturbedCameraPose = cameraPose.compose(noise_pose); return CAMERA(perturbedCameraPose, camera.calibration()); } -template -void projectToMultipleCameras(const CAMERA& cam1, const CAMERA& cam2, - const CAMERA& cam3, Point3 landmark, typename CAMERA::MeasurementVector& measurements_cam) { - Point2 cam1_uv1 = cam1.project(landmark); - Point2 cam2_uv1 = cam2.project(landmark); - Point2 cam3_uv1 = cam3.project(landmark); +template +void projectToMultipleCameras( + const CAMERA& cam1, const CAMERA& cam2, const CAMERA& cam3, Point3 landmark, + typename CAMERA::MeasurementVector& measurements_cam) { + typename CAMERA::Measurement cam1_uv1 = cam1.project(landmark); + typename CAMERA::Measurement cam2_uv1 = cam2.project(landmark); + typename CAMERA::Measurement cam3_uv1 = cam3.project(landmark); measurements_cam.push_back(cam1_uv1); measurements_cam.push_back(cam2_uv1); measurements_cam.push_back(cam3_uv1); } /* ************************************************************************* */ - diff --git a/gtsam/slam/tests/testDataset.cpp b/gtsam/slam/tests/testDataset.cpp index aad9124c5..be638d51a 100644 --- a/gtsam/slam/tests/testDataset.cpp +++ b/gtsam/slam/tests/testDataset.cpp @@ -151,27 +151,6 @@ TEST(dataSet, load2DVictoriaPark) { EXPECT_LONGS_EQUAL(L(5), graph->at(4)->keys()[1]); } -/* ************************************************************************* */ -TEST( dataSet, Balbianello) -{ - ///< The structure where we will save the SfM data - const string filename = findExampleDataFile("Balbianello"); - SfmData mydata; - CHECK(readBundler(filename, mydata)); - - // Check number of things - EXPECT_LONGS_EQUAL(5,mydata.number_cameras()); - EXPECT_LONGS_EQUAL(544,mydata.number_tracks()); - const SfmTrack& track0 = mydata.tracks[0]; - EXPECT_LONGS_EQUAL(3,track0.number_measurements()); - - // Check projection of a given point - EXPECT_LONGS_EQUAL(0,track0.measurements[0].first); - const SfmCamera& camera0 = mydata.cameras[0]; - Point2 expected = camera0.project(track0.p), actual = track0.measurements[0].second; - EXPECT(assert_equal(expected,actual,1)); -} - /* ************************************************************************* */ TEST(dataSet, readG2o3D) { const string g2oFile = findExampleDataFile("pose3example"); @@ -461,160 +440,6 @@ TEST( dataSet, writeG2o3DNonDiagonalNoise) EXPECT(assert_equal(*expectedGraph,*actualGraph,1e-4)); } -/* ************************************************************************* */ -TEST( dataSet, readBAL_Dubrovnik) -{ - ///< The structure where we will save the SfM data - const string filename = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData mydata; - CHECK(readBAL(filename, mydata)); - - // Check number of things - EXPECT_LONGS_EQUAL(3,mydata.number_cameras()); - EXPECT_LONGS_EQUAL(7,mydata.number_tracks()); - const SfmTrack& track0 = mydata.tracks[0]; - EXPECT_LONGS_EQUAL(3,track0.number_measurements()); - - // Check projection of a given point - EXPECT_LONGS_EQUAL(0,track0.measurements[0].first); - const SfmCamera& camera0 = mydata.cameras[0]; - Point2 expected = camera0.project(track0.p), actual = track0.measurements[0].second; - EXPECT(assert_equal(expected,actual,12)); -} - -/* ************************************************************************* */ -TEST( dataSet, openGL2gtsam) -{ - Vector3 rotVec(0.2, 0.7, 1.1); - Rot3 R = Rot3::Expmap(rotVec); - Point3 t = Point3(0.0,0.0,0.0); - Pose3 poseGTSAM = Pose3(R,t); - - Pose3 expected = openGL2gtsam(R, t.x(), t.y(), t.z()); - - Point3 r1 = R.r1(), r2 = R.r2(), r3 = R.r3(); //columns! - Rot3 cRw( - r1.x(), r2.x(), r3.x(), - -r1.y(), -r2.y(), -r3.y(), - -r1.z(), -r2.z(), -r3.z()); - Rot3 wRc = cRw.inverse(); - Pose3 actual = Pose3(wRc,t); - - EXPECT(assert_equal(expected,actual)); -} - -/* ************************************************************************* */ -TEST( dataSet, gtsam2openGL) -{ - Vector3 rotVec(0.2, 0.7, 1.1); - Rot3 R = Rot3::Expmap(rotVec); - Point3 t = Point3(1.0,20.0,10.0); - Pose3 actual = Pose3(R,t); - Pose3 poseGTSAM = openGL2gtsam(R, t.x(), t.y(), t.z()); - - Pose3 expected = gtsam2openGL(poseGTSAM); - EXPECT(assert_equal(expected,actual)); -} - -/* ************************************************************************* */ -TEST( dataSet, writeBAL_Dubrovnik) -{ - ///< Read a file using the unit tested readBAL - const string filenameToRead = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData readData; - readBAL(filenameToRead, readData); - - // Write readData to file filenameToWrite - const string filenameToWrite = createRewrittenFileName(filenameToRead); - CHECK(writeBAL(filenameToWrite, readData)); - - // Read what we wrote - SfmData writtenData; - CHECK(readBAL(filenameToWrite, writtenData)); - - // Check that what we read is the same as what we wrote - EXPECT_LONGS_EQUAL(readData.number_cameras(),writtenData.number_cameras()); - EXPECT_LONGS_EQUAL(readData.number_tracks(),writtenData.number_tracks()); - - for (size_t i = 0; i < readData.number_cameras(); i++){ - PinholeCamera expectedCamera = writtenData.cameras[i]; - PinholeCamera actualCamera = readData.cameras[i]; - EXPECT(assert_equal(expectedCamera,actualCamera)); - } - - for (size_t j = 0; j < readData.number_tracks(); j++){ - // check point - SfmTrack expectedTrack = writtenData.tracks[j]; - SfmTrack actualTrack = readData.tracks[j]; - Point3 expectedPoint = expectedTrack.p; - Point3 actualPoint = actualTrack.p; - EXPECT(assert_equal(expectedPoint,actualPoint)); - - // check rgb - Point3 expectedRGB = Point3( expectedTrack.r, expectedTrack.g, expectedTrack.b ); - Point3 actualRGB = Point3( actualTrack.r, actualTrack.g, actualTrack.b); - EXPECT(assert_equal(expectedRGB,actualRGB)); - - // check measurements - for (size_t k = 0; k < actualTrack.number_measurements(); k++){ - EXPECT_LONGS_EQUAL(expectedTrack.measurements[k].first,actualTrack.measurements[k].first); - EXPECT(assert_equal(expectedTrack.measurements[k].second,actualTrack.measurements[k].second)); - } - } -} - - -/* ************************************************************************* */ -TEST( dataSet, writeBALfromValues_Dubrovnik){ - - ///< Read a file using the unit tested readBAL - const string filenameToRead = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData readData; - readBAL(filenameToRead, readData); - - Pose3 poseChange = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), Point3(0.3,0.1,0.3)); - - Values value; - for(size_t i=0; i < readData.number_cameras(); i++){ // for each camera - Pose3 pose = poseChange.compose(readData.cameras[i].pose()); - value.insert(X(i), pose); - } - for(size_t j=0; j < readData.number_tracks(); j++){ // for each point - Point3 point = poseChange.transformFrom( readData.tracks[j].p ); - value.insert(P(j), point); - } - - // Write values and readData to a file - const string filenameToWrite = createRewrittenFileName(filenameToRead); - writeBALfromValues(filenameToWrite, readData, value); - - // Read the file we wrote - SfmData writtenData; - readBAL(filenameToWrite, writtenData); - - // Check that the reprojection errors are the same and the poses are correct - // Check number of things - EXPECT_LONGS_EQUAL(3,writtenData.number_cameras()); - EXPECT_LONGS_EQUAL(7,writtenData.number_tracks()); - const SfmTrack& track0 = writtenData.tracks[0]; - EXPECT_LONGS_EQUAL(3,track0.number_measurements()); - - // Check projection of a given point - EXPECT_LONGS_EQUAL(0,track0.measurements[0].first); - const SfmCamera& camera0 = writtenData.cameras[0]; - Point2 expected = camera0.project(track0.p), actual = track0.measurements[0].second; - EXPECT(assert_equal(expected,actual,12)); - - Pose3 expectedPose = camera0.pose(); - Pose3 actualPose = value.at(X(0)); - EXPECT(assert_equal(expectedPose,actualPose, 1e-7)); - - Point3 expectedPoint = track0.p; - Point3 actualPoint = value.at(P(0)); - EXPECT(assert_equal(expectedPoint,actualPoint, 1e-6)); -} - - /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp index 080239b35..2faac24d1 100644 --- a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp +++ b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file testEssentialMatrixConstraint.cpp + * @file TestEssentialMatrixConstraint.cpp * @brief Unit tests for EssentialMatrixConstraint Class * @author Frank Dellaert * @author Pablo Alcantarilla diff --git a/gtsam/slam/tests/testEssentialMatrixFactor.cpp b/gtsam/slam/tests/testEssentialMatrixFactor.cpp index 03775a70f..ef22bad2a 100644 --- a/gtsam/slam/tests/testEssentialMatrixFactor.cpp +++ b/gtsam/slam/tests/testEssentialMatrixFactor.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include using namespace std::placeholders; @@ -34,8 +35,7 @@ gtsam::Rot3 cRb = gtsam::Rot3(bX, bZ, -bY).inverse(); namespace example1 { const string filename = findExampleDataFile("18pointExample1.txt"); -SfmData data; -bool readOK = readBAL(filename, data); +SfmData data = SfmData::FromBalFile(filename); Rot3 c1Rc2 = data.cameras[1].pose().rotation(); Point3 c1Tc2 = data.cameras[1].pose().translation(); // TODO: maybe default value not good; assert with 0th @@ -53,8 +53,6 @@ Vector vB(size_t i) { return EssentialMatrix::Homogeneous(pB(i)); } //************************************************************************* TEST(EssentialMatrixFactor, testData) { - CHECK(readOK); - // Check E matrix Matrix expected(3, 3); expected << 0, 0, 0, 0, 0, -0.1, 0.1, 0, 0; @@ -538,8 +536,7 @@ TEST(EssentialMatrixFactor4, minimizationWithStrongCal3BundlerPrior) { namespace example2 { const string filename = findExampleDataFile("5pointExample2.txt"); -SfmData data; -bool readOK = readBAL(filename, data); +SfmData data = SfmData::FromBalFile(filename); Rot3 aRb = data.cameras[1].pose().rotation(); Point3 aTb = data.cameras[1].pose().translation(); EssentialMatrix trueE(aRb, Unit3(aTb)); @@ -632,14 +629,14 @@ TEST(EssentialMatrixFactor2, extraMinimization) { // We start with a factor graph and add constraints to it // Noise sigma is 1, assuming pixel measurements NonlinearFactorGraph graph; - for (size_t i = 0; i < data.number_tracks(); i++) + for (size_t i = 0; i < data.numberTracks(); i++) graph.emplace_shared(100, i, pA(i), pB(i), model2, K); // Check error at ground truth Values truth; truth.insert(100, trueE); - for (size_t i = 0; i < data.number_tracks(); i++) { + for (size_t i = 0; i < data.numberTracks(); i++) { Point3 P1 = data.tracks[i].p; truth.insert(i, double(baseline / P1.z())); } @@ -654,7 +651,7 @@ TEST(EssentialMatrixFactor2, extraMinimization) { // Check result EssentialMatrix actual = result.at(100); EXPECT(assert_equal(trueE, actual, 1e-1)); - for (size_t i = 0; i < data.number_tracks(); i++) + for (size_t i = 0; i < data.numberTracks(); i++) EXPECT_DOUBLES_EQUAL(truth.at(i), result.at(i), 1e-1); // Check error at result diff --git a/gtsam/slam/tests/testSerializationDataset.cpp b/gtsam/slam/tests/testSerializationDataset.cpp index 6ef82f07f..dcac3d47e 100644 --- a/gtsam/slam/tests/testSerializationDataset.cpp +++ b/gtsam/slam/tests/testSerializationDataset.cpp @@ -16,6 +16,7 @@ * @date Jan 1, 2021 */ +#include #include #include @@ -29,8 +30,7 @@ using namespace gtsam::serializationTestHelpers; TEST(dataSet, sfmDataSerialization) { // Test the serialization of SfmData const string filename = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData mydata; - CHECK(readBAL(filename, mydata)); + SfmData mydata = SfmData::FromBalFile(filename); // round-trip equality check on serialization and subsequent deserialization EXPECT(equalsObj(mydata)); @@ -42,8 +42,7 @@ TEST(dataSet, sfmDataSerialization) { TEST(dataSet, sfmTrackSerialization) { // Test the serialization of SfmTrack const string filename = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData mydata; - CHECK(readBAL(filename, mydata)); + SfmData mydata = SfmData::FromBalFile(filename); SfmTrack track = mydata.track(0); diff --git a/gtsam/slam/tests/testSerializationInSlam.cpp b/gtsam/slam/tests/testSerializationInSlam.cpp new file mode 100644 index 000000000..6aec8ecb0 --- /dev/null +++ b/gtsam/slam/tests/testSerializationInSlam.cpp @@ -0,0 +1,105 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSerializationSlam.cpp + * @brief all serialization tests in this directory + * @author Frank Dellaert + * @date February 2022 + */ + +#include "smartFactorScenarios.h" +#include "PinholeFactor.h" + +#include +#include +#include + +#include + +#include +#include + +namespace { +static const double rankTol = 1.0; +static const double sigma = 0.1; +static SharedIsotropic model(noiseModel::Isotropic::Sigma(2, sigma)); +} // namespace + +/* ************************************************************************* */ +BOOST_CLASS_EXPORT_GUID(noiseModel::Constrained, "gtsam_noiseModel_Constrained") +BOOST_CLASS_EXPORT_GUID(noiseModel::Diagonal, "gtsam_noiseModel_Diagonal") +BOOST_CLASS_EXPORT_GUID(noiseModel::Gaussian, "gtsam_noiseModel_Gaussian") +BOOST_CLASS_EXPORT_GUID(noiseModel::Unit, "gtsam_noiseModel_Unit") +BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic") +BOOST_CLASS_EXPORT_GUID(SharedNoiseModel, "gtsam_SharedNoiseModel") +BOOST_CLASS_EXPORT_GUID(SharedDiagonal, "gtsam_SharedDiagonal") + +/* ************************************************************************* */ +TEST(SmartFactorBase, serialize) { + using namespace serializationTestHelpers; + PinholeFactor factor(unit2); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ************************************************************************* */ +TEST(SerializationSlam, SmartProjectionFactor) { + using namespace vanilla; + using namespace serializationTestHelpers; + SmartFactor factor(unit2); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ************************************************************************* */ +TEST(SerializationSlam, SmartProjectionPoseFactor) { + using namespace vanillaPose; + using namespace serializationTestHelpers; + SmartProjectionParams params; + params.setRankTolerance(rankTol); + SmartFactor factor(model, sharedK, params); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +TEST(SerializationSlam, SmartProjectionPoseFactor2) { + using namespace vanillaPose; + using namespace serializationTestHelpers; + SmartProjectionParams params; + params.setRankTolerance(rankTol); + Pose3 bts; + SmartFactor factor(model, sharedK, bts, params); + + // insert some measurments + KeyVector key_view; + Point2Vector meas_view; + key_view.push_back(Symbol('x', 1)); + meas_view.push_back(Point2(10, 10)); + factor.add(meas_view, key_view); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/slam/tests/testSlamExpressions.cpp b/gtsam/slam/tests/testSlamExpressions.cpp index 294b821d3..b5298989f 100644 --- a/gtsam/slam/tests/testSlamExpressions.cpp +++ b/gtsam/slam/tests/testSlamExpressions.cpp @@ -58,6 +58,13 @@ TEST(SlamExpressions, unrotate) { const Point3_ q_ = unrotate(R_, p_); } +/* ************************************************************************* */ +TEST(SlamExpressions, logmap) { + Pose3_ T1_(0); + Pose3_ T2_(1); + const Vector6_ l = logmap(T1_, T2_); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/slam/tests/testSmartFactorBase.cpp b/gtsam/slam/tests/testSmartFactorBase.cpp index 951cbf8f4..544fd3264 100644 --- a/gtsam/slam/tests/testSmartFactorBase.cpp +++ b/gtsam/slam/tests/testSmartFactorBase.cpp @@ -16,47 +16,29 @@ * @date Feb 2015 */ -#include -#include -#include #include +#include +#include +#include +#include +#include -using namespace std; +#include "PinholeFactor.h" + +namespace { using namespace gtsam; - static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); static SharedNoiseModel unit3(noiseModel::Unit::Create(3)); +} // namespace -/* ************************************************************************* */ -#include -#include +using namespace std; namespace gtsam { -class PinholeFactor: public SmartFactorBase > { -public: - typedef SmartFactorBase > Base; - PinholeFactor() {} - PinholeFactor(const SharedNoiseModel& sharedNoiseModel, - boost::optional body_P_sensor = boost::none, - size_t expectedNumberCameras = 10) - : Base(sharedNoiseModel, body_P_sensor, expectedNumberCameras) {} - double error(const Values& values) const override { return 0.0; } - boost::shared_ptr linearize( - const Values& values) const override { - return boost::shared_ptr(new JacobianFactor()); - } -}; - -/// traits -template<> -struct traits : public Testable {}; -} - TEST(SmartFactorBase, Pinhole) { - PinholeFactor f= PinholeFactor(unit2); - f.add(Point2(0,0), 1); - f.add(Point2(0,0), 2); + PinholeFactor f = PinholeFactor(unit2); + f.add(Point2(0, 0), 1); + f.add(Point2(0, 0), 2); EXPECT_LONGS_EQUAL(2 * 2, f.dim()); } @@ -71,7 +53,7 @@ TEST(SmartFactorBase, PinholeWithSensor) { // Camera coordinates in world frame. Pose3 wTc = world_P_body * body_P_sensor; cameras.push_back(PinholeCamera(wTc)); - + // Simple point to project slightly off image center Point3 p(0, 0, 10); Point2 measurement = cameras[0].project(p); @@ -81,9 +63,10 @@ TEST(SmartFactorBase, PinholeWithSensor) { Matrix E; Vector error = f.unwhitenedError(cameras, p, Fs, E); - Vector expectedError = Vector::Zero(2); + Vector expectedError = Vector::Zero(2); Matrix29 expectedFs; - expectedFs << -0.001, -1.00001, 0, -0.1, 0, -0.01, 0, 0, 0, 1, 0, 0, 0, -0.1, 0, 0, 0, 0; + expectedFs << -0.001, -1.00001, 0, -0.1, 0, -0.01, 0, 0, 0, 1, 0, 0, 0, -0.1, + 0, 0, 0, 0; Matrix23 expectedE; expectedE << 0.1, 0, 0.01, 0, 0.1, 0; @@ -94,20 +77,13 @@ TEST(SmartFactorBase, PinholeWithSensor) { EXPECT(assert_equal(expectedE, E)); } -/* ************************************************************************* */ -#include - -namespace gtsam { - -class StereoFactor: public SmartFactorBase { -public: +class StereoFactor : public SmartFactorBase { + public: typedef SmartFactorBase Base; StereoFactor() {} - StereoFactor(const SharedNoiseModel& sharedNoiseModel): Base(sharedNoiseModel) { - } - double error(const Values& values) const override { - return 0.0; - } + StereoFactor(const SharedNoiseModel& sharedNoiseModel) + : Base(sharedNoiseModel) {} + double error(const Values& values) const override { return 0.0; } boost::shared_ptr linearize( const Values& values) const override { return boost::shared_ptr(new JacobianFactor()); @@ -115,9 +91,8 @@ public: }; /// traits -template<> +template <> struct traits : public Testable {}; -} TEST(SmartFactorBase, Stereo) { StereoFactor f(unit3); @@ -125,24 +100,7 @@ TEST(SmartFactorBase, Stereo) { f.add(StereoPoint2(), 2); EXPECT_LONGS_EQUAL(2 * 3, f.dim()); } - -/* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); - -TEST(SmartFactorBase, serialize) { - using namespace gtsam::serializationTestHelpers; - PinholeFactor factor(unit2); - - EXPECT(equalsObj(factor)); - EXPECT(equalsXML(factor)); - EXPECT(equalsBinary(factor)); -} +} // namespace gtsam /* ************************************************************************* */ int main() { @@ -150,4 +108,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/slam/tests/testSmartProjectionFactor.cpp b/gtsam/slam/tests/testSmartProjectionFactor.cpp index 1fd06cc9f..ecdb5287f 100644 --- a/gtsam/slam/tests/testSmartProjectionFactor.cpp +++ b/gtsam/slam/tests/testSmartProjectionFactor.cpp @@ -22,18 +22,19 @@ #include "smartFactorScenarios.h" #include #include -#include #include #include #include using namespace boost::assign; +namespace { static const bool isDebugTest = false; static const Symbol l1('l', 1), l2('l', 2), l3('l', 3); static const Key c1 = 1, c2 = 2, c3 = 3; static const Point2 measurement1(323.0, 240.0); static const double rankTol = 1.0; +} template PinholeCamera perturbCameraPoseAndCalibration( @@ -70,8 +71,9 @@ TEST(SmartProjectionFactor, Constructor) { /* ************************************************************************* */ TEST(SmartProjectionFactor, Constructor2) { using namespace vanilla; - params.setRankTolerance(rankTol); - SmartFactor factor1(unit2, params); + auto myParams = params; + myParams.setRankTolerance(rankTol); + SmartFactor factor1(unit2, myParams); } /* ************************************************************************* */ @@ -84,8 +86,9 @@ TEST(SmartProjectionFactor, Constructor3) { /* ************************************************************************* */ TEST(SmartProjectionFactor, Constructor4) { using namespace vanilla; - params.setRankTolerance(rankTol); - SmartFactor factor1(unit2, params); + auto myParams = params; + myParams.setRankTolerance(rankTol); + SmartFactor factor1(unit2, myParams); factor1.add(measurement1, c1); } @@ -810,25 +813,6 @@ TEST(SmartProjectionFactor, implicitJacobianFactor ) { EXPECT(assert_equal(yActual, yExpected, 1e-7)); } -/* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); - -TEST(SmartProjectionFactor, serialize) { - using namespace vanilla; - using namespace gtsam::serializationTestHelpers; - SmartFactor factor(unit2); - - EXPECT(equalsObj(factor)); - EXPECT(equalsXML(factor)); - EXPECT(equalsBinary(factor)); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/slam/tests/testSmartProjectionPoseFactor.cpp b/gtsam/slam/tests/testSmartProjectionPoseFactor.cpp index 997c33846..5c38233c1 100644 --- a/gtsam/slam/tests/testSmartProjectionPoseFactor.cpp +++ b/gtsam/slam/tests/testSmartProjectionPoseFactor.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,7 @@ using namespace boost::assign; using namespace std::placeholders; +namespace { static const double rankTol = 1.0; // Create a noise model for the pixel error static const double sigma = 0.1; @@ -51,6 +51,7 @@ static Point2 measurement1(323.0, 240.0); LevenbergMarquardtParams lmParams; // Make more verbose like so (in tests): // lmParams.verbosityLM = LevenbergMarquardtParams::SUMMARY; +} /* ************************************************************************* */ TEST( SmartProjectionPoseFactor, Constructor) { @@ -1332,47 +1333,6 @@ TEST( SmartProjectionPoseFactor, Cal3BundlerRotationOnly ) { values.at(x3))); } -/* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); - -TEST(SmartProjectionPoseFactor, serialize) { - using namespace vanillaPose; - using namespace gtsam::serializationTestHelpers; - SmartProjectionParams params; - params.setRankTolerance(rankTol); - SmartFactor factor(model, sharedK, params); - - EXPECT(equalsObj(factor)); - EXPECT(equalsXML(factor)); - EXPECT(equalsBinary(factor)); -} - -TEST(SmartProjectionPoseFactor, serialize2) { - using namespace vanillaPose; - using namespace gtsam::serializationTestHelpers; - SmartProjectionParams params; - params.setRankTolerance(rankTol); - Pose3 bts; - SmartFactor factor(model, sharedK, bts, params); - - // insert some measurments - KeyVector key_view; - Point2Vector meas_view; - key_view.push_back(Symbol('x', 1)); - meas_view.push_back(Point2(10, 10)); - factor.add(meas_view, key_view); - - EXPECT(equalsObj(factor)); - EXPECT(equalsXML(factor)); - EXPECT(equalsBinary(factor)); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/slam/tests/testSmartProjectionRigFactor.cpp b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp new file mode 100644 index 000000000..b4876b27e --- /dev/null +++ b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp @@ -0,0 +1,1603 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSmartProjectionRigFactor.cpp + * @brief Unit tests for SmartProjectionRigFactor Class + * @author Chris Beall + * @author Luca Carlone + * @author Zsolt Kira + * @author Frank Dellaert + * @date August 2021 + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "smartFactorScenarios.h" +#define DISABLE_TIMING + +using namespace boost::assign; +using namespace std::placeholders; + +static const double rankTol = 1.0; +// Create a noise model for the pixel error +static const double sigma = 0.1; +static SharedIsotropic model(noiseModel::Isotropic::Sigma(2, sigma)); + +// Convenience for named keys +using symbol_shorthand::L; +using symbol_shorthand::X; + +// tests data +static Symbol x1('X', 1); +static Symbol x2('X', 2); +static Symbol x3('X', 3); + +Key cameraId1 = 0; // first camera +Key cameraId2 = 1; +Key cameraId3 = 2; + +static Point2 measurement1(323.0, 240.0); + +LevenbergMarquardtParams lmParams; + +/* ************************************************************************* */ +// default Cal3_S2 poses with rolling shutter effect +namespace vanillaRig { +using namespace vanillaPose; +SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors +} // namespace vanillaRig + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, Constructor) { + using namespace vanillaRig; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + SmartRigFactor::shared_ptr factor1( + new SmartRigFactor(model, cameraRig, params)); +} + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, Constructor2) { + using namespace vanillaRig; + boost::shared_ptr cameraRig(new Cameras()); + SmartProjectionParams params2( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params2.setRankTolerance(rankTol); + SmartRigFactor factor1(model, cameraRig, params2); +} + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, Constructor3) { + using namespace vanillaRig; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + SmartRigFactor::shared_ptr factor1( + new SmartRigFactor(model, cameraRig, params)); + factor1->add(measurement1, x1, cameraId1); +} + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, Constructor4) { + using namespace vanillaRig; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + SmartProjectionParams params2( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params2.setRankTolerance(rankTol); + SmartRigFactor factor1(model, cameraRig, params2); + factor1.add(measurement1, x1, cameraId1); +} + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, Equals) { + using namespace vanillaRig; + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartRigFactor::shared_ptr factor1( + new SmartRigFactor(model, cameraRig, params)); + factor1->add(measurement1, x1, cameraId1); + + SmartRigFactor::shared_ptr factor2( + new SmartRigFactor(model, cameraRig, params)); + factor2->add(measurement1, x1, cameraId1); + + CHECK(assert_equal(*factor1, *factor2)); + + SmartRigFactor::shared_ptr factor3( + new SmartRigFactor(model, cameraRig, params)); + factor3->add(measurement1, x1); // now use default camera ID + + CHECK(assert_equal(*factor1, *factor3)); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, noiseless) { + using namespace vanillaRig; + + // Project two landmarks into two cameras + Point2 level_uv = level_camera.project(landmark1); + Point2 level_uv_right = level_camera_right.project(landmark1); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartRigFactor factor(model, cameraRig, params); + factor.add(level_uv, x1); // both taken from the same camera + factor.add(level_uv_right, x2); + + Values values; // it's a pose factor, hence these are poses + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + + double actualError = factor.error(values); + double expectedError = 0.0; + EXPECT_DOUBLES_EQUAL(expectedError, actualError, 1e-7); + + SmartRigFactor::Cameras cameras = factor.cameras(values); + double actualError2 = factor.totalReprojectionError(cameras); + EXPECT_DOUBLES_EQUAL(expectedError, actualError2, 1e-7); + + // Calculate expected derivative for point (easiest to check) + std::function f = // + std::bind(&SmartRigFactor::whitenedError, factor, cameras, + std::placeholders::_1); + + // Calculate using computeEP + Matrix actualE; + factor.triangulateAndComputeE(actualE, values); + + // get point + boost::optional point = factor.point(); + CHECK(point); + + // calculate numerical derivative with triangulated point + Matrix expectedE = sigma * numericalDerivative11(f, *point); + EXPECT(assert_equal(expectedE, actualE, 1e-7)); + + // Calculate using reprojectionError + SmartRigFactor::Cameras::FBlocks F; + Matrix E; + Vector actualErrors = factor.unwhitenedError(cameras, *point, F, E); + EXPECT(assert_equal(expectedE, E, 1e-7)); + + EXPECT(assert_equal(Z_4x1, actualErrors, 1e-7)); + + // Calculate using computeJacobians + Vector b; + SmartRigFactor::FBlocks Fs; + factor.computeJacobians(Fs, E, b, cameras, *point); + double actualError3 = b.squaredNorm(); + EXPECT(assert_equal(expectedE, E, 1e-7)); + EXPECT_DOUBLES_EQUAL(expectedError, actualError3, 1e-6); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, noisy) { + using namespace vanillaRig; + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + // Project two landmarks into two cameras + Point2 pixelError(0.2, 0.2); + Point2 level_uv = level_camera.project(landmark1) + pixelError; + Point2 level_uv_right = level_camera_right.project(landmark1); + + Values values; + values.insert(x1, cam1.pose()); + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.5, 0.1, 0.3)); + values.insert(x2, pose_right.compose(noise_pose)); + + SmartRigFactor::shared_ptr factor( + new SmartRigFactor(model, cameraRig, params)); + factor->add(level_uv, x1, cameraId1); + factor->add(level_uv_right, x2, cameraId1); + + double actualError1 = factor->error(values); + + // create other factor by passing multiple measurements + SmartRigFactor::shared_ptr factor2( + new SmartRigFactor(model, cameraRig, params)); + + Point2Vector measurements; + measurements.push_back(level_uv); + measurements.push_back(level_uv_right); + + KeyVector views{x1, x2}; + FastVector cameraIds{0, 0}; + + factor2->add(measurements, views, cameraIds); + double actualError2 = factor2->error(values); + DOUBLES_EQUAL(actualError1, actualError2, 1e-7); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, smartFactorWithSensorBodyTransform) { + using namespace vanillaRig; + + // create arbitrary body_T_sensor (transforms from sensor to body) + Pose3 body_T_sensor = + Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), Point3(1, 1, 1)); + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(body_T_sensor, sharedK)); + + // These are the poses we want to estimate, from camera measurements + const Pose3 sensor_T_body = body_T_sensor.inverse(); + Pose3 wTb1 = cam1.pose() * sensor_T_body; + Pose3 wTb2 = cam2.pose() * sensor_T_body; + Pose3 wTb3 = cam3.pose() * sensor_T_body; + + // three landmarks ~5 meters infront of camera + Point3 landmark1(5, 0.5, 1.2), landmark2(5, -0.5, 1.2), landmark3(5, 0, 3.0); + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + // Create smart factors + KeyVector views{x1, x2, x3}; + FastVector cameraIds{0, 0, 0}; + + SmartProjectionParams params; + params.setRankTolerance(1.0); + params.setDegeneracyMode(ZERO_ON_DEGENERACY); + params.setEnableEPI(false); + + SmartRigFactor smartFactor1(model, cameraRig, params); + smartFactor1.add( + measurements_cam1, + views); // use default CameraIds since we have a single camera + + SmartRigFactor smartFactor2(model, cameraRig, params); + smartFactor2.add(measurements_cam2, views); + + SmartRigFactor smartFactor3(model, cameraRig, params); + smartFactor3.add(measurements_cam3, views); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + // Put all factors in factor graph, adding priors + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, wTb1, noisePrior); + graph.addPrior(x2, wTb2, noisePrior); + + // Check errors at ground truth poses + Values gtValues; + gtValues.insert(x1, wTb1); + gtValues.insert(x2, wTb2); + gtValues.insert(x3, wTb3); + double actualError = graph.error(gtValues); + double expectedError = 0.0; + DOUBLES_EQUAL(expectedError, actualError, 1e-7) + + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), Point3(0.1, 0.1, 0.1)); + Values values; + values.insert(x1, wTb1); + values.insert(x2, wTb2); + // initialize third pose with some noise, we expect it to move back to + // original pose3 + values.insert(x3, wTb3 * noise_pose); + + LevenbergMarquardtParams lmParams; + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(wTb3, result.at(x3))); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, smartFactorWithMultipleCameras) { + using namespace vanillaRig; + + // create arbitrary body_T_sensor (transforms from sensor to body) + Pose3 body_T_sensor1 = + Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), Point3(1, 1, 1)); + Pose3 body_T_sensor2 = + Pose3(Rot3::Ypr(-M_PI / 5, 0., -M_PI / 2), Point3(0, 0, 1)); + Pose3 body_T_sensor3 = Pose3::identity(); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(body_T_sensor1, sharedK)); + cameraRig->push_back(Camera(body_T_sensor2, sharedK)); + cameraRig->push_back(Camera(body_T_sensor3, sharedK)); + + // These are the poses we want to estimate, from camera measurements + const Pose3 sensor_T_body1 = body_T_sensor1.inverse(); + const Pose3 sensor_T_body2 = body_T_sensor2.inverse(); + const Pose3 sensor_T_body3 = body_T_sensor3.inverse(); + Pose3 wTb1 = cam1.pose() * sensor_T_body1; + Pose3 wTb2 = cam2.pose() * sensor_T_body2; + Pose3 wTb3 = cam3.pose() * sensor_T_body3; + + // three landmarks ~5 meters infront of camera + Point3 landmark1(5, 0.5, 1.2), landmark2(5, -0.5, 1.2), landmark3(5, 0, 3.0); + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + // Create smart factors + KeyVector views{x1, x2, x3}; + FastVector cameraIds{0, 1, 2}; + + SmartProjectionParams params; + params.setRankTolerance(1.0); + params.setDegeneracyMode(ZERO_ON_DEGENERACY); + params.setEnableEPI(false); + + SmartRigFactor smartFactor1(model, cameraRig, params); + smartFactor1.add(measurements_cam1, views, cameraIds); + + SmartRigFactor smartFactor2(model, cameraRig, params); + smartFactor2.add(measurements_cam2, views, cameraIds); + + SmartRigFactor smartFactor3(model, cameraRig, params); + smartFactor3.add(measurements_cam3, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + // Put all factors in factor graph, adding priors + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, wTb1, noisePrior); + graph.addPrior(x2, wTb2, noisePrior); + + // Check errors at ground truth poses + Values gtValues; + gtValues.insert(x1, wTb1); + gtValues.insert(x2, wTb2); + gtValues.insert(x3, wTb3); + double actualError = graph.error(gtValues); + double expectedError = 0.0; + DOUBLES_EQUAL(expectedError, actualError, 1e-7) + + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), Point3(0.1, 0.1, 0.1)); + Values values; + values.insert(x1, wTb1); + values.insert(x2, wTb2); + // initialize third pose with some noise, we expect it to move back to + // original pose3 + values.insert(x3, wTb3 * noise_pose); + + LevenbergMarquardtParams lmParams; + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(wTb3, result.at(x3))); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, 3poses_smart_projection_factor) { + using namespace vanillaPose2; + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK2)); + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + KeyVector views{x1, x2, x3}; + FastVector cameraIds{ + 0, 0, 0}; // 3 measurements from the same camera in the rig + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_cam3, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, cam1.pose(), noisePrior); + graph.addPrior(x2, cam2.pose(), noisePrior); + + Values groundTruth; + groundTruth.insert(x1, cam1.pose()); + groundTruth.insert(x2, cam2.pose()); + groundTruth.insert(x3, cam3.pose()); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT(assert_equal( + Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, -0.0313952598, + -0.000986635786, 0.0314107591, -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, Factors) { + using namespace vanillaRig; + + // Default cameras for simple derivatives + Rot3 R; + static Cal3_S2::shared_ptr sharedK(new Cal3_S2(100, 100, 0, 0, 0)); + Camera cam1(Pose3(R, Point3(0, 0, 0)), sharedK), + cam2(Pose3(R, Point3(1, 0, 0)), sharedK); + + // one landmarks 1m in front of camera + Point3 landmark1(0, 0, 10); + + Point2Vector measurements_cam1; + + // Project 2 landmarks into 2 cameras + measurements_cam1.push_back(cam1.project(landmark1)); + measurements_cam1.push_back(cam2.project(landmark1)); + + // Create smart factors + KeyVector views{x1, x2}; + FastVector cameraIds{0, 0}; + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartRigFactor::shared_ptr smartFactor1 = boost::make_shared( + model, cameraRig, params); + smartFactor1->add(measurements_cam1, + views); // we have a single camera so use default cameraIds + + SmartRigFactor::Cameras cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + // Make sure triangulation works + CHECK(smartFactor1->triangulateSafe(cameras)); + CHECK(!smartFactor1->isDegenerate()); + CHECK(!smartFactor1->isPointBehindCamera()); + boost::optional p = smartFactor1->point(); + CHECK(p); + EXPECT(assert_equal(landmark1, *p)); + + VectorValues zeroDelta; + Vector6 delta; + delta.setZero(); + zeroDelta.insert(x1, delta); + zeroDelta.insert(x2, delta); + + VectorValues perturbedDelta; + delta.setOnes(); + perturbedDelta.insert(x1, delta); + perturbedDelta.insert(x2, delta); + double expectedError = 2500; + + // After eliminating the point, A1 and A2 contain 2-rank information on + // cameras: + Matrix16 A1, A2; + A1 << -10, 0, 0, 0, 1, 0; + A2 << 10, 0, 1, 0, -1, 0; + A1 *= 10. / sigma; + A2 *= 10. / sigma; + Matrix expectedInformation; // filled below + { + // createHessianFactor + Matrix66 G11 = 0.5 * A1.transpose() * A1; + Matrix66 G12 = 0.5 * A1.transpose() * A2; + Matrix66 G22 = 0.5 * A2.transpose() * A2; + + Vector6 g1; + g1.setZero(); + Vector6 g2; + g2.setZero(); + + double f = 0; + + RegularHessianFactor<6> expected(x1, x2, G11, G12, g1, G22, g2, f); + expectedInformation = expected.information(); + + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + + boost::shared_ptr> actual = + smartFactor1->createHessianFactor(values, 0.0); + EXPECT(assert_equal(expectedInformation, actual->information(), 1e-6)); + EXPECT(assert_equal(expected, *actual, 1e-6)); + EXPECT_DOUBLES_EQUAL(0, actual->error(zeroDelta), 1e-6); + EXPECT_DOUBLES_EQUAL(expectedError, actual->error(perturbedDelta), 1e-6); + } +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, 3poses_iterative_smart_projection_factor) { + using namespace vanillaRig; + + KeyVector views{x1, x2, x3}; + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + // create smart factor + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0}; + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_cam3, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, cam1.pose(), noisePrior); + graph.addPrior(x2, cam2.pose(), noisePrior); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT(assert_equal(Pose3(Rot3(1.11022302e-16, -0.0314107591, 0.99950656, + -0.99950656, -0.0313952598, -0.000986635786, + 0.0314107591, -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-7)); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, landmarkDistance) { + using namespace vanillaRig; + + double excludeLandmarksFutherThanDist = 2; + + KeyVector views{x1, x2, x3}; + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + SmartProjectionParams params; + params.setRankTolerance(1.0); + params.setLinearizationMode(gtsam::HESSIAN); + params.setDegeneracyMode(gtsam::ZERO_ON_DEGENERACY); + params.setLandmarkDistanceThreshold(excludeLandmarksFutherThanDist); + params.setEnableEPI(false); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0}; + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_cam3, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, cam1.pose(), noisePrior); + graph.addPrior(x2, cam2.pose(), noisePrior); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + values.insert(x3, pose_above * noise_pose); + + // All factors are disabled and pose should remain where it is + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(values.at(x3), result.at(x3))); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, dynamicOutlierRejection) { + using namespace vanillaRig; + + double excludeLandmarksFutherThanDist = 1e10; + double dynamicOutlierRejectionThreshold = + 1; // max 1 pixel of average reprojection error + + KeyVector views{x1, x2, x3}; + + // add fourth landmark + Point3 landmark4(5, -0.5, 1); + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3, + measurements_cam4; + + // Project 4 landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + projectToMultipleCameras(cam1, cam2, cam3, landmark4, measurements_cam4); + measurements_cam4.at(0) = + measurements_cam4.at(0) + Point2(10, 10); // add outlier + + SmartProjectionParams params; + params.setLinearizationMode(gtsam::HESSIAN); + params.setDegeneracyMode(gtsam::ZERO_ON_DEGENERACY); + params.setLandmarkDistanceThreshold(excludeLandmarksFutherThanDist); + params.setDynamicOutlierRejectionThreshold(dynamicOutlierRejectionThreshold); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0}; + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_cam3, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor4( + new SmartRigFactor(model, cameraRig, params)); + smartFactor4->add(measurements_cam4, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.push_back(smartFactor4); + graph.addPrior(x1, cam1.pose(), noisePrior); + graph.addPrior(x2, cam2.pose(), noisePrior); + + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + values.insert(x3, cam3.pose()); + + // All factors are disabled and pose should remain where it is + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(cam3.pose(), result.at(x3))); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, CheckHessian) { + KeyVector views{x1, x2, x3}; + + using namespace vanillaRig; + + // Two slightly different cameras + Pose3 pose2 = + level_pose * Pose3(Rot3::RzRyRx(-0.05, 0.0, -0.05), Point3(0, 0, 0)); + Pose3 pose3 = pose2 * Pose3(Rot3::RzRyRx(-0.05, 0.0, -0.05), Point3(0, 0, 0)); + Camera cam2(pose2, sharedK); + Camera cam3(pose3, sharedK); + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + SmartProjectionParams params; + params.setRankTolerance(10); + params.setDegeneracyMode(gtsam::ZERO_ON_DEGENERACY); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0}; + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); // HESSIAN, by default + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); // HESSIAN, by default + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); // HESSIAN, by default + smartFactor3->add(measurements_cam3, views, cameraIds); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose3 * noise_pose); + EXPECT(assert_equal(Pose3(Rot3(0.00563056869, -0.130848107, 0.991386438, + -0.991390265, -0.130426831, -0.0115837907, + 0.130819108, -0.98278564, -0.130455917), + Point3(0.0897734171, -0.110201006, 0.901022872)), + values.at(x3))); + + boost::shared_ptr factor1 = smartFactor1->linearize(values); + boost::shared_ptr factor2 = smartFactor2->linearize(values); + boost::shared_ptr factor3 = smartFactor3->linearize(values); + + Matrix CumulativeInformation = + factor1->information() + factor2->information() + factor3->information(); + + boost::shared_ptr GaussianGraph = + graph.linearize(values); + Matrix GraphInformation = GaussianGraph->hessian().first; + + // Check Hessian + EXPECT(assert_equal(GraphInformation, CumulativeInformation, 1e-6)); + + Matrix AugInformationMatrix = factor1->augmentedInformation() + + factor2->augmentedInformation() + + factor3->augmentedInformation(); + + // Check Information vector + Vector InfoVector = AugInformationMatrix.block( + 0, 18, 18, 1); // 18x18 Hessian + information vector + + // Check Hessian + EXPECT(assert_equal(InfoVector, GaussianGraph->hessian().second, 1e-6)); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, Hessian) { + using namespace vanillaPose2; + + KeyVector views{x1, x2}; + + // Project three landmarks into 2 cameras + Point2 cam1_uv1 = cam1.project(landmark1); + Point2 cam2_uv1 = cam2.project(landmark1); + Point2Vector measurements_cam1; + measurements_cam1.push_back(cam1_uv1); + measurements_cam1.push_back(cam2_uv1); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK2)); + FastVector cameraIds{0, 0}; + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.5, 0.1, 0.3)); + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + + boost::shared_ptr factor = smartFactor1->linearize(values); + + // compute triangulation from linearization point + // compute reprojection errors (sum squared) + // compare with factor.info(): the bottom right element is the squared sum of + // the reprojection errors (normalized by the covariance) check that it is + // correctly scaled when using noiseProjection = [1/4 0; 0 1/4] +} + +/* ************************************************************************* */ +TEST(SmartProjectionRigFactor, ConstructorWithCal3Bundler) { + using namespace bundlerPose; + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedBundlerK)); + + SmartProjectionParams params; + params.setDegeneracyMode(gtsam::ZERO_ON_DEGENERACY); + SmartRigFactor factor(model, cameraRig, params); + factor.add(measurement1, x1, cameraId1); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, Cal3Bundler) { + using namespace bundlerPose; + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + // three landmarks ~5 meters in front of camera + Point3 landmark3(3, 0, 3.0); + + Point2Vector measurements_cam1, measurements_cam2, measurements_cam3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_cam1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_cam2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_cam3); + + KeyVector views{x1, x2, x3}; + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedBundlerK)); + FastVector cameraIds{0, 0, 0}; + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_cam1, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_cam2, views, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_cam3, views, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, cam1.pose(), noisePrior); + graph.addPrior(x2, cam2.pose(), noisePrior); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, cam1.pose()); + values.insert(x2, cam2.pose()); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT(assert_equal( + Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, -0.0313952598, + -0.000986635786, 0.0314107591, -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(cam3.pose(), result.at(x3), 1e-6)); +} + +#include +typedef GenericProjectionFactor TestProjectionFactor; +static Symbol l0('L', 0); +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, + hessianComparedToProjFactors_measurementsFromSamePose) { + // in this test we make sure the fact works even if we have multiple pixel + // measurements of the same landmark at a single pose, a setup that occurs in + // multi-camera systems + + using namespace vanillaRig; + Point2Vector measurements_lmk1; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_lmk1); + + // create redundant measurements: + Camera::MeasurementVector measurements_lmk1_redundant = measurements_lmk1; + measurements_lmk1_redundant.push_back( + measurements_lmk1.at(0)); // we readd the first measurement + + // create inputs + KeyVector keys{x1, x2, x3, x1}; + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0, 0}; + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1_redundant, keys, cameraIds); + + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise to get a nontrivial linearization + // point + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + // linearization point for the poses + Pose3 pose1 = level_pose; + Pose3 pose2 = pose_right; + Pose3 pose3 = pose_above * noise_pose; + + // ==== check Hessian of smartFactor1 ===== + // -- compute actual Hessian + boost::shared_ptr linearfactor1 = + smartFactor1->linearize(values); + Matrix actualHessian = linearfactor1->information(); + + // -- compute expected Hessian from manual Schur complement from Jacobians + // linearization point for the 3D point + smartFactor1->triangulateSafe(smartFactor1->cameras(values)); + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // check triangulated point is valid + + // Use standard ProjectionFactor factor to calculate the Jacobians + Matrix F = Matrix::Zero(2 * 4, 6 * 3); + Matrix E = Matrix::Zero(2 * 4, 3); + Vector b = Vector::Zero(2 * 4); + + // create projection factors rolling shutter + TestProjectionFactor factor11(measurements_lmk1_redundant[0], model, x1, l0, + sharedK); + Matrix HPoseActual, HEActual; + // note: b is minus the reprojection error, cf the smart factor jacobian + // computation + b.segment<2>(0) = + -factor11.evaluateError(pose1, *point, HPoseActual, HEActual); + F.block<2, 6>(0, 0) = HPoseActual; + E.block<2, 3>(0, 0) = HEActual; + + TestProjectionFactor factor12(measurements_lmk1_redundant[1], model, x2, l0, + sharedK); + b.segment<2>(2) = + -factor12.evaluateError(pose2, *point, HPoseActual, HEActual); + F.block<2, 6>(2, 6) = HPoseActual; + E.block<2, 3>(2, 0) = HEActual; + + TestProjectionFactor factor13(measurements_lmk1_redundant[2], model, x3, l0, + sharedK); + b.segment<2>(4) = + -factor13.evaluateError(pose3, *point, HPoseActual, HEActual); + F.block<2, 6>(4, 12) = HPoseActual; + E.block<2, 3>(4, 0) = HEActual; + + TestProjectionFactor factor14(measurements_lmk1_redundant[3], model, x1, l0, + sharedK); + b.segment<2>(6) = + -factor11.evaluateError(pose1, *point, HPoseActual, HEActual); + F.block<2, 6>(6, 0) = HPoseActual; + E.block<2, 3>(6, 0) = HEActual; + + // whiten + F = (1 / sigma) * F; + E = (1 / sigma) * E; + b = (1 / sigma) * b; + //* G = F' * F - F' * E * P * E' * F + Matrix P = (E.transpose() * E).inverse(); + Matrix expectedHessian = + F.transpose() * F - (F.transpose() * E * P * E.transpose() * F); + EXPECT(assert_equal(expectedHessian, actualHessian, 1e-6)); + + // ==== check Information vector of smartFactor1 ===== + GaussianFactorGraph gfg; + gfg.add(linearfactor1); + Matrix actualHessian_v2 = gfg.hessian().first; + EXPECT(assert_equal(actualHessian_v2, actualHessian, + 1e-6)); // sanity check on hessian + + // -- compute actual information vector + Vector actualInfoVector = gfg.hessian().second; + + // -- compute expected information vector from manual Schur complement from + // Jacobians + //* g = F' * (b - E * P * E' * b) + Vector expectedInfoVector = F.transpose() * (b - E * P * E.transpose() * b); + EXPECT(assert_equal(expectedInfoVector, actualInfoVector, 1e-6)); + + // ==== check error of smartFactor1 (again) ===== + NonlinearFactorGraph nfg_projFactors; + nfg_projFactors.add(factor11); + nfg_projFactors.add(factor12); + nfg_projFactors.add(factor13); + nfg_projFactors.add(factor14); + values.insert(l0, *point); + + double actualError = smartFactor1->error(values); + double expectedError = nfg_projFactors.error(values); + EXPECT_DOUBLES_EQUAL(expectedError, actualError, 1e-7); +} + +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, optimization_3poses_measurementsFromSamePose) { + using namespace vanillaRig; + Point2Vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_lmk3); + + // create inputs + KeyVector keys{x1, x2, x3}; + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + FastVector cameraIds{0, 0, 0}; + FastVector cameraIdsRedundant{0, 0, 0, 0}; + + // For first factor, we create redundant measurement (taken by the same keys + // as factor 1, to make sure the redundancy in the keys does not create + // problems) + Camera::MeasurementVector& measurements_lmk1_redundant = measurements_lmk1; + measurements_lmk1_redundant.push_back( + measurements_lmk1.at(0)); // we readd the first measurement + KeyVector keys_redundant = keys; + keys_redundant.push_back(keys.at(0)); // we readd the first key + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1_redundant, keys_redundant, + cameraIdsRedundant); + + SmartRigFactor::shared_ptr smartFactor2( + new SmartRigFactor(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, keys, cameraIds); + + SmartRigFactor::shared_ptr smartFactor3( + new SmartRigFactor(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, keys, cameraIds); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-5)); +} + +#ifndef DISABLE_TIMING +#include +// this factor is slightly slower (but comparable) to original +// SmartProjectionPoseFactor +//-Total: 0 CPU (0 times, 0 wall, 0.17 children, min: 0 max: 0) +//| -SmartRigFactor LINEARIZE: 0.05 CPU (10000 times, 0.057952 wall, 0.05 +// children, min: 0 max: 0) | -SmartPoseFactor LINEARIZE: 0.05 CPU (10000 +// times, 0.069647 wall, 0.05 children, min: 0 max: 0) +/* *************************************************************************/ +TEST(SmartProjectionRigFactor, timing) { + using namespace vanillaRig; + + // Default cameras for simple derivatives + static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); + + Rot3 R = Rot3::identity(); + Pose3 pose1 = Pose3(R, Point3(0, 0, 0)); + Pose3 pose2 = Pose3(R, Point3(1, 0, 0)); + Camera cam1(pose1, sharedKSimple), cam2(pose2, sharedKSimple); + Pose3 body_P_sensorId = Pose3::identity(); + + boost::shared_ptr cameraRig(new Cameras()); // single camera in the rig + cameraRig->push_back(Camera(body_P_sensorId, sharedKSimple)); + + // one landmarks 1m in front of camera + Point3 landmark1(0, 0, 10); + + Point2Vector measurements_lmk1; + + // Project 2 landmarks into 2 cameras + measurements_lmk1.push_back(cam1.project(landmark1)); + measurements_lmk1.push_back(cam2.project(landmark1)); + + size_t nrTests = 10000; + + for (size_t i = 0; i < nrTests; i++) { + SmartRigFactor::shared_ptr smartRigFactor( + new SmartRigFactor(model, cameraRig, params)); + smartRigFactor->add(measurements_lmk1[0], x1, cameraId1); + smartRigFactor->add(measurements_lmk1[1], x1, cameraId1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartRigFactor_LINEARIZE); + smartRigFactor->linearize(values); + gttoc_(SmartRigFactor_LINEARIZE); + } + + for (size_t i = 0; i < nrTests; i++) { + SmartFactor::shared_ptr smartFactor( + new SmartFactor(model, sharedKSimple, params)); + smartFactor->add(measurements_lmk1[0], x1); + smartFactor->add(measurements_lmk1[1], x2); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartPoseFactor_LINEARIZE); + smartFactor->linearize(values); + gttoc_(SmartPoseFactor_LINEARIZE); + } + tictoc_print_(); +} +#endif + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, optimization_3poses_sphericalCamera) { + using namespace sphericalCamera; + Camera::MeasurementVector measurements_lmk1, measurements_lmk2, + measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + KeyVector keys; + keys.push_back(x1); + keys.push_back(x2); + keys.push_back(x3); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartFactorP::shared_ptr smartFactor1( + new SmartFactorP(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, keys); + + SmartFactorP::shared_ptr smartFactor2( + new SmartFactorP(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, keys); + + SmartFactorP::shared_ptr smartFactor3( + new SmartFactorP(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, keys); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 100), + Point3(0.2, 0.2, 0.2)); // note: larger noise! + + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + + DOUBLES_EQUAL(0.94148963675515274, graph.error(values), 1e-9); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + + EXPECT(assert_equal(pose_above, result.at(x3), 1e-5)); +} + +#ifndef DISABLE_TIMING +#include +// using spherical camera is slightly slower (but comparable) to +// PinholePose +//| -SmartFactorP spherical LINEARIZE: 0.01 CPU (1000 times, 0.008178 wall, +// 0.01 children, min: 0 max: 0) | -SmartFactorP pinhole LINEARIZE: 0.01 CPU +//(1000 times, 0.005717 wall, 0.01 children, min: 0 max: 0) +/* *************************************************************************/ +TEST(SmartProjectionFactorP, timing_sphericalCamera) { + // create common data + Rot3 R = Rot3::identity(); + Pose3 pose1 = Pose3(R, Point3(0, 0, 0)); + Pose3 pose2 = Pose3(R, Point3(1, 0, 0)); + Pose3 body_P_sensorId = Pose3::identity(); + Point3 landmark1(0, 0, 10); + + // create spherical data + EmptyCal::shared_ptr emptyK; + SphericalCamera cam1_sphere(pose1, emptyK), cam2_sphere(pose2, emptyK); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1_sphere; + measurements_lmk1_sphere.push_back(cam1_sphere.project(landmark1)); + measurements_lmk1_sphere.push_back(cam2_sphere.project(landmark1)); + + // create Cal3_S2 data + static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); + PinholePose cam1(pose1, sharedKSimple), cam2(pose2, sharedKSimple); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1; + measurements_lmk1.push_back(cam1.project(landmark1)); + measurements_lmk1.push_back(cam2.project(landmark1)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + size_t nrTests = 1000; + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(body_P_sensorId, emptyK)); + + SmartProjectionRigFactor::shared_ptr smartFactorP( + new SmartProjectionRigFactor(model, cameraRig, + params)); + smartFactorP->add(measurements_lmk1_sphere[0], x1); + smartFactorP->add(measurements_lmk1_sphere[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_spherical_LINEARIZE); + smartFactorP->linearize(values); + gttoc_(SmartFactorP_spherical_LINEARIZE); + } + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(body_P_sensorId, sharedKSimple)); + + SmartProjectionRigFactor>::shared_ptr smartFactorP2( + new SmartProjectionRigFactor>(model, cameraRig, + params)); + smartFactorP2->add(measurements_lmk1[0], x1); + smartFactorP2->add(measurements_lmk1[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_pinhole_LINEARIZE); + smartFactorP2->linearize(values); + gttoc_(SmartFactorP_pinhole_LINEARIZE); + } + tictoc_print_(); +} +#endif + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_rankTol) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + // triangulate from a stereo with 10cm baseline, assuming standard calibration + { // default rankTol = 1 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + + Camera cam1(poseA, sharedK); + Camera cam2(poseB, sharedK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), sharedK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // default rankTol = 1 or 0.1 gives a degenerate point, which is + // undesirable for a point 5m away and 10cm baseline + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // valid triangulation + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // smaller rankTol = 0.01 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_sphericalCamera_rankTol) { + typedef SphericalCamera Camera; + typedef SmartProjectionRigFactor SmartRigFactor; + EmptyCal::shared_ptr emptyK(new EmptyCal()); + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + Camera cam1(poseA); + Camera cam2(poseB); + + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(Pose3::identity(), emptyK)); + + // TRIANGULATION TEST WITH DEFAULT RANK TOL + { // rankTol = 1 or 0.1 gives a degenerate point, which is undesirable for a + // point 5m away and 10cm baseline + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // not enough parallax + } + // SAME TEST WITH SMALLER RANK TOL + { // rankTol = 0.01 gives a valid point + // By playing with this test, we can show we can triangulate also with a + // baseline of 5cm (even for points far away, >100m), but the test fails + // when the baseline becomes 1cm. This suggests using rankTol = 0.01 and + // setting a reasonable max landmark distance to obtain best results. + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/symbolic/SymbolicBayesNet.cpp b/gtsam/symbolic/SymbolicBayesNet.cpp index 5bc20ad12..f7113b23a 100644 --- a/gtsam/symbolic/SymbolicBayesNet.cpp +++ b/gtsam/symbolic/SymbolicBayesNet.cpp @@ -16,41 +16,16 @@ * @author Richard Roberts */ -#include -#include #include - -#include -#include +#include namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool SymbolicBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ - void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const - { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional: boost::adaptors::reverse(*this)) { - SymbolicConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - SymbolicConditional::Parents parents = conditional->parents(); - for(Key p: parents) - of << p << "->" << me << std::endl; - } - - of << "}"; - of.close(); - } - +// Instantiate base class +template class FactorGraph; +/* ************************************************************************* */ +bool SymbolicBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); } +} // namespace gtsam diff --git a/gtsam/symbolic/SymbolicBayesNet.h b/gtsam/symbolic/SymbolicBayesNet.h index 464af060b..2f66b80e2 100644 --- a/gtsam/symbolic/SymbolicBayesNet.h +++ b/gtsam/symbolic/SymbolicBayesNet.h @@ -19,19 +19,19 @@ #pragma once #include +#include #include #include namespace gtsam { - /** Symbolic Bayes Net - * \nosubgrouping + /** + * A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals. + * @addtogroup symbolic */ - class SymbolicBayesNet : public FactorGraph { - - public: - - typedef FactorGraph Base; + class SymbolicBayesNet : public BayesNet { + public: + typedef BayesNet Base; typedef SymbolicBayesNet This; typedef SymbolicConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +44,21 @@ namespace gtsam { SymbolicBayesNet() {} /** Construct from iterator over conditionals */ - template - SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit SymbolicBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - SymbolicBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit SymbolicBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~SymbolicBayesNet() {} @@ -75,13 +80,6 @@ namespace gtsam { /// @} - /// @name Standard Interface - /// @{ - - GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /// @} - private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/symbolic/SymbolicFactor.h b/gtsam/symbolic/SymbolicFactor.h index 2a488a4da..767998d22 100644 --- a/gtsam/symbolic/SymbolicFactor.h +++ b/gtsam/symbolic/SymbolicFactor.h @@ -144,9 +144,6 @@ namespace gtsam { /// @name Standard Interface /// @{ - /** Whether the factor is empty (involves zero variables). */ - bool empty() const { return keys_.empty(); } - /** Eliminate the variables in \c keys, in the order specified in \c keys, returning a * conditional and marginal. */ std::pair, boost::shared_ptr > diff --git a/gtsam/symbolic/SymbolicJunctionTree.h b/gtsam/symbolic/SymbolicJunctionTree.h index 7a152e532..0dcfae541 100644 --- a/gtsam/symbolic/SymbolicJunctionTree.h +++ b/gtsam/symbolic/SymbolicJunctionTree.h @@ -16,6 +16,8 @@ * @author Richard Roberts */ +#pragma once + #include #include #include diff --git a/gtsam/symbolic/symbolic.i b/gtsam/symbolic/symbolic.i index 4e7cca68a..1f1d4b48f 100644 --- a/gtsam/symbolic/symbolic.i +++ b/gtsam/symbolic/symbolic.i @@ -3,11 +3,6 @@ //************************************************************************* namespace gtsam { -#include -#include - -// ################### - #include virtual class SymbolicFactor { // Standard Constructors and Named Constructors @@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph { const gtsam::KeyVector& key_vector, const gtsam::Ordering& marginalizedVariableOrdering); gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor { bool equals(const gtsam::SymbolicConditional& other, double tol) const; // Standard interface + gtsam::Key firstFrontalKey() const; size_t nrFrontals() const; size_t nrParents() const; }; @@ -125,6 +129,14 @@ class SymbolicBayesNet { gtsam::SymbolicConditional* back() const; void push_back(gtsam::SymbolicConditional* conditional); void push_back(const gtsam::SymbolicBayesNet& bayesNet); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -173,29 +185,4 @@ class SymbolicBayesTreeClique { void deleteCachedShortcuts(); }; -#include -class VariableIndex { - // Standard Constructors and Named Constructors - VariableIndex(); - // TODO: Templetize constructor when wrap supports it - // template - // VariableIndex(const T& factorGraph, size_t nVariables); - // VariableIndex(const T& factorGraph); - VariableIndex(const gtsam::SymbolicFactorGraph& sfg); - VariableIndex(const gtsam::GaussianFactorGraph& gfg); - VariableIndex(const gtsam::NonlinearFactorGraph& fg); - VariableIndex(const gtsam::VariableIndex& other); - - // Testable - bool equals(const gtsam::VariableIndex& other, double tol) const; - void print(string s = "VariableIndex: ", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - - // Standard interface - size_t size() const; - size_t nFactors() const; - size_t nEntries() const; -}; - } // namespace gtsam diff --git a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp index a92d66f68..2e13be10e 100644 --- a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp +++ b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp @@ -15,13 +15,16 @@ * @author Frank Dellaert */ -#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include using namespace std; using namespace gtsam; @@ -30,7 +33,6 @@ static const Key _L_ = 0; static const Key _A_ = 1; static const Key _B_ = 2; static const Key _C_ = 3; -static const Key _D_ = 4; static SymbolicConditional::shared_ptr B(new SymbolicConditional(_B_)), @@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine ) } /* ************************************************************************* */ -TEST(SymbolicBayesNet, saveGraph) { +TEST(SymbolicBayesNet, Dot) { + using symbol_shorthand::A; + using symbol_shorthand::X; SymbolicBayesNet bn; - bn += SymbolicConditional(_A_, _B_); - KeyVector keys {_B_, _C_, _D_}; - bn += SymbolicConditional::FromKeys(keys,2); - bn += SymbolicConditional(_D_); + bn += SymbolicConditional(X(3), X(2), A(2)); + bn += SymbolicConditional(X(2), X(1), A(1)); + bn += SymbolicConditional(X(1)); - bn.saveGraph("SymbolicBayesNet.dot"); + DotWriter writer; + writer.positionHints.emplace('a', 2); + writer.positionHints.emplace('x', 1); + writer.boxes.emplace(A(1)); + writer.boxes.emplace(A(2)); + + auto position = writer.variablePos(A(1)); + CHECK(position); + EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5)); + + string actual = bn.dot(DefaultKeyFormatter, writer); + bn.saveGraph("bn.dot", DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n" + " vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n" + " varx1[label=\"x1\", pos=\"1,1!\"];\n" + " varx2[label=\"x2\", pos=\"2,1!\"];\n" + " varx3[label=\"x3\", pos=\"3,1!\"];\n" + "\n" + " varx1->varx2\n" + " vara1->varx2\n" + " varx2->varx3\n" + " vara2->varx3\n" + "}"); } /* ************************************************************************* */ diff --git a/gtsam/symbolic/tests/testSymbolicBayesTree.cpp b/gtsam/symbolic/tests/testSymbolicBayesTree.cpp index 33fc3243b..ee9b41a5a 100644 --- a/gtsam/symbolic/tests/testSymbolicBayesTree.cpp +++ b/gtsam/symbolic/tests/testSymbolicBayesTree.cpp @@ -731,10 +731,12 @@ TEST(SymbolicBayesTree, COLAMDvsMETIS) { { Ordering ordering = Ordering::Create(Ordering::METIS, sfg); // Linux and Mac split differently when using mettis -#if !defined(__APPLE__) - EXPECT(assert_equal(Ordering(list_of(3)(2)(5)(0)(4)(1)), ordering)); -#else +#if defined(__APPLE__) EXPECT(assert_equal(Ordering(list_of(5)(4)(2)(1)(0)(3)), ordering)); +#elif defined(_WIN32) + EXPECT(assert_equal(Ordering(list_of(4)(3)(1)(0)(5)(2)), ordering)); +#else + EXPECT(assert_equal(Ordering(list_of(3)(2)(5)(0)(4)(1)), ordering)); #endif // - P( 1 0 3) @@ -742,20 +744,27 @@ TEST(SymbolicBayesTree, COLAMDvsMETIS) { // | | - P( 5 | 0 4) // | - P( 2 | 1 3) SymbolicBayesTree expected; -#if !defined(__APPLE__) - expected.insertRoot( - MakeClique(list_of(2)(4)(1), 3, - list_of( - MakeClique(list_of(0)(1)(4), 1, - list_of(MakeClique(list_of(5)(0)(4), 1))))( - MakeClique(list_of(3)(2)(4), 1)))); -#else +#if defined(__APPLE__) expected.insertRoot( MakeClique(list_of(1)(0)(3), 3, list_of( MakeClique(list_of(4)(0)(3), 1, list_of(MakeClique(list_of(5)(0)(4), 1))))( MakeClique(list_of(2)(1)(3), 1)))); +#elif defined(_WIN32) + expected.insertRoot( + MakeClique(list_of(3)(5)(2), 3, + list_of( + MakeClique(list_of(4)(3)(5), 1, + list_of(MakeClique(list_of(0)(2)(5), 1))))( + MakeClique(list_of(1)(0)(2), 1)))); +#else + expected.insertRoot( + MakeClique(list_of(2)(4)(1), 3, + list_of( + MakeClique(list_of(0)(1)(4), 1, + list_of(MakeClique(list_of(5)(0)(4), 1))))( + MakeClique(list_of(3)(2)(4), 1)))); #endif SymbolicBayesTree actual = *sfg.eliminateMultifrontal(ordering); EXPECT(assert_equal(expected, actual)); diff --git a/gtsam_unstable/CMakeLists.txt b/gtsam_unstable/CMakeLists.txt index 13c061b9b..98a1b4ef9 100644 --- a/gtsam_unstable/CMakeLists.txt +++ b/gtsam_unstable/CMakeLists.txt @@ -100,12 +100,12 @@ endif() install( TARGETS gtsam_unstable - EXPORT GTSAM-exports + EXPORT GTSAM_UNSTABLE-exports LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) -list(APPEND GTSAM_EXPORTED_TARGETS gtsam_unstable) -set(GTSAM_EXPORTED_TARGETS "${GTSAM_EXPORTED_TARGETS}" PARENT_SCOPE) +list(APPEND GTSAM_UNSTABLE_EXPORTED_TARGETS gtsam_unstable) +set(GTSAM_UNSTABLE_EXPORTED_TARGETS "${GTSAM_UNSTABLE_EXPORTED_TARGETS}" PARENT_SCOPE) # Build examples add_subdirectory(examples) diff --git a/gtsam_unstable/base/BTree.h b/gtsam_unstable/base/BTree.h index 9d854a169..94e27d6c4 100644 --- a/gtsam_unstable/base/BTree.h +++ b/gtsam_unstable/base/BTree.h @@ -17,6 +17,8 @@ * @date Feb 3, 2010 */ +#pragma once + #include #include #include diff --git a/gtsam_unstable/base/Dummy.h b/gtsam_unstable/base/Dummy.h index a2f544de5..548bce344 100644 --- a/gtsam_unstable/base/Dummy.h +++ b/gtsam_unstable/base/Dummy.h @@ -17,6 +17,8 @@ * @date June 14, 2012 */ +#pragma once + #include #include #include diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 9e124954f..bff524bc2 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -5,107 +5,109 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include + #include namespace gtsam { - /* ************************************************************************* */ - AllDiff::AllDiff(const DiscreteKeys& dkeys) : - Constraint(dkeys.indices()) { - for(const DiscreteKey& dkey: dkeys) - cardinalities_.insert(dkey); - } +/* ************************************************************************* */ +AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { + for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); +} - /* ************************************************************************* */ - void AllDiff::print(const std::string& s, - const KeyFormatter& formatter) const { - std::cout << s << "AllDiff on "; - for (Key dkey: keys_) - std::cout << formatter(dkey) << " "; - std::cout << std::endl; - } +/* ************************************************************************* */ +void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { + std::cout << s << "AllDiff on "; + for (Key dkey : keys_) std::cout << formatter(dkey) << " "; + std::cout << std::endl; +} - /* ************************************************************************* */ - double AllDiff::operator()(const Values& values) const { - std::set < size_t > taken; // record values taken by keys - for(Key dkey: keys_) { - size_t value = values.at(dkey); // get the value for that key - if (taken.count(value)) return 0.0;// check if value alreday taken - taken.insert(value);// if not, record it as taken and keep checking +/* ************************************************************************* */ +double AllDiff::operator()(const DiscreteValues& values) const { + std::set taken; // record values taken by keys + for (Key dkey : keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0; // check if value alreday taken + taken.insert(value); // if not, record it as taken and keep checking + } + return 1.0; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff + // constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); + converted = converted * binary12.toDecisionTreeFactor(); } - return 1.0; + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { + Domain& Dj = domains->at(j); + + // Though strictly not part of allDiff, we check for + // a value in domains->at(j) that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; } - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { - // We will do this by converting the allDif into many BinaryAllDiff constraints - DecisionTreeFactor converted; - size_t nrKeys = keys_.size(); - for (size_t i1 = 0; i1 < nrKeys; i1++) - for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { - BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); - converted = converted * binary12.toDecisionTreeFactor(); - } - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool AllDiff::ensureArcConsistency(size_t j, std::vector& domains) const { - // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. - // If found, we make this a singleton... - // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; - - // Check all other domains for singletons and erase corresponding values - // This is the same as arc-consistency on the equivalent binary constraints - bool changed = false; - for(Key k: keys_) - if (k != j) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) { // check if singleton - size_t value = Dk.firstValue(); - if (Dj.contains(value)) { - Dj.erase(value); // erase value if true - changed = true; - } + // Check all other domains for singletons and erase corresponding values. + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + for (Key k : keys_) + if (k != j) { + const Domain& Dk = domains->at(k); + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; } } - return changed; - } + } + return changed; +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { - DiscreteKeys newKeys; - // loop over keys and add them only if they do not appear in values - for(Key k: keys_) - if (values.find(k) == values.end()) { - newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); - } - return boost::make_shared(newKeys); - } +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + for (Key k : keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); + } + return boost::make_shared(newKeys); +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; - for(Key k: keys_) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) - known[k] = Dk.firstValue(); - } - return partiallyApply(known); +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply( + const Domains& domains) const { + DiscreteValues known; + for (Key k : keys_) { + const Domain& Dk = domains.at(k); + if (Dk.isSingleton()) known[k] = Dk.firstValue(); } + return partiallyApply(known); +} - /* ************************************************************************* */ -} // namespace gtsam +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 80e700b29..9496fc1a6 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -7,71 +7,66 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. +/** + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. + */ +class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { + std::map cardinalities_; + + DiscreteKey discreteKey(size_t i) const { + Key j = keys_[i]; + return DiscreteKey(j, cardinalities_.at(j)); + } + + public: + /// Construct from keys. + AllDiff(const DiscreteKeys& dkeys); + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const AllDiff& f(static_cast(other)); + return cardinalities_.size() == f.cardinalities_.size() && + std::equal(cardinalities_.begin(), cardinalities_.end(), + f.cardinalities_.begin()); + } + } + + /// Calculate value = expensive ! + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree, can be *very* expensive ! + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - std::map cardinalities_; + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override; - DiscreteKey discreteKey(size_t i) const { - Key j = keys_[i]; - return DiscreteKey(j,cardinalities_.at(j)); - } + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains&) const override; +}; - public: - - /// Constructor - AllDiff(const DiscreteKeys& dkeys); - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const AllDiff& f(static_cast(other)); - return cardinalities_.size() == f.cardinalities_.size() - && std::equal(cardinalities_.begin(), cardinalities_.end(), - f.cardinalities_.begin()); - } - } - - /// Calculate value = expensive ! - double operator()(const Values& values) const override; - - /// Convert into a decisiontree, can be *very* expensive ! - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply(const std::vector&) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index bbb60e2f1..b207acb9d 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -7,94 +7,90 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { - /** - * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ - class BinaryAllDiff: public Constraint { +/** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise. + */ +class BinaryAllDiff : public Constraint { + size_t cardinality0_, cardinality1_; /// cardinality - size_t cardinality0_, cardinality1_; /// cardinality + public: + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) + : Constraint(key1.first, key2.first), + cardinality0_(key1.second), + cardinality1_(key2.second) {} - public: + // print + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " + << formatter(keys_[1]) << std::endl; + } - /// Constructor - BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : - Constraint(key1.first, key2.first), - cardinality0_(key1.second), cardinality1_(key2.second) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " - << formatter(keys_[1]) << std::endl; - } - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const BinaryAllDiff& f(static_cast(other)); - return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); - } - } - - /// Calculate value - double operator()(const Values& values) const override { - return (double) (values.at(keys_[0]) != values.at(keys_[1])); - } - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - DiscreteKeys keys; - keys.push_back(DiscreteKey(keys_[0],cardinality0_)); - keys.push_back(DiscreteKey(keys_[1],cardinality1_)); - std::vector table; - for (size_t i1 = 0; i1 < cardinality0_; i1++) - for (size_t i2 = 0; i2 < cardinality1_; i2++) - table.push_back(i1 != i2); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - /// - bool ensureArcConsistency(size_t j, std::vector& domains) const override { -// throw std::runtime_error( -// "BinaryAllDiff::ensureArcConsistency not implemented"); + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) return false; + else { + const BinaryAllDiff& f(static_cast(other)); + return (cardinality0_ == f.cardinality0_) && + (cardinality1_ == f.cardinality1_); } + } - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } + /// Calculate value + double operator()(const DiscreteValues& values) const override { + return (double)(values.at(keys_[0]) != values.at(keys_[1])); + } - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } - }; + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0], cardinality0_)); + keys.push_back(DiscreteKey(keys_[1], cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } -} // namespace gtsam + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. + */ + bool ensureArcConsistency(Key j, Domains* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } +}; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index 525abd098..08143c469 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -5,99 +5,84 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include +#include using namespace std; namespace gtsam { - /// Find the best total assignment - can be expensive - CSP::sharedValues CSP::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); - sharedValues mpe = chordal->optimize(); - return mpe; - } +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; - /// Find the best total assignment - can be expensive - CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); - sharedValues mpe = chordal->optimize(); - return mpe; - } + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; - void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const { - // Create VariableIndex - VariableIndex index(*this); - // index.print(); + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; - size_t n = index.size(); - - // Initialize domains - std::vector < Domain > domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j,cardinality))); - - // Create array of flags indicating a domain changed or not - std::vector changed(n); - - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for(size_t f: factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it - -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application - // TODO: create a new ordering as we go, to ensure a connected graph - // KeyOrdering ordering; - // vector dkeys; - for(const DiscreteFactor::shared_ptr& f: factors_) { - Constraint::shared_ptr constraint = boost::dynamic_pointer_cast(f); - if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } } -#endif } -} // gtsam + return changed; +} +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { + // Create VariableIndex + VariableIndex index(*this); + + // Initialize domains + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } + + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} + +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. + // TODO: create a new ordering as we go, to ensure a connected graph + // KeyOrdering ordering; + // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: + for (const DiscreteFactor::shared_ptr& f : factors_) { + auto constraint = boost::dynamic_pointer_cast(f); + if (!constraint) + throw runtime_error("CSP:runArcConsistency: non-constraint factor"); + Constraint::shared_ptr reduced = constraint->partiallyApply(domains); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } + } + return new_csp; +} +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index 9e843f667..40853bed6 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -7,84 +7,70 @@ #pragma once +#include #include #include -#include namespace gtsam { - /** - * Constraint Satisfaction Problem class - * A specialization of a DiscreteFactorGraph. - * It knows about CSP-specific constraints and algorithms +/** + * Constraint Satisfaction Problem class + * A specialization of a DiscreteFactorGraph. + * It knows about CSP-specific constraints and algorithms + */ +class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { + public: + using Values = DiscreteValues; ///< backwards compatibility + + /// Add a unary constraint, allowing only a single value + void addSingleValue(const DiscreteKey& dkey, size_t value) { + emplace_shared(dkey, value); + } + + /// Add a binary AllDiff constraint + void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { + emplace_shared(key1, key2); + } + + /// Add a general AllDiff constraint + void addAllDiff(const DiscreteKeys& dkeys) { emplace_shared(dkeys); } + + // /** return product of all factors as a single factor */ + // DecisionTreeFactor product() const { + // DecisionTreeFactor result; + // for(const sharedFactor& factor: *this) + // if (factor) result = (*factor) * result; + // return result; + // } + + // /* + // * Perform loopy belief propagation + // * True belief propagation would check for each value in domain + // * whether any satisfying separator assignment can be found. + // * This corresponds to hyper-arc consistency in CSP speak. + // * This can be done by creating a mini-factor graph and search. + // * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels + // deep. + // * It will be very expensive to exclude values that way. + // */ + // void applyBeliefPropagation(size_t maxIterations = 10) const; + + /* + * Apply arc-consistency ~ Approximate loopy belief propagation + * We need to give the domains to a constraint, and it returns + * a domain whose values don't conflict in the arc-consistency way. + * TODO: should get cardinality from DiscreteKeys */ - class GTSAM_UNSTABLE_EXPORT CSP: public DiscreteFactorGraph { - public: + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; - /** A map from keys to values */ - typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; - public: - -// /// Constructor -// CSP() { -// } - - /// Add a unary constraint, allowing only a single value - void addSingleValue(const DiscreteKey& dkey, size_t value) { - boost::shared_ptr factor(new SingleValue(dkey, value)); - push_back(factor); - } - - /// Add a binary AllDiff constraint - void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { - boost::shared_ptr factor( - new BinaryAllDiff(key1, key2)); - push_back(factor); - } - - /// Add a general AllDiff constraint - void addAllDiff(const DiscreteKeys& dkeys) { - boost::shared_ptr factor(new AllDiff(dkeys)); - push_back(factor); - } - -// /** return product of all factors as a single factor */ -// DecisionTreeFactor product() const { -// DecisionTreeFactor result; -// for(const sharedFactor& factor: *this) -// if (factor) result = (*factor) * result; -// return result; -// } - - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment() const; - - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment(const Ordering& ordering) const; - -// /* -// * Perform loopy belief propagation -// * True belief propagation would check for each value in domain -// * whether any satisfying separator assignment can be found. -// * This corresponds to hyper-arc consistency in CSP speak. -// * This can be done by creating a mini-factor graph and search. -// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. -// * It will be very expensive to exclude values that way. -// */ -// void applyBeliefPropagation(size_t nrIterations = 10) const; - - /* - * Apply arc-consistency ~ Approximate loopy belief propagation - * We need to give the domains to a constraint, and it returns - * a domain whose values don't conflict in the arc-consistency way. - * TODO: should get cardinality from Indices - */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; - }; // CSP - -} // gtsam + /* + * Create a new CSP, applying the given Domain constraints. + */ + CSP partiallyApply(const Domains& domains) const; +}; // CSP +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index c3a26de68..168891e6f 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -17,77 +17,88 @@ #pragma once -#include #include +#include +#include + #include +#include +#include namespace gtsam { - class Domain; +class Domain; +using Domains = std::map; - /** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor +/** + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. + */ +class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { + public: + typedef boost::shared_ptr shared_ptr; + + protected: + /// Construct unary constraint factor. + Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} + + /// Construct binary constraint factor. + Constraint(Key j1, Key j2) + : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} + + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : DiscreteFactor(js) {} + + /// construct from container + template + Constraint(KeyIterator beginKey, KeyIterator endKey) + : DiscreteFactor(beginKey, endKey) {} + + public: + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + Constraint(); + + /// Virtual destructor + ~Constraint() override {} + + /// @} + /// @name Standard Interface + /// @{ + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class Constraint : public DiscreteFactor { + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; - public: + /// Partially apply known values + virtual shared_ptr partiallyApply(const DiscreteValues&) const = 0; - typedef boost::shared_ptr shared_ptr; + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const Domains&) const = 0; + /// @} + /// @name Wrapper support + /// @{ - protected: + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { + return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); + } - /// Construct n-way factor - Constraint(const KeyVector& js) : - DiscreteFactor(js) { - } + /// Render as html table. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { + return (boost::format("

Constraint on %1% variables

") % (size())).str(); + } - /// Construct unary factor - Constraint(Key j) : - DiscreteFactor(boost::assign::cref_list_of<1>(j)) { - } - - /// Construct binary factor - Constraint(Key j1, Key j2) : - DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { - } - - /// construct from container - template - Constraint(KeyIterator beginKey, KeyIterator endKey) : - DiscreteFactor(beginKey, endKey) { - } - - public: - - /// @name Standard Constructors - /// @{ - - /// Default constructor for I/O - Constraint(); - - /// Virtual destructor - ~Constraint() override {} - - /// @} - /// @name Standard Interface - /// @{ - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - virtual bool ensureArcConsistency(size_t j, std::vector& domains) const = 0; - - /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; - - - /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; - /// @} - }; + /// @} +}; // DiscreteFactor -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 740ef067c..7acc10cb4 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -5,92 +5,94 @@ * @author Frank Dellaert */ -#include -#include #include -#include +#include +#include +#include +#include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void Domain::print(const string& s, - const KeyFormatter& formatter) const { -// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << -// formatter(keys_[0]) << ") with values"; -// for (size_t v: values_) cout << " " << v; -// cout << endl; - for (size_t v: values_) cout << v; - } - - /* ************************************************************************* */ - double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; ++i1) - table.push_back(contains(i1)); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; - for(size_t value: values_) - if (!D.contains(value)) throw runtime_error("Unsatisfiable"); - D = *this; - return true; - } - - /* ************************************************************************* */ - bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; - // for all values in this domain - for(size_t value: values_) { - // for all connected domains - for(Key k: keys) - // if any domain contains the value we cannot make this domain singleton - if (k!=j && domains[k].contains(value)) - goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it - found:; - } - return false; // we did not change it - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && !contains(it->second)) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (*this); - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (Dk); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void Domain::print(const string& s, const KeyFormatter& formatter) const { + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; + for (size_t v : values_) cout << " " << v; + cout << endl; +} + +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); +} + +/* ************************************************************************* */ +double Domain::operator()(const DiscreteValues& values) const { + return contains(values.at(key())); +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(key(), cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); + for (size_t value : values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; +} + +/* ************************************************************************* */ +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); + // for all values in this domain + for (const size_t value : values_) { + // for all connected domains + for (const Key k : keys) + // if any domain contains the value we cannot make this domain singleton + if (k != j && domains.at(k).contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); + found:; + } + return boost::none; // we did not change it +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(key()); + if (it != values.end() && !contains(it->second)) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(*this); +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); + if (Dk.isSingleton() && !contains(*Dk.begin())) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(Dk); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 5acc5a08f..1047101c5 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -7,111 +7,107 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * Domain restriction constraint +/** + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. + */ +class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { + size_t cardinality_; /// Cardinality + std::set values_; /// allowed values + + public: + typedef boost::shared_ptr shared_ptr; + + // Constructor on Discrete Key initializes an "all-allowed" domain + Domain(const DiscreteKey& dkey) + : Constraint(dkey.first), cardinality_(dkey.second) { + for (size_t v = 0; v < cardinality_; v++) values_.insert(v); + } + + // Constructor on Discrete Key with single allowed value + // Consider SingleValue constraint + Domain(const DiscreteKey& dkey, size_t v) + : Constraint(dkey.first), cardinality_(dkey.second) { + values_.insert(v); + } + + /// The one key + Key key() const { return keys_[0]; } + + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + + /// Insert a value, non const :-( + void insert(size_t value) { values_.insert(value); } + + /// Erase a value, non const :-( + void erase(size_t value) { values_.erase(value); } + + size_t nrValues() const { return values_.size(); } + + bool isSingleton() const { return nrValues() == 1; } + + size_t firstValue() const { return *values_.begin(); } + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const Domain& f(static_cast(other)); + return (cardinality_ == f.cardinality_) && (values_ == f.values_); + } + } + + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. + bool contains(size_t value) const { return values_.count(value) > 0; } + + /// Calculate value + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT Domain: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - size_t cardinality_; /// Cardinality - std::set values_; /// allowed values + /** + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains + */ + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; - public: + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; - typedef boost::shared_ptr shared_ptr; + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; +}; - // Constructor on Discrete Key initializes an "all-allowed" domain - Domain(const DiscreteKey& dkey) : - Constraint(dkey.first), cardinality_(dkey.second) { - for (size_t v = 0; v < cardinality_; v++) - values_.insert(v); - } - - // Constructor on Discrete Key with single allowed value - // Consider SingleValue constraint - Domain(const DiscreteKey& dkey, size_t v) : - Constraint(dkey.first), cardinality_(dkey.second) { - values_.insert(v); - } - - /// Constructor - Domain(const Domain& other) : - Constraint(other.keys_[0]), values_(other.values_) { - } - - /// insert a value, non const :-( - void insert(size_t value) { - values_.insert(value); - } - - /// erase a value, non const :-( - void erase(size_t value) { - values_.erase(value); - } - - size_t nrValues() const { - return values_.size(); - } - - bool isSingleton() const { - return nrValues() == 1; - } - - size_t firstValue() const { - return *values_.begin(); - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const Domain& f(static_cast(other)); - return (cardinality_==f.cardinality_) && (values_==f.values_); - } - } - - bool contains(size_t value) const { - return values_.count(value)>0; - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /** - * Check for a value in domain that does not occur in any other connected domain. - * If found, we make this a singleton... Called in AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff - */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index 3273778c4..b86df6c29 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -5,298 +5,268 @@ * @author Frank Dellaert */ -#include -#include #include #include +#include +#include #include - +#include #include #include -#include namespace gtsam { - using namespace std; +using namespace std; - Scheduler::Scheduler(size_t maxNrStudents, const string& filename): - maxNrStudents_(maxNrStudents) - { - typedef boost::tokenizer > Tokenizer; +Scheduler::Scheduler(size_t maxNrStudents, const string& filename) + : maxNrStudents_(maxNrStudents) { + typedef boost::tokenizer > Tokenizer; - // open file - ifstream is(filename.c_str()); - if (!is) { - cerr << "Scheduler: could not open file " << filename << endl; - throw runtime_error("Scheduler: could not open file " + filename); - } - - string line; // buffer - - // process first line with faculty - if (getline(is, line, '\r')) { - Tokenizer tok(line); - Tokenizer::iterator it = tok.begin(); - for (++it; it != tok.end(); ++it) - addFaculty(*it); - } - - // for all remaining lines - size_t count = 0; - while (getline(is, line, '\r')) { - if (count++ > 100) throw runtime_error("reached 100 lines, exiting"); - Tokenizer tok(line); - Tokenizer::iterator it = tok.begin(); - addSlot(*it++); // add slot - // add availability - for (; it != tok.end(); ++it) - available_ += (it->empty()) ? "0 " : "1 "; - available_ += '\n'; - } - } // constructor - - /** addStudent has to be called after adding slots and faculty */ - void Scheduler::addStudent(const string& studentName, - const string& area1, const string& area2, - const string& area3, const string& advisor) { - assert(nrStudents() area) const { - return area ? students_[s].keys_[*area] : students_[s].key_; + // open file + ifstream is(filename.c_str()); + if (!is) { + cerr << "Scheduler: could not open file " << filename << endl; + throw runtime_error("Scheduler: could not open file " + filename); } - const string& Scheduler::studentName(size_t i) const { - assert(i 100) throw runtime_error("reached 100 lines, exiting"); + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + addSlot(*it++); // add slot + // add availability + for (; it != tok.end(); ++it) available_ += (it->empty()) ? "0 " : "1 "; + available_ += '\n'; + } +} // constructor + +/** addStudent has to be called after adding slots and faculty */ +void Scheduler::addStudent(const string& studentName, const string& area1, + const string& area2, const string& area3, + const string& advisor) { + assert(nrStudents() < maxNrStudents_); + assert(facultyInArea_.count(area1)); + assert(facultyInArea_.count(area2)); + assert(facultyInArea_.count(area3)); + size_t advisorIndex = facultyIndex_[advisor]; + Student student(nrFaculty(), advisorIndex); + student.name_ = studentName; + // We fix the ordering by assigning a higher index to the student + // and numbering the areas lower + Key j = 3 * maxNrStudents_ + nrStudents(); + student.key_ = DiscreteKey(j, nrTimeSlots()); + Key base = 3 * nrStudents(); + student.keys_[0] = DiscreteKey(base + 0, nrFaculty()); + student.keys_[1] = DiscreteKey(base + 1, nrFaculty()); + student.keys_[2] = DiscreteKey(base + 2, nrFaculty()); + student.areaName_[0] = area1; + student.areaName_[1] = area2; + student.areaName_[2] = area3; + students_.push_back(student); +} + +/** get key for student and area, 0 is time slot itself */ +const DiscreteKey& Scheduler::key(size_t s, + boost::optional area) const { + return area ? students_[s].keys_[*area] : students_[s].key_; +} + +const string& Scheduler::studentName(size_t i) const { + assert(i < nrStudents()); + return students_[i].name_; +} + +const DiscreteKey& Scheduler::studentKey(size_t i) const { + assert(i < nrStudents()); + return students_[i].key_; +} + +const string& Scheduler::studentArea(size_t i, size_t area) const { + assert(i < nrStudents()); + return students_[i].areaName_[area]; +} + +/** Add student-specific constraints to the graph */ +void Scheduler::addStudentSpecificConstraints(size_t i, + boost::optional slot) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + assert(i < nrStudents()); + const Student& s = students_[i]; + + if (!slot && !slotsAvailable_.empty()) { + if (debug) cout << "Adding availability of slots" << endl; + assert(slotsAvailable_.size() == s.key_.second); + CSP::add(s.key_, slotsAvailable_); } - const string& Scheduler::studentArea(size_t i, size_t area) const { - assert(i slot) { - bool debug = ISDEBUG("Scheduler::buildGraph"); + if (debug) cout << "Area constraints " << areaName << endl; + assert(facultyInArea_[areaName].size() == areaKey.second); + CSP::add(areaKey, facultyInArea_[areaName]); - assert(i p(dummy & areaKey, + available_); // available_ is Doodle string + auto q = p.choose(dummyIndex, *slot); + CSP::add(areaKey, q); } else { - if (debug) cout << "Mutex for Students" << endl; - for (size_t i1 = 0; i1 < nrStudents(); i1++) { - // if mutexBound=1, we only mutex with next student - size_t bound = min((i1 + 1 + mutexBound), nrStudents()); - for (size_t i2 = i1 + 1; i2 < bound; i2++) { - addAllDiff(studentKey(i1), studentKey(i2)); - } + DiscreteKeys keys {s.key_, areaKey}; + CSP::add(keys, available_); // available_ is Doodle string + } + } + + // add mutex + if (debug) cout << "Mutex for faculty" << endl; + addAllDiff(s.keys_[0] & s.keys_[1] & s.keys_[2]); +} + +/** Main routine that builds factor graph */ +void Scheduler::buildGraph(size_t mutexBound) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + if (debug) cout << "Adding student-specific constraints" << endl; + for (size_t i = 0; i < nrStudents(); i++) addStudentSpecificConstraints(i); + + // special constraint for MN + if (studentName(0) == "Michael N") + CSP::add(studentKey(0), "0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"); + + if (!mutexBound) { + DiscreteKeys dkeys; + for (const Student& s : students_) dkeys.push_back(s.key_); + addAllDiff(dkeys); + } else { + if (debug) cout << "Mutex for Students" << endl; + for (size_t i1 = 0; i1 < nrStudents(); i1++) { + // if mutexBound=1, we only mutex with next student + size_t bound = min((i1 + 1 + mutexBound), nrStudents()); + for (size_t i2 = i1 + 1; i2 < bound; i2++) { + addAllDiff(studentKey(i1), studentKey(i2)); } } - } // buildGraph - - /** print */ - void Scheduler::print(const string& s, const KeyFormatter& formatter) const { - cout << s << " Faculty:" << endl; - for(const string& name: facultyName_) - cout << name << '\n'; - cout << endl; - - cout << s << " Slots:\n"; - size_t i = 0; - for(const string& name: slotName_) - cout << i++ << " " << name << endl; - cout << endl; - - cout << "Availability:\n" << available_ << '\n'; - - cout << s << " Area constraints:\n"; - for(const FacultyInArea::value_type& it: facultyInArea_) - { - cout << setw(12) << it.first << ": "; - for(double v: it.second) - cout << v << " "; - cout << '\n'; - } - cout << endl; - - cout << s << " Students:\n"; - for (const Student& student: students_) - student.print(); - cout << endl; - - CSP::print(s + " Factor graph"); - cout << endl; - } // print - - /** Print readable form of assignment */ - void Scheduler::printAssignment(sharedValues assignment) const { - // Not intended to be general! Assumes very particular ordering ! - cout << endl; - for (size_t s = 0; s < nrStudents(); s++) { - Key j = 3*maxNrStudents_ + s; - size_t slot = assignment->at(j); - cout << studentName(s) << " slot: " << slotName_[slot] << endl; - Key base = 3*s; - for (size_t area = 0; area < 3; area++) { - size_t faculty = assignment->at(base+area); - cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty] - << endl; - } - cout << endl; - } } +} // buildGraph - /** Special print for single-student case */ - void Scheduler::printSpecial(sharedValues assignment) const { - Values::const_iterator it = assignment->begin(); - for (size_t area = 0; area < 3; area++, it++) { - size_t f = it->second; - cout << setw(12) << studentArea(0,area) << ": " << facultyName_[f] << endl; +/** print */ +void Scheduler::print(const string& s, const KeyFormatter& formatter) const { + cout << s << " Faculty:" << endl; + for (const string& name : facultyName_) cout << name << '\n'; + cout << endl; + + cout << s << " Slots:\n"; + size_t i = 0; + for (const string& name : slotName_) cout << i++ << " " << name << endl; + cout << endl; + + cout << "Availability:\n" << available_ << '\n'; + + cout << s << " Area constraints:\n"; + for (const FacultyInArea::value_type& it : facultyInArea_) { + cout << setw(12) << it.first << ": "; + for (double v : it.second) cout << v << " "; + cout << '\n'; + } + cout << endl; + + cout << s << " Students:\n"; + for (const Student& student : students_) student.print(); + cout << endl; + + CSP::print(s + " Factor graph"); + cout << endl; +} // print + +/** Print readable form of assignment */ +void Scheduler::printAssignment(const DiscreteValues& assignment) const { + // Not intended to be general! Assumes very particular ordering ! + cout << endl; + for (size_t s = 0; s < nrStudents(); s++) { + Key j = 3 * maxNrStudents_ + s; + size_t slot = assignment.at(j); + cout << studentName(s) << " slot: " << slotName_[slot] << endl; + Key base = 3 * s; + for (size_t area = 0; area < 3; area++) { + size_t faculty = assignment.at(base + area); + cout << setw(12) << studentArea(s, area) << ": " << facultyName_[faculty] + << endl; } cout << endl; } +} - /** Accumulate faculty stats */ - void Scheduler::accumulateStats(sharedValues assignment, vector< - size_t>& stats) const { - for (size_t s = 0; s < nrStudents(); s++) { - Key base = 3*s; - for (size_t area = 0; area < 3; area++) { - size_t f = assignment->at(base+area); - assert(fsecond; + cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl; } + cout << endl; +} - /** Eliminate, return a Bayes net */ - DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { - gttic(my_eliminate); - // TODO: fix this!! - size_t maxKey = keys().size(); - Ordering defaultKeyOrdering; - for (size_t i = 0; ieliminateSequential(defaultKeyOrdering); - gttoc(my_eliminate); - return chordal; - } +/** Accumulate faculty stats */ +void Scheduler::accumulateStats(const DiscreteValues& assignment, + vector& stats) const { + for (size_t s = 0; s < nrStudents(); s++) { + Key base = 3 * s; + for (size_t area = 0; area < 3; area++) { + size_t f = assignment.at(base + area); + assert(f < stats.size()); + stats[f]++; + } // area + } // s +} - /** Find the best total assignment - can be expensive */ - Scheduler::sharedValues Scheduler::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = eliminate(); +/** Eliminate, return a Bayes net */ +DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { + gttic(my_eliminate); + // TODO: fix this!! + size_t maxKey = keys().size(); + Ordering defaultKeyOrdering; + for (size_t i = 0; i < maxKey; ++i) defaultKeyOrdering += Key(i); + DiscreteBayesNet::shared_ptr chordal = + this->eliminateSequential(defaultKeyOrdering); + gttoc(my_eliminate); + return chordal; +} - if (ISDEBUG("Scheduler::optimalAssignment")) { - DiscreteBayesNet::const_iterator it = chordal->end()-1; - const Student & student = students_.front(); - cout << endl; - (*it)->print(student.name_); - } - - gttic(my_optimize); - sharedValues mpe = chordal->optimize(); - gttoc(my_optimize); - return mpe; - } - - /** find the assignment of students to slots with most possible committees */ - Scheduler::sharedValues Scheduler::bestSchedule() const { - sharedValues best; - throw runtime_error("bestSchedule not implemented"); - return best; - } - - /** find the corresponding most desirable committee assignment */ - Scheduler::sharedValues Scheduler::bestAssignment( - sharedValues bestSchedule) const { - sharedValues best; - throw runtime_error("bestAssignment not implemented"); - return best; - } - -} // gtsam +/** find the assignment of students to slots with most possible committees */ +DiscreteValues Scheduler::bestSchedule() const { + DiscreteValues best; + throw runtime_error("bestSchedule not implemented"); + return best; +} +/** find the corresponding most desirable committee assignment */ +DiscreteValues Scheduler::bestAssignment(const DiscreteValues& bestSchedule) const { + DiscreteValues best; + throw runtime_error("bestAssignment not implemented"); + return best; +} +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Scheduler.h b/gtsam_unstable/discrete/Scheduler.h index 6faf9956f..8d269e81a 100644 --- a/gtsam_unstable/discrete/Scheduler.h +++ b/gtsam_unstable/discrete/Scheduler.h @@ -8,168 +8,151 @@ #pragma once #include +#include namespace gtsam { +/** + * Scheduler class + * Creates one variable for each student, and three variables for each + * of the student's areas, for a total of 4*nrStudents variables. + * The "student" variable will determine when the student takes the qual. + * The "area" variables determine which faculty are on his/her committee. + */ +class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { + private: + /** Internal data structure for students */ + struct Student { + std::string name_; + DiscreteKey key_; // key for student + std::vector keys_; // key for areas + std::vector areaName_; + std::vector advisor_; + Student(size_t nrFaculty, size_t advisorIndex) + : keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { + advisor_[advisorIndex] = 0.0; + } + void print() const { + using std::cout; + cout << name_ << ": "; + for (size_t area = 0; area < 3; area++) cout << areaName_[area] << " "; + cout << std::endl; + } + }; + + /** Maximum number of students */ + size_t maxNrStudents_; + + /** discrete keys, indexed by student and area index */ + std::vector students_; + + /** faculty identifiers */ + std::map facultyIndex_; + std::vector facultyName_, slotName_, areaName_; + + /** area constraints */ + typedef std::map > FacultyInArea; + FacultyInArea facultyInArea_; + + /** nrTimeSlots * nrFaculty availability constraints */ + std::string available_; + + /** which slots are good */ + std::vector slotsAvailable_; + + public: /** - * Scheduler class - * Creates one variable for each student, and three variables for each - * of the student's areas, for a total of 4*nrStudents variables. - * The "student" variable will determine when the student takes the qual. - * The "area" variables determine which faculty are on his/her committee. + * Constructor + * We need to know the number of students in advance for ordering keys. + * then add faculty, slots, areas, availability, students, in that order */ - class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { + Scheduler(size_t maxNrStudents) : maxNrStudents_(maxNrStudents) {} - private: + /// Destructor + virtual ~Scheduler() {} - /** Internal data structure for students */ - struct Student { - std::string name_; - DiscreteKey key_; // key for student - std::vector keys_; // key for areas - std::vector areaName_; - std::vector advisor_; - Student(size_t nrFaculty, size_t advisorIndex) : - keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { - advisor_[advisorIndex] = 0.0; - } - void print() const { - using std::cout; - cout << name_ << ": "; - for (size_t area = 0; area < 3; area++) - cout << areaName_[area] << " "; - cout << std::endl; - } - }; + void addFaculty(const std::string& facultyName) { + facultyIndex_[facultyName] = nrFaculty(); + facultyName_.push_back(facultyName); + } - /** Maximum number of students */ - size_t maxNrStudents_; + size_t nrFaculty() const { return facultyName_.size(); } - /** discrete keys, indexed by student and area index */ - std::vector students_; + /** boolean std::string of nrTimeSlots * nrFaculty */ + void setAvailability(const std::string& available) { available_ = available; } - /** faculty identifiers */ - std::map facultyIndex_; - std::vector facultyName_, slotName_, areaName_; + void addSlot(const std::string& slotName) { slotName_.push_back(slotName); } - /** area constraints */ - typedef std::map > FacultyInArea; - FacultyInArea facultyInArea_; + size_t nrTimeSlots() const { return slotName_.size(); } - /** nrTimeSlots * nrFaculty availability constraints */ - std::string available_; + const std::string& slotName(size_t s) const { return slotName_[s]; } - /** which slots are good */ - std::vector slotsAvailable_; + /** slots available, boolean */ + void setSlotsAvailable(const std::vector& slotsAvailable) { + slotsAvailable_ = slotsAvailable; + } - public: + void addArea(const std::string& facultyName, const std::string& areaName) { + areaName_.push_back(areaName); + std::vector& table = + facultyInArea_[areaName]; // will create if needed + if (table.empty()) table.resize(nrFaculty(), 0); + table[facultyIndex_[facultyName]] = 1; + } - /** - * Constructor - * We need to know the number of students in advance for ordering keys. - * then add faculty, slots, areas, availability, students, in that order - */ - Scheduler(size_t maxNrStudents) : maxNrStudents_(maxNrStudents) {} + /** + * Constructor that reads in faculty, slots, availibility. + * Still need to add areas and students after this + */ + Scheduler(size_t maxNrStudents, const std::string& filename); - /// Destructor - virtual ~Scheduler() {} + /** get key for student and area, 0 is time slot itself */ + const DiscreteKey& key(size_t s, + boost::optional area = boost::none) const; - void addFaculty(const std::string& facultyName) { - facultyIndex_[facultyName] = nrFaculty(); - facultyName_.push_back(facultyName); - } + /** addStudent has to be called after adding slots and faculty */ + void addStudent(const std::string& studentName, const std::string& area1, + const std::string& area2, const std::string& area3, + const std::string& advisor); - size_t nrFaculty() const { - return facultyName_.size(); - } + /// current number of students + size_t nrStudents() const { return students_.size(); } - /** boolean std::string of nrTimeSlots * nrFaculty */ - void setAvailability(const std::string& available) { - available_ = available; - } + const std::string& studentName(size_t i) const; + const DiscreteKey& studentKey(size_t i) const; + const std::string& studentArea(size_t i, size_t area) const; - void addSlot(const std::string& slotName) { - slotName_.push_back(slotName); - } + /** Add student-specific constraints to the graph */ + void addStudentSpecificConstraints( + size_t i, boost::optional slot = boost::none); - size_t nrTimeSlots() const { - return slotName_.size(); - } + /** Main routine that builds factor graph */ + void buildGraph(size_t mutexBound = 7); - const std::string& slotName(size_t s) const { - return slotName_[s]; - } + /** print */ + void print( + const std::string& s = "Scheduler", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** slots available, boolean */ - void setSlotsAvailable(const std::vector& slotsAvailable) { - slotsAvailable_ = slotsAvailable; - } + /** Print readable form of assignment */ + void printAssignment(const DiscreteValues& assignment) const; - void addArea(const std::string& facultyName, const std::string& areaName) { - areaName_.push_back(areaName); - std::vector& table = facultyInArea_[areaName]; // will create if needed - if (table.empty()) table.resize(nrFaculty(), 0); - table[facultyIndex_[facultyName]] = 1; - } + /** Special print for single-student case */ + void printSpecial(const DiscreteValues& assignment) const; - /** - * Constructor that reads in faculty, slots, availibility. - * Still need to add areas and students after this - */ - Scheduler(size_t maxNrStudents, const std::string& filename); + /** Accumulate faculty stats */ + void accumulateStats(const DiscreteValues& assignment, + std::vector& stats) const; - /** get key for student and area, 0 is time slot itself */ - const DiscreteKey& key(size_t s, boost::optional area = boost::none) const; + /** Eliminate, return a Bayes net */ + DiscreteBayesNet::shared_ptr eliminate() const; - /** addStudent has to be called after adding slots and faculty */ - void addStudent(const std::string& studentName, const std::string& area1, - const std::string& area2, const std::string& area3, - const std::string& advisor); + /** find the assignment of students to slots with most possible committees */ + DiscreteValues bestSchedule() const; - /// current number of students - size_t nrStudents() const { - return students_.size(); - } - - const std::string& studentName(size_t i) const; - const DiscreteKey& studentKey(size_t i) const; - const std::string& studentArea(size_t i, size_t area) const; - - /** Add student-specific constraints to the graph */ - void addStudentSpecificConstraints(size_t i, boost::optional slot = boost::none); - - /** Main routine that builds factor graph */ - void buildGraph(size_t mutexBound = 7); - - /** print */ - void print( - const std::string& s = "Scheduler", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /** Print readable form of assignment */ - void printAssignment(sharedValues assignment) const; - - /** Special print for single-student case */ - void printSpecial(sharedValues assignment) const; - - /** Accumulate faculty stats */ - void accumulateStats(sharedValues assignment, - std::vector& stats) const; - - /** Eliminate, return a Bayes net */ - DiscreteBayesNet::shared_ptr eliminate() const; - - /** Find the best total assignment - can be expensive */ - sharedValues optimalAssignment() const; - - /** find the assignment of students to slots with most possible committees */ - sharedValues bestSchedule() const; - - /** find the corresponding most desirable committee assignment */ - sharedValues bestAssignment(sharedValues bestSchedule) const; - - }; // Scheduler - -} // gtsam + /** find the corresponding most desirable committee assignment */ + DiscreteValues bestAssignment(const DiscreteValues& bestSchedule) const; +}; // Scheduler +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6324f14cd..6dd81a7dc 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -5,75 +5,73 @@ * @author Frank Dellaert */ -#include -#include -#include #include +#include +#include +#include + #include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void SingleValue::print(const string& s, - const KeyFormatter& formatter) const { - cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) - << " with value " << value_ << endl; - } - - /* ************************************************************************* */ - double SingleValue::operator()(const Values& values) const { - return (double) (values.at(keys_[0]) == value_); - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; i1++) - table.push_back(i1 == value_); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { - if (j != keys_[0]) throw invalid_argument( - "SingleValue check on wrong domain"); - Domain& D = domains[j]; - if (D.isSingleton()) { - if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); - return false; - } - D = Domain(discreteKey(),value_); - return true; - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && it->second != value_) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_); - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared < SingleValue > (discreteKey(), value_); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void SingleValue::print(const string& s, const KeyFormatter& formatter) const { + cout << s << "SingleValue on " + << "j=" << formatter(keys_[0]) << " with value " << value_ << endl; +} + +/* ************************************************************************* */ +double SingleValue::operator()(const DiscreteValues& values) const { + return (double)(values.at(keys_[0]) == value_); +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0], cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { + if (j != keys_[0]) + throw invalid_argument("SingleValue check on wrong domain"); + Domain& D = domains->at(j); + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(), value_); + return true; +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(keys_[0], cardinality_, value_); +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply( + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); + if (Dk.isSingleton() && !Dk.contains(value_)) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(discreteKey(), value_); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index c4d2addec..3b2d6e80b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -7,76 +7,71 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * SingleValue constraint +/** + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. + */ +class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value + + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0], cardinality_); + } + + public: + typedef boost::shared_ptr shared_ptr; + + /// Construct from key, cardinality, and given value. + SingleValue(Key key, size_t n, size_t value) + : Constraint(key), cardinality_(n), value_(value) {} + + /// Construct from DiscreteKey and given value. + SingleValue(const DiscreteKey& dkey, size_t value) + : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const SingleValue& f(static_cast(other)); + return (cardinality_ == f.cardinality_) && (value_ == f.value_); + } + } + + /// Calculate value + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency: just sets domain[j] to {value_}. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - /// Number of values - size_t cardinality_; + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; - /// allowed value - size_t value_; + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains& domains) const override; +}; - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0],cardinality_); - } - - public: - - typedef boost::shared_ptr shared_ptr; - - /// Constructor - SingleValue(Key key, size_t n, size_t value) : - Constraint(key), cardinality_(n), value_(value) { - } - - /// Constructor - SingleValue(const DiscreteKey& dkey, size_t value) : - Constraint(dkey.first), cardinality_(dkey.second), value_(value) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const SingleValue& f(static_cast(other)); - return (cardinality_==f.cardinality_) && (value_==f.value_); - } - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index e9f63b2d8..487edc97a 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -115,14 +115,14 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference // SETDEBUG("timing-verbose", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true); gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(6 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -225,7 +225,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < 7; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); @@ -319,11 +319,11 @@ void accomodateStudent() { // GTSAM_PRINT(*chordal); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(0); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) @@ -331,7 +331,7 @@ void accomodateStudent() { // sample schedules for (size_t n = 0; n < 10; n++) { - Scheduler::sharedValues sample0 = chordal->sample(); + auto sample0 = chordal->sample(); scheduler.printAssignment(sample0); } } diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 1fc4a1459..830d59ba7 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference @@ -129,7 +129,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = chordal->sample(); + auto assignment = chordal->sample(); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -143,7 +143,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -234,7 +234,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 95b64f289..b24f9bf0a 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -139,7 +139,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference @@ -153,7 +153,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = sample(*chordal); + auto assignment = sample(*chordal); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -167,7 +167,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; double count = (*root)(values); @@ -259,7 +259,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 10000; n++) { vector stats(nrFaculty, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 3dd493b1b..fb386b255 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -7,59 +7,119 @@ #include #include + #include using boost::assign::insert; #include -#include + #include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( BinaryAllDif, allInOne) -{ - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; -// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); DecisionTreeFactor f1(ID & UT, "0 1 1 0"); - EXPECT(assert_equal(f1,c1.toDecisionTreeFactor())); + EXPECT(assert_equal(f1, c1.toDecisionTreeFactor())); // Check construction and conversion BinaryAllDiff c2(UT, AZ); DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); - EXPECT(assert_equal(f2,c2.toDecisionTreeFactor())); + EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); - DecisionTreeFactor f3 = f1*f2; - EXPECT(assert_equal(f3,c1*f2)); - EXPECT(assert_equal(f3,c2*f1)); + // Check multiplication of factors with constraint: + DecisionTreeFactor f3 = f1 * f2; + EXPECT(assert_equal(f3, c1 * f2)); + EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE( CSP, allInOne) -{ - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; - csp.addAllDiff(ID,UT); - csp.addAllDiff(UT,AZ); + csp.addAllDiff(ID, UT); + csp.addAllDiff(UT, AZ); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 0; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 0; @@ -69,68 +129,62 @@ TEST_UNSAFE( CSP, allInOne) DecisionTreeFactor product = csp.product(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); - EXPECT(assert_equal(expectedProduct,product)); + EXPECT(assert_equal(expectedProduct, product)); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimize(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); } /* ************************************************************************* */ -TEST_UNSAFE( CSP, WesternUS) -{ - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), - ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), - MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; - csp.addAllDiff(WA,ID); - csp.addAllDiff(WA,OR); - csp.addAllDiff(OR,ID); - csp.addAllDiff(OR,CA); - csp.addAllDiff(OR,NV); - csp.addAllDiff(CA,NV); - csp.addAllDiff(CA,AZ); - csp.addAllDiff(ID,MT); - csp.addAllDiff(ID,WY); - csp.addAllDiff(ID,UT); - csp.addAllDiff(ID,NV); - csp.addAllDiff(NV,UT); - csp.addAllDiff(NV,AZ); - csp.addAllDiff(UT,WY); - csp.addAllDiff(UT,CO); - csp.addAllDiff(UT,NM); - csp.addAllDiff(UT,AZ); - csp.addAllDiff(AZ,CO); - csp.addAllDiff(AZ,NM); - csp.addAllDiff(MT,WY); - csp.addAllDiff(WY,CO); - csp.addAllDiff(CO,NM); + csp.addAllDiff(WA, ID); + csp.addAllDiff(WA, OR); + csp.addAllDiff(OR, ID); + csp.addAllDiff(OR, CA); + csp.addAllDiff(OR, NV); + csp.addAllDiff(CA, NV); + csp.addAllDiff(CA, AZ); + csp.addAllDiff(ID, MT); + csp.addAllDiff(ID, WY); + csp.addAllDiff(ID, UT); + csp.addAllDiff(ID, NV); + csp.addAllDiff(NV, UT); + csp.addAllDiff(NV, AZ); + csp.addAllDiff(UT, WY); + csp.addAllDiff(UT, CO); + csp.addAllDiff(UT, NM); + csp.addAllDiff(UT, AZ); + csp.addAllDiff(AZ, CO); + csp.addAllDiff(AZ, NM); + csp.addAllDiff(MT, WY); + csp.addAllDiff(WY, CO); + csp.addAllDiff(CO, NM); - // Solve + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0); + + // Create ordering according to example in ND-CSP.lyx Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7),Key(8),Key(9),Key(10); - CSP::sharedValues mpe = csp.optimalAssignment(ordering); - // GTSAM_PRINT(*mpe); - CSP::Values expected; - insert(expected) - (WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) - (MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) - (ID.first,2)(UT.first,1)(AZ.first,0); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), + Key(8), Key(9), Key(10); - // TODO: Fix me! mpe result seems to be right. (See the printing) - // It has the same prob as the expected solution. - // Is mpe another solution, or the expected solution is unique??? - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + // Solve using that ordering: + auto actualMPE = csp.optimize(ordering); + + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); // Write out the dual graph for hmetis #ifdef DUAL @@ -142,85 +196,74 @@ TEST_UNSAFE( CSP, WesternUS) } /* ************************************************************************* */ -TEST_UNSAFE( CSP, AllDiff) -{ - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID,UT,AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); - csp.addSingleValue(AZ,2); -// GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ,2); - DecisionTreeFactor f1(AZ,"0 0 1"); - EXPECT(assert_equal(f1,s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); -// GTSAM_PRINT(actual); -// actual.dot("actual"); - DecisionTreeFactor f2(ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2,actual)); + csp.addSingleValue(AZ, 2); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 1; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 2; EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimize(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); - // Arc-consistency - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); - SingleValue singleValue(AZ,2); - EXPECT(singleValue.ensureArcConsistency(1,domains)); - EXPECT(alldiff.ensureArcConsistency(0,domains)); - EXPECT(!alldiff.ensureArcConsistency(1,domains)); - EXPECT(alldiff.ensureArcConsistency(2,domains)); - LONGS_EQUAL(2,domains[0].nrValues()); - LONGS_EQUAL(1,domains[1].nrValues()); - LONGS_EQUAL(2,domains[2].nrValues()); + // ensure arc-consistency, i.e., narrow domains... + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + SingleValue singleValue(AZ, 2); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 - DiscreteFactor::Values known; + DiscreteValues known; known[AZ.first] = 2; DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); - EXPECT(assert_equal(f3,reduced1->toDecisionTreeFactor())); + EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor())); DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); DecisionTreeFactor f4(AZ, "0 0 1"); - EXPECT(assert_equal(f4,reduced2->toDecisionTreeFactor())); + EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor())); // Parial application, version 2 DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); - EXPECT(assert_equal(f3,reduced3->toDecisionTreeFactor())); + EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor())); DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); - EXPECT(assert_equal(f4,reduced4->toDecisionTreeFactor())); + EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor())); // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */ @@ -229,4 +272,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index 9929938d5..eac0d834e 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -5,14 +5,16 @@ * @date Oct 11, 2013 */ -#include +#include #include #include -#include -#include +#include +#include + #include -#include +#include #include +#include using namespace std; using namespace boost; @@ -23,11 +25,12 @@ using namespace gtsam; * Loopy belief solver for graphs with only binary and unary factors */ class LoopyBelief { - /** Star graph struct for each node, containing * - the star graph itself - * - the product of original unary factors so we don't have to recompute it later, and - * - the factor indices of the corrected belief factors of the neighboring nodes + * - the product of original unary factors so we don't have to recompute it + * later, and + * - the factor indices of the corrected belief factors of the neighboring + * nodes */ typedef std::map CorrectedBeliefIndices; struct StarGraph { @@ -36,41 +39,41 @@ class LoopyBelief { DecisionTreeFactor::shared_ptr unary; VariableIndex varIndex_; StarGraph(const DiscreteFactorGraph::shared_ptr& _star, - const CorrectedBeliefIndices& _beliefIndices, - const DecisionTreeFactor::shared_ptr& _unary) : - star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( - *_star) { - } + const CorrectedBeliefIndices& _beliefIndices, + const DecisionTreeFactor::shared_ptr& _unary) + : star(_star), + correctedBeliefIndices(_beliefIndices), + unary(_unary), + varIndex_(*_star) {} void print(const std::string& s = "") const { cout << s << ":" << endl; star->print("Star graph: "); - for(Key key: correctedBeliefIndices | boost::adaptors::map_keys) { + for (Key key : correctedBeliefIndices | boost::adaptors::map_keys) { cout << "Belief factor index for " << key << ": " - << correctedBeliefIndices.at(key) << endl; + << correctedBeliefIndices.at(key) << endl; } - if (unary) - unary->print("Unary: "); + if (unary) unary->print("Unary: "); } }; typedef std::map StarGraphs; - StarGraphs starGraphs_; ///< star graph at each variable + StarGraphs starGraphs_; ///< star graph at each variable -public: + public: /** Constructor - * Need all discrete keys to access node's cardinality for creating belief factors + * Need all discrete keys to access node's cardinality for creating belief + * factors * TODO: so troublesome!! */ LoopyBelief(const DiscreteFactorGraph& graph, - const std::map& allDiscreteKeys) : - starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { - } + const std::map& allDiscreteKeys) + : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {} /// print void print(const std::string& s = "") const { cout << s << ":" << endl; - for(Key key: starGraphs_ | boost::adaptors::map_keys) { + for (Key key : starGraphs_ | boost::adaptors::map_keys) { starGraphs_.at(key).print((boost::format("Node %d:") % key).str()); } } @@ -79,12 +82,13 @@ public: DiscreteFactorGraph::shared_ptr iterate( const std::map& allDiscreteKeys) { static const bool debug = false; - static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination + static DiscreteConditional::shared_ptr + dummyCond; // unused by-product of elimination DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); std::map > allMessages; // Eliminate each star graph - for(Key key: starGraphs_ | boost::adaptors::map_keys) { -// cout << "***** Node " << key << "*****" << endl; + for (Key key : starGraphs_ | boost::adaptors::map_keys) { + // cout << "***** Node " << key << "*****" << endl; // initialize belief to the unary factor from the original graph DecisionTreeFactor::shared_ptr beliefAtKey; @@ -92,15 +96,16 @@ public: std::map messages; // eliminate each neighbor in this star graph one by one - for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { + for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices | + boost::adaptors::map_keys) { DiscreteFactorGraph subGraph; - for(size_t factor: starGraphs_.at(key).varIndex_[neighbor]) { + for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) { subGraph.push_back(starGraphs_.at(key).star->at(factor)); } if (debug) subGraph.print("------- Subgraph:"); DiscreteFactor::shared_ptr message; - boost::tie(dummyCond, message) = EliminateDiscrete(subGraph, - Ordering(list_of(neighbor))); + boost::tie(dummyCond, message) = + EliminateDiscrete(subGraph, Ordering(list_of(neighbor))); // store the new factor into messages messages.insert(make_pair(neighbor, message)); if (debug) message->print("------- Message: "); @@ -108,14 +113,12 @@ public: // Belief is the product of all messages and the unary factor // Incorporate new the factor to belief if (!beliefAtKey) - beliefAtKey = boost::dynamic_pointer_cast( - message); - else beliefAtKey = - boost::make_shared( - (*beliefAtKey) - * (*boost::dynamic_pointer_cast( - message))); + boost::dynamic_pointer_cast(message); + else + beliefAtKey = boost::make_shared( + (*beliefAtKey) * + (*boost::dynamic_pointer_cast(message))); } if (starGraphs_.at(key).unary) beliefAtKey = boost::make_shared( @@ -124,7 +127,7 @@ public: // normalize belief double sum = 0.0; for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) { - DiscreteFactor::Values val; + DiscreteValues val; val[key] = v; sum += (*beliefAtKey)(val); } @@ -133,7 +136,8 @@ public: sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str(); DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable); if (debug) sumFactor.print("denomFactor: "); - beliefAtKey = boost::make_shared((*beliefAtKey) / sumFactor); + beliefAtKey = + boost::make_shared((*beliefAtKey) / sumFactor); if (debug) beliefAtKey->print("New belief at key normalized: "); beliefs->push_back(beliefAtKey); allMessages[key] = messages; @@ -141,17 +145,20 @@ public: // Update corrected beliefs VariableIndex beliefFactors(*beliefs); - for(Key key: starGraphs_ | boost::adaptors::map_keys) { + for (Key key : starGraphs_ | boost::adaptors::map_keys) { std::map messages = allMessages[key]; - for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { - DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast< - DecisionTreeFactor>(beliefs->at(beliefFactors[key].front()))) - / (*boost::dynamic_pointer_cast( + for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices | + boost::adaptors::map_keys) { + DecisionTreeFactor correctedBelief = + (*boost::dynamic_pointer_cast( + beliefs->at(beliefFactors[key].front()))) / + (*boost::dynamic_pointer_cast( messages.at(neighbor))); if (debug) correctedBelief.print("correctedBelief: "); - size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( - key); - starGraphs_.at(neighbor).star->replace(beliefIndex, + size_t beliefIndex = + starGraphs_.at(neighbor).correctedBeliefIndices.at(key); + starGraphs_.at(neighbor).star->replace( + beliefIndex, boost::make_shared(correctedBelief)); } } @@ -161,21 +168,22 @@ public: return beliefs; } -private: + private: /** * Build star graphs for each node. */ - StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph, + StarGraphs buildStarGraphs( + const DiscreteFactorGraph& graph, const std::map& allDiscreteKeys) const { StarGraphs starGraphs; - VariableIndex varIndex(graph); ///< access to all factors of each node - for(Key key: varIndex | boost::adaptors::map_keys) { + VariableIndex varIndex(graph); ///< access to all factors of each node + for (Key key : varIndex | boost::adaptors::map_keys) { // initialize to multiply with other unary factors later DecisionTreeFactor::shared_ptr prodOfUnaries; // collect all factors involving this key in the original graph DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); - for(size_t factorIndex: varIndex[key]) { + for (size_t factorIndex : varIndex[key]) { star->push_back(graph.at(factorIndex)); // accumulate unary factors @@ -185,9 +193,9 @@ private: graph.at(factorIndex)); else prodOfUnaries = boost::make_shared( - *prodOfUnaries - * (*boost::dynamic_pointer_cast( - graph.at(factorIndex)))); + *prodOfUnaries * + (*boost::dynamic_pointer_cast( + graph.at(factorIndex)))); } } @@ -196,7 +204,7 @@ private: KeySet neighbors = star->keys(); neighbors.erase(key); CorrectedBeliefIndices correctedBeliefIndices; - for(Key neighbor: neighbors) { + for (Key neighbor : neighbors) { // TODO: default table for keys with more than 2 values? string initialBelief; for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) { @@ -207,9 +215,8 @@ private: DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief)); correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); } - starGraphs.insert( - make_pair(key, - StarGraph(star, correctedBeliefIndices, prodOfUnaries))); + starGraphs.insert(make_pair( + key, StarGraph(star, correctedBeliefIndices, prodOfUnaries))); } return starGraphs; } @@ -249,7 +256,6 @@ TEST_UNSAFE(LoopyBelief, construction) { DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys); beliefs->print(); } - } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index 3f6c6a1e0..086057a46 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -5,14 +5,13 @@ */ //#define ENABLE_TIMING -#include +#include #include #include +#include -#include - -#include #include +#include #include using namespace boost::assign; @@ -22,7 +21,6 @@ using namespace gtsam; /* ************************************************************************* */ // Create the expected graph of constraints DiscreteFactorGraph createExpected() { - // Start building size_t nrFaculty = 4, nrTimeSlots = 3; @@ -47,27 +45,27 @@ DiscreteFactorGraph createExpected() { string available = "1 1 1 0 1 1 1 1 0 1 1 1"; // Akansel - expected.add(A1, faculty_in_A); // Area 1 - expected.add(A1, "1 1 1 0"); // Advisor + expected.add(A1, faculty_in_A); // Area 1 + expected.add(A1, "1 1 1 0"); // Advisor expected.add(A & A1, available); - expected.add(A2, faculty_in_M); // Area 2 - expected.add(A2, "1 1 1 0"); // Advisor + expected.add(A2, faculty_in_M); // Area 2 + expected.add(A2, "1 1 1 0"); // Advisor expected.add(A & A2, available); - expected.add(A3, faculty_in_P); // Area 3 - expected.add(A3, "1 1 1 0"); // Advisor + expected.add(A3, faculty_in_P); // Area 3 + expected.add(A3, "1 1 1 0"); // Advisor expected.add(A & A3, available); // Mutual exclusion for faculty expected.addAllDiff(A1 & A2 & A3); // Jake - expected.add(J1, faculty_in_H); // Area 1 - expected.add(J1, "1 0 1 1"); // Advisor + expected.add(J1, faculty_in_H); // Area 1 + expected.add(J1, "1 0 1 1"); // Advisor expected.add(J & J1, available); - expected.add(J2, faculty_in_C); // Area 2 - expected.add(J2, "1 0 1 1"); // Advisor + expected.add(J2, faculty_in_C); // Area 2 + expected.add(J2, "1 0 1 1"); // Advisor expected.add(J & J2, available); - expected.add(J3, faculty_in_A); // Area 3 - expected.add(J3, "1 0 1 1"); // Advisor + expected.add(J3, faculty_in_A); // Area 3 + expected.add(J3, "1 0 1 1"); // Advisor expected.add(J & J3, available); // Mutual exclusion for faculty expected.addAllDiff(J1 & J2 & J3); @@ -79,8 +77,7 @@ DiscreteFactorGraph createExpected() { } /* ************************************************************************* */ -TEST( schedulingExample, test) -{ +TEST(schedulingExample, test) { Scheduler s(2); // add faculty @@ -121,33 +118,32 @@ TEST( schedulingExample, test) // Do brute force product and output that to file DecisionTreeFactor product = s.product(); - //product.dot("scheduling", false); + // product.dot("scheduling", false); // Do exact inference gttic(small); - DiscreteFactor::sharedValues MPE = s.optimalAssignment(); + auto MPE = s.optimize(); gttoc(small); // print MPE, commented out as unit tests don't print -// s.printAssignment(MPE); + // s.printAssignment(MPE); // Commented out as does not work yet // s.runArcConsistency(8,10,true); // find the assignment of students to slots with most possible committees // Commented out as not implemented yet -// sharedValues bestSchedule = s.bestSchedule(); -// GTSAM_PRINT(*bestSchedule); + // auto bestSchedule = s.bestSchedule(); + // GTSAM_PRINT(bestSchedule); // find the corresponding most desirable committee assignment // Commented out as not implemented yet -// sharedValues bestAssignment = s.bestAssignment(bestSchedule); -// GTSAM_PRINT(*bestAssignment); + // auto bestAssignment = s.bestAssignment(bestSchedule); + // GTSAM_PRINT(bestAssignment); } /* ************************************************************************* */ -TEST( schedulingExample, smallFromFile) -{ +TEST(schedulingExample, smallFromFile) { string path(TOPSRCDIR "/gtsam_unstable/discrete/examples/"); Scheduler s(2, path + "small.csv"); @@ -179,4 +175,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index e2115e8bc..8b2858169 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -5,74 +5,69 @@ * @author Frank Dellaert */ -#include #include +#include +#include + #include using boost::assign::insert; +#include + #include #include -#include using namespace std; using namespace gtsam; #define PRINT false -class Sudoku: public CSP { +/// A class that encodes Sudoku's as a CSP problem +class Sudoku : public CSP { + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// sudoku size - size_t n_; - - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; -public: - + public: /// return DiscreteKey for cell(i,j) const DiscreteKey& dkey(size_t i, size_t j) const { return dkeys_.at(IJ(i, j)); } /// return Key for cell(i,j) - Key key(size_t i, size_t j) const { - return dkey(i, j).first; - } + Key key(size_t i, size_t j) const { return dkey(i, j).first; } /// Constructor - Sudoku(size_t n, ...) : - n_(n) { + Sudoku(size_t n, ...) : n_(n) { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k=0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } - //cout << endl; + // cout << endl; } va_end(ap); // add row constraints for (size_t i = 0; i < n; i++) { DiscreteKeys dkeys; - for (size_t j = 0; j < n; j++) - dkeys += dkey(i, j); + for (size_t j = 0; j < n; j++) dkeys += dkey(i, j); addAllDiff(dkeys); } // add col constraints for (size_t j = 0; j < n; j++) { DiscreteKeys dkeys; - for (size_t i = 0; i < n; i++) - dkeys += dkey(i, j); + for (size_t i = 0; i < n; i++) dkeys += dkey(i, j); addAllDiff(dkeys); } @@ -84,8 +79,7 @@ public: // Box I,J DiscreteKeys dkeys; for (size_t i = i0; i < i0 + N; i++) - for (size_t j = j0; j < j0 + N; j++) - dkeys += dkey(i, j); + for (size_t j = j0; j < j0 + N; j++) dkeys += dkey(i, j); addAllDiff(dkeys); j0 += N; } @@ -94,120 +88,171 @@ public: } /// Print readable form of assignment - void printAssignment(DiscreteFactor::sharedValues assignment) const { + void printAssignment(const DiscreteValues& assignment) const { for (size_t i = 0; i < n_; i++) { for (size_t j = 0; j < n_; j++) { Key k = key(i, j); - cout << 1 + assignment->at(k) << " "; + cout << 1 + assignment.at(k) << " "; } cout << endl; } } /// solve and print solution - void printSolution() { - DiscreteFactor::sharedValues MPE = optimalAssignment(); + void printSolution() const { + auto MPE = optimize(); printAssignment(MPE); } + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, small) -{ - Sudoku csp(4, - 1,0, 0,4, - 0,0, 0,0, - - 4,0, 2,0, - 0,1, 0,0); - - // Do BP - csp.runArcConsistency(4,10,PRINT); +TEST(Sudoku, small) { + Sudoku csp(4, // + 1, 0, 0, 4, // + 0, 0, 0, 0, // + 4, 0, 2, 0, // + 0, 1, 0, 0); // optimize and check - CSP::sharedValues solution = csp.optimalAssignment(); - CSP::Values expected; - insert(expected) - (csp.key(0,0), 0)(csp.key(0,1), 1)(csp.key(0,2), 2)(csp.key(0,3), 3) - (csp.key(1,0), 2)(csp.key(1,1), 3)(csp.key(1,2), 0)(csp.key(1,3), 1) - (csp.key(2,0), 3)(csp.key(2,1), 2)(csp.key(2,2), 1)(csp.key(2,3), 0) - (csp.key(3,0), 1)(csp.key(3,1), 0)(csp.key(3,2), 3)(csp.key(3,3), 2); - EXPECT(assert_equal(expected,*solution)); - //csp.printAssignment(solution); + auto solution = csp.optimize(); + DiscreteValues expected; + insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( + csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( + csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)( + csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)( + csp.key(3, 3), 2); + EXPECT(assert_equal(expected, solution)); + // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + auto new_solution = new_csp.optimize(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, new_solution)); } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, easy) -{ - Sudoku sudoku(9, - 0,0,5, 0,9,0, 0,0,1, - 0,0,0, 0,0,2, 0,7,3, - 7,6,0, 0,0,8, 2,0,0, +TEST(Sudoku, easy) { + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // - 0,1,2, 0,0,9, 0,0,4, - 0,0,0, 2,0,3, 0,0,0, - 3,0,0, 1,0,0, 9,6,0, + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // - 0,0,1, 9,0,0, 0,5,8, - 9,7,0, 5,0,0, 0,0,0, - 5,0,0, 0,3,0, 7,0,0); + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); - // Do BP - sudoku.runArcConsistency(4,10,PRINT); + // csp.printSolution(); // don't do it - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, extreme) -{ - Sudoku sudoku(9, - 0,0,9, 7,4,8, 0,0,0, - 7,0,0, 0,0,0, 0,0,0, - 0,2,0, 1,0,9, 0,0,0, - - 0,0,7, 0,0,0, 2,4,0, - 0,6,4, 0,1,0, 5,9,0, - 0,9,8, 0,0,0, 3,0,0, - - 0,0,0, 8,0,3, 0,2,0, - 0,0,0, 0,0,0, 0,0,6, - 0,0,0, 2,7,5, 9,0,0); +TEST(Sudoku, extreme) { + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9,10,PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - //sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012) -{ - Sudoku sudoku(9, - 9,5,0, 0,0,6, 0,0,0, - 0,8,4, 0,7,0, 0,0,0, - 6,2,0, 5,0,0, 4,0,0, +TEST(Sudoku, AJC_3star_Feb8_2012) { + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // - 0,0,0, 2,9,0, 6,0,0, - 0,9,0, 0,0,0, 0,2,0, - 0,0,2, 0,6,3, 0,0,0, + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // - 0,0,9, 0,0,7, 0,6,8, - 0,0,0, 0,3,0, 2,9,0, - 0,0,0, 1,0,0, 0,3,7); + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); - // Do BP - sudoku.runArcConsistency(9,10,PRINT); + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); - //sudoku.printSolution(); // don't do it + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + auto solution = new_csp.optimize(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution.at(key99)); } /* ************************************************************************* */ @@ -216,4 +261,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/gtsam_unstable.i b/gtsam_unstable/gtsam_unstable.i index 8c9345147..08cd45e18 100644 --- a/gtsam_unstable/gtsam_unstable.i +++ b/gtsam_unstable/gtsam_unstable.i @@ -566,7 +566,13 @@ virtual class FixedLagSmoother { gtsam::FixedLagSmootherKeyTimestampMap timestamps() const; double smootherLag() const; - gtsam::FixedLagSmootherResult update(const gtsam::NonlinearFactorGraph& newFactors, const gtsam::Values& newTheta, const gtsam::FixedLagSmootherKeyTimestampMap& timestamps); + gtsam::FixedLagSmootherResult update(const gtsam::NonlinearFactorGraph &newFactors, + const gtsam::Values &newTheta, + const gtsam::FixedLagSmootherKeyTimestampMap ×tamps); + gtsam::FixedLagSmootherResult update(const gtsam::NonlinearFactorGraph &newFactors, + const gtsam::Values &newTheta, + const gtsam::FixedLagSmootherKeyTimestampMap ×tamps, + const gtsam::FactorIndices &factorsToRemove); gtsam::Values calculateEstimate() const; }; @@ -576,6 +582,8 @@ virtual class BatchFixedLagSmoother : gtsam::FixedLagSmoother { BatchFixedLagSmoother(double smootherLag); BatchFixedLagSmoother(double smootherLag, const gtsam::LevenbergMarquardtParams& params); + void print(string s = "BatchFixedLagSmoother:\n") const; + gtsam::LevenbergMarquardtParams params() const; template @@ -784,4 +797,30 @@ virtual class ProjectionFactorPPPC : gtsam::NoiseModelFactor { typedef gtsam::ProjectionFactorPPPC ProjectionFactorPPPCCal3_S2; typedef gtsam::ProjectionFactorPPPC ProjectionFactorPPPCCal3DS2; +#include +virtual class ProjectionFactorRollingShutter : gtsam::NoiseModelFactor { + ProjectionFactorRollingShutter(const gtsam::Point2& measured, double alpha, const gtsam::noiseModel::Base* noiseModel, + size_t poseKey_a, size_t poseKey_b, size_t pointKey, const gtsam::Cal3_S2* K); + + ProjectionFactorRollingShutter(const gtsam::Point2& measured, double alpha, const gtsam::noiseModel::Base* noiseModel, + size_t poseKey_a, size_t poseKey_b, size_t pointKey, const gtsam::Cal3_S2* K, gtsam::Pose3& body_P_sensor); + + ProjectionFactorRollingShutter(const gtsam::Point2& measured, double alpha, const gtsam::noiseModel::Base* noiseModel, + size_t poseKey_a, size_t poseKey_b, size_t pointKey, const gtsam::Cal3_S2* K, bool throwCheirality, + bool verboseCheirality); + + ProjectionFactorRollingShutter(const gtsam::Point2& measured, double alpha, const gtsam::noiseModel::Base* noiseModel, + size_t poseKey_a, size_t poseKey_b, size_t pointKey, const gtsam::Cal3_S2* K, bool throwCheirality, + bool verboseCheirality, gtsam::Pose3& body_P_sensor); + + gtsam::Point2 measured() const; + double alpha() const; + gtsam::Cal3_S2* calibration() const; + bool verboseCheirality() const; + bool throwCheirality() const; + + // enabling serialization functionality + void serialize() const; +}; + } //\namespace gtsam diff --git a/gtsam_unstable/linear/ActiveSetSolver-inl.h b/gtsam_unstable/linear/ActiveSetSolver-inl.h index 12374ac76..350985cf4 100644 --- a/gtsam_unstable/linear/ActiveSetSolver-inl.h +++ b/gtsam_unstable/linear/ActiveSetSolver-inl.h @@ -17,6 +17,8 @@ * @date 2/11/16 */ +#pragma once + #include /******************************************************************************/ @@ -283,4 +285,4 @@ Template std::pair This::optimize() const { } #undef Template -#undef This \ No newline at end of file +#undef This diff --git a/gtsam_unstable/linear/LPSolver.h b/gtsam_unstable/linear/LPSolver.h index 460b4b7ee..f36462bda 100644 --- a/gtsam_unstable/linear/LPSolver.h +++ b/gtsam_unstable/linear/LPSolver.h @@ -17,6 +17,8 @@ * @date 6/16/16 */ +#pragma once + #include #include #include diff --git a/gtsam_unstable/linear/QPSolver.h b/gtsam_unstable/linear/QPSolver.h index 3854d2a15..ae87b3ab7 100644 --- a/gtsam_unstable/linear/QPSolver.h +++ b/gtsam_unstable/linear/QPSolver.h @@ -17,6 +17,8 @@ * @date 6/16/16 */ +#pragma once + #include #include #include @@ -45,4 +47,4 @@ struct QPPolicy { using QPSolver = ActiveSetSolver; -} \ No newline at end of file +} diff --git a/gtsam_unstable/linear/tests/testQPSolver.cpp b/gtsam_unstable/linear/tests/testQPSolver.cpp index 67a0c971e..12bd93416 100644 --- a/gtsam_unstable/linear/tests/testQPSolver.cpp +++ b/gtsam_unstable/linear/tests/testQPSolver.cpp @@ -226,7 +226,7 @@ pair testParser(QPSParser parser) { expected.inequalities.add(X1, -I_1x1, 0, 2); // x >= 0 expected.inequalities.add(X2, -I_1x1, 0, 3); // y > = 0 return {expected, exampleqp}; -}; +} TEST(QPSolver, ParserSyntaticTest) { auto result = testParser(QPSParser("QPExample.QPS")); diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.h b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.h index 79c05a01a..4079dbb23 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.h +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.h @@ -113,6 +113,9 @@ public: /// Get results of latest isam2 update const ISAM2Result& getISAM2Result() const{ return isamResult_; } + /// Get the iSAM2 object which is used for the inference internally + const ISAM2& getISAM2() const { return isam_; } + protected: /** Create default parameters */ diff --git a/gtsam_unstable/nonlinear/LinearizedFactor.cpp b/gtsam_unstable/nonlinear/LinearizedFactor.cpp index 1a86adbfa..0c821b872 100644 --- a/gtsam_unstable/nonlinear/LinearizedFactor.cpp +++ b/gtsam_unstable/nonlinear/LinearizedFactor.cpp @@ -16,6 +16,7 @@ */ #include +#include #include namespace gtsam { diff --git a/gtsam_unstable/partition/FindSeparator-inl.h b/gtsam_unstable/partition/FindSeparator-inl.h index 2e48b0d45..0e4950b79 100644 --- a/gtsam_unstable/partition/FindSeparator-inl.h +++ b/gtsam_unstable/partition/FindSeparator-inl.h @@ -20,11 +20,10 @@ #include "FindSeparator.h" -#ifndef GTSAM_USE_SYSTEM_METIS +#include extern "C" { -#include -#include "metislib.h" +#include } @@ -566,5 +565,3 @@ namespace gtsam { namespace partition { } }} //namespace - -#endif diff --git a/gtsam_unstable/partition/FindSeparator.h b/gtsam_unstable/partition/FindSeparator.h index 42d971a82..f4342695b 100644 --- a/gtsam_unstable/partition/FindSeparator.h +++ b/gtsam_unstable/partition/FindSeparator.h @@ -6,6 +6,8 @@ * Description: find the separator of bisectioning for a given graph */ +#pragma once + #include #include #include diff --git a/gtsam_unstable/partition/tests/testFindSeparator.cpp b/gtsam_unstable/partition/tests/testFindSeparator.cpp index 63acc8f18..fe49de928 100644 --- a/gtsam_unstable/partition/tests/testFindSeparator.cpp +++ b/gtsam_unstable/partition/tests/testFindSeparator.cpp @@ -20,8 +20,6 @@ using namespace std; using namespace gtsam; using namespace gtsam::partition; -#ifndef GTSAM_USE_SYSTEM_METIS - /* ************************************************************************* */ // x0 - x1 - x2 // l3 l4 @@ -229,8 +227,6 @@ TEST ( Partition, findSeparator3_with_reduced_camera ) LONGS_EQUAL(2, partitionTable[28]); } -#endif - /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */ diff --git a/gtsam_unstable/slam/AHRS.h b/gtsam_unstable/slam/AHRS.h index 35b4677d5..714e62288 100644 --- a/gtsam_unstable/slam/AHRS.h +++ b/gtsam_unstable/slam/AHRS.h @@ -5,8 +5,7 @@ * Author: cbeall3 */ -#ifndef AHRS_H_ -#define AHRS_H_ +#pragma once #include "Mechanization_bRn2.h" #include @@ -82,4 +81,3 @@ public: }; } /* namespace gtsam */ -#endif /* AHRS_H_ */ diff --git a/gtsam_unstable/slam/BetweenFactorEM.h b/gtsam_unstable/slam/BetweenFactorEM.h index 98ec59fe9..9c19bae8c 100644 --- a/gtsam_unstable/slam/BetweenFactorEM.h +++ b/gtsam_unstable/slam/BetweenFactorEM.h @@ -56,7 +56,8 @@ private: bool flag_bump_up_near_zero_probs_; /** concept check by type */ - GTSAM_CONCEPT_LIE_TYPE(T)GTSAM_CONCEPT_TESTABLE_TYPE(T) + GTSAM_CONCEPT_LIE_TYPE(T) + GTSAM_CONCEPT_TESTABLE_TYPE(T) public: @@ -420,4 +421,8 @@ private: }; // \class BetweenFactorEM +/// traits +template +struct traits > : public Testable > {}; + } // namespace gtsam diff --git a/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h b/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h index 0e2aebd7f..b053b13f8 100644 --- a/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h +++ b/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h @@ -372,15 +372,15 @@ public: Matrix Z_3x3 = Z_3x3; Matrix I_3x3 = I_3x3; - Matrix H_pos_pos = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, _1, delta_vel_in_t0), delta_pos_in_t0); - Matrix H_pos_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, delta_pos_in_t0, _1), delta_vel_in_t0); + Matrix H_pos_pos = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, _1, delta_vel_in_t0), delta_pos_in_t0); + Matrix H_pos_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, delta_pos_in_t0, _1), delta_vel_in_t0); Matrix H_pos_angles = Z_3x3; - Matrix H_vel_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, delta_angles, _1, flag_use_body_P_sensor, body_P_sensor), delta_vel_in_t0); - Matrix H_vel_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, _1, delta_vel_in_t0, flag_use_body_P_sensor, body_P_sensor), delta_angles); + Matrix H_vel_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, delta_angles, _1, flag_use_body_P_sensor, body_P_sensor), delta_vel_in_t0); + Matrix H_vel_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, _1, delta_vel_in_t0, flag_use_body_P_sensor, body_P_sensor), delta_angles); Matrix H_vel_pos = Z_3x3; - Matrix H_angles_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_angles, msr_gyro_t, msr_dt, _1, flag_use_body_P_sensor, body_P_sensor), delta_angles); + Matrix H_angles_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_angles, msr_gyro_t, msr_dt, _1, flag_use_body_P_sensor, body_P_sensor), delta_angles); Matrix H_angles_pos = Z_3x3; Matrix H_angles_vel = Z_3x3; diff --git a/gtsam_unstable/slam/InvDepthFactor3.h b/gtsam_unstable/slam/InvDepthFactor3.h index 3fd86f271..44d3b8fd0 100644 --- a/gtsam_unstable/slam/InvDepthFactor3.h +++ b/gtsam_unstable/slam/InvDepthFactor3.h @@ -92,7 +92,7 @@ public: } catch( CheiralityException& e) { if (H1) *H1 = Matrix::Zero(2,6); if (H2) *H2 = Matrix::Zero(2,5); - if (H3) *H2 = Matrix::Zero(2,1); + if (H3) *H3 = Matrix::Zero(2,1); std::cout << e.what() << ": Landmark "<< DefaultKeyFormatter(this->key2()) << " moved behind camera " << DefaultKeyFormatter(this->key1()) << std::endl; return Vector::Ones(2) * 2.0 * K_->fx(); diff --git a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h index 5264c8f4b..f81c18bfa 100644 --- a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h +++ b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h @@ -9,6 +9,8 @@ #include #include +#include + #include namespace gtsam { @@ -32,16 +34,16 @@ namespace gtsam { * a local linearisation point for the plane. The plane is representated and * optimized in x1 frame in the optimization. */ -class LocalOrientedPlane3Factor: public NoiseModelFactor3 { -protected: +class GTSAM_UNSTABLE_EXPORT LocalOrientedPlane3Factor + : public NoiseModelFactor3 { + protected: OrientedPlane3 measured_p_; typedef NoiseModelFactor3 Base; public: /// Constructor LocalOrientedPlane3Factor() {} - virtual ~LocalOrientedPlane3Factor() {} + ~LocalOrientedPlane3Factor() override {} /** Constructor with measured plane (a,b,c,d) coefficients * @param z measured plane (a,b,c,d) coefficients as 4D vector @@ -54,12 +56,12 @@ public: * Note: The anchorPoseKey can simply be chosen as the first pose a plane * is observed. */ - LocalOrientedPlane3Factor(const Vector4& z, const SharedGaussian& noiseModel, + LocalOrientedPlane3Factor(const Vector4& z, const SharedNoiseModel& noiseModel, Key poseKey, Key anchorPoseKey, Key landmarkKey) : Base(noiseModel, poseKey, anchorPoseKey, landmarkKey), measured_p_(z) {} LocalOrientedPlane3Factor(const OrientedPlane3& z, - const SharedGaussian& noiseModel, + const SharedNoiseModel& noiseModel, Key poseKey, Key anchorPoseKey, Key landmarkKey) : Base(noiseModel, poseKey, anchorPoseKey, landmarkKey), measured_p_(z) {} diff --git a/gtsam_unstable/slam/PoseToPointFactor.h b/gtsam_unstable/slam/PoseToPointFactor.h index ec7da22ef..cab48e506 100644 --- a/gtsam_unstable/slam/PoseToPointFactor.h +++ b/gtsam_unstable/slam/PoseToPointFactor.h @@ -1,11 +1,14 @@ /** * @file PoseToPointFactor.hpp - * @brief This factor can be used to track a 3D landmark over time by - *providing local measurements of its location. + * @brief This factor can be used to model relative position measurements + * from a (2D or 3D) pose to a landmark * @author David Wisth + * @author Luca Carlone **/ #pragma once +#include +#include #include #include #include @@ -17,12 +20,13 @@ namespace gtsam { * A class for a measurement between a pose and a point. * @addtogroup SLAM */ -class PoseToPointFactor : public NoiseModelFactor2 { +template +class PoseToPointFactor : public NoiseModelFactor2 { private: typedef PoseToPointFactor This; - typedef NoiseModelFactor2 Base; + typedef NoiseModelFactor2 Base; - Point3 measured_; /** the point measurement in local coordinates */ + POINT measured_; /** the point measurement in local coordinates */ public: // shorthand for a smart pointer to a factor @@ -32,7 +36,7 @@ class PoseToPointFactor : public NoiseModelFactor2 { PoseToPointFactor() {} /** Constructor */ - PoseToPointFactor(Key key1, Key key2, const Point3& measured, + PoseToPointFactor(Key key1, Key key2, const POINT& measured, const SharedNoiseModel& model) : Base(model, key1, key2), measured_(measured) {} @@ -41,8 +45,8 @@ class PoseToPointFactor : public NoiseModelFactor2 { /** implement functions needed for Testable */ /** print */ - virtual void print(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const { + void print(const std::string& s, const KeyFormatter& keyFormatter = + DefaultKeyFormatter) const override { std::cout << s << "PoseToPointFactor(" << keyFormatter(this->key1()) << "," << keyFormatter(this->key2()) << ")\n" << " measured: " << measured_.transpose() << std::endl; @@ -50,30 +54,31 @@ class PoseToPointFactor : public NoiseModelFactor2 { } /** equals */ - virtual bool equals(const NonlinearFactor& expected, - double tol = 1e-9) const { + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override { const This* e = dynamic_cast(&expected); return e != nullptr && Base::equals(*e, tol) && - traits::Equals(this->measured_, e->measured_, tol); + traits::Equals(this->measured_, e->measured_, tol); } /** implement functions needed to derive from Factor */ /** vector of errors - * @brief Error = wTwi.inverse()*wPwp - measured_ - * @param wTwi The pose of the sensor in world coordinates - * @param wPwp The estimated point location in world coordinates + * @brief Error = w_T_b.inverse()*w_P - measured_ + * @param w_T_b The pose of the body in world coordinates + * @param w_P The estimated point location in world coordinates * * Note: measured_ and the error are in local coordiantes. */ - Vector evaluateError(const Pose3& wTwi, const Point3& wPwp, - boost::optional H1 = boost::none, - boost::optional H2 = boost::none) const { - return wTwi.transformTo(wPwp, H1, H2) - measured_; + Vector evaluateError( + const POSE& w_T_b, const POINT& w_P, + boost::optional H1 = boost::none, + boost::optional H2 = boost::none) const override { + return w_T_b.transformTo(w_P, H1, H2) - measured_; } /** return the measured */ - const Point3& measured() const { return measured_; } + const POINT& measured() const { return measured_; } private: /** Serialization function */ diff --git a/gtsam_unstable/slam/ProjectionFactorPPPC.h b/gtsam_unstable/slam/ProjectionFactorPPPC.h index fbc11503c..18ee13b9a 100644 --- a/gtsam_unstable/slam/ProjectionFactorPPPC.h +++ b/gtsam_unstable/slam/ProjectionFactorPPPC.h @@ -18,9 +18,11 @@ #pragma once -#include -#include #include +#include +#include +#include + #include namespace gtsam { @@ -30,60 +32,50 @@ namespace gtsam { * estimates the body pose, body-camera transform, 3D landmark, and calibration. * @addtogroup SLAM */ - template - class ProjectionFactorPPPC: public NoiseModelFactor4 { - protected: +template +class GTSAM_UNSTABLE_EXPORT ProjectionFactorPPPC + : public NoiseModelFactor4 { + protected: + Point2 measured_; ///< 2D measurement - Point2 measured_; ///< 2D measurement + // verbosity handling for Cheirality Exceptions + bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) + bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) - // verbosity handling for Cheirality Exceptions - bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) - bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) + public: + /// shorthand for base class type + typedef NoiseModelFactor4 Base; - public: + /// shorthand for this class + typedef ProjectionFactorPPPC This; - /// shorthand for base class type - typedef NoiseModelFactor4 Base; + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; - /// shorthand for this class - typedef ProjectionFactorPPPC This; - - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - - /// Default constructor + /// Default constructor ProjectionFactorPPPC() : measured_(0.0, 0.0), throwCheirality_(false), verboseCheirality_(false) { } - /** - * Constructor - * TODO: Mark argument order standard (keys, measurement, parameters) - * @param measured is the 2 dimensional location of point in image (the measurement) - * @param model is the standard deviation - * @param poseKey is the index of the camera - * @param pointKey is the index of the landmark - * @param K shared pointer to the constant calibration - */ - ProjectionFactorPPPC(const Point2& measured, const SharedNoiseModel& model, - Key poseKey, Key transformKey, Key pointKey, Key calibKey) : - Base(model, poseKey, transformKey, pointKey, calibKey), measured_(measured), - throwCheirality_(false), verboseCheirality_(false) {} /** * Constructor with exception-handling flags * TODO: Mark argument order standard (keys, measurement, parameters) - * @param measured is the 2 dimensional location of point in image (the measurement) + * @param measured is the 2 dimensional location of point in image (the + * measurement) * @param model is the standard deviation * @param poseKey is the index of the camera + * @param transformKey is the index of the extrinsic calibration * @param pointKey is the index of the landmark - * @param K shared pointer to the constant calibration - * @param throwCheirality determines whether Cheirality exceptions are rethrown - * @param verboseCheirality determines whether exceptions are printed for Cheirality + * @param calibKey is the index of the intrinsic calibration + * @param throwCheirality determines whether Cheirality exceptions are + * rethrown + * @param verboseCheirality determines whether exceptions are printed for + * Cheirality */ ProjectionFactorPPPC(const Point2& measured, const SharedNoiseModel& model, Key poseKey, Key transformKey, Key pointKey, Key calibKey, - bool throwCheirality, bool verboseCheirality) : + bool throwCheirality = false, bool verboseCheirality = false) : Base(model, poseKey, transformKey, pointKey, calibKey), measured_(measured), throwCheirality_(throwCheirality), verboseCheirality_(verboseCheirality) {} @@ -123,8 +115,8 @@ namespace gtsam { try { if(H1 || H2 || H3 || H4) { Matrix H0, H02; - PinholeCamera camera(pose.compose(transform, H0, H02), K); - Point2 reprojectionError(camera.project(point, H1, H3, H4) - measured_); + const PinholeCamera camera(pose.compose(transform, H0, H02), K); + const Point2 reprojectionError(camera.project(point, H1, H3, H4) - measured_); *H2 = *H1 * H02; *H1 = *H1 * H0; return reprojectionError; @@ -168,7 +160,7 @@ namespace gtsam { ar & BOOST_SERIALIZATION_NVP(throwCheirality_); ar & BOOST_SERIALIZATION_NVP(verboseCheirality_); } - }; +}; /// traits template diff --git a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h index c92653c13..2aeaa4824 100644 --- a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h +++ b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -40,7 +41,7 @@ namespace gtsam { * @addtogroup SLAM */ -class ProjectionFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT ProjectionFactorRollingShutter : public NoiseModelFactor3 { protected: // Keep a copy of measurement and calibration for I/O diff --git a/gtsam_unstable/slam/README.md b/gtsam_unstable/slam/README.md new file mode 100644 index 000000000..9aa0fed78 --- /dev/null +++ b/gtsam_unstable/slam/README.md @@ -0,0 +1,40 @@ +# SLAM Factors + +## SmartFactors + +These are "structure-less" factors, i.e., rather than introducing a new variable for an observed 3D point or landmark, a single factor is created that provides a multi-view constraint on several poses and/or cameras. + +### SmartRangeFactor + +An experiment in creating a structure-less 2D range-SLAM factor with range-only measurements. +It uses a sophisticated `triangulate` logic based on circle intersections. + +### SmartStereoProjectionFactor + +Version of `SmartProjectionFactor` for stereo observations, specializes SmartFactorBase for `CAMERA == StereoCamera`. + +TODO: a lot of commented out code and could move a lot to .cpp file. + +### SmartStereoProjectionPoseFactor + +Derives from `SmartStereoProjectionFactor` but adds an array of `Cal3_S2Stereo` calibration objects . + +TODO: Again, as no template arguments, we could move a lot to .cpp file. + +### SmartStereoProjectionFactorPP + +Similar `SmartStereoProjectionPoseFactor` but *additionally* adds an array of body_P_cam poses. The dimensions seem to be hardcoded and the types defined in the SmartFactorBase have been re-defined. +The body_P_cam poses are optimized here! + +TODO: See above, same issues as `SmartStereoProjectionPoseFactor`. + +### SmartProjectionPoseFactorRollingShutter + +Is templated on a `CAMERA` type and derives from `SmartProjectionFactor`. + +This factor optimizes two consecutive poses of a body assuming a rolling +shutter model of the camera with given readout time. The factor requires that +values contain (for each 2D observation) two consecutive camera poses from +which the 2D observation pose can be interpolated. + +TODO: the dimensions seem to be hardcoded and the types defined in the SmartFactorBase have been re-defined. Also, possibly a lot of copy/paste computation of things that (should) happen in base class. \ No newline at end of file diff --git a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h index 7660ff236..ff84fcd16 100644 --- a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h +++ b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h @@ -20,6 +20,7 @@ #include #include +#include namespace gtsam { /** @@ -41,15 +42,16 @@ namespace gtsam { * @addtogroup SLAM */ template -class SmartProjectionPoseFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT SmartProjectionPoseFactorRollingShutter : public SmartProjectionFactor { - public: + private: + typedef SmartProjectionFactor Base; + typedef SmartProjectionPoseFactorRollingShutter This; typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; protected: - /// shared pointer to calibration object (one for each observation) - std::vector> K_all_; - /// The keys of the pose of the body (with respect to an external world /// frame): two consecutive poses for each observation std::vector> world_P_body_key_pairs_; @@ -58,21 +60,17 @@ class SmartProjectionPoseFactorRollingShutter /// pair of consecutive poses std::vector alphas_; - /// Pose of the camera in the body frame - std::vector body_P_sensors_; + /// one or more cameras taking observations (fixed poses wrt body + fixed + /// intrinsics) + boost::shared_ptr cameraRig_; + + /// vector of camera Ids (one for each observation, in the same order), + /// identifying which camera took the measurement + FastVector cameraIds_; public: EIGEN_MAKE_ALIGNED_OPERATOR_NEW - /// shorthand for base class type - typedef SmartProjectionFactor> Base; - - /// shorthand for this class - typedef SmartProjectionPoseFactorRollingShutter This; - - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - static const int DimBlock = 12; ///< size of the variable stacking 2 poses from which the observation ///< pose is interpolated @@ -83,22 +81,43 @@ class SmartProjectionPoseFactorRollingShutter typedef std::vector> FBlocks; // vector of F blocks + typedef CAMERA Camera; + typedef CameraSet Cameras; + + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + + /// Default constructor, only for serialization + SmartProjectionPoseFactorRollingShutter() {} + /** * Constructor * @param Isotropic measurement noise + * @param cameraRig set of cameras (fixed poses wrt body and intrinsics) + * taking the measurements * @param params internal parameters of the smart factors */ SmartProjectionPoseFactorRollingShutter( const SharedNoiseModel& sharedNoiseModel, + const boost::shared_ptr& cameraRig, const SmartProjectionParams& params = SmartProjectionParams()) - : Base(sharedNoiseModel, params) {} + : Base(sharedNoiseModel, params), cameraRig_(cameraRig) { + // throw exception if configuration is not supported by this factor + if (Base::params_.degeneracyMode != gtsam::ZERO_ON_DEGENERACY) + throw std::runtime_error( + "SmartProjectionRigFactor: " + "degeneracyMode must be set to ZERO_ON_DEGENERACY"); + if (Base::params_.linearizationMode != gtsam::HESSIAN) + throw std::runtime_error( + "SmartProjectionRigFactor: " + "linearizationMode must be set to HESSIAN"); + } /** Virtual destructor */ ~SmartProjectionPoseFactorRollingShutter() override = default; /** - * add a new measurement, with 2 pose keys, interpolation factor, camera - * (intrinsic and extrinsic) calibration, and observed pixel. + * add a new measurement, with 2 pose keys, interpolation factor, and cameraId * @param measured 2-dimensional location of the projection of a single * landmark in a single view (the measurement), interpolated from the 2 poses * @param world_P_body_key1 key corresponding to the first body poses (time <= @@ -107,13 +126,11 @@ class SmartProjectionPoseFactorRollingShutter * >= time pixel is acquired) * @param alpha interpolation factor in [0,1], such that if alpha = 0 the * interpolated pose is the same as world_P_body_key1 - * @param K (fixed) camera intrinsic calibration - * @param body_P_sensor (fixed) camera extrinsic calibration + * @param cameraId ID of the camera taking the measurement (default 0) */ - void add(const Point2& measured, const Key& world_P_body_key1, + void add(const MEASUREMENT& measured, const Key& world_P_body_key1, const Key& world_P_body_key2, const double& alpha, - const boost::shared_ptr& K, - const Pose3& body_P_sensor = Pose3::identity()) { + const size_t& cameraId = 0) { // store measurements in base class this->measured_.push_back(measured); @@ -133,11 +150,8 @@ class SmartProjectionPoseFactorRollingShutter // store interpolation factor alphas_.push_back(alpha); - // store fixed intrinsic calibration - K_all_.push_back(K); - - // store fixed extrinsics of the camera - body_P_sensors_.push_back(body_P_sensor); + // store id of the camera taking the measurement + cameraIds_.push_back(cameraId); } /** @@ -150,56 +164,36 @@ class SmartProjectionPoseFactorRollingShutter * for the i0-th measurement can be interpolated * @param alphas vector of interpolation params (in [0,1]), one for each * measurement (in the same order) - * @param Ks vector of (fixed) intrinsic calibration objects - * @param body_P_sensors vector of (fixed) extrinsic calibration objects + * @param cameraIds IDs of the cameras taking each measurement (same order as + * the measurements) */ - void add(const Point2Vector& measurements, + void add(const MEASUREMENTS& measurements, const std::vector>& world_P_body_key_pairs, const std::vector& alphas, - const std::vector>& Ks, - const std::vector& body_P_sensors) { - assert(world_P_body_key_pairs.size() == measurements.size()); - assert(world_P_body_key_pairs.size() == alphas.size()); - assert(world_P_body_key_pairs.size() == Ks.size()); + const FastVector& cameraIds = FastVector()) { + if (world_P_body_key_pairs.size() != measurements.size() || + world_P_body_key_pairs.size() != alphas.size() || + (world_P_body_key_pairs.size() != cameraIds.size() && + cameraIds.size() != 0)) { // cameraIds.size()=0 is default + throw std::runtime_error( + "SmartProjectionPoseFactorRollingShutter: " + "trying to add inconsistent inputs"); + } + if (cameraIds.size() == 0 && cameraRig_->size() > 1) { + throw std::runtime_error( + "SmartProjectionPoseFactorRollingShutter: " + "camera rig includes multiple camera " + "but add did not input cameraIds"); + } for (size_t i = 0; i < measurements.size(); i++) { add(measurements[i], world_P_body_key_pairs[i].first, - world_P_body_key_pairs[i].second, alphas[i], Ks[i], - body_P_sensors[i]); + world_P_body_key_pairs[i].second, alphas[i], + cameraIds.size() == 0 ? 0 + : cameraIds[i]); // use 0 as default if + // cameraIds was not specified } } - /** - * Variant of the previous "add" function in which we include multiple - * measurements with the same (intrinsic and extrinsic) calibration - * @param measurements vector of the 2m dimensional location of the projection - * of a single landmark in the m views (the measurements) - * @param world_P_body_key_pairs vector where the i-th element contains a pair - * of keys corresponding to the pair of poses from which the observation pose - * for the i0-th measurement can be interpolated - * @param alphas vector of interpolation params (in [0,1]), one for each - * measurement (in the same order) - * @param K (fixed) camera intrinsic calibration (same for all measurements) - * @param body_P_sensor (fixed) camera extrinsic calibration (same for all - * measurements) - */ - void add(const Point2Vector& measurements, - const std::vector>& world_P_body_key_pairs, - const std::vector& alphas, - const boost::shared_ptr& K, - const Pose3& body_P_sensor = Pose3::identity()) { - assert(world_P_body_key_pairs.size() == measurements.size()); - assert(world_P_body_key_pairs.size() == alphas.size()); - for (size_t i = 0; i < measurements.size(); i++) { - add(measurements[i], world_P_body_key_pairs[i].first, - world_P_body_key_pairs[i].second, alphas[i], K, body_P_sensor); - } - } - - /// return the calibration object - const std::vector>& calibration() const { - return K_all_; - } - /// return (for each observation) the keys of the pair of poses from which we /// interpolate const std::vector>& world_P_body_key_pairs() const { @@ -209,8 +203,11 @@ class SmartProjectionPoseFactorRollingShutter /// return the interpolation factors alphas const std::vector& alphas() const { return alphas_; } - /// return the extrinsic camera calibration body_P_sensors - const std::vector& body_P_sensors() const { return body_P_sensors_; } + /// return the calibration object + const boost::shared_ptr& cameraRig() const { return cameraRig_; } + + /// return the calibration object + const FastVector& cameraIds() const { return cameraIds_; } /** * print @@ -221,15 +218,15 @@ class SmartProjectionPoseFactorRollingShutter const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { std::cout << s << "SmartProjectionPoseFactorRollingShutter: \n "; - for (size_t i = 0; i < K_all_.size(); i++) { + for (size_t i = 0; i < cameraIds_.size(); i++) { std::cout << "-- Measurement nr " << i << std::endl; std::cout << " pose1 key: " << keyFormatter(world_P_body_key_pairs_[i].first) << std::endl; std::cout << " pose2 key: " << keyFormatter(world_P_body_key_pairs_[i].second) << std::endl; std::cout << " alpha: " << alphas_[i] << std::endl; - body_P_sensors_[i].print("extrinsic calibration:\n"); - K_all_[i]->print("intrinsic calibration = "); + std::cout << "cameraId: " << cameraIds_[i] << std::endl; + (*cameraRig_)[cameraIds_[i]].print("camera in rig:\n"); } Base::print("", keyFormatter); } @@ -257,20 +254,48 @@ class SmartProjectionPoseFactorRollingShutter keyPairsEqual = false; } - double extrinsicCalibrationEqual = true; - if (this->body_P_sensors_.size() == e->body_P_sensors().size()) { - for (size_t i = 0; i < this->body_P_sensors_.size(); i++) { - if (!body_P_sensors_[i].equals(e->body_P_sensors()[i])) { - extrinsicCalibrationEqual = false; - break; - } - } - } else { - extrinsicCalibrationEqual = false; - } + return e && Base::equals(p, tol) && alphas_ == e->alphas() && + keyPairsEqual && cameraRig_->equals(*(e->cameraRig())) && + std::equal(cameraIds_.begin(), cameraIds_.end(), + e->cameraIds().begin()); + } - return e && Base::equals(p, tol) && K_all_ == e->calibration() && - alphas_ == e->alphas() && keyPairsEqual && extrinsicCalibrationEqual; + /** + * Collect all cameras involved in this factor + * @param values Values structure which must contain camera poses + * corresponding to keys involved in this factor + * @return Cameras + */ + typename Base::Cameras cameras(const Values& values) const override { + typename Base::Cameras cameras; + for (size_t i = 0; i < this->measured_.size(); + i++) { // for each measurement + const Pose3& w_P_body1 = + values.at(world_P_body_key_pairs_[i].first); + const Pose3& w_P_body2 = + values.at(world_P_body_key_pairs_[i].second); + double interpolationFactor = alphas_[i]; + const Pose3& w_P_body = + interpolate(w_P_body1, w_P_body2, interpolationFactor); + const typename Base::Camera& camera_i = (*cameraRig_)[cameraIds_[i]]; + const Pose3& body_P_cam = camera_i.pose(); + const Pose3& w_P_cam = w_P_body.compose(body_P_cam); + cameras.emplace_back(w_P_cam, + make_shared( + camera_i.calibration())); + } + return cameras; + } + + /** + * error calculates the error of the factor. + */ + double error(const Values& values) const override { + if (this->active(values)) { + return this->totalReprojectionError(this->cameras(values)); + } else { // else of active flag + return 0.0; + } } /** @@ -305,14 +330,16 @@ class SmartProjectionPoseFactorRollingShutter auto w_P_body = interpolate(w_P_body1, w_P_body2, interpolationFactor, dInterpPose_dPoseBody1, dInterpPose_dPoseBody2); - auto body_P_cam = body_P_sensors_[i]; + const typename Base::Camera& camera_i = (*cameraRig_)[cameraIds_[i]]; + auto body_P_cam = camera_i.pose(); auto w_P_cam = w_P_body.compose(body_P_cam, dPoseCam_dInterpPose); - PinholeCamera camera(w_P_cam, *K_all_[i]); + typename Base::Camera camera( + w_P_cam, make_shared( + camera_i.calibration())); // get jacobians and error vector for current measurement - Point2 reprojectionError_i = - Point2(camera.project(*this->result_, dProject_dPoseCam, Ei) - - this->measured_.at(i)); + Point2 reprojectionError_i = camera.reprojectionError( + *this->result_, this->measured_.at(i), dProject_dPoseCam, Ei); Eigen::Matrix J; // 2 x 12 J.block(0, 0, ZDim, 6) = dProject_dPoseCam * dPoseCam_dInterpPose * @@ -332,7 +359,7 @@ class SmartProjectionPoseFactorRollingShutter /// linearize and return a Hessianfactor that is an approximation of error(p) boost::shared_ptr> createHessianFactor( - const Values& values, const double lambda = 0.0, + const Values& values, const double& lambda = 0.0, bool diagonalDamping = false) const { // we may have multiple observation sharing the same keys (due to the // rolling shutter interpolation), hence the number of unique keys may be @@ -341,19 +368,21 @@ class SmartProjectionPoseFactorRollingShutter this->keys_ .size(); // note: by construction, keys_ only contains unique keys + typename Base::Cameras cameras = this->cameras(values); + // Create structures for Hessian Factors KeyVector js; std::vector Gs(nrUniqueKeys * (nrUniqueKeys + 1) / 2); std::vector gs(nrUniqueKeys); if (this->measured_.size() != - this->cameras(values).size()) // 1 observation per interpolated camera + cameras.size()) // 1 observation per interpolated camera throw std::runtime_error( "SmartProjectionPoseFactorRollingShutter: " "measured_.size() inconsistent with input"); // triangulate 3D point at given linearization point - this->triangulateSafe(this->cameras(values)); + this->triangulateSafe(cameras); if (!this->result_) { // failed: return "empty/zero" Hessian if (this->params_.degeneracyMode == ZERO_ON_DEGENERACY) { @@ -378,7 +407,7 @@ class SmartProjectionPoseFactorRollingShutter for (size_t i = 0; i < Fs.size(); i++) Fs[i] = this->noiseModel_->Whiten(Fs[i]); - Matrix3 P = Base::Cameras::PointCov(E, lambda, diagonalDamping); + Matrix3 P = Cameras::PointCov(E, lambda, diagonalDamping); // Collect all the key pairs: these are the keys that correspond to the // blocks in Fs (on which we apply the Schur Complement) @@ -399,46 +428,6 @@ class SmartProjectionPoseFactorRollingShutter this->keys_, augmentedHessianUniqueKeys); } - /** - * error calculates the error of the factor. - */ - double error(const Values& values) const override { - if (this->active(values)) { - return this->totalReprojectionError(this->cameras(values)); - } else { // else of active flag - return 0.0; - } - } - - /** - * Collect all cameras involved in this factor - * @param values Values structure which must contain camera poses - * corresponding to keys involved in this factor - * @return Cameras - */ - typename Base::Cameras cameras(const Values& values) const override { - size_t numViews = this->measured_.size(); - assert(numViews == K_all_.size()); - assert(numViews == alphas_.size()); - assert(numViews == body_P_sensors_.size()); - assert(numViews == world_P_body_key_pairs_.size()); - - typename Base::Cameras cameras; - for (size_t i = 0; i < numViews; i++) { // for each measurement - const Pose3& w_P_body1 = - values.at(world_P_body_key_pairs_[i].first); - const Pose3& w_P_body2 = - values.at(world_P_body_key_pairs_[i].second); - double interpolationFactor = alphas_[i]; - const Pose3& w_P_body = - interpolate(w_P_body1, w_P_body2, interpolationFactor); - const Pose3& body_P_cam = body_P_sensors_[i]; - const Pose3& w_P_cam = w_P_body.compose(body_P_cam); - cameras.emplace_back(w_P_cam, K_all_[i]); - } - return cameras; - } - /** * Linearize to Gaussian Factor (possibly adding a damping factor Lambda for * LM) @@ -447,7 +436,7 @@ class SmartProjectionPoseFactorRollingShutter * @return a Gaussian factor */ boost::shared_ptr linearizeDamped( - const Values& values, const double lambda = 0.0) const { + const Values& values, const double& lambda = 0.0) const { // depending on flag set on construction we may linearize to different // linear factors switch (this->params_.linearizationMode) { @@ -455,8 +444,8 @@ class SmartProjectionPoseFactorRollingShutter return this->createHessianFactor(values, lambda); default: throw std::runtime_error( - "SmartProjectionPoseFactorRollingShutter: unknown linearization " - "mode"); + "SmartProjectionPoseFactorRollingShutter: " + "unknown linearization mode"); } } @@ -472,7 +461,6 @@ class SmartProjectionPoseFactorRollingShutter template void serialize(ARCHIVE& ar, const unsigned int /*version*/) { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - ar& BOOST_SERIALIZATION_NVP(K_all_); } }; // end of class declaration diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactor.h b/gtsam_unstable/slam/SmartStereoProjectionFactor.h index 52fd99356..61f110d3a 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactor.h @@ -11,7 +11,7 @@ /** * @file SmartStereoProjectionFactor.h - * @brief Smart stereo factor on StereoCameras (pose + calibration) + * @brief Smart stereo factor on StereoCameras (pose) * @author Luca Carlone * @author Zsolt Kira * @author Frank Dellaert @@ -20,18 +20,18 @@ #pragma once -#include -#include - -#include #include #include -#include +#include #include +#include +#include +#include #include +#include -#include #include +#include #include namespace gtsam { @@ -49,8 +49,9 @@ typedef SmartProjectionParams SmartStereoProjectionParams; * If you'd like to store poses in values instead of cameras, use * SmartStereoProjectionPoseFactor instead */ -class SmartStereoProjectionFactor: public SmartFactorBase { -private: +class SmartStereoProjectionFactor + : public SmartFactorBase { + private: typedef SmartFactorBase Base; @@ -447,23 +448,23 @@ public: } /** - * This corrects the Jacobians and error vector for the case in which the right pixel in the monocular camera is missing (nan) + * This corrects the Jacobians and error vector for the case in which the + * right 2D measurement in the monocular camera is missing (nan). */ - void correctForMissingMeasurements(const Cameras& cameras, Vector& ue, - boost::optional Fs = boost::none, - boost::optional E = boost::none) const override - { + void correctForMissingMeasurements( + const Cameras& cameras, Vector& ue, + boost::optional Fs = boost::none, + boost::optional E = boost::none) const override { // when using stereo cameras, some of the measurements might be missing: - for(size_t i=0; i < cameras.size(); i++){ + for (size_t i = 0; i < cameras.size(); i++) { const StereoPoint2& z = measured_.at(i); - if(std::isnan(z.uR())) // if the right pixel is invalid + if (std::isnan(z.uR())) // if the right 2D measurement is invalid { - if(Fs){ // delete influence of right point on jacobian Fs + if (Fs) { // delete influence of right point on jacobian Fs MatrixZD& Fi = Fs->at(i); - for(size_t ii=0; iirow(ZDim * i + 1) = Matrix::Zero(1, E->cols()); // set the corresponding entry of vector ue to zero diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h index 25be48b0f..e20241a0e 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h @@ -33,12 +33,15 @@ namespace gtsam { */ /** - * This factor optimizes the pose of the body as well as the extrinsic camera calibration (pose of camera wrt body). - * Each camera may have its own extrinsic calibration or the same calibration can be shared by multiple cameras. - * This factor requires that values contain the involved poses and extrinsics (both are Pose3 variables). + * This factor optimizes the pose of the body as well as the extrinsic camera + * calibration (pose of camera wrt body). Each camera may have its own extrinsic + * calibration or the same calibration can be shared by multiple cameras. This + * factor requires that values contain the involved poses and extrinsics (both + * are Pose3 variables). * @addtogroup SLAM */ -class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionFactorPP + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; @@ -292,7 +295,6 @@ class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_all_); } - }; // end of class declaration diff --git a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h index 2a8180ac5..a46000a68 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h @@ -43,7 +43,8 @@ namespace gtsam { * This factor requires that values contains the involved poses (Pose3). * @addtogroup SLAM */ -class SmartStereoProjectionPoseFactor : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionPoseFactor + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; diff --git a/gtsam_unstable/slam/serialization.cpp b/gtsam_unstable/slam/serialization.cpp index 88a94fd51..d87ca6f2d 100644 --- a/gtsam_unstable/slam/serialization.cpp +++ b/gtsam_unstable/slam/serialization.cpp @@ -5,8 +5,6 @@ * @author Alex Cunningham */ -#include -#include #include #include @@ -31,8 +29,6 @@ using namespace gtsam; // Creating as many permutations of factors as possible -typedef PriorFactor PriorFactorLieVector; -typedef PriorFactor PriorFactorLieMatrix; typedef PriorFactor PriorFactorPoint2; typedef PriorFactor PriorFactorStereoPoint2; typedef PriorFactor PriorFactorPoint3; @@ -46,8 +42,6 @@ typedef PriorFactor PriorFactorCalibratedCamera; typedef PriorFactor PriorFactorPinholeCameraCal3_S2; typedef PriorFactor PriorFactorStereoCamera; -typedef BetweenFactor BetweenFactorLieVector; -typedef BetweenFactor BetweenFactorLieMatrix; typedef BetweenFactor BetweenFactorPoint2; typedef BetweenFactor BetweenFactorPoint3; typedef BetweenFactor BetweenFactorRot2; @@ -55,8 +49,6 @@ typedef BetweenFactor BetweenFactorRot3; typedef BetweenFactor BetweenFactorPose2; typedef BetweenFactor BetweenFactorPose3; -typedef NonlinearEquality NonlinearEqualityLieVector; -typedef NonlinearEquality NonlinearEqualityLieMatrix; typedef NonlinearEquality NonlinearEqualityPoint2; typedef NonlinearEquality NonlinearEqualityStereoPoint2; typedef NonlinearEquality NonlinearEqualityPoint3; @@ -112,8 +104,6 @@ BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); /* Create GUIDs for geometry */ /* ************************************************************************* */ -GTSAM_VALUE_EXPORT(gtsam::LieVector); -GTSAM_VALUE_EXPORT(gtsam::LieMatrix); GTSAM_VALUE_EXPORT(gtsam::Point2); GTSAM_VALUE_EXPORT(gtsam::StereoPoint2); GTSAM_VALUE_EXPORT(gtsam::Point3); @@ -133,8 +123,6 @@ GTSAM_VALUE_EXPORT(gtsam::StereoCamera); BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor"); BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor"); -BOOST_CLASS_EXPORT_GUID(PriorFactorLieVector, "gtsam::PriorFactorLieVector"); -BOOST_CLASS_EXPORT_GUID(PriorFactorLieMatrix, "gtsam::PriorFactorLieMatrix"); BOOST_CLASS_EXPORT_GUID(PriorFactorPoint2, "gtsam::PriorFactorPoint2"); BOOST_CLASS_EXPORT_GUID(PriorFactorStereoPoint2, "gtsam::PriorFactorStereoPoint2"); BOOST_CLASS_EXPORT_GUID(PriorFactorPoint3, "gtsam::PriorFactorPoint3"); @@ -147,8 +135,6 @@ BOOST_CLASS_EXPORT_GUID(PriorFactorCal3DS2, "gtsam::PriorFactorCal3DS2"); BOOST_CLASS_EXPORT_GUID(PriorFactorCalibratedCamera, "gtsam::PriorFactorCalibratedCamera"); BOOST_CLASS_EXPORT_GUID(PriorFactorStereoCamera, "gtsam::PriorFactorStereoCamera"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorLieVector, "gtsam::BetweenFactorLieVector"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorLieMatrix, "gtsam::BetweenFactorLieMatrix"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint2, "gtsam::BetweenFactorPoint2"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint3, "gtsam::BetweenFactorPoint3"); BOOST_CLASS_EXPORT_GUID(BetweenFactorRot2, "gtsam::BetweenFactorRot2"); @@ -156,8 +142,6 @@ BOOST_CLASS_EXPORT_GUID(BetweenFactorRot3, "gtsam::BetweenFactorRot3"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPose2, "gtsam::BetweenFactorPose2"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPose3, "gtsam::BetweenFactorPose3"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityLieVector, "gtsam::NonlinearEqualityLieVector"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityLieMatrix, "gtsam::NonlinearEqualityLieMatrix"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint2, "gtsam::NonlinearEqualityPoint2"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoPoint2, "gtsam::NonlinearEqualityStereoPoint2"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint3, "gtsam::NonlinearEqualityPoint3"); @@ -189,7 +173,7 @@ BOOST_CLASS_EXPORT_GUID(GeneralSFMFactor2Cal3_S2, "gtsam::GeneralSFMFactor2Cal3_ BOOST_CLASS_EXPORT_GUID(GenericStereoFactor3D, "gtsam::GenericStereoFactor3D"); -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 typedef PriorFactor PriorFactorSimpleCamera; typedef NonlinearEquality NonlinearEqualitySimpleCamera; diff --git a/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp b/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp index 4d6e1912a..f43ae293e 100644 --- a/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp +++ b/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -21,26 +22,24 @@ using namespace gtsam; // Disabled this test because it is currently failing - remove the lines "#if 0" and "#endif" below // to reenable the test. -#if 0 +// #if 0 /* ************************************************************************* */ -LieVector predictionError(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactorEM& factor){ +Vector predictionError(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactorEM& factor){ gtsam::Values values; values.insert(key1, p1); values.insert(key2, p2); - // LieVector err = factor.whitenedError(values); - // return err; - return LieVector::Expmap(factor.whitenedError(values)); + return factor.whitenedError(values); } /* ************************************************************************* */ -LieVector predictionError_standard(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactor& factor){ +Vector predictionError_standard(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactor& factor){ gtsam::Values values; values.insert(key1, p1); values.insert(key2, p2); - // LieVector err = factor.whitenedError(values); + // Vector err = factor.whitenedError(values); // return err; - return LieVector::Expmap(factor.whitenedError(values)); + return factor.whitenedError(values); } /* ************************************************************************* */ @@ -99,8 +98,8 @@ TEST( BetweenFactorEM, EvaluateError) Vector actual_err_wh = f.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); - Vector actual_err_wh_outlier = (Vector(3) << actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); + Vector3 actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector3 actual_err_wh_outlier = Vector3(actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); // cout << "Inlier test. norm of actual_err_wh_inlier, actual_err_wh_outlier: "< h_EM(key1, key2, rel_pose_msr, model_inlier, model_outlier, prior_inlier, prior_outlier); actual_err_wh = h_EM.whitenedError(values); - actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); BetweenFactor h(key1, key2, rel_pose_msr, model_inlier ); Vector actual_err_wh_stnd = h.whitenedError(values); @@ -178,7 +177,7 @@ TEST (BetweenFactorEM, jacobian ) { // compare to standard between factor BetweenFactor h(key1, key2, rel_pose_msr, model_inlier ); Vector actual_err_wh_stnd = h.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); // CHECK( assert_equal(actual_err_wh_stnd, actual_err_wh_inlier, 1e-8)); std::vector H_actual_stnd_unwh(2); (void)h.unwhitenedError(values, H_actual_stnd_unwh); @@ -190,12 +189,13 @@ TEST (BetweenFactorEM, jacobian ) { // CHECK( assert_equal(H2_actual_stnd, H2_actual, 1e-8)); double stepsize = 1.0e-9; - Matrix H1_expected = gtsam::numericalDerivative11(std::bind(&predictionError, _1, p2, key1, key2, f), p1, stepsize); - Matrix H2_expected = gtsam::numericalDerivative11(std::bind(&predictionError, p1, _1, key1, key2, f), p2, stepsize); + using std::placeholders::_1; + Matrix H1_expected = gtsam::numericalDerivative11(std::bind(&predictionError, _1, p2, key1, key2, f), p1, stepsize); + Matrix H2_expected = gtsam::numericalDerivative11(std::bind(&predictionError, p1, _1, key1, key2, f), p2, stepsize); // try to check numerical derivatives of a standard between factor - Matrix H1_expected_stnd = gtsam::numericalDerivative11(std::bind(&predictionError_standard, _1, p2, key1, key2, h), p1, stepsize); + Matrix H1_expected_stnd = gtsam::numericalDerivative11(std::bind(&predictionError_standard, _1, p2, key1, key2, h), p1, stepsize); // CHECK( assert_equal(H1_expected_stnd, H1_actual_stnd, 1e-5)); // // @@ -240,8 +240,8 @@ TEST( BetweenFactorEM, CaseStudy) Vector actual_err_unw = f.unwhitenedError(values); Vector actual_err_wh = f.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); - Vector actual_err_wh_outlier = (Vector(3) << actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); + Vector3 actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector3 actual_err_wh_outlier = Vector3(actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); if (debug){ cout << "p_inlier_outler: "<print("model_inlier:"); - model_outlier->print("model_outlier:"); - model_inlier_new->print("model_inlier_new:"); - model_outlier_new->print("model_outlier_new:"); + // model_inlier->print("model_inlier:"); + // model_outlier->print("model_outlier:"); + // model_inlier_new->print("model_inlier_new:"); + // model_outlier_new->print("model_outlier_new:"); } -#endif +// #endif /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} diff --git a/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp b/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp index 8692cf584..ed4092c60 100644 --- a/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp +++ b/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp @@ -16,22 +16,23 @@ * @date Jan 17, 2012 */ -#include -#include -#include -#include #include -#include +#include +#include +#include +#include +#include using namespace std::placeholders; using namespace std; using namespace gtsam; //! Factors -typedef GaussMarkov1stOrderFactor GaussMarkovFactor; +typedef GaussMarkov1stOrderFactor GaussMarkovFactor; /* ************************************************************************* */ -LieVector predictionError(const LieVector& v1, const LieVector& v2, const GaussMarkovFactor factor) { +Vector predictionError(const Vector& v1, const Vector& v2, + const GaussMarkovFactor factor) { return factor.evaluateError(v1, v2); } @@ -58,29 +59,29 @@ TEST( GaussMarkovFactor, error ) Key x1(1); Key x2(2); double delta_t = 0.10; - Vector tau = Vector3(100.0, 150.0, 10.0); + Vector3 tau(100.0, 150.0, 10.0); SharedGaussian model = noiseModel::Isotropic::Sigma(3, 1.0); - LieVector v1 = LieVector(Vector3(10.0, 12.0, 13.0)); - LieVector v2 = LieVector(Vector3(10.0, 15.0, 14.0)); + Vector3 v1(10.0, 12.0, 13.0); + Vector3 v2(10.0, 15.0, 14.0); // Create two nodes linPoint.insert(x1, v1); linPoint.insert(x2, v2); GaussMarkovFactor factor(x1, x2, delta_t, tau, model); - Vector Err1( factor.evaluateError(v1, v2) ); + Vector3 error1 = factor.evaluateError(v1, v2); // Manually calculate the error - Vector alpha(tau.size()); - Vector alpha_v1(tau.size()); + Vector3 alpha(tau.size()); + Vector3 alpha_v1(tau.size()); for(int i=0; i @@ -12,6 +22,7 @@ #include #include #include +#include #include @@ -28,6 +39,11 @@ PinholeCamera level_camera(level_pose, *K); typedef InvDepthFactor3 InverseDepthFactor; typedef NonlinearEquality PoseConstraint; +Matrix factorError(const Pose3& pose, const Vector5& point, double invDepth, + const InverseDepthFactor& factor) { + return factor.evaluateError(pose, point, invDepth); +} + /* ************************************************************************* */ TEST( InvDepthFactor, optimize) { @@ -92,6 +108,55 @@ TEST( InvDepthFactor, optimize) { } +/* ************************************************************************* */ +TEST( InvDepthFactor, Jacobian3D ) { + + // landmark 5 meters infront of camera (camera center at (0,0,1)) + Point3 landmark(5, 0, 1); + + // get expected projection using pinhole camera + Point2 expected_uv = level_camera.project(landmark); + + // get expected landmark representation using backprojection + double inv_depth; + Vector5 inv_landmark; + InvDepthCamera3 inv_camera(level_pose, K); + std::tie(inv_landmark, inv_depth) = inv_camera.backproject(expected_uv, 5); + Vector5 expected_inv_landmark((Vector(5) << 0., 0., 1., 0., 0.).finished()); + + CHECK(assert_equal(expected_inv_landmark, inv_landmark, 1e-6)); + CHECK(assert_equal(inv_depth, 1./5, 1e-6)); + + Symbol poseKey('x',1); + Symbol pointKey('l',1); + Symbol invDepthKey('d',1); + InverseDepthFactor factor(expected_uv, sigma, poseKey, pointKey, invDepthKey, K); + + std::vector actualHs(3); + factor.unwhitenedError({{poseKey, genericValue(level_pose)}, + {pointKey, genericValue(inv_landmark)}, + {invDepthKey,genericValue(inv_depth)}}, + actualHs); + + const Matrix& H1Actual = actualHs.at(0); + const Matrix& H2Actual = actualHs.at(1); + const Matrix& H3Actual = actualHs.at(2); + + // Use numerical derivatives to verify the Jacobians + Matrix H1Expected, H2Expected, H3Expected; + + std::function + func = std::bind(&factorError, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, factor); + H1Expected = numericalDerivative31(func, level_pose, inv_landmark, inv_depth); + H2Expected = numericalDerivative32(func, level_pose, inv_landmark, inv_depth); + H3Expected = numericalDerivative33(func, level_pose, inv_landmark, inv_depth); + + // Verify the Jacobians + CHECK(assert_equal(H1Expected, H1Actual, 1e-6)) + CHECK(assert_equal(H2Expected, H2Actual, 1e-6)) + CHECK(assert_equal(H3Expected, H3Actual, 1e-6)) +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp b/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp new file mode 100644 index 000000000..5aaaaec53 --- /dev/null +++ b/gtsam_unstable/slam/tests/testPoseToPointFactor.cpp @@ -0,0 +1,161 @@ +/** + * @file testPoseToPointFactor.cpp + * @brief + * @author David Wisth + * @author Luca Carlone + * @date June 20, 2020 + */ + +#include +#include +#include + +using namespace gtsam; +using namespace gtsam::noiseModel; + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(0.0, 0.0); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = Vector2(0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(-1.0, 0.5); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_2D) { + // Measurement + gtsam::Point2 l_meas(1, 2); + + // Linearisation point + gtsam::Point2 p_t(-5, 12); + gtsam::Rot2 p_R(1.5 * M_PI); + Pose2 p(p_R, p_t); + + gtsam::Point2 l(3, 0); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector2(0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(0.0, 0.0, 0.0); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = Vector3(0.0, 0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(-1.0, 0.5, 0.3); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_3D) { + // Measurement + gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); + + // Linearisation point + gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); + gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); + Pose3 p(p_R, p_t); + + gtsam::Point3 l = gtsam::Point3(3, 0, 5); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testPoseToPointFactor.h b/gtsam_unstable/slam/tests/testPoseToPointFactor.h deleted file mode 100644 index e0e5c4581..000000000 --- a/gtsam_unstable/slam/tests/testPoseToPointFactor.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * @file testPoseToPointFactor.cpp - * @brief - * @author David Wisth - * @date June 20, 2020 - */ - -#include -#include -#include - -using namespace gtsam; -using namespace gtsam::noiseModel; - -/// Verify zero error when there is no noise -TEST(PoseToPointFactor, errorNoiseless) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(0.0, 0.0, 0.0); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = Vector3(0.0, 0.0, 0.0); - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Verify expected error in test scenario -TEST(PoseToPointFactor, errorNoise) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(-1.0, 0.5, 0.3); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = noise; - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Check Jacobians are correct -TEST(PoseToPointFactor, jacobian) { - // Measurement - gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); - - // Linearisation point - gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); - gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); - Pose3 p(p_R, p_t); - - gtsam::Point3 l = gtsam::Point3(3, 0, 5); - - // Factor - Key pose_key(1); - Key point_key(2); - SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); - PoseToPointFactor factor(pose_key, point_key, l_meas, noise); - - // Calculate numerical derivatives - auto f = std::bind(&PoseToPointFactor::evaluateError, factor, _1, _2, - boost::none, boost::none); - Matrix numerical_H1 = numericalDerivative21(f, p, l); - Matrix numerical_H2 = numericalDerivative22(f, p, l); - - // Use the factor to calculate the derivative - Matrix actual_H1; - Matrix actual_H2; - factor.evaluateError(p, l, actual_H1, actual_H2); - - // Verify we get the expected error - EXPECT_TRUE(assert_equal(numerical_H1, actual_H1, 1e-8)); - EXPECT_TRUE(assert_equal(numerical_H2, actual_H2, 1e-8)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testSerialization.cpp b/gtsam_unstable/slam/tests/testSerialization.cpp index 792fd1133..362cf3778 100644 --- a/gtsam_unstable/slam/tests/testSerialization.cpp +++ b/gtsam_unstable/slam/tests/testSerialization.cpp @@ -7,10 +7,7 @@ * @author Alex Cunningham */ -#include -#include - -#include +#include #include #include @@ -18,12 +15,16 @@ #include #include -#include -#include -#include +#include + #include #include +#include +#include +#include +#include + using namespace std; using namespace gtsam; using namespace boost::assign; diff --git a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp index 0b94d2c3f..b5962d777 100644 --- a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp +++ b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp @@ -61,10 +61,13 @@ static double interp_factor1 = 0.3; static double interp_factor2 = 0.4; static double interp_factor3 = 0.5; +static size_t cameraId1 = 0; + /* ************************************************************************* */ // default Cal3_S2 poses with rolling shutter effect namespace vanillaPoseRS { typedef PinholePose Camera; +typedef CameraSet Cameras; static Cal3_S2::shared_ptr sharedK(new Cal3_S2(fov, w, h)); Pose3 interp_pose1 = interpolate(level_pose, pose_right, interp_factor1); Pose3 interp_pose2 = interpolate(pose_right, pose_above, interp_factor2); @@ -72,6 +75,9 @@ Pose3 interp_pose3 = interpolate(pose_above, level_pose, interp_factor3); Camera cam1(interp_pose1, sharedK); Camera cam2(interp_pose2, sharedK); Camera cam3(interp_pose3, sharedK); +SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with RS factors } // namespace vanillaPoseRS LevenbergMarquardtParams lmParams; @@ -80,26 +86,35 @@ typedef SmartProjectionPoseFactorRollingShutter> /* ************************************************************************* */ TEST(SmartProjectionPoseFactorRollingShutter, Constructor) { - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); + using namespace vanillaPoseRS; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); } /* ************************************************************************* */ TEST(SmartProjectionPoseFactorRollingShutter, Constructor2) { - SmartProjectionParams params; + using namespace vanillaPoseRS; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); params.setRankTolerance(rankTol); - SmartFactorRS factor1(model, params); + SmartFactorRS factor1(model, cameraRig, params); } /* ************************************************************************* */ TEST(SmartProjectionPoseFactorRollingShutter, add) { - using namespace vanillaPose; - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); - factor1->add(measurement1, x1, x2, interp_factor, sharedK, body_P_sensor); + using namespace vanillaPoseRS; + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurement1, x1, x2, interp_factor); } /* ************************************************************************* */ TEST(SmartProjectionPoseFactorRollingShutter, Equals) { - using namespace vanillaPose; + using namespace vanillaPoseRS; // create fake measurements Point2Vector measurements; @@ -112,68 +127,88 @@ TEST(SmartProjectionPoseFactorRollingShutter, Equals) { key_pairs.push_back(std::make_pair(x2, x3)); key_pairs.push_back(std::make_pair(x3, x4)); - std::vector> intrinsicCalibrations; - intrinsicCalibrations.push_back(sharedK); - intrinsicCalibrations.push_back(sharedK); - intrinsicCalibrations.push_back(sharedK); - - std::vector extrinsicCalibrations; - extrinsicCalibrations.push_back(body_P_sensor); - extrinsicCalibrations.push_back(body_P_sensor); - extrinsicCalibrations.push_back(body_P_sensor); - std::vector interp_factors; interp_factors.push_back(interp_factor1); interp_factors.push_back(interp_factor2); interp_factors.push_back(interp_factor3); + FastVector cameraIds{0, 0, 0}; + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensor, sharedK)); + // create by adding a batch of measurements with a bunch of calibrations - SmartFactorRS::shared_ptr factor2(new SmartFactorRS(model)); - factor2->add(measurements, key_pairs, interp_factors, intrinsicCalibrations, - extrinsicCalibrations); + SmartFactorRS::shared_ptr factor2( + new SmartFactorRS(model, cameraRig, params)); + factor2->add(measurements, key_pairs, interp_factors, cameraIds); // create by adding a batch of measurements with a single calibrations - SmartFactorRS::shared_ptr factor3(new SmartFactorRS(model)); - factor3->add(measurements, key_pairs, interp_factors, sharedK, body_P_sensor); + SmartFactorRS::shared_ptr factor3( + new SmartFactorRS(model, cameraRig, params)); + factor3->add(measurements, key_pairs, interp_factors, cameraIds); { // create equal factors and show equal returns true - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); - factor1->add(measurement1, x1, x2, interp_factor1, sharedK, body_P_sensor); - factor1->add(measurement2, x2, x3, interp_factor2, sharedK, body_P_sensor); - factor1->add(measurement3, x3, x4, interp_factor3, sharedK, body_P_sensor); + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurement1, x1, x2, interp_factor1, cameraId1); + factor1->add(measurement2, x2, x3, interp_factor2, cameraId1); + factor1->add(measurement3, x3, x4, interp_factor3, cameraId1); + + EXPECT(factor1->equals(*factor2)); + EXPECT(factor1->equals(*factor3)); + } + { // create equal factors and show equal returns true (use default cameraId) + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurement1, x1, x2, interp_factor1); + factor1->add(measurement2, x2, x3, interp_factor2); + factor1->add(measurement3, x3, x4, interp_factor3); + + EXPECT(factor1->equals(*factor2)); + EXPECT(factor1->equals(*factor3)); + } + { // create equal factors and show equal returns true (use default cameraId) + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurements, key_pairs, interp_factors); EXPECT(factor1->equals(*factor2)); EXPECT(factor1->equals(*factor3)); } { // create slightly different factors (different keys) and show equal - // returns false - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); - factor1->add(measurement1, x1, x2, interp_factor1, sharedK, body_P_sensor); - factor1->add(measurement2, x2, x2, interp_factor2, sharedK, - body_P_sensor); // different! - factor1->add(measurement3, x3, x4, interp_factor3, sharedK, body_P_sensor); + // returns false (use default cameraIds) + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurement1, x1, x2, interp_factor1, cameraId1); + factor1->add(measurement2, x2, x2, interp_factor2, + cameraId1); // different! + factor1->add(measurement3, x3, x4, interp_factor3, cameraId1); EXPECT(!factor1->equals(*factor2)); EXPECT(!factor1->equals(*factor3)); } { // create slightly different factors (different extrinsics) and show equal // returns false - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); - factor1->add(measurement1, x1, x2, interp_factor1, sharedK, body_P_sensor); - factor1->add(measurement2, x2, x3, interp_factor2, sharedK, - body_P_sensor * body_P_sensor); // different! - factor1->add(measurement3, x3, x4, interp_factor3, sharedK, body_P_sensor); + boost::shared_ptr cameraRig2(new Cameras()); + cameraRig2->push_back(Camera(body_P_sensor * body_P_sensor, sharedK)); + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig2, params)); + factor1->add(measurement1, x1, x2, interp_factor1, cameraId1); + factor1->add(measurement2, x2, x3, interp_factor2, + cameraId1); // different! + factor1->add(measurement3, x3, x4, interp_factor3, cameraId1); EXPECT(!factor1->equals(*factor2)); EXPECT(!factor1->equals(*factor3)); } { // create slightly different factors (different interp factors) and show // equal returns false - SmartFactorRS::shared_ptr factor1(new SmartFactorRS(model)); - factor1->add(measurement1, x1, x2, interp_factor1, sharedK, body_P_sensor); - factor1->add(measurement2, x2, x3, interp_factor1, sharedK, - body_P_sensor); // different! - factor1->add(measurement3, x3, x4, interp_factor3, sharedK, body_P_sensor); + SmartFactorRS::shared_ptr factor1( + new SmartFactorRS(model, cameraRig, params)); + factor1->add(measurement1, x1, x2, interp_factor1, cameraId1); + factor1->add(measurement2, x2, x3, interp_factor1, + cameraId1); // different! + factor1->add(measurement3, x3, x4, interp_factor3, cameraId1); EXPECT(!factor1->equals(*factor2)); EXPECT(!factor1->equals(*factor3)); @@ -197,9 +232,12 @@ TEST(SmartProjectionPoseFactorRollingShutter, noiselessErrorAndJacobians) { Point2 level_uv_right = cam2.project(landmark1); Pose3 body_P_sensorId = Pose3::identity(); - SmartFactorRS factor(model); - factor.add(level_uv, x1, x2, interp_factor1, sharedK, body_P_sensorId); - factor.add(level_uv_right, x2, x3, interp_factor2, sharedK, body_P_sensorId); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensorId, sharedK)); + + SmartFactorRS factor(model, cameraRig, params); + factor.add(level_uv, x1, x2, interp_factor1); + factor.add(level_uv_right, x2, x3, interp_factor2); Values values; // it's a pose factor, hence these are poses values.insert(x1, level_pose); @@ -272,10 +310,12 @@ TEST(SmartProjectionPoseFactorRollingShutter, noisyErrorAndJacobians) { Point2 level_uv_right = cam2.project(landmark1); Pose3 body_P_sensorNonId = body_P_sensor; - SmartFactorRS factor(model); - factor.add(level_uv, x1, x2, interp_factor1, sharedK, body_P_sensorNonId); - factor.add(level_uv_right, x2, x3, interp_factor2, sharedK, - body_P_sensorNonId); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensorNonId, sharedK)); + + SmartFactorRS factor(model, cameraRig, params); + factor.add(level_uv, x1, x2, interp_factor1); + factor.add(level_uv_right, x2, x3, interp_factor2); Values values; // it's a pose factor, hence these are poses values.insert(x1, level_pose); @@ -364,14 +404,20 @@ TEST(SmartProjectionPoseFactorRollingShutter, optimization_3poses) { interp_factors.push_back(interp_factor2); interp_factors.push_back(interp_factor3); - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model)); - smartFactor1->add(measurements_lmk1, key_pairs, interp_factors, sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); - SmartFactorRS::shared_ptr smartFactor2(new SmartFactorRS(model)); - smartFactor2->add(measurements_lmk2, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); - SmartFactorRS::shared_ptr smartFactor3(new SmartFactorRS(model)); - smartFactor3->add(measurements_lmk3, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor2( + new SmartFactorRS(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS::shared_ptr smartFactor3( + new SmartFactorRS(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); @@ -411,6 +457,170 @@ TEST(SmartProjectionPoseFactorRollingShutter, optimization_3poses) { EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); } +/* *************************************************************************/ +TEST(SmartProjectionPoseFactorRollingShutter, optimization_3poses_multiCam) { + using namespace vanillaPoseRS; + Point2Vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, measurements_lmk3); + + // create inputs + std::vector> key_pairs; + key_pairs.push_back(std::make_pair(x1, x2)); + key_pairs.push_back(std::make_pair(x2, x3)); + key_pairs.push_back(std::make_pair(x3, x1)); + + std::vector interp_factors; + interp_factors.push_back(interp_factor1); + interp_factors.push_back(interp_factor2); + interp_factors.push_back(interp_factor3); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensor, sharedK)); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors, {1, 1, 1}); + + SmartFactorRS::shared_ptr smartFactor2( + new SmartFactorRS(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors, {1, 1, 1}); + + SmartFactorRS::shared_ptr smartFactor3( + new SmartFactorRS(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors, {1, 1, 1}); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); // pose above is the pose of the camera + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); +} + +/* *************************************************************************/ +TEST(SmartProjectionPoseFactorRollingShutter, optimization_3poses_multiCam2) { + using namespace vanillaPoseRS; + + Point2Vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // create arbitrary body_T_sensor (transforms from sensor to body) + Pose3 body_T_sensor1 = Pose3(Rot3::Ypr(-0.03, 0., 0.01), Point3(1, 1, 1)); + Pose3 body_T_sensor2 = Pose3(Rot3::Ypr(-0.1, 0., 0.05), Point3(0, 0, 1)); + Pose3 body_T_sensor3 = Pose3(Rot3::Ypr(-0.3, 0., -0.05), Point3(0, 1, 1)); + + Camera camera1(interp_pose1 * body_T_sensor1, sharedK); + Camera camera2(interp_pose2 * body_T_sensor2, sharedK); + Camera camera3(interp_pose3 * body_T_sensor3, sharedK); + + // Project three landmarks into three cameras + projectToMultipleCameras(camera1, camera2, camera3, landmark1, + measurements_lmk1); + projectToMultipleCameras(camera1, camera2, camera3, landmark2, + measurements_lmk2); + projectToMultipleCameras(camera1, camera2, camera3, landmark3, + measurements_lmk3); + + // create inputs + std::vector> key_pairs; + key_pairs.push_back(std::make_pair(x1, x2)); + key_pairs.push_back(std::make_pair(x2, x3)); + key_pairs.push_back(std::make_pair(x3, x1)); + + std::vector interp_factors; + interp_factors.push_back(interp_factor1); + interp_factors.push_back(interp_factor2); + interp_factors.push_back(interp_factor3); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_T_sensor1, sharedK)); + cameraRig->push_back(Camera(body_T_sensor2, sharedK)); + cameraRig->push_back(Camera(body_T_sensor3, sharedK)); + + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors, {0, 1, 2}); + + SmartFactorRS::shared_ptr smartFactor2( + new SmartFactorRS(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors, {0, 1, 2}); + + SmartFactorRS::shared_ptr smartFactor3( + new SmartFactorRS(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors, {0, 1, 2}); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); // pose above is the pose of the camera + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-4)); +} + /* *************************************************************************/ TEST(SmartProjectionPoseFactorRollingShutter, hessian_simple_2poses) { // here we replicate a test in SmartProjectionPoseFactor by setting @@ -418,7 +628,7 @@ TEST(SmartProjectionPoseFactorRollingShutter, hessian_simple_2poses) { // falls back to standard pixel measurements) Note: this is a quite extreme // test since in typical camera you would not have more than 1 measurement per // landmark at each interpolated pose - using namespace vanillaPose; + using namespace vanillaPoseRS; // Default cameras for simple derivatives static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); @@ -438,15 +648,17 @@ TEST(SmartProjectionPoseFactorRollingShutter, hessian_simple_2poses) { measurements_lmk1.push_back(cam1.project(landmark1)); measurements_lmk1.push_back(cam2.project(landmark1)); - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model)); - double interp_factor = 0; // equivalent to measurement taken at pose 1 - smartFactor1->add(measurements_lmk1[0], x1, x2, interp_factor, sharedKSimple, - body_P_sensorId); - interp_factor = 1; // equivalent to measurement taken at pose 2 - smartFactor1->add(measurements_lmk1[1], x1, x2, interp_factor, sharedKSimple, - body_P_sensorId); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensorId, sharedKSimple)); - SmartFactor::Cameras cameras; + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + double interp_factor = 0; // equivalent to measurement taken at pose 1 + smartFactor1->add(measurements_lmk1[0], x1, x2, interp_factor); + interp_factor = 1; // equivalent to measurement taken at pose 2 + smartFactor1->add(measurements_lmk1[1], x1, x2, interp_factor); + + SmartFactorRS::Cameras cameras; cameras.push_back(cam1); cameras.push_back(cam2); @@ -534,14 +746,17 @@ TEST(SmartProjectionPoseFactorRollingShutter, optimization_3poses_EPI) { params.setLandmarkDistanceThreshold(excludeLandmarksFutherThanDist); params.setEnableEPI(true); - SmartFactorRS smartFactor1(model, params); - smartFactor1.add(measurements_lmk1, key_pairs, interp_factors, sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); - SmartFactorRS smartFactor2(model, params); - smartFactor2.add(measurements_lmk2, key_pairs, interp_factors, sharedK); + SmartFactorRS smartFactor1(model, cameraRig, params); + smartFactor1.add(measurements_lmk1, key_pairs, interp_factors); - SmartFactorRS smartFactor3(model, params); - smartFactor3.add(measurements_lmk3, key_pairs, interp_factors, sharedK); + SmartFactorRS smartFactor2(model, cameraRig, params); + smartFactor2.add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS smartFactor3(model, cameraRig, params); + smartFactor3.add(measurements_lmk3, key_pairs, interp_factors); const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); @@ -594,18 +809,23 @@ TEST(SmartProjectionPoseFactorRollingShutter, SmartProjectionParams params; params.setRankTolerance(1.0); params.setLinearizationMode(gtsam::HESSIAN); - params.setDegeneracyMode(gtsam::IGNORE_DEGENERACY); + // params.setDegeneracyMode(gtsam::IGNORE_DEGENERACY); // this would give an + // exception as expected + params.setDegeneracyMode(gtsam::ZERO_ON_DEGENERACY); params.setLandmarkDistanceThreshold(excludeLandmarksFutherThanDist); params.setEnableEPI(false); - SmartFactorRS smartFactor1(model, params); - smartFactor1.add(measurements_lmk1, key_pairs, interp_factors, sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); - SmartFactorRS smartFactor2(model, params); - smartFactor2.add(measurements_lmk2, key_pairs, interp_factors, sharedK); + SmartFactorRS smartFactor1(model, cameraRig, params); + smartFactor1.add(measurements_lmk1, key_pairs, interp_factors); - SmartFactorRS smartFactor3(model, params); - smartFactor3.add(measurements_lmk3, key_pairs, interp_factors, sharedK); + SmartFactorRS smartFactor2(model, cameraRig, params); + smartFactor2.add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS smartFactor3(model, cameraRig, params); + smartFactor3.add(measurements_lmk3, key_pairs, interp_factors); const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); @@ -673,17 +893,24 @@ TEST(SmartProjectionPoseFactorRollingShutter, params.setDynamicOutlierRejectionThreshold(dynamicOutlierRejectionThreshold); params.setEnableEPI(false); - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model, params)); - smartFactor1->add(measurements_lmk1, key_pairs, interp_factors, sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); - SmartFactorRS::shared_ptr smartFactor2(new SmartFactorRS(model, params)); - smartFactor2->add(measurements_lmk2, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); - SmartFactorRS::shared_ptr smartFactor3(new SmartFactorRS(model, params)); - smartFactor3->add(measurements_lmk3, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor2( + new SmartFactorRS(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); - SmartFactorRS::shared_ptr smartFactor4(new SmartFactorRS(model, params)); - smartFactor4->add(measurements_lmk4, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor3( + new SmartFactorRS(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); + + SmartFactorRS::shared_ptr smartFactor4( + new SmartFactorRS(model, cameraRig, params)); + smartFactor4->add(measurements_lmk4, key_pairs, interp_factors); const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); @@ -733,8 +960,12 @@ TEST(SmartProjectionPoseFactorRollingShutter, interp_factors.push_back(interp_factor2); interp_factors.push_back(interp_factor3); - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model)); - smartFactor1->add(measurements_lmk1, key_pairs, interp_factors, sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), Point3(0.1, 0.1, 0.1)); // smaller noise @@ -870,9 +1101,12 @@ TEST(SmartProjectionPoseFactorRollingShutter, interp_factors.push_back(interp_factor3); interp_factors.push_back(interp_factor1); - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model)); - smartFactor1->add(measurements_lmk1_redundant, key_pairs, interp_factors, - sharedK); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1_redundant, key_pairs, interp_factors); Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), Point3(0.1, 0.1, 0.1)); // smaller noise @@ -1026,15 +1260,21 @@ TEST(SmartProjectionPoseFactorRollingShutter, interp_factors_redundant.push_back( interp_factors.at(0)); // we readd the first interp factor - SmartFactorRS::shared_ptr smartFactor1(new SmartFactorRS(model)); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), sharedK)); + + SmartFactorRS::shared_ptr smartFactor1( + new SmartFactorRS(model, cameraRig, params)); smartFactor1->add(measurements_lmk1_redundant, key_pairs_redundant, - interp_factors_redundant, sharedK); + interp_factors_redundant); - SmartFactorRS::shared_ptr smartFactor2(new SmartFactorRS(model)); - smartFactor2->add(measurements_lmk2, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor2( + new SmartFactorRS(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); - SmartFactorRS::shared_ptr smartFactor3(new SmartFactorRS(model)); - smartFactor3->add(measurements_lmk3, key_pairs, interp_factors, sharedK); + SmartFactorRS::shared_ptr smartFactor3( + new SmartFactorRS(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); @@ -1076,16 +1316,20 @@ TEST(SmartProjectionPoseFactorRollingShutter, #ifndef DISABLE_TIMING #include -// -Total: 0 CPU (0 times, 0 wall, 0.04 children, min: 0 max: 0) -//| -SF RS LINEARIZE: 0.02 CPU (1000 times, 0.017244 wall, 0.02 children, min: -// 0 max: 0) | -RS LINEARIZE: 0.02 CPU (1000 times, 0.009035 wall, 0.02 -// children, min: 0 max: 0) +//-Total: 0 CPU (0 times, 0 wall, 0.21 children, min: 0 max: 0) +//| -SF RS LINEARIZE: 0.14 CPU +//(10000 times, 0.131202 wall, 0.14 children, min: 0 max: 0) +//| -RS LINEARIZE: 0.06 CPU +//(10000 times, 0.066951 wall, 0.06 children, min: 0 max: 0) /* *************************************************************************/ TEST(SmartProjectionPoseFactorRollingShutter, timing) { using namespace vanillaPose; // Default cameras for simple derivatives static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with RS factors Rot3 R = Rot3::identity(); Pose3 pose1 = Pose3(R, Point3(0, 0, 0)); @@ -1102,16 +1346,18 @@ TEST(SmartProjectionPoseFactorRollingShutter, timing) { measurements_lmk1.push_back(cam1.project(landmark1)); measurements_lmk1.push_back(cam2.project(landmark1)); - size_t nrTests = 1000; + size_t nrTests = 10000; for (size_t i = 0; i < nrTests; i++) { - SmartFactorRS::shared_ptr smartFactorRS(new SmartFactorRS(model)); + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(body_P_sensorId, sharedKSimple)); + + SmartFactorRS::shared_ptr smartFactorRS(new SmartFactorRS( + model, cameraRig, params)); double interp_factor = 0; // equivalent to measurement taken at pose 1 - smartFactorRS->add(measurements_lmk1[0], x1, x2, interp_factor, - sharedKSimple, body_P_sensorId); + smartFactorRS->add(measurements_lmk1[0], x1, x2, interp_factor); interp_factor = 1; // equivalent to measurement taken at pose 2 - smartFactorRS->add(measurements_lmk1[1], x1, x2, interp_factor, - sharedKSimple, body_P_sensorId); + smartFactorRS->add(measurements_lmk1[1], x1, x2, interp_factor); Values values; values.insert(x1, pose1); @@ -1122,7 +1368,8 @@ TEST(SmartProjectionPoseFactorRollingShutter, timing) { } for (size_t i = 0; i < nrTests; i++) { - SmartFactor::shared_ptr smartFactor(new SmartFactor(model, sharedKSimple)); + SmartFactor::shared_ptr smartFactor( + new SmartFactor(model, sharedKSimple, params)); smartFactor->add(measurements_lmk1[0], x1); smartFactor->add(measurements_lmk1[1], x2); @@ -1137,6 +1384,105 @@ TEST(SmartProjectionPoseFactorRollingShutter, timing) { } #endif +#include +/* ************************************************************************* */ +// spherical Camera with rolling shutter effect +namespace sphericalCameraRS { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionPoseFactorRollingShutter SmartFactorRS_spherical; +Pose3 interp_pose1 = interpolate(level_pose, pose_right, interp_factor1); +Pose3 interp_pose2 = interpolate(pose_right, pose_above, interp_factor2); +Pose3 interp_pose3 = interpolate(pose_above, level_pose, interp_factor3); +static EmptyCal::shared_ptr emptyK(new EmptyCal()); +Camera cam1(interp_pose1, emptyK); +Camera cam2(interp_pose2, emptyK); +Camera cam3(interp_pose3, emptyK); +} // namespace sphericalCameraRS + +/* *************************************************************************/ +TEST(SmartProjectionPoseFactorRollingShutter, + optimization_3poses_sphericalCameras) { + using namespace sphericalCameraRS; + std::vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + std::vector> key_pairs; + key_pairs.push_back(std::make_pair(x1, x2)); + key_pairs.push_back(std::make_pair(x2, x3)); + key_pairs.push_back(std::make_pair(x3, x1)); + + std::vector interp_factors; + interp_factors.push_back(interp_factor1); + interp_factors.push_back(interp_factor2); + interp_factors.push_back(interp_factor3); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with RS factors + params.setRankTolerance(0.1); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartFactorRS_spherical::shared_ptr smartFactor1( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor2( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor3( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam_unstable/slam/tests/testSmartStereoProjectionFactorPP.cpp b/gtsam_unstable/slam/tests/testSmartStereoProjectionFactorPP.cpp index 61836d098..c71c19038 100644 --- a/gtsam_unstable/slam/tests/testSmartStereoProjectionFactorPP.cpp +++ b/gtsam_unstable/slam/tests/testSmartStereoProjectionFactorPP.cpp @@ -31,12 +31,13 @@ using namespace std; using namespace boost::assign; using namespace gtsam; +namespace { // make a realistic calibration matrix static double b = 1; static Cal3_S2Stereo::shared_ptr K(new Cal3_S2Stereo(fov, w, h, b)); -static Cal3_S2Stereo::shared_ptr K2( - new Cal3_S2Stereo(1500, 1200, 0, 640, 480, b)); +static Cal3_S2Stereo::shared_ptr K2(new Cal3_S2Stereo(1500, 1200, 0, 640, 480, + b)); static SmartStereoProjectionParams params; @@ -45,8 +46,8 @@ static SmartStereoProjectionParams params; static SharedNoiseModel model(noiseModel::Isotropic::Sigma(3, 0.1)); // Convenience for named keys -using symbol_shorthand::X; using symbol_shorthand::L; +using symbol_shorthand::X; // tests data static Symbol x1('X', 1); @@ -59,16 +60,19 @@ static Symbol body_P_cam3_key('P', 3); static Key poseKey1(x1); static Key poseExtrinsicKey1(body_P_cam1_key); static Key poseExtrinsicKey2(body_P_cam2_key); -static StereoPoint2 measurement1(323.0, 300.0, 240.0); //potentially use more reasonable measurement value? -static StereoPoint2 measurement2(350.0, 200.0, 240.0); //potentially use more reasonable measurement value? +static StereoPoint2 measurement1( + 323.0, 300.0, 240.0); // potentially use more reasonable measurement value? +static StereoPoint2 measurement2( + 350.0, 200.0, 240.0); // potentially use more reasonable measurement value? static Pose3 body_P_sensor1(Rot3::RzRyRx(-M_PI_2, 0.0, -M_PI_2), - Point3(0.25, -0.10, 1.0)); + Point3(0.25, -0.10, 1.0)); static double missing_uR = std::numeric_limits::quiet_NaN(); vector stereo_projectToMultipleCameras(const StereoCamera& cam1, - const StereoCamera& cam2, const StereoCamera& cam3, Point3 landmark) { - + const StereoCamera& cam2, + const StereoCamera& cam3, + Point3 landmark) { vector measurements_cam; StereoPoint2 cam1_uv1 = cam1.project(landmark); @@ -82,6 +86,7 @@ vector stereo_projectToMultipleCameras(const StereoCamera& cam1, } LevenbergMarquardtParams lm_params; +} // namespace /* ************************************************************************* */ TEST( SmartStereoProjectionFactorPP, params) { diff --git a/gtsam_unstable/slam/tests/testSmartStereoProjectionPoseFactor.cpp b/gtsam_unstable/slam/tests/testSmartStereoProjectionPoseFactor.cpp index a0bfc3649..fc56b1a9f 100644 --- a/gtsam_unstable/slam/tests/testSmartStereoProjectionPoseFactor.cpp +++ b/gtsam_unstable/slam/tests/testSmartStereoProjectionPoseFactor.cpp @@ -32,13 +32,13 @@ using namespace std; using namespace boost::assign; using namespace gtsam; +namespace { // make a realistic calibration matrix static double b = 1; static Cal3_S2Stereo::shared_ptr K(new Cal3_S2Stereo(fov, w, h, b)); -static Cal3_S2Stereo::shared_ptr K2( - new Cal3_S2Stereo(1500, 1200, 0, 640, 480, b)); - +static Cal3_S2Stereo::shared_ptr K2(new Cal3_S2Stereo(1500, 1200, 0, 640, 480, + b)); static SmartStereoProjectionParams params; @@ -47,8 +47,8 @@ static SmartStereoProjectionParams params; static SharedNoiseModel model(noiseModel::Isotropic::Sigma(3, 0.1)); // Convenience for named keys -using symbol_shorthand::X; using symbol_shorthand::L; +using symbol_shorthand::X; // tests data static Symbol x1('X', 1); @@ -56,15 +56,17 @@ static Symbol x2('X', 2); static Symbol x3('X', 3); static Key poseKey1(x1); -static StereoPoint2 measurement1(323.0, 300.0, 240.0); //potentially use more reasonable measurement value? +static StereoPoint2 measurement1( + 323.0, 300.0, 240.0); // potentially use more reasonable measurement value? static Pose3 body_P_sensor1(Rot3::RzRyRx(-M_PI_2, 0.0, -M_PI_2), - Point3(0.25, -0.10, 1.0)); + Point3(0.25, -0.10, 1.0)); static double missing_uR = std::numeric_limits::quiet_NaN(); vector stereo_projectToMultipleCameras(const StereoCamera& cam1, - const StereoCamera& cam2, const StereoCamera& cam3, Point3 landmark) { - + const StereoCamera& cam2, + const StereoCamera& cam3, + Point3 landmark) { vector measurements_cam; StereoPoint2 cam1_uv1 = cam1.project(landmark); @@ -78,6 +80,7 @@ vector stereo_projectToMultipleCameras(const StereoCamera& cam1, } LevenbergMarquardtParams lm_params; +} // namespace /* ************************************************************************* */ TEST( SmartStereoProjectionPoseFactor, params) { diff --git a/matlab/+gtsam/Contents.m b/matlab/+gtsam/Contents.m index fb6d3081e..77536e5c9 100644 --- a/matlab/+gtsam/Contents.m +++ b/matlab/+gtsam/Contents.m @@ -49,9 +49,6 @@ % Ordering - class Ordering, see Doxygen page for details % Value - class Value, see Doxygen page for details % Values - class Values, see Doxygen page for details -% LieScalar - class LieScalar, see Doxygen page for details -% LieVector - class LieVector, see Doxygen page for details -% LieMatrix - class LieMatrix, see Doxygen page for details % NonlinearFactor - class NonlinearFactor, see Doxygen page for details % NonlinearFactorGraph - class NonlinearFactorGraph, see Doxygen page for details % @@ -101,9 +98,6 @@ % BearingFactor2D - class BearingFactor2D, see Doxygen page for details % BearingFactor3D - class BearingFactor3D, see Doxygen page for details % BearingRangeFactor2D - class BearingRangeFactor2D, see Doxygen page for details -% BetweenFactorLieMatrix - class BetweenFactorLieMatrix, see Doxygen page for details -% BetweenFactorLieScalar - class BetweenFactorLieScalar, see Doxygen page for details -% BetweenFactorLieVector - class BetweenFactorLieVector, see Doxygen page for details % BetweenFactorPoint2 - class BetweenFactorPoint2, see Doxygen page for details % BetweenFactorPoint3 - class BetweenFactorPoint3, see Doxygen page for details % BetweenFactorPose2 - class BetweenFactorPose2, see Doxygen page for details @@ -116,9 +110,6 @@ % GenericStereoFactor3D - class GenericStereoFactor3D, see Doxygen page for details % NonlinearEqualityCal3_S2 - class NonlinearEqualityCal3_S2, see Doxygen page for details % NonlinearEqualityCalibratedCamera - class NonlinearEqualityCalibratedCamera, see Doxygen page for details -% NonlinearEqualityLieMatrix - class NonlinearEqualityLieMatrix, see Doxygen page for details -% NonlinearEqualityLieScalar - class NonlinearEqualityLieScalar, see Doxygen page for details -% NonlinearEqualityLieVector - class NonlinearEqualityLieVector, see Doxygen page for details % NonlinearEqualityPoint2 - class NonlinearEqualityPoint2, see Doxygen page for details % NonlinearEqualityPoint3 - class NonlinearEqualityPoint3, see Doxygen page for details % NonlinearEqualityPose2 - class NonlinearEqualityPose2, see Doxygen page for details @@ -129,9 +120,6 @@ % NonlinearEqualityStereoPoint2 - class NonlinearEqualityStereoPoint2, see Doxygen page for details % PriorFactorCal3_S2 - class PriorFactorCal3_S2, see Doxygen page for details % PriorFactorCalibratedCamera - class PriorFactorCalibratedCamera, see Doxygen page for details -% PriorFactorLieMatrix - class PriorFactorLieMatrix, see Doxygen page for details -% PriorFactorLieScalar - class PriorFactorLieScalar, see Doxygen page for details -% PriorFactorLieVector - class PriorFactorLieVector, see Doxygen page for details % PriorFactorPoint2 - class PriorFactorPoint2, see Doxygen page for details % PriorFactorPoint3 - class PriorFactorPoint3, see Doxygen page for details % PriorFactorPose2 - class PriorFactorPose2, see Doxygen page for details diff --git a/matlab/+gtsam/VisualISAMInitialize.m b/matlab/+gtsam/VisualISAMInitialize.m index 29f8b3b46..560503345 100644 --- a/matlab/+gtsam/VisualISAMInitialize.m +++ b/matlab/+gtsam/VisualISAMInitialize.m @@ -7,16 +7,16 @@ import gtsam.* %% Initialize iSAM params = gtsam.ISAM2Params; if options.alwaysRelinearize - params.setRelinearizeSkip(1); + params.relinearizeSkip = 1; end isam = ISAM2(params); %% Set Noise parameters -noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); +noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]', true); %noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); -noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]'); -noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1); -noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0); +noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]', true); +noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1, true); +noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0, true); %% Add constraints/priors % TODO: should not be from ground truth! diff --git a/matlab/CMakeLists.txt b/matlab/CMakeLists.txt index 28e7cce6e..a657c6be7 100644 --- a/matlab/CMakeLists.txt +++ b/matlab/CMakeLists.txt @@ -64,8 +64,23 @@ set(ignore gtsam::Point3 gtsam::CustomFactor) +set(interface_files + ${GTSAM_SOURCE_DIR}/gtsam/gtsam.i + ${GTSAM_SOURCE_DIR}/gtsam/base/base.i + ${GTSAM_SOURCE_DIR}/gtsam/basis/basis.i + ${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i + ${GTSAM_SOURCE_DIR}/gtsam/geometry/geometry.i + ${GTSAM_SOURCE_DIR}/gtsam/linear/linear.i + ${GTSAM_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i + ${GTSAM_SOURCE_DIR}/gtsam/symbolic/symbolic.i + ${GTSAM_SOURCE_DIR}/gtsam/sam/sam.i + ${GTSAM_SOURCE_DIR}/gtsam/slam/slam.i + ${GTSAM_SOURCE_DIR}/gtsam/sfm/sfm.i + ${GTSAM_SOURCE_DIR}/gtsam/navigation/navigation.i +) # Wrap -matlab_wrap(${GTSAM_SOURCE_DIR}/gtsam/gtsam.i "${GTSAM_ADDITIONAL_LIBRARIES}" +matlab_wrap("${interface_files}" "gtsam" "${GTSAM_ADDITIONAL_LIBRARIES}" "" "${mexFlags}" "${ignore}") # Wrap version for gtsam_unstable @@ -77,8 +92,8 @@ if(GTSAM_UNSTABLE_INSTALL_MATLAB_TOOLBOX) endif() # Wrap - matlab_wrap(${GTSAM_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i "gtsam" "" - "${mexFlags}" "${ignore}") + matlab_wrap(${GTSAM_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i "gtsam_unstable" + "${GTSAM_ADDITIONAL_LIBRARIES}" "" "${mexFlags}" "${ignore}") endif(GTSAM_UNSTABLE_INSTALL_MATLAB_TOOLBOX) # Record the root dir for gtsam - needed during external builds, e.g., ROS diff --git a/matlab/gtsam_examples/IMUKittiExampleGPS.m b/matlab/gtsam_examples/IMUKittiExampleGPS.m index 530c3382c..b350618e5 100755 --- a/matlab/gtsam_examples/IMUKittiExampleGPS.m +++ b/matlab/gtsam_examples/IMUKittiExampleGPS.m @@ -52,7 +52,7 @@ IMU_params.setOmegaCoriolis(w_coriolis); %% Solver object isamParams = ISAM2Params; isamParams.setFactorization('CHOLESKY'); -isamParams.setRelinearizeSkip(10); +isamParams.relinearizeSkip = 10; isam = gtsam.ISAM2(isamParams); newFactors = NonlinearFactorGraph; newValues = Values; diff --git a/matlab/gtsam_tests/testUtilities.m b/matlab/gtsam_tests/testUtilities.m index da8dec789..2bfe81a83 100644 --- a/matlab/gtsam_tests/testUtilities.m +++ b/matlab/gtsam_tests/testUtilities.m @@ -45,3 +45,12 @@ CHECK('KeySet', isa(actual,'gtsam.KeySet')); CHECK('size==3', actual.size==3); CHECK('actual.count(x1)', actual.count(x1)); +% test extractVectors +values = Values(); +values.insert(symbol('x', 0), (1:6)'); +values.insert(symbol('x', 1), (7:12)'); +values.insert(symbol('x', 2), (13:18)'); +values.insert(symbol('x', 7), Pose3()); +actual = utilities.extractVectors(values, 'x'); +expected = reshape(1:18, 6, 3)'; +CHECK('extractVectors', all(actual == expected, 'all')); diff --git a/matlab/unstable_examples/+imuSimulator/IMUComparison.m b/matlab/unstable_examples/+imuSimulator/IMUComparison.m index 871f023ef..b753916c6 100644 --- a/matlab/unstable_examples/+imuSimulator/IMUComparison.m +++ b/matlab/unstable_examples/+imuSimulator/IMUComparison.m @@ -46,18 +46,18 @@ posesIMUbody(1).R = poses(1).R; %% Solver object isamParams = ISAM2Params; -isamParams.setRelinearizeSkip(1); +isamParams.relinearizeSkip = 1; isam = gtsam.ISAM2(isamParams); initialValues = Values; initialValues.insert(symbol('x',0), currentPoseGlobal); -initialValues.insert(symbol('v',0), LieVector(currentVelocityGlobal)); +initialValues.insert(symbol('v',0), currentVelocityGlobal); initialValues.insert(symbol('b',0), imuBias.ConstantBias([0;0;0],[0;0;0])); initialFactors = NonlinearFactorGraph; initialFactors.add(PriorFactorPose3(symbol('x',0), ... currentPoseGlobal, noiseModel.Isotropic.Sigma(6, 1.0))); -initialFactors.add(PriorFactorLieVector(symbol('v',0), ... - LieVector(currentVelocityGlobal), noiseModel.Isotropic.Sigma(3, 1.0))); +initialFactors.add(PriorFactorVector(symbol('v',0), ... + currentVelocityGlobal, noiseModel.Isotropic.Sigma(3, 1.0))); initialFactors.add(PriorFactorConstantBias(symbol('b',0), ... imuBias.ConstantBias([0;0;0],[0;0;0]), noiseModel.Isotropic.Sigma(6, 1.0))); @@ -96,7 +96,7 @@ for t = times initialVel = isam.calculateEstimate(symbol('v',lastSummaryIndex)); else initialPose = Pose3; - initialVel = LieVector(velocity); + initialVel = velocity; end initialValues.insert(symbol('x',lastSummaryIndex+1), initialPose); initialValues.insert(symbol('v',lastSummaryIndex+1), initialVel); diff --git a/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m b/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m index 450697de0..689d8a3f5 100644 --- a/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m +++ b/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m @@ -34,7 +34,7 @@ poses(1).R = currentPoseGlobal.rotation.matrix; %% Solver object isamParams = ISAM2Params; -isamParams.setRelinearizeSkip(1); +isamParams.relinearizeSkip = 1; isam = gtsam.ISAM2(isamParams); sigma_init_x = 1.0; @@ -43,15 +43,15 @@ sigma_init_b = 1.0; initialValues = Values; initialValues.insert(symbol('x',0), currentPoseGlobal); -initialValues.insert(symbol('v',0), LieVector(currentVelocityGlobal)); +initialValues.insert(symbol('v',0), currentVelocityGlobal); initialValues.insert(symbol('b',0), imuBias.ConstantBias([0;0;0],[0;0;0])); initialFactors = NonlinearFactorGraph; % Prior on initial pose initialFactors.add(PriorFactorPose3(symbol('x',0), ... currentPoseGlobal, noiseModel.Isotropic.Sigma(6, sigma_init_x))); % Prior on initial velocity -initialFactors.add(PriorFactorLieVector(symbol('v',0), ... - LieVector(currentVelocityGlobal), noiseModel.Isotropic.Sigma(3, sigma_init_v))); +initialFactors.add(PriorFactorVector(symbol('v',0), ... + currentVelocityGlobal, noiseModel.Isotropic.Sigma(3, sigma_init_v))); % Prior on initial bias initialFactors.add(PriorFactorConstantBias(symbol('b',0), ... imuBias.ConstantBias([0;0;0],[0;0;0]), noiseModel.Isotropic.Sigma(6, sigma_init_b))); @@ -91,7 +91,7 @@ for t = times initialVel = isam.calculateEstimate(symbol('v',lastSummaryIndex)); else initialPose = Pose3; - initialVel = LieVector(velocity); + initialVel = velocity; end initialValues.insert(symbol('x',lastSummaryIndex+1), initialPose); initialValues.insert(symbol('v',lastSummaryIndex+1), initialVel); diff --git a/matlab/unstable_examples/+imuSimulator/coriolisExample.m b/matlab/unstable_examples/+imuSimulator/coriolisExample.m index ee4deb433..dd276e2c1 100644 --- a/matlab/unstable_examples/+imuSimulator/coriolisExample.m +++ b/matlab/unstable_examples/+imuSimulator/coriolisExample.m @@ -119,7 +119,7 @@ h = figure; % Solver object isamParams = ISAM2Params; isamParams.setFactorization('CHOLESKY'); -isamParams.setRelinearizeSkip(10); +isamParams.relinearizeSkip = 10; isam = gtsam.ISAM2(isamParams); newFactors = NonlinearFactorGraph; newValues = Values; @@ -175,9 +175,9 @@ for i = 1:length(times) % known initial conditions currentPoseEstimate = currentPoseFixedGT; if navFrameRotating == 1 - currentVelocityEstimate = LieVector(currentVelocityRotatingGT); + currentVelocityEstimate = currentVelocityRotatingGT; else - currentVelocityEstimate = LieVector(currentVelocityFixedGT); + currentVelocityEstimate = currentVelocityFixedGT; end % Set Priors @@ -186,7 +186,7 @@ for i = 1:length(times) newValues.insert(currentBiasKey, zeroBias); % Initial values, same for IMU types 1 and 2 newFactors.add(PriorFactorPose3(currentPoseKey, currentPoseEstimate, sigma_init_x)); - newFactors.add(PriorFactorLieVector(currentVelKey, currentVelocityEstimate, sigma_init_v)); + newFactors.add(PriorFactorVector(currentVelKey, currentVelocityEstimate, sigma_init_v)); newFactors.add(PriorFactorConstantBias(currentBiasKey, zeroBias, sigma_init_b)); % Store data diff --git a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m index 07f146dcb..037065ac5 100644 --- a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m +++ b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m @@ -27,7 +27,7 @@ for i=0:length(measurements) if options.includeIMUFactors == 1 currentVelKey = symbol('v', 0); currentVel = values.atPoint3(currentVelKey); - graph.add(PriorFactorLieVector(currentVelKey, LieVector(currentVel), noiseModels.noiseVel)); + graph.add(PriorFactorVector(currentVelKey, currentVel, noiseModels.noiseVel)); currentBiasKey = symbol('b', 0); currentBias = values.atPoint3(currentBiasKey); diff --git a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m index 3d8a9b5d2..5fb6589d6 100644 --- a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m +++ b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m @@ -82,7 +82,7 @@ if options.useRealData == 1 end % Add Values: velocity and bias - values.insert(currentVelKey, LieVector(currentVel)); + values.insert(currentVelKey, currentVel); values.insert(currentBiasKey, metadata.imu.zeroBias); end diff --git a/matlab/unstable_examples/FlightCameraTransformIMU.m b/matlab/unstable_examples/FlightCameraTransformIMU.m index d2f2bc34d..aeac2e243 100644 --- a/matlab/unstable_examples/FlightCameraTransformIMU.m +++ b/matlab/unstable_examples/FlightCameraTransformIMU.m @@ -167,7 +167,7 @@ for i=1:size(trajectory)-1 %% priors on first two poses if i < 3 - % fg.add(PriorFactorLieVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); + % fg.add(PriorFactorVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); fg.add(PriorFactorConstantBias(currentBiasKey, currentBias, sigma_init_b)); end diff --git a/matlab/unstable_examples/IMUKittiExampleAdvanced.m b/matlab/unstable_examples/IMUKittiExampleAdvanced.m index cb13eacee..b09764ec0 100644 --- a/matlab/unstable_examples/IMUKittiExampleAdvanced.m +++ b/matlab/unstable_examples/IMUKittiExampleAdvanced.m @@ -82,7 +82,7 @@ w_coriolis = [0;0;0]; %% Solver object isamParams = ISAM2Params; isamParams.setFactorization('QR'); -isamParams.setRelinearizeSkip(1); +isamParams.relinearizeSkip = 1; isam = gtsam.ISAM2(isamParams); newFactors = NonlinearFactorGraph; newValues = Values; diff --git a/matlab/unstable_examples/IMUKittiExampleVO.m b/matlab/unstable_examples/IMUKittiExampleVO.m index 6434e750a..4183e439a 100644 --- a/matlab/unstable_examples/IMUKittiExampleVO.m +++ b/matlab/unstable_examples/IMUKittiExampleVO.m @@ -46,7 +46,7 @@ clear logposes relposes %% Get initial conditions for the estimated trajectory currentPoseGlobal = Pose3; -currentVelocityGlobal = LieVector([0;0;0]); % the vehicle is stationary at the beginning +currentVelocityGlobal = [0;0;0]; % the vehicle is stationary at the beginning currentBias = imuBias.ConstantBias(zeros(3,1), zeros(3,1)); sigma_init_x = noiseModel.Isotropic.Sigmas([ 1.0; 1.0; 0.01; 0.01; 0.01; 0.01 ]); sigma_init_v = noiseModel.Isotropic.Sigma(3, 1000.0); @@ -58,7 +58,7 @@ w_coriolis = [0;0;0]; %% Solver object isamParams = ISAM2Params; isamParams.setFactorization('CHOLESKY'); -isamParams.setRelinearizeSkip(10); +isamParams.relinearizeSkip = 10; isam = gtsam.ISAM2(isamParams); newFactors = NonlinearFactorGraph; newValues = Values; @@ -88,7 +88,7 @@ for measurementIndex = 1:length(timestamps) newValues.insert(currentVelKey, currentVelocityGlobal); newValues.insert(currentBiasKey, currentBias); newFactors.add(PriorFactorPose3(currentPoseKey, currentPoseGlobal, sigma_init_x)); - newFactors.add(PriorFactorLieVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); + newFactors.add(PriorFactorVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); newFactors.add(PriorFactorConstantBias(currentBiasKey, currentBias, sigma_init_b)); else t_previous = timestamps(measurementIndex-1, 1); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index c3524adad..d3cbff32d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -33,8 +33,6 @@ add_custom_target(gtsam_unstable_header DEPENDS "${PROJECT_SOURCE_DIR}/gtsam_uns set(ignore gtsam::Point2 gtsam::Point3 - gtsam::LieVector - gtsam::LieMatrix gtsam::ISAM2ThresholdMapValue gtsam::FactorIndices gtsam::FactorIndexSet @@ -47,13 +45,17 @@ set(ignore gtsam::Point3Pairs gtsam::Pose3Pairs gtsam::Pose3Vector + gtsam::Rot3Vector gtsam::KeyVector gtsam::BinaryMeasurementsUnit3 + gtsam::DiscreteKey gtsam::KeyPairDoubleMap) set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i + ${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i @@ -65,8 +67,10 @@ set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/basis/basis.i ) +set(GTSAM_PYTHON_TARGET gtsam_py) +set(GTSAM_PYTHON_UNSTABLE_TARGET gtsam_unstable_py) -pybind_wrap(gtsam_py # target +pybind_wrap(${GTSAM_PYTHON_TARGET} # target "${interface_headers}" # interface_headers "gtsam.cpp" # generated_cpp "gtsam" # module_name @@ -78,7 +82,7 @@ pybind_wrap(gtsam_py # target ON # use_boost ) -set_target_properties(gtsam_py PROPERTIES +set_target_properties(${GTSAM_PYTHON_TARGET} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib" INSTALL_RPATH_USE_LINK_PATH TRUE OUTPUT_NAME "gtsam" @@ -87,26 +91,35 @@ set_target_properties(gtsam_py PROPERTIES RELWITHDEBINFO_POSTFIX "" # Otherwise you will have a wrong name ) +# Set the path for the GTSAM python module set(GTSAM_MODULE_PATH ${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam) -# Symlink all tests .py files to build folder. -create_symlinks("${CMAKE_CURRENT_SOURCE_DIR}/gtsam" +# Copy all python files to build folder. +copy_directory("${CMAKE_CURRENT_SOURCE_DIR}/gtsam" "${GTSAM_MODULE_PATH}") +# Hack to get python test and util files copied every time they are modified +file(GLOB GTSAM_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam/tests/*.py") +foreach(test_file ${GTSAM_PYTHON_TEST_FILES}) + configure_file(${test_file} "${GTSAM_MODULE_PATH}/tests/${test_file}" COPYONLY) +endforeach() +file(GLOB GTSAM_PYTHON_UTIL_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam/utils/*.py") +foreach(util_file ${GTSAM_PYTHON_UTIL_FILES}) + configure_file(${util_file} "${GTSAM_MODULE_PATH}/utils/${test_file}" COPYONLY) +endforeach() + # Common directory for data/datasets stored with the package. # This will store the data in the Python site package directly. file(COPY "${GTSAM_SOURCE_DIR}/examples/Data" DESTINATION "${GTSAM_MODULE_PATH}") # Add gtsam as a dependency to the install target -set(GTSAM_PYTHON_DEPENDENCIES gtsam_py) +set(GTSAM_PYTHON_DEPENDENCIES ${GTSAM_PYTHON_TARGET}) if(GTSAM_UNSTABLE_BUILD_PYTHON) set(ignore gtsam::Point2 gtsam::Point3 - gtsam::LieVector - gtsam::LieMatrix gtsam::ISAM2ThresholdMapValue gtsam::FactorIndices gtsam::FactorIndexSet @@ -122,7 +135,7 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) gtsam::CameraSetCal3Fisheye gtsam::KeyPairDoubleMap) - pybind_wrap(gtsam_unstable_py # target + pybind_wrap(${GTSAM_PYTHON_UNSTABLE_TARGET} # target ${PROJECT_SOURCE_DIR}/gtsam_unstable/gtsam_unstable.i # interface_header "gtsam_unstable.cpp" # generated_cpp "gtsam_unstable" # module_name @@ -134,7 +147,7 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) ON # use_boost ) - set_target_properties(gtsam_unstable_py PROPERTIES + set_target_properties(${GTSAM_PYTHON_UNSTABLE_TARGET} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib" INSTALL_RPATH_USE_LINK_PATH TRUE OUTPUT_NAME "gtsam_unstable" @@ -145,19 +158,25 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) set(GTSAM_UNSTABLE_MODULE_PATH ${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam_unstable) - # Symlink all tests .py files to build folder. - create_symlinks("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable" + # Copy all python files to build folder. + copy_directory("${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable" "${GTSAM_UNSTABLE_MODULE_PATH}") + # Hack to get python test files copied every time they are modified + file(GLOB GTSAM_UNSTABLE_PYTHON_TEST_FILES "${CMAKE_CURRENT_SOURCE_DIR}/gtsam_unstable/tests/*.py") + foreach(test_file ${GTSAM_UNSTABLE_PYTHON_TEST_FILES}) + configure_file(${test_file} "${GTSAM_UNSTABLE_MODULE_PATH}/tests/${test_file}" COPYONLY) + endforeach() + # Add gtsam_unstable to the install target - list(APPEND GTSAM_PYTHON_DEPENDENCIES gtsam_unstable_py) + list(APPEND GTSAM_PYTHON_DEPENDENCIES ${GTSAM_PYTHON_UNSTABLE_TARGET}) endif() # Add custom target so we can install with `make python-install` set(GTSAM_PYTHON_INSTALL_TARGET python-install) add_custom_target(${GTSAM_PYTHON_INSTALL_TARGET} - COMMAND ${PYTHON_EXECUTABLE} ${GTSAM_PYTHON_BUILD_DIRECTORY}/setup.py install + COMMAND ${PYTHON_EXECUTABLE} -m pip install --user . DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} WORKING_DIRECTORY ${GTSAM_PYTHON_BUILD_DIRECTORY}) @@ -168,5 +187,5 @@ add_custom_target( ${CMAKE_COMMAND} -E env # add package to python path so no need to install "PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}" ${PYTHON_EXECUTABLE} -m unittest discover -v -s . - DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} + DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES} WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests") diff --git a/python/README.md b/python/README.md index 54436df93..278d62094 100644 --- a/python/README.md +++ b/python/README.md @@ -8,6 +8,7 @@ For instructions on updating the version of the [wrap library](https://github.co ## Requirements +- Cmake >= 3.15 - If you want to build the GTSAM python library for a specific python version (eg 3.6), use the `-DGTSAM_PYTHON_VERSION=3.6` option when running `cmake` otherwise the default interpreter will be used. - If the interpreter is inside an environment (such as an anaconda environment or virtualenv environment), diff --git a/python/gtsam/__init__.py b/python/gtsam/__init__.py index 70a00c3dc..d00e47b65 100644 --- a/python/gtsam/__init__.py +++ b/python/gtsam/__init__.py @@ -1,6 +1,12 @@ -from . import utils -from .gtsam import * -from .utils import findExampleDataFile +"""Module definition file for GTSAM""" + +# pylint: disable=import-outside-toplevel, global-variable-not-assigned, possibly-unused-variable, import-error, import-self + +import sys + +from gtsam import gtsam, utils +from gtsam.gtsam import * +from gtsam.utils import findExampleDataFile def _init(): @@ -13,7 +19,7 @@ def _init(): def Point2(x=np.nan, y=np.nan): """Shim for the deleted Point2 type.""" if isinstance(x, np.ndarray): - assert x.shape == (2,), "Point2 takes 2-vector" + assert x.shape == (2, ), "Point2 takes 2-vector" return x # "copy constructor" return np.array([x, y], dtype=float) @@ -22,7 +28,7 @@ def _init(): def Point3(x=np.nan, y=np.nan, z=np.nan): """Shim for the deleted Point3 type.""" if isinstance(x, np.ndarray): - assert x.shape == (3,), "Point3 takes 3-vector" + assert x.shape == (3, ), "Point3 takes 3-vector" return x # "copy constructor" return np.array([x, y, z], dtype=float) diff --git a/python/gtsam/examples/CustomFactorExample.py b/python/gtsam/examples/CustomFactorExample.py index c7fe1e202..36c1e003d 100644 --- a/python/gtsam/examples/CustomFactorExample.py +++ b/python/gtsam/examples/CustomFactorExample.py @@ -9,15 +9,17 @@ CustomFactor demo that simulates a 1-D sensor fusion task. Author: Fan Jiang, Frank Dellaert """ +from functools import partial +from typing import List, Optional + import gtsam import numpy as np -from typing import List, Optional -from functools import partial +I = np.eye(1) -def simulate_car(): - # Simulate a car for one second +def simulate_car() -> List[float]: + """Simulate a car for one second""" x0 = 0 dt = 0.25 # 4 Hz, typical GPS v = 144 * 1000 / 3600 # 144 km/hour = 90mph, pretty fast @@ -26,46 +28,9 @@ def simulate_car(): return x -x = simulate_car() -print(f"Simulated car trajectory: {x}") - -# %% -add_noise = True # set this to False to run with "perfect" measurements - -# GPS measurements -sigma_gps = 3.0 # assume GPS is +/- 3m -g = [x[k] + (np.random.normal(scale=sigma_gps) if add_noise else 0) - for k in range(5)] - -# Odometry measurements -sigma_odo = 0.1 # assume Odometry is 10cm accurate at 4Hz -o = [x[k + 1] - x[k] + (np.random.normal(scale=sigma_odo) if add_noise else 0) - for k in range(4)] - -# Landmark measurements: -sigma_lm = 1 # assume landmark measurement is accurate up to 1m - -# Assume first landmark is at x=5, we measure it at time k=0 -lm_0 = 5.0 -z_0 = x[0] - lm_0 + (np.random.normal(scale=sigma_lm) if add_noise else 0) - -# Assume other landmark is at x=28, we measure it at time k=3 -lm_3 = 28.0 -z_3 = x[3] - lm_3 + (np.random.normal(scale=sigma_lm) if add_noise else 0) - -unknown = [gtsam.symbol('x', k) for k in range(5)] - -print("unknowns = ", list(map(gtsam.DefaultKeyFormatter, unknown))) - -# We now can use nonlinear factor graphs -factor_graph = gtsam.NonlinearFactorGraph() - -# Add factors for GPS measurements -I = np.eye(1) -gps_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_gps) - - -def error_gps(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobians: Optional[List[np.ndarray]]): +def error_gps(measurement: np.ndarray, this: gtsam.CustomFactor, + values: gtsam.Values, + jacobians: Optional[List[np.ndarray]]) -> float: """GPS Factor error function :param measurement: GPS measurement, to be filled with `partial` :param this: gtsam.CustomFactor handle @@ -82,36 +47,9 @@ def error_gps(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobia return error -# Add the GPS factors -for k in range(5): - gf = gtsam.CustomFactor(gps_model, [unknown[k]], partial(error_gps, np.array([g[k]]))) - factor_graph.add(gf) - -# New Values container -v = gtsam.Values() - -# Add initial estimates to the Values container -for i in range(5): - v.insert(unknown[i], np.array([0.0])) - -# Initialize optimizer -params = gtsam.GaussNewtonParams() -optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) - -# Optimize the factor graph -result = optimizer.optimize() - -# calculate the error from ground truth -error = np.array([(result.atVector(unknown[k]) - x[k])[0] for k in range(5)]) - -print("Result with only GPS") -print(result, np.round(error, 2), f"\nJ(X)={0.5 * np.sum(np.square(error))}") - -# Adding odometry will improve things a lot -odo_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_odo) - - -def error_odom(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobians: Optional[List[np.ndarray]]): +def error_odom(measurement: np.ndarray, this: gtsam.CustomFactor, + values: gtsam.Values, + jacobians: Optional[List[np.ndarray]]) -> float: """Odometry Factor error function :param measurement: Odometry measurement, to be filled with `partial` :param this: gtsam.CustomFactor handle @@ -130,25 +68,9 @@ def error_odom(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobi return error -for k in range(4): - odof = gtsam.CustomFactor(odo_model, [unknown[k], unknown[k + 1]], partial(error_odom, np.array([o[k]]))) - factor_graph.add(odof) - -params = gtsam.GaussNewtonParams() -optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) - -result = optimizer.optimize() - -error = np.array([(result.atVector(unknown[k]) - x[k])[0] for k in range(5)]) - -print("Result with GPS+Odometry") -print(result, np.round(error, 2), f"\nJ(X)={0.5 * np.sum(np.square(error))}") - -# This is great, but GPS noise is still apparent, so now we add the two landmarks -lm_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_lm) - - -def error_lm(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobians: Optional[List[np.ndarray]]): +def error_lm(measurement: np.ndarray, this: gtsam.CustomFactor, + values: gtsam.Values, + jacobians: Optional[List[np.ndarray]]) -> float: """Landmark Factor error function :param measurement: Landmark measurement, to be filled with `partial` :param this: gtsam.CustomFactor handle @@ -165,15 +87,120 @@ def error_lm(measurement: np.ndarray, this: gtsam.CustomFactor, values, jacobian return error -factor_graph.add(gtsam.CustomFactor(lm_model, [unknown[0]], partial(error_lm, np.array([lm_0 + z_0])))) -factor_graph.add(gtsam.CustomFactor(lm_model, [unknown[3]], partial(error_lm, np.array([lm_3 + z_3])))) +def main(): + """Main runner.""" -params = gtsam.GaussNewtonParams() -optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) + x = simulate_car() + print(f"Simulated car trajectory: {x}") -result = optimizer.optimize() + add_noise = True # set this to False to run with "perfect" measurements -error = np.array([(result.atVector(unknown[k]) - x[k])[0] for k in range(5)]) + # GPS measurements + sigma_gps = 3.0 # assume GPS is +/- 3m + g = [ + x[k] + (np.random.normal(scale=sigma_gps) if add_noise else 0) + for k in range(5) + ] -print("Result with GPS+Odometry+Landmark") -print(result, np.round(error, 2), f"\nJ(X)={0.5 * np.sum(np.square(error))}") + # Odometry measurements + sigma_odo = 0.1 # assume Odometry is 10cm accurate at 4Hz + o = [ + x[k + 1] - x[k] + + (np.random.normal(scale=sigma_odo) if add_noise else 0) + for k in range(4) + ] + + # Landmark measurements: + sigma_lm = 1 # assume landmark measurement is accurate up to 1m + + # Assume first landmark is at x=5, we measure it at time k=0 + lm_0 = 5.0 + z_0 = x[0] - lm_0 + (np.random.normal(scale=sigma_lm) if add_noise else 0) + + # Assume other landmark is at x=28, we measure it at time k=3 + lm_3 = 28.0 + z_3 = x[3] - lm_3 + (np.random.normal(scale=sigma_lm) if add_noise else 0) + + unknown = [gtsam.symbol('x', k) for k in range(5)] + + print("unknowns = ", list(map(gtsam.DefaultKeyFormatter, unknown))) + + # We now can use nonlinear factor graphs + factor_graph = gtsam.NonlinearFactorGraph() + + # Add factors for GPS measurements + gps_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_gps) + + # Add the GPS factors + for k in range(5): + gf = gtsam.CustomFactor(gps_model, [unknown[k]], + partial(error_gps, np.array([g[k]]))) + factor_graph.add(gf) + + # New Values container + v = gtsam.Values() + + # Add initial estimates to the Values container + for i in range(5): + v.insert(unknown[i], np.array([0.0])) + + # Initialize optimizer + params = gtsam.GaussNewtonParams() + optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) + + # Optimize the factor graph + result = optimizer.optimize() + + # calculate the error from ground truth + error = np.array([(result.atVector(unknown[k]) - x[k])[0] + for k in range(5)]) + + print("Result with only GPS") + print(result, np.round(error, 2), + f"\nJ(X)={0.5 * np.sum(np.square(error))}") + + # Adding odometry will improve things a lot + odo_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_odo) + + for k in range(4): + odof = gtsam.CustomFactor(odo_model, [unknown[k], unknown[k + 1]], + partial(error_odom, np.array([o[k]]))) + factor_graph.add(odof) + + params = gtsam.GaussNewtonParams() + optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) + + result = optimizer.optimize() + + error = np.array([(result.atVector(unknown[k]) - x[k])[0] + for k in range(5)]) + + print("Result with GPS+Odometry") + print(result, np.round(error, 2), + f"\nJ(X)={0.5 * np.sum(np.square(error))}") + + # This is great, but GPS noise is still apparent, so now we add the two landmarks + lm_model = gtsam.noiseModel.Isotropic.Sigma(1, sigma_lm) + + factor_graph.add( + gtsam.CustomFactor(lm_model, [unknown[0]], + partial(error_lm, np.array([lm_0 + z_0])))) + factor_graph.add( + gtsam.CustomFactor(lm_model, [unknown[3]], + partial(error_lm, np.array([lm_3 + z_3])))) + + params = gtsam.GaussNewtonParams() + optimizer = gtsam.GaussNewtonOptimizer(factor_graph, v, params) + + result = optimizer.optimize() + + error = np.array([(result.atVector(unknown[k]) - x[k])[0] + for k in range(5)]) + + print("Result with GPS+Odometry+Landmark") + print(result, np.round(error, 2), + f"\nJ(X)={0.5 * np.sum(np.square(error))}") + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/GPSFactorExample.py b/python/gtsam/examples/GPSFactorExample.py index 0bc0d1bf3..8eb663cb4 100644 --- a/python/gtsam/examples/GPSFactorExample.py +++ b/python/gtsam/examples/GPSFactorExample.py @@ -13,13 +13,8 @@ Author: Mandy Xie from __future__ import print_function -import numpy as np - import gtsam -import matplotlib.pyplot as plt -import gtsam.utils.plot as gtsam_plot - # ENU Origin is where the plane was in hold next to runway lat0 = 33.86998 lon0 = -84.30626 @@ -29,28 +24,34 @@ h0 = 274 GPS_NOISE = gtsam.noiseModel.Isotropic.Sigma(3, 0.1) PRIOR_NOISE = gtsam.noiseModel.Isotropic.Sigma(6, 0.25) -# Create an empty nonlinear factor graph -graph = gtsam.NonlinearFactorGraph() -# Add a prior on the first point, setting it to the origin -# A prior factor consists of a mean and a noise model (covariance matrix) -priorMean = gtsam.Pose3() # prior at origin -graph.add(gtsam.PriorFactorPose3(1, priorMean, PRIOR_NOISE)) +def main(): + """Main runner.""" + # Create an empty nonlinear factor graph + graph = gtsam.NonlinearFactorGraph() -# Add GPS factors -gps = gtsam.Point3(lat0, lon0, h0) -graph.add(gtsam.GPSFactor(1, gps, GPS_NOISE)) -print("\nFactor Graph:\n{}".format(graph)) + # Add a prior on the first point, setting it to the origin + # A prior factor consists of a mean and a noise model (covariance matrix) + priorMean = gtsam.Pose3() # prior at origin + graph.add(gtsam.PriorFactorPose3(1, priorMean, PRIOR_NOISE)) -# Create the data structure to hold the initialEstimate estimate to the solution -# For illustrative purposes, these have been deliberately set to incorrect values -initial = gtsam.Values() -initial.insert(1, gtsam.Pose3()) -print("\nInitial Estimate:\n{}".format(initial)) + # Add GPS factors + gps = gtsam.Point3(lat0, lon0, h0) + graph.add(gtsam.GPSFactor(1, gps, GPS_NOISE)) + print("\nFactor Graph:\n{}".format(graph)) -# optimize using Levenberg-Marquardt optimization -params = gtsam.LevenbergMarquardtParams() -optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) -result = optimizer.optimize() -print("\nFinal Result:\n{}".format(result)) + # Create the data structure to hold the initialEstimate estimate to the solution + # For illustrative purposes, these have been deliberately set to incorrect values + initial = gtsam.Values() + initial.insert(1, gtsam.Pose3()) + print("\nInitial Estimate:\n{}".format(initial)) + # optimize using Levenberg-Marquardt optimization + params = gtsam.LevenbergMarquardtParams() + optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) + result = optimizer.optimize() + print("\nFinal Result:\n{}".format(result)) + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/IMUKittiExampleGPS.py b/python/gtsam/examples/IMUKittiExampleGPS.py new file mode 100644 index 000000000..d00f633ba --- /dev/null +++ b/python/gtsam/examples/IMUKittiExampleGPS.py @@ -0,0 +1,366 @@ +""" +Example of application of ISAM2 for GPS-aided navigation on the KITTI VISION BENCHMARK SUITE + +Author: Varun Agrawal +""" +import argparse +from typing import List, Tuple + +import gtsam +import numpy as np +from gtsam import ISAM2, Pose3, noiseModel +from gtsam.symbol_shorthand import B, V, X + +GRAVITY = 9.8 + + +class KittiCalibration: + """Class to hold KITTI calibration info.""" + def __init__(self, body_ptx: float, body_pty: float, body_ptz: float, + body_prx: float, body_pry: float, body_prz: float, + accelerometer_sigma: float, gyroscope_sigma: float, + integration_sigma: float, accelerometer_bias_sigma: float, + gyroscope_bias_sigma: float, average_delta_t: float): + self.bodyTimu = Pose3(gtsam.Rot3.RzRyRx(body_prx, body_pry, body_prz), + gtsam.Point3(body_ptx, body_pty, body_ptz)) + self.accelerometer_sigma = accelerometer_sigma + self.gyroscope_sigma = gyroscope_sigma + self.integration_sigma = integration_sigma + self.accelerometer_bias_sigma = accelerometer_bias_sigma + self.gyroscope_bias_sigma = gyroscope_bias_sigma + self.average_delta_t = average_delta_t + + +class ImuMeasurement: + """An instance of an IMU measurement.""" + def __init__(self, time: float, dt: float, accelerometer: gtsam.Point3, + gyroscope: gtsam.Point3): + self.time = time + self.dt = dt + self.accelerometer = accelerometer + self.gyroscope = gyroscope + + +class GpsMeasurement: + """An instance of a GPS measurement.""" + def __init__(self, time: float, position: gtsam.Point3): + self.time = time + self.position = position + + +def loadImuData(imu_data_file: str) -> List[ImuMeasurement]: + """Helper to load the IMU data.""" + # Read IMU data + # Time dt accelX accelY accelZ omegaX omegaY omegaZ + imu_data_file = gtsam.findExampleDataFile(imu_data_file) + imu_measurements = [] + + print("-- Reading IMU measurements from file") + with open(imu_data_file, encoding='UTF-8') as imu_data: + data = imu_data.readlines() + for i in range(1, len(data)): # ignore the first line + time, dt, acc_x, acc_y, acc_z, gyro_x, gyro_y, gyro_z = map( + float, data[i].split(' ')) + imu_measurement = ImuMeasurement( + time, dt, np.asarray([acc_x, acc_y, acc_z]), + np.asarray([gyro_x, gyro_y, gyro_z])) + imu_measurements.append(imu_measurement) + + return imu_measurements + + +def loadGpsData(gps_data_file: str) -> List[GpsMeasurement]: + """Helper to load the GPS data.""" + # Read GPS data + # Time,X,Y,Z + gps_data_file = gtsam.findExampleDataFile(gps_data_file) + gps_measurements = [] + + print("-- Reading GPS measurements from file") + with open(gps_data_file, encoding='UTF-8') as gps_data: + data = gps_data.readlines() + for i in range(1, len(data)): + time, x, y, z = map(float, data[i].split(',')) + gps_measurement = GpsMeasurement(time, np.asarray([x, y, z])) + gps_measurements.append(gps_measurement) + + return gps_measurements + + +def loadKittiData( + imu_data_file: str = "KittiEquivBiasedImu.txt", + gps_data_file: str = "KittiGps_converted.txt", + imu_metadata_file: str = "KittiEquivBiasedImu_metadata.txt" +) -> Tuple[KittiCalibration, List[ImuMeasurement], List[GpsMeasurement]]: + """ + Load the KITTI Dataset. + """ + # Read IMU metadata and compute relative sensor pose transforms + # BodyPtx BodyPty BodyPtz BodyPrx BodyPry BodyPrz AccelerometerSigma + # GyroscopeSigma IntegrationSigma AccelerometerBiasSigma GyroscopeBiasSigma + # AverageDeltaT + imu_metadata_file = gtsam.findExampleDataFile(imu_metadata_file) + with open(imu_metadata_file, encoding='UTF-8') as imu_metadata: + print("-- Reading sensor metadata") + line = imu_metadata.readline() # Ignore the first line + line = imu_metadata.readline().strip() + data = list(map(float, line.split(' '))) + kitti_calibration = KittiCalibration(*data) + print("IMU metadata:", data) + + imu_measurements = loadImuData(imu_data_file) + gps_measurements = loadGpsData(gps_data_file) + + return kitti_calibration, imu_measurements, gps_measurements + + +def getImuParams(kitti_calibration: KittiCalibration): + """Get the IMU parameters from the KITTI calibration data.""" + w_coriolis = np.zeros(3) + + # Set IMU preintegration parameters + measured_acc_cov = np.eye(3) * np.power( + kitti_calibration.accelerometer_sigma, 2) + measured_omega_cov = np.eye(3) * np.power( + kitti_calibration.gyroscope_sigma, 2) + # error committed in integrating position from velocities + integration_error_cov = np.eye(3) * np.power( + kitti_calibration.integration_sigma, 2) + + imu_params = gtsam.PreintegrationParams.MakeSharedU(GRAVITY) + # acc white noise in continuous + imu_params.setAccelerometerCovariance(measured_acc_cov) + # integration uncertainty continuous + imu_params.setIntegrationCovariance(integration_error_cov) + # gyro white noise in continuous + imu_params.setGyroscopeCovariance(measured_omega_cov) + imu_params.setOmegaCoriolis(w_coriolis) + + return imu_params + + +def save_results(isam: gtsam.ISAM2, output_filename: str, first_gps_pose: int, + gps_measurements: List[GpsMeasurement]): + """Write the results from `isam` to `output_filename`.""" + # Save results to file + print("Writing results to file...") + with open(output_filename, 'w', encoding='UTF-8') as fp_out: + fp_out.write( + "#time(s),x(m),y(m),z(m),qx,qy,qz,qw,gt_x(m),gt_y(m),gt_z(m)\n") + + result = isam.calculateEstimate() + for i in range(first_gps_pose, len(gps_measurements)): + pose_key = X(i) + vel_key = V(i) + bias_key = B(i) + + pose = result.atPose3(pose_key) + velocity = result.atVector(vel_key) + bias = result.atConstantBias(bias_key) + + pose_quat = pose.rotation().toQuaternion() + gps = gps_measurements[i].position + + print(f"State at #{i}") + print(f"Pose:\n{pose}") + print(f"Velocity:\n{velocity}") + print(f"Bias:\n{bias}") + + fp_out.write("{},{},{},{},{},{},{},{},{},{},{}\n".format( + gps_measurements[i].time, pose.x(), pose.y(), pose.z(), + pose_quat.x(), pose_quat.y(), pose_quat.z(), pose_quat.w(), + gps[0], gps[1], gps[2])) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--output_filename", + default="IMUKittiExampleGPSResults.csv") + return parser.parse_args() + + +def optimize(gps_measurements: List[GpsMeasurement], + imu_measurements: List[ImuMeasurement], + sigma_init_x: gtsam.noiseModel.Diagonal, + sigma_init_v: gtsam.noiseModel.Diagonal, + sigma_init_b: gtsam.noiseModel.Diagonal, + noise_model_gps: gtsam.noiseModel.Diagonal, + kitti_calibration: KittiCalibration, first_gps_pose: int, + gps_skip: int) -> gtsam.ISAM2: + """Run ISAM2 optimization on the measurements.""" + # Set initial conditions for the estimated trajectory + # initial pose is the reference frame (navigation frame) + current_pose_global = Pose3(gtsam.Rot3(), + gps_measurements[first_gps_pose].position) + + # the vehicle is stationary at the beginning at position 0,0,0 + current_velocity_global = np.zeros(3) + current_bias = gtsam.imuBias.ConstantBias() # init with zero bias + + imu_params = getImuParams(kitti_calibration) + + # Set ISAM2 parameters and create ISAM2 solver object + isam_params = gtsam.ISAM2Params() + isam_params.setFactorization("CHOLESKY") + isam_params.relinearizeSkip = 10 + + isam = gtsam.ISAM2(isam_params) + + # Create the factor graph and values object that will store new factors and + # values to add to the incremental graph + new_factors = gtsam.NonlinearFactorGraph() + # values storing the initial estimates of new nodes in the factor graph + new_values = gtsam.Values() + + # Main loop: + # (1) we read the measurements + # (2) we create the corresponding factors in the graph + # (3) we solve the graph to obtain and optimal estimate of robot trajectory + print("-- Starting main loop: inference is performed at each time step, " + "but we plot trajectory every 10 steps") + + j = 0 + included_imu_measurement_count = 0 + + for i in range(first_gps_pose, len(gps_measurements)): + # At each non=IMU measurement we initialize a new node in the graph + current_pose_key = X(i) + current_vel_key = V(i) + current_bias_key = B(i) + t = gps_measurements[i].time + + if i == first_gps_pose: + # Create initial estimate and prior on initial pose, velocity, and biases + new_values.insert(current_pose_key, current_pose_global) + new_values.insert(current_vel_key, current_velocity_global) + new_values.insert(current_bias_key, current_bias) + + new_factors.addPriorPose3(current_pose_key, current_pose_global, + sigma_init_x) + new_factors.addPriorVector(current_vel_key, + current_velocity_global, sigma_init_v) + new_factors.addPriorConstantBias(current_bias_key, current_bias, + sigma_init_b) + else: + t_previous = gps_measurements[i - 1].time + + # Summarize IMU data between the previous GPS measurement and now + current_summarized_measurement = gtsam.PreintegratedImuMeasurements( + imu_params, current_bias) + + while (j < len(imu_measurements) + and imu_measurements[j].time <= t): + if imu_measurements[j].time >= t_previous: + current_summarized_measurement.integrateMeasurement( + imu_measurements[j].accelerometer, + imu_measurements[j].gyroscope, imu_measurements[j].dt) + included_imu_measurement_count += 1 + j += 1 + + # Create IMU factor + previous_pose_key = X(i - 1) + previous_vel_key = V(i - 1) + previous_bias_key = B(i - 1) + + new_factors.push_back( + gtsam.ImuFactor(previous_pose_key, previous_vel_key, + current_pose_key, current_vel_key, + previous_bias_key, + current_summarized_measurement)) + + # Bias evolution as given in the IMU metadata + sigma_between_b = gtsam.noiseModel.Diagonal.Sigmas( + np.asarray([ + np.sqrt(included_imu_measurement_count) * + kitti_calibration.accelerometer_bias_sigma + ] * 3 + [ + np.sqrt(included_imu_measurement_count) * + kitti_calibration.gyroscope_bias_sigma + ] * 3)) + + new_factors.push_back( + gtsam.BetweenFactorConstantBias(previous_bias_key, + current_bias_key, + gtsam.imuBias.ConstantBias(), + sigma_between_b)) + + # Create GPS factor + gps_pose = Pose3(current_pose_global.rotation(), + gps_measurements[i].position) + if (i % gps_skip) == 0: + new_factors.addPriorPose3(current_pose_key, gps_pose, + noise_model_gps) + new_values.insert(current_pose_key, gps_pose) + + print(f"############ POSE INCLUDED AT TIME {t} ############") + print(gps_pose.translation(), "\n") + else: + new_values.insert(current_pose_key, current_pose_global) + + # Add initial values for velocity and bias based on the previous + # estimates + new_values.insert(current_vel_key, current_velocity_global) + new_values.insert(current_bias_key, current_bias) + + # Update solver + # ======================================================================= + # We accumulate 2*GPSskip GPS measurements before updating the solver at + # first so that the heading becomes observable. + if i > (first_gps_pose + 2 * gps_skip): + print(f"############ NEW FACTORS AT TIME {t:.6f} ############") + new_factors.print() + + isam.update(new_factors, new_values) + + # Reset the newFactors and newValues list + new_factors.resize(0) + new_values.clear() + + # Extract the result/current estimates + result = isam.calculateEstimate() + + current_pose_global = result.atPose3(current_pose_key) + current_velocity_global = result.atVector(current_vel_key) + current_bias = result.atConstantBias(current_bias_key) + + print(f"############ POSE AT TIME {t} ############") + current_pose_global.print() + print("\n") + + return isam + + +def main(): + """Main runner.""" + args = parse_args() + kitti_calibration, imu_measurements, gps_measurements = loadKittiData() + + if not kitti_calibration.bodyTimu.equals(Pose3(), 1e-8): + raise ValueError( + "Currently only support IMUinBody is identity, i.e. IMU and body frame are the same" + ) + + # Configure different variables + first_gps_pose = 1 + gps_skip = 10 + + # Configure noise models + noise_model_gps = noiseModel.Diagonal.Precisions( + np.asarray([0, 0, 0] + [1.0 / 0.07] * 3)) + + sigma_init_x = noiseModel.Diagonal.Precisions( + np.asarray([0, 0, 0, 1, 1, 1])) + sigma_init_v = noiseModel.Diagonal.Sigmas(np.ones(3) * 1000.0) + sigma_init_b = noiseModel.Diagonal.Sigmas( + np.asarray([0.1] * 3 + [5.00e-05] * 3)) + + isam = optimize(gps_measurements, imu_measurements, sigma_init_x, + sigma_init_v, sigma_init_b, noise_model_gps, + kitti_calibration, first_gps_pose, gps_skip) + + save_results(isam, args.output_filename, first_gps_pose, gps_measurements) + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/ImuFactorExample.py b/python/gtsam/examples/ImuFactorExample.py index 86613234d..c86a4e216 100644 --- a/python/gtsam/examples/ImuFactorExample.py +++ b/python/gtsam/examples/ImuFactorExample.py @@ -10,31 +10,51 @@ A script validating and demonstrating the ImuFactor inference. Author: Frank Dellaert, Varun Agrawal """ -# pylint: disable=no-name-in-module,unused-import,arguments-differ +# pylint: disable=no-name-in-module,unused-import,arguments-differ,import-error,wrong-import-order from __future__ import print_function import argparse import math +import gtsam import matplotlib.pyplot as plt import numpy as np -from mpl_toolkits.mplot3d import Axes3D - -import gtsam from gtsam.symbol_shorthand import B, V, X from gtsam.utils.plot import plot_pose3 +from mpl_toolkits.mplot3d import Axes3D from PreintegrationExample import POSES_FIG, PreintegrationExample BIAS_KEY = B(0) +GRAVITY = 9.81 np.set_printoptions(precision=3, suppress=True) +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser("ImuFactorExample.py") + parser.add_argument("--twist_scenario", + default="sick_twist", + choices=("zero_twist", "forward_twist", "loop_twist", + "sick_twist")) + parser.add_argument("--time", + "-T", + default=12, + type=int, + help="Total navigation time in seconds") + parser.add_argument("--compute_covariances", + default=False, + action='store_true') + parser.add_argument("--verbose", default=False, action='store_true') + args = parser.parse_args() + return args + + class ImuFactorExample(PreintegrationExample): """Class to run example of the Imu Factor.""" - def __init__(self, twist_scenario="sick_twist"): + def __init__(self, twist_scenario: str = "sick_twist"): self.velocity = np.array([2, 0, 0]) self.priorNoise = gtsam.noiseModel.Isotropic.Sigma(6, 0.1) self.velNoise = gtsam.noiseModel.Isotropic.Sigma(3, 0.1) @@ -51,19 +71,30 @@ class ImuFactorExample(PreintegrationExample): gyroBias = np.array([0.1, 0.3, -0.1]) bias = gtsam.imuBias.ConstantBias(accBias, gyroBias) + params = gtsam.PreintegrationParams.MakeSharedU(GRAVITY) + + # Some arbitrary noise sigmas + gyro_sigma = 1e-3 + accel_sigma = 1e-3 + I_3x3 = np.eye(3) + params.setGyroscopeCovariance(gyro_sigma**2 * I_3x3) + params.setAccelerometerCovariance(accel_sigma**2 * I_3x3) + params.setIntegrationCovariance(1e-7**2 * I_3x3) + dt = 1e-2 super(ImuFactorExample, self).__init__(twist_scenarios[twist_scenario], - bias, dt) + bias, params, dt) - def addPrior(self, i, graph): - """Add priors at time step `i`.""" + def addPrior(self, i: int, graph: gtsam.NonlinearFactorGraph): + """Add a prior on the navigation state at time `i`.""" state = self.scenario.navState(i) graph.push_back( gtsam.PriorFactorPose3(X(i), state.pose(), self.priorNoise)) graph.push_back( gtsam.PriorFactorVector(V(i), state.velocity(), self.velNoise)) - def optimize(self, graph, initial): + def optimize(self, graph: gtsam.NonlinearFactorGraph, + initial: gtsam.Values): """Optimize using Levenberg-Marquardt optimization.""" params = gtsam.LevenbergMarquardtParams() params.setVerbosityLM("SUMMARY") @@ -71,24 +102,49 @@ class ImuFactorExample(PreintegrationExample): result = optimizer.optimize() return result - def plot(self, result): - """Plot resulting poses.""" + def plot(self, + values: gtsam.Values, + title: str = "Estimated Trajectory", + fignum: int = POSES_FIG + 1, + show: bool = False): + """ + Plot poses in values. + + Args: + values: The values object with the poses to plot. + title: The title of the plot. + fignum: The matplotlib figure number. + POSES_FIG is a value from the PreintegrationExample which we simply increment to generate a new figure. + show: Flag indicating whether to display the figure. + """ i = 0 - while result.exists(X(i)): - pose_i = result.atPose3(X(i)) - plot_pose3(POSES_FIG + 1, pose_i, 1) + while values.exists(X(i)): + pose_i = values.atPose3(X(i)) + plot_pose3(fignum, pose_i, 1) i += 1 - plt.title("Estimated Trajectory") + plt.title(title) - gtsam.utils.plot.set_axes_equal(POSES_FIG + 1) + gtsam.utils.plot.set_axes_equal(fignum) - print("Bias Values", result.atConstantBias(BIAS_KEY)) + print("Bias Values", values.atConstantBias(BIAS_KEY)) plt.ioff() - plt.show() - def run(self, T=12, compute_covariances=False, verbose=True): - """Main runner.""" + if show: + plt.show() + + def run(self, + T: int = 12, + compute_covariances: bool = False, + verbose: bool = True): + """ + Main runner. + + Args: + T: Total trajectory time. + compute_covariances: Flag indicating whether to compute marginal covariances. + verbose: Flag indicating if printing should be verbose. + """ graph = gtsam.NonlinearFactorGraph() # initialize data structure for pre-integrated IMU measurements @@ -173,25 +229,11 @@ class ImuFactorExample(PreintegrationExample): print("Covariance on vel {}:\n{}\n".format( i, marginals.marginalCovariance(V(i)))) - self.plot(result) + self.plot(result, show=True) if __name__ == '__main__': - parser = argparse.ArgumentParser("ImuFactorExample.py") - parser.add_argument("--twist_scenario", - default="sick_twist", - choices=("zero_twist", "forward_twist", "loop_twist", - "sick_twist")) - parser.add_argument("--time", - "-T", - default=12, - type=int, - help="Total time in seconds") - parser.add_argument("--compute_covariances", - default=False, - action='store_true') - parser.add_argument("--verbose", default=False, action='store_true') - args = parser.parse_args() + args = parse_args() ImuFactorExample(args.twist_scenario).run(args.time, args.compute_covariances, diff --git a/python/gtsam/examples/OdometryExample.py b/python/gtsam/examples/OdometryExample.py index 8b519ce9a..210aeb808 100644 --- a/python/gtsam/examples/OdometryExample.py +++ b/python/gtsam/examples/OdometryExample.py @@ -13,57 +13,60 @@ Author: Frank Dellaert from __future__ import print_function -import numpy as np - import gtsam - -import matplotlib.pyplot as plt import gtsam.utils.plot as gtsam_plot +import matplotlib.pyplot as plt +import numpy as np # Create noise models ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.2, 0.2, 0.1])) PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.3, 0.3, 0.1])) -# Create an empty nonlinear factor graph -graph = gtsam.NonlinearFactorGraph() -# Add a prior on the first pose, setting it to the origin -# A prior factor consists of a mean and a noise model (covariance matrix) -priorMean = gtsam.Pose2(0.0, 0.0, 0.0) # prior at origin -graph.add(gtsam.PriorFactorPose2(1, priorMean, PRIOR_NOISE)) +def main(): + """Main runner""" + # Create an empty nonlinear factor graph + graph = gtsam.NonlinearFactorGraph() -# Add odometry factors -odometry = gtsam.Pose2(2.0, 0.0, 0.0) -# For simplicity, we will use the same noise model for each odometry factor -# Create odometry (Between) factors between consecutive poses -graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, ODOMETRY_NOISE)) -graph.add(gtsam.BetweenFactorPose2(2, 3, odometry, ODOMETRY_NOISE)) -print("\nFactor Graph:\n{}".format(graph)) + # Add a prior on the first pose, setting it to the origin + # A prior factor consists of a mean and a noise model (covariance matrix) + priorMean = gtsam.Pose2(0.0, 0.0, 0.0) # prior at origin + graph.add(gtsam.PriorFactorPose2(1, priorMean, PRIOR_NOISE)) -# Create the data structure to hold the initialEstimate estimate to the solution -# For illustrative purposes, these have been deliberately set to incorrect values -initial = gtsam.Values() -initial.insert(1, gtsam.Pose2(0.5, 0.0, 0.2)) -initial.insert(2, gtsam.Pose2(2.3, 0.1, -0.2)) -initial.insert(3, gtsam.Pose2(4.1, 0.1, 0.1)) -print("\nInitial Estimate:\n{}".format(initial)) + # Add odometry factors + odometry = gtsam.Pose2(2.0, 0.0, 0.0) + # For simplicity, we will use the same noise model for each odometry factor + # Create odometry (Between) factors between consecutive poses + graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, ODOMETRY_NOISE)) + graph.add(gtsam.BetweenFactorPose2(2, 3, odometry, ODOMETRY_NOISE)) + print("\nFactor Graph:\n{}".format(graph)) -# optimize using Levenberg-Marquardt optimization -params = gtsam.LevenbergMarquardtParams() -optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) -result = optimizer.optimize() -print("\nFinal Result:\n{}".format(result)) + # Create the data structure to hold the initialEstimate estimate to the solution + # For illustrative purposes, these have been deliberately set to incorrect values + initial = gtsam.Values() + initial.insert(1, gtsam.Pose2(0.5, 0.0, 0.2)) + initial.insert(2, gtsam.Pose2(2.3, 0.1, -0.2)) + initial.insert(3, gtsam.Pose2(4.1, 0.1, 0.1)) + print("\nInitial Estimate:\n{}".format(initial)) -# 5. Calculate and print marginal covariances for all variables -marginals = gtsam.Marginals(graph, result) -for i in range(1, 4): - print("X{} covariance:\n{}\n".format(i, marginals.marginalCovariance(i))) - -fig = plt.figure(0) -for i in range(1, 4): - gtsam_plot.plot_pose2(0, result.atPose2(i), 0.5, marginals.marginalCovariance(i)) -plt.axis('equal') -plt.show() + # optimize using Levenberg-Marquardt optimization + params = gtsam.LevenbergMarquardtParams() + optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) + result = optimizer.optimize() + print("\nFinal Result:\n{}".format(result)) + # 5. Calculate and print marginal covariances for all variables + marginals = gtsam.Marginals(graph, result) + for i in range(1, 4): + print("X{} covariance:\n{}\n".format(i, + marginals.marginalCovariance(i))) + + for i in range(1, 4): + gtsam_plot.plot_pose2(0, result.atPose2(i), 0.5, + marginals.marginalCovariance(i)) + plt.axis('equal') + plt.show() +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/PlanarSLAMExample.py b/python/gtsam/examples/PlanarSLAMExample.py index 5ffdf048d..d2ee92c95 100644 --- a/python/gtsam/examples/PlanarSLAMExample.py +++ b/python/gtsam/examples/PlanarSLAMExample.py @@ -13,69 +13,85 @@ Author: Alex Cunningham (C++), Kevin Deng & Frank Dellaert (Python) from __future__ import print_function -import numpy as np - import gtsam -from gtsam.symbol_shorthand import X, L +import numpy as np +from gtsam.symbol_shorthand import L, X # Create noise models PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.3, 0.3, 0.1])) ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.2, 0.2, 0.1])) MEASUREMENT_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.1, 0.2])) -# Create an empty nonlinear factor graph -graph = gtsam.NonlinearFactorGraph() -# Create the keys corresponding to unknown variables in the factor graph -X1 = X(1) -X2 = X(2) -X3 = X(3) -L1 = L(4) -L2 = L(5) +def main(): + """Main runner""" -# Add a prior on pose X1 at the origin. A prior factor consists of a mean and a noise model -graph.add(gtsam.PriorFactorPose2(X1, gtsam.Pose2(0.0, 0.0, 0.0), PRIOR_NOISE)) + # Create an empty nonlinear factor graph + graph = gtsam.NonlinearFactorGraph() -# Add odometry factors between X1,X2 and X2,X3, respectively -graph.add(gtsam.BetweenFactorPose2( - X1, X2, gtsam.Pose2(2.0, 0.0, 0.0), ODOMETRY_NOISE)) -graph.add(gtsam.BetweenFactorPose2( - X2, X3, gtsam.Pose2(2.0, 0.0, 0.0), ODOMETRY_NOISE)) + # Create the keys corresponding to unknown variables in the factor graph + X1 = X(1) + X2 = X(2) + X3 = X(3) + L1 = L(4) + L2 = L(5) -# Add Range-Bearing measurements to two different landmarks L1 and L2 -graph.add(gtsam.BearingRangeFactor2D( - X1, L1, gtsam.Rot2.fromDegrees(45), np.sqrt(4.0+4.0), MEASUREMENT_NOISE)) -graph.add(gtsam.BearingRangeFactor2D( - X2, L1, gtsam.Rot2.fromDegrees(90), 2.0, MEASUREMENT_NOISE)) -graph.add(gtsam.BearingRangeFactor2D( - X3, L2, gtsam.Rot2.fromDegrees(90), 2.0, MEASUREMENT_NOISE)) + # Add a prior on pose X1 at the origin. A prior factor consists of a mean and a noise model + graph.add( + gtsam.PriorFactorPose2(X1, gtsam.Pose2(0.0, 0.0, 0.0), PRIOR_NOISE)) -# Print graph -print("Factor Graph:\n{}".format(graph)) + # Add odometry factors between X1,X2 and X2,X3, respectively + graph.add( + gtsam.BetweenFactorPose2(X1, X2, gtsam.Pose2(2.0, 0.0, 0.0), + ODOMETRY_NOISE)) + graph.add( + gtsam.BetweenFactorPose2(X2, X3, gtsam.Pose2(2.0, 0.0, 0.0), + ODOMETRY_NOISE)) -# Create (deliberately inaccurate) initial estimate -initial_estimate = gtsam.Values() -initial_estimate.insert(X1, gtsam.Pose2(-0.25, 0.20, 0.15)) -initial_estimate.insert(X2, gtsam.Pose2(2.30, 0.10, -0.20)) -initial_estimate.insert(X3, gtsam.Pose2(4.10, 0.10, 0.10)) -initial_estimate.insert(L1, gtsam.Point2(1.80, 2.10)) -initial_estimate.insert(L2, gtsam.Point2(4.10, 1.80)) + # Add Range-Bearing measurements to two different landmarks L1 and L2 + graph.add( + gtsam.BearingRangeFactor2D(X1, L1, gtsam.Rot2.fromDegrees(45), + np.sqrt(4.0 + 4.0), MEASUREMENT_NOISE)) + graph.add( + gtsam.BearingRangeFactor2D(X2, L1, gtsam.Rot2.fromDegrees(90), 2.0, + MEASUREMENT_NOISE)) + graph.add( + gtsam.BearingRangeFactor2D(X3, L2, gtsam.Rot2.fromDegrees(90), 2.0, + MEASUREMENT_NOISE)) -# Print -print("Initial Estimate:\n{}".format(initial_estimate)) + # Print graph + print("Factor Graph:\n{}".format(graph)) -# Optimize using Levenberg-Marquardt optimization. The optimizer -# accepts an optional set of configuration parameters, controlling -# things like convergence criteria, the type of linear system solver -# to use, and the amount of information displayed during optimization. -# Here we will use the default set of parameters. See the -# documentation for the full set of parameters. -params = gtsam.LevenbergMarquardtParams() -optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial_estimate, params) -result = optimizer.optimize() -print("\nFinal Result:\n{}".format(result)) + # Create (deliberately inaccurate) initial estimate + initial_estimate = gtsam.Values() + initial_estimate.insert(X1, gtsam.Pose2(-0.25, 0.20, 0.15)) + initial_estimate.insert(X2, gtsam.Pose2(2.30, 0.10, -0.20)) + initial_estimate.insert(X3, gtsam.Pose2(4.10, 0.10, 0.10)) + initial_estimate.insert(L1, gtsam.Point2(1.80, 2.10)) + initial_estimate.insert(L2, gtsam.Point2(4.10, 1.80)) -# Calculate and print marginal covariances for all variables -marginals = gtsam.Marginals(graph, result) -for (key, str) in [(X1, "X1"), (X2, "X2"), (X3, "X3"), (L1, "L1"), (L2, "L2")]: - print("{} covariance:\n{}\n".format(str, marginals.marginalCovariance(key))) + # Print + print("Initial Estimate:\n{}".format(initial_estimate)) + + # Optimize using Levenberg-Marquardt optimization. The optimizer + # accepts an optional set of configuration parameters, controlling + # things like convergence criteria, the type of linear system solver + # to use, and the amount of information displayed during optimization. + # Here we will use the default set of parameters. See the + # documentation for the full set of parameters. + params = gtsam.LevenbergMarquardtParams() + optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial_estimate, + params) + result = optimizer.optimize() + print("\nFinal Result:\n{}".format(result)) + + # Calculate and print marginal covariances for all variables + marginals = gtsam.Marginals(graph, result) + for (key, s) in [(X1, "X1"), (X2, "X2"), (X3, "X3"), (L1, "L1"), + (L2, "L2")]: + print("{} covariance:\n{}\n".format(s, + marginals.marginalCovariance(key))) + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/Pose2ISAM2Example.py b/python/gtsam/examples/Pose2ISAM2Example.py new file mode 100644 index 000000000..3a8de0317 --- /dev/null +++ b/python/gtsam/examples/Pose2ISAM2Example.py @@ -0,0 +1,178 @@ +""" +GTSAM Copyright 2010-2018, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved +Authors: Frank Dellaert, et al. (see THANKS for the full author list) + +See LICENSE for the license information + +Pose SLAM example using iSAM2 in the 2D plane. +Author: Jerred Chen, Yusuf Ali +Modeled after: + - VisualISAM2Example by: Duy-Nguyen Ta (C++), Frank Dellaert (Python) + - Pose2SLAMExample by: Alex Cunningham (C++), Kevin Deng & Frank Dellaert (Python) +""" + +import math + +import matplotlib.pyplot as plt +import numpy as np + +import gtsam +import gtsam.utils.plot as gtsam_plot + +def report_on_progress(graph: gtsam.NonlinearFactorGraph, current_estimate: gtsam.Values, + key: int): + """Print and plot incremental progress of the robot for 2D Pose SLAM using iSAM2.""" + + # Print the current estimates computed using iSAM2. + print("*"*50 + f"\nInference after State {key+1}:\n") + print(current_estimate) + + # Compute the marginals for all states in the graph. + marginals = gtsam.Marginals(graph, current_estimate) + + # Plot the newly updated iSAM2 inference. + fig = plt.figure(0) + axes = fig.gca() + plt.cla() + + i = 1 + while current_estimate.exists(i): + gtsam_plot.plot_pose2(0, current_estimate.atPose2(i), 0.5, marginals.marginalCovariance(i)) + i += 1 + + plt.axis('equal') + axes.set_xlim(-1, 5) + axes.set_ylim(-1, 3) + plt.pause(1) + +def determine_loop_closure(odom: np.ndarray, current_estimate: gtsam.Values, + key: int, xy_tol=0.6, theta_tol=17) -> int: + """Simple brute force approach which iterates through previous states + and checks for loop closure. + + Args: + odom: Vector representing noisy odometry (x, y, theta) measurement in the body frame. + current_estimate: The current estimates computed by iSAM2. + key: Key corresponding to the current state estimate of the robot. + xy_tol: Optional argument for the x-y measurement tolerance, in meters. + theta_tol: Optional argument for the theta measurement tolerance, in degrees. + Returns: + k: The key of the state which is helping add the loop closure constraint. + If loop closure is not found, then None is returned. + """ + if current_estimate: + prev_est = current_estimate.atPose2(key+1) + rotated_odom = prev_est.rotation().matrix() @ odom[:2] + curr_xy = np.array([prev_est.x() + rotated_odom[0], + prev_est.y() + rotated_odom[1]]) + curr_theta = prev_est.theta() + odom[2] + for k in range(1, key+1): + pose_xy = np.array([current_estimate.atPose2(k).x(), + current_estimate.atPose2(k).y()]) + pose_theta = current_estimate.atPose2(k).theta() + if (abs(pose_xy - curr_xy) <= xy_tol).all() and \ + (abs(pose_theta - curr_theta) <= theta_tol*np.pi/180): + return k + +def Pose2SLAM_ISAM2_example(): + """Perform 2D SLAM given the ground truth changes in pose as well as + simple loop closure detection.""" + plt.ion() + + # Declare the 2D translational standard deviations of the prior factor's Gaussian model, in meters. + prior_xy_sigma = 0.3 + + # Declare the 2D rotational standard deviation of the prior factor's Gaussian model, in degrees. + prior_theta_sigma = 5 + + # Declare the 2D translational standard deviations of the odometry factor's Gaussian model, in meters. + odometry_xy_sigma = 0.2 + + # Declare the 2D rotational standard deviation of the odometry factor's Gaussian model, in degrees. + odometry_theta_sigma = 5 + + # Although this example only uses linear measurements and Gaussian noise models, it is important + # to note that iSAM2 can be utilized to its full potential during nonlinear optimization. This example + # simply showcases how iSAM2 may be applied to a Pose2 SLAM problem. + PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([prior_xy_sigma, + prior_xy_sigma, + prior_theta_sigma*np.pi/180])) + ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([odometry_xy_sigma, + odometry_xy_sigma, + odometry_theta_sigma*np.pi/180])) + + # Create a Nonlinear factor graph as well as the data structure to hold state estimates. + graph = gtsam.NonlinearFactorGraph() + initial_estimate = gtsam.Values() + + # Create iSAM2 parameters which can adjust the threshold necessary to force relinearization and how many + # update calls are required to perform the relinearization. + parameters = gtsam.ISAM2Params() + parameters.setRelinearizeThreshold(0.1) + parameters.relinearizeSkip = 1 + isam = gtsam.ISAM2(parameters) + + # Create the ground truth odometry measurements of the robot during the trajectory. + true_odometry = [(2, 0, 0), + (2, 0, math.pi/2), + (2, 0, math.pi/2), + (2, 0, math.pi/2), + (2, 0, math.pi/2)] + + # Corrupt the odometry measurements with gaussian noise to create noisy odometry measurements. + odometry_measurements = [np.random.multivariate_normal(true_odom, ODOMETRY_NOISE.covariance()) + for true_odom in true_odometry] + + # Add the prior factor to the factor graph, and poorly initialize the prior pose to demonstrate + # iSAM2 incremental optimization. + graph.push_back(gtsam.PriorFactorPose2(1, gtsam.Pose2(0, 0, 0), PRIOR_NOISE)) + initial_estimate.insert(1, gtsam.Pose2(0.5, 0.0, 0.2)) + + # Initialize the current estimate which is used during the incremental inference loop. + current_estimate = initial_estimate + + for i in range(len(true_odometry)): + + # Obtain the noisy odometry that is received by the robot and corrupted by gaussian noise. + noisy_odom_x, noisy_odom_y, noisy_odom_theta = odometry_measurements[i] + + # Determine if there is loop closure based on the odometry measurement and the previous estimate of the state. + loop = determine_loop_closure(odometry_measurements[i], current_estimate, i, xy_tol=0.8, theta_tol=25) + + # Add a binary factor in between two existing states if loop closure is detected. + # Otherwise, add a binary factor between a newly observed state and the previous state. + if loop: + graph.push_back(gtsam.BetweenFactorPose2(i + 1, loop, + gtsam.Pose2(noisy_odom_x, noisy_odom_y, noisy_odom_theta), ODOMETRY_NOISE)) + else: + graph.push_back(gtsam.BetweenFactorPose2(i + 1, i + 2, + gtsam.Pose2(noisy_odom_x, noisy_odom_y, noisy_odom_theta), ODOMETRY_NOISE)) + + # Compute and insert the initialization estimate for the current pose using the noisy odometry measurement. + computed_estimate = current_estimate.atPose2(i + 1).compose(gtsam.Pose2(noisy_odom_x, + noisy_odom_y, + noisy_odom_theta)) + initial_estimate.insert(i + 2, computed_estimate) + + # Perform incremental update to iSAM2's internal Bayes tree, optimizing only the affected variables. + isam.update(graph, initial_estimate) + current_estimate = isam.calculateEstimate() + + # Report all current state estimates from the iSAM2 optimzation. + report_on_progress(graph, current_estimate, i) + initial_estimate.clear() + + # Print the final covariance matrix for each pose after completing inference on the trajectory. + marginals = gtsam.Marginals(graph, current_estimate) + i = 1 + for i in range(1, len(true_odometry)+1): + print(f"X{i} covariance:\n{marginals.marginalCovariance(i)}\n") + + plt.ioff() + plt.show() + + +if __name__ == "__main__": + Pose2SLAM_ISAM2_example() diff --git a/python/gtsam/examples/Pose2SLAMExample.py b/python/gtsam/examples/Pose2SLAMExample.py index 2ec190d73..300a70fbd 100644 --- a/python/gtsam/examples/Pose2SLAMExample.py +++ b/python/gtsam/examples/Pose2SLAMExample.py @@ -15,82 +15,88 @@ from __future__ import print_function import math -import numpy as np - import gtsam - -import matplotlib.pyplot as plt import gtsam.utils.plot as gtsam_plot +import matplotlib.pyplot as plt -def vector3(x, y, z): - """Create 3d double numpy array.""" - return np.array([x, y, z], dtype=float) +def main(): + """Main runner.""" + # Create noise models + PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(gtsam.Point3(0.3, 0.3, 0.1)) + ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas( + gtsam.Point3(0.2, 0.2, 0.1)) -# Create noise models -PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(vector3(0.3, 0.3, 0.1)) -ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(vector3(0.2, 0.2, 0.1)) + # 1. Create a factor graph container and add factors to it + graph = gtsam.NonlinearFactorGraph() -# 1. Create a factor graph container and add factors to it -graph = gtsam.NonlinearFactorGraph() + # 2a. Add a prior on the first pose, setting it to the origin + # A prior factor consists of a mean and a noise ODOMETRY_NOISE (covariance matrix) + graph.add(gtsam.PriorFactorPose2(1, gtsam.Pose2(0, 0, 0), PRIOR_NOISE)) -# 2a. Add a prior on the first pose, setting it to the origin -# A prior factor consists of a mean and a noise ODOMETRY_NOISE (covariance matrix) -graph.add(gtsam.PriorFactorPose2(1, gtsam.Pose2(0, 0, 0), PRIOR_NOISE)) + # 2b. Add odometry factors + # Create odometry (Between) factors between consecutive poses + graph.add( + gtsam.BetweenFactorPose2(1, 2, gtsam.Pose2(2, 0, 0), ODOMETRY_NOISE)) + graph.add( + gtsam.BetweenFactorPose2(2, 3, gtsam.Pose2(2, 0, math.pi / 2), + ODOMETRY_NOISE)) + graph.add( + gtsam.BetweenFactorPose2(3, 4, gtsam.Pose2(2, 0, math.pi / 2), + ODOMETRY_NOISE)) + graph.add( + gtsam.BetweenFactorPose2(4, 5, gtsam.Pose2(2, 0, math.pi / 2), + ODOMETRY_NOISE)) -# 2b. Add odometry factors -# Create odometry (Between) factors between consecutive poses -graph.add(gtsam.BetweenFactorPose2(1, 2, gtsam.Pose2(2, 0, 0), ODOMETRY_NOISE)) -graph.add(gtsam.BetweenFactorPose2( - 2, 3, gtsam.Pose2(2, 0, math.pi / 2), ODOMETRY_NOISE)) -graph.add(gtsam.BetweenFactorPose2( - 3, 4, gtsam.Pose2(2, 0, math.pi / 2), ODOMETRY_NOISE)) -graph.add(gtsam.BetweenFactorPose2( - 4, 5, gtsam.Pose2(2, 0, math.pi / 2), ODOMETRY_NOISE)) + # 2c. Add the loop closure constraint + # This factor encodes the fact that we have returned to the same pose. In real + # systems, these constraints may be identified in many ways, such as appearance-based + # techniques with camera images. We will use another Between Factor to enforce this constraint: + graph.add( + gtsam.BetweenFactorPose2(5, 2, gtsam.Pose2(2, 0, math.pi / 2), + ODOMETRY_NOISE)) + print("\nFactor Graph:\n{}".format(graph)) # print -# 2c. Add the loop closure constraint -# This factor encodes the fact that we have returned to the same pose. In real -# systems, these constraints may be identified in many ways, such as appearance-based -# techniques with camera images. We will use another Between Factor to enforce this constraint: -graph.add(gtsam.BetweenFactorPose2( - 5, 2, gtsam.Pose2(2, 0, math.pi / 2), ODOMETRY_NOISE)) -print("\nFactor Graph:\n{}".format(graph)) # print + # 3. Create the data structure to hold the initial_estimate estimate to the + # solution. For illustrative purposes, these have been deliberately set to incorrect values + initial_estimate = gtsam.Values() + initial_estimate.insert(1, gtsam.Pose2(0.5, 0.0, 0.2)) + initial_estimate.insert(2, gtsam.Pose2(2.3, 0.1, -0.2)) + initial_estimate.insert(3, gtsam.Pose2(4.1, 0.1, math.pi / 2)) + initial_estimate.insert(4, gtsam.Pose2(4.0, 2.0, math.pi)) + initial_estimate.insert(5, gtsam.Pose2(2.1, 2.1, -math.pi / 2)) + print("\nInitial Estimate:\n{}".format(initial_estimate)) # print -# 3. Create the data structure to hold the initial_estimate estimate to the -# solution. For illustrative purposes, these have been deliberately set to incorrect values -initial_estimate = gtsam.Values() -initial_estimate.insert(1, gtsam.Pose2(0.5, 0.0, 0.2)) -initial_estimate.insert(2, gtsam.Pose2(2.3, 0.1, -0.2)) -initial_estimate.insert(3, gtsam.Pose2(4.1, 0.1, math.pi / 2)) -initial_estimate.insert(4, gtsam.Pose2(4.0, 2.0, math.pi)) -initial_estimate.insert(5, gtsam.Pose2(2.1, 2.1, -math.pi / 2)) -print("\nInitial Estimate:\n{}".format(initial_estimate)) # print + # 4. Optimize the initial values using a Gauss-Newton nonlinear optimizer + # The optimizer accepts an optional set of configuration parameters, + # controlling things like convergence criteria, the type of linear + # system solver to use, and the amount of information displayed during + # optimization. We will set a few parameters as a demonstration. + parameters = gtsam.GaussNewtonParams() -# 4. Optimize the initial values using a Gauss-Newton nonlinear optimizer -# The optimizer accepts an optional set of configuration parameters, -# controlling things like convergence criteria, the type of linear -# system solver to use, and the amount of information displayed during -# optimization. We will set a few parameters as a demonstration. -parameters = gtsam.GaussNewtonParams() + # Stop iterating once the change in error between steps is less than this value + parameters.setRelativeErrorTol(1e-5) + # Do not perform more than N iteration steps + parameters.setMaxIterations(100) + # Create the optimizer ... + optimizer = gtsam.GaussNewtonOptimizer(graph, initial_estimate, parameters) + # ... and optimize + result = optimizer.optimize() + print("Final Result:\n{}".format(result)) -# Stop iterating once the change in error between steps is less than this value -parameters.setRelativeErrorTol(1e-5) -# Do not perform more than N iteration steps -parameters.setMaxIterations(100) -# Create the optimizer ... -optimizer = gtsam.GaussNewtonOptimizer(graph, initial_estimate, parameters) -# ... and optimize -result = optimizer.optimize() -print("Final Result:\n{}".format(result)) + # 5. Calculate and print marginal covariances for all variables + marginals = gtsam.Marginals(graph, result) + for i in range(1, 6): + print("X{} covariance:\n{}\n".format(i, + marginals.marginalCovariance(i))) -# 5. Calculate and print marginal covariances for all variables -marginals = gtsam.Marginals(graph, result) -for i in range(1, 6): - print("X{} covariance:\n{}\n".format(i, marginals.marginalCovariance(i))) + for i in range(1, 6): + gtsam_plot.plot_pose2(0, result.atPose2(i), 0.5, + marginals.marginalCovariance(i)) -fig = plt.figure(0) -for i in range(1, 6): - gtsam_plot.plot_pose2(0, result.atPose2(i), 0.5, marginals.marginalCovariance(i)) + plt.axis('equal') + plt.show() -plt.axis('equal') -plt.show() + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/Pose2SLAMExample_g2o.py b/python/gtsam/examples/Pose2SLAMExample_g2o.py index 97fb46c5f..cf029c049 100644 --- a/python/gtsam/examples/Pose2SLAMExample_g2o.py +++ b/python/gtsam/examples/Pose2SLAMExample_g2o.py @@ -12,77 +12,86 @@ and does the optimization. Output is written on a file, in g2o format # pylint: disable=invalid-name, E1101 from __future__ import print_function + import argparse -import math -import numpy as np -import matplotlib.pyplot as plt import gtsam +import matplotlib.pyplot as plt from gtsam.utils import plot -def vector3(x, y, z): - """Create 3d double numpy array.""" - return np.array([x, y, z], dtype=float) +def main(): + """Main runner.""" + + parser = argparse.ArgumentParser( + description="A 2D Pose SLAM example that reads input from g2o, " + "converts it to a factor graph and does the optimization. " + "Output is written on a file, in g2o format") + parser.add_argument('-i', '--input', help='input file g2o format') + parser.add_argument( + '-o', + '--output', + help="the path to the output file with optimized graph") + parser.add_argument('-m', + '--maxiter', + type=int, + help="maximum number of iterations for optimizer") + parser.add_argument('-k', + '--kernel', + choices=['none', 'huber', 'tukey'], + default="none", + help="Type of kernel used") + parser.add_argument("-p", + "--plot", + action="store_true", + help="Flag to plot results") + args = parser.parse_args() + + g2oFile = gtsam.findExampleDataFile("noisyToyGraph.txt") if args.input is None\ + else args.input + + maxIterations = 100 if args.maxiter is None else args.maxiter + + is3D = False + + graph, initial = gtsam.readG2o(g2oFile, is3D) + + assert args.kernel == "none", "Supplied kernel type is not yet implemented" + + # Add prior on the pose having index (key) = 0 + priorModel = gtsam.noiseModel.Diagonal.Variances(gtsam.Point3(1e-6, 1e-6, 1e-8)) + graph.add(gtsam.PriorFactorPose2(0, gtsam.Pose2(), priorModel)) + + params = gtsam.GaussNewtonParams() + params.setVerbosity("Termination") + params.setMaxIterations(maxIterations) + # parameters.setRelativeErrorTol(1e-5) + # Create the optimizer ... + optimizer = gtsam.GaussNewtonOptimizer(graph, initial, params) + # ... and optimize + result = optimizer.optimize() + + print("Optimization complete") + print("initial error = ", graph.error(initial)) + print("final error = ", graph.error(result)) + + if args.output is None: + print("\nFactor Graph:\n{}".format(graph)) + print("\nInitial Estimate:\n{}".format(initial)) + print("Final Result:\n{}".format(result)) + else: + outputFile = args.output + print("Writing results to file: ", outputFile) + graphNoKernel, _ = gtsam.readG2o(g2oFile, is3D) + gtsam.writeG2o(graphNoKernel, result, outputFile) + print("Done!") + + if args.plot: + resultPoses = gtsam.utilities.extractPose2(result) + for i in range(resultPoses.shape[0]): + plot.plot_pose2(1, gtsam.Pose2(resultPoses[i, :])) + plt.show() -parser = argparse.ArgumentParser( - description="A 2D Pose SLAM example that reads input from g2o, " - "converts it to a factor graph and does the optimization. " - "Output is written on a file, in g2o format") -parser.add_argument('-i', '--input', help='input file g2o format') -parser.add_argument('-o', '--output', - help="the path to the output file with optimized graph") -parser.add_argument('-m', '--maxiter', type=int, - help="maximum number of iterations for optimizer") -parser.add_argument('-k', '--kernel', choices=['none', 'huber', 'tukey'], - default="none", help="Type of kernel used") -parser.add_argument("-p", "--plot", action="store_true", - help="Flag to plot results") -args = parser.parse_args() - -g2oFile = gtsam.findExampleDataFile("noisyToyGraph.txt") if args.input is None\ - else args.input - -maxIterations = 100 if args.maxiter is None else args.maxiter - -is3D = False - -graph, initial = gtsam.readG2o(g2oFile, is3D) - -assert args.kernel == "none", "Supplied kernel type is not yet implemented" - -# Add prior on the pose having index (key) = 0 -priorModel = gtsam.noiseModel.Diagonal.Variances(vector3(1e-6, 1e-6, 1e-8)) -graph.add(gtsam.PriorFactorPose2(0, gtsam.Pose2(), priorModel)) - -params = gtsam.GaussNewtonParams() -params.setVerbosity("Termination") -params.setMaxIterations(maxIterations) -# parameters.setRelativeErrorTol(1e-5) -# Create the optimizer ... -optimizer = gtsam.GaussNewtonOptimizer(graph, initial, params) -# ... and optimize -result = optimizer.optimize() - -print("Optimization complete") -print("initial error = ", graph.error(initial)) -print("final error = ", graph.error(result)) - - -if args.output is None: - print("\nFactor Graph:\n{}".format(graph)) - print("\nInitial Estimate:\n{}".format(initial)) - print("Final Result:\n{}".format(result)) -else: - outputFile = args.output - print("Writing results to file: ", outputFile) - graphNoKernel, _ = gtsam.readG2o(g2oFile, is3D) - gtsam.writeG2o(graphNoKernel, result, outputFile) - print ("Done!") - -if args.plot: - resultPoses = gtsam.utilities.extractPose2(result) - for i in range(resultPoses.shape[0]): - plot.plot_pose2(1, gtsam.Pose2(resultPoses[i, :])) - plt.show() +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/Pose2SLAMExample_lago.py b/python/gtsam/examples/Pose2SLAMExample_lago.py new file mode 100644 index 000000000..d8cddde0b --- /dev/null +++ b/python/gtsam/examples/Pose2SLAMExample_lago.py @@ -0,0 +1,67 @@ +""" +GTSAM Copyright 2010, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved +Authors: Frank Dellaert, et al. (see THANKS for the full author list) +See LICENSE for the license information + +A 2D Pose SLAM example that reads input from g2o, and solve the Pose2 problem +using LAGO (Linear Approximation for Graph Optimization). +Output is written to a file, in g2o format + +Reference: +L. Carlone, R. Aragues, J. Castellanos, and B. Bona, A fast and accurate +approximation for planar pose graph optimization, IJRR, 2014. + +L. Carlone, R. Aragues, J.A. Castellanos, and B. Bona, A linear approximation +for graph-based simultaneous localization and mapping, RSS, 2011. + +Author: Luca Carlone (C++), John Lambert (Python) +""" + +import argparse +from argparse import Namespace + +import numpy as np + +import gtsam +from gtsam import Point3, Pose2, PriorFactorPose2, Values + + +def run(args: Namespace) -> None: + """Run LAGO on input data stored in g2o file.""" + g2oFile = gtsam.findExampleDataFile("noisyToyGraph.txt") if args.input is None else args.input + + graph = gtsam.NonlinearFactorGraph() + graph, initial = gtsam.readG2o(g2oFile) + + # Add prior on the pose having index (key) = 0 + priorModel = gtsam.noiseModel.Diagonal.Variances(Point3(1e-6, 1e-6, 1e-8)) + graph.add(PriorFactorPose2(0, Pose2(), priorModel)) + print(graph) + + print("Computing LAGO estimate") + estimateLago: Values = gtsam.lago.initialize(graph) + print("done!") + + if args.output is None: + estimateLago.print("estimateLago") + else: + outputFile = args.output + print("Writing results to file: ", outputFile) + graphNoKernel = gtsam.NonlinearFactorGraph() + graphNoKernel, initial2 = gtsam.readG2o(g2oFile) + gtsam.writeG2o(graphNoKernel, estimateLago, outputFile) + print("Done! ") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A 2D Pose SLAM example that reads input from g2o, " + "converts it to a factor graph and does the optimization. " + "Output is written on a file, in g2o format" + ) + parser.add_argument("-i", "--input", help="input file g2o format") + parser.add_argument("-o", "--output", help="the path to the output file with optimized graph") + args = parser.parse_args() + run(args) diff --git a/python/gtsam/examples/Pose3ISAM2Example.py b/python/gtsam/examples/Pose3ISAM2Example.py new file mode 100644 index 000000000..cb71813c5 --- /dev/null +++ b/python/gtsam/examples/Pose3ISAM2Example.py @@ -0,0 +1,208 @@ +""" +GTSAM Copyright 2010-2018, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved +Authors: Frank Dellaert, et al. (see THANKS for the full author list) + +See LICENSE for the license information + +Pose SLAM example using iSAM2 in 3D space. +Author: Jerred Chen +Modeled after: + - VisualISAM2Example by: Duy-Nguyen Ta (C++), Frank Dellaert (Python) + - Pose2SLAMExample by: Alex Cunningham (C++), Kevin Deng & Frank Dellaert (Python) +""" + +from typing import List + +import matplotlib.pyplot as plt +import numpy as np + +import gtsam +import gtsam.utils.plot as gtsam_plot + +def report_on_progress(graph: gtsam.NonlinearFactorGraph, current_estimate: gtsam.Values, + key: int): + """Print and plot incremental progress of the robot for 2D Pose SLAM using iSAM2.""" + + # Print the current estimates computed using iSAM2. + print("*"*50 + f"\nInference after State {key+1}:\n") + print(current_estimate) + + # Compute the marginals for all states in the graph. + marginals = gtsam.Marginals(graph, current_estimate) + + # Plot the newly updated iSAM2 inference. + fig = plt.figure(0) + axes = fig.gca(projection='3d') + plt.cla() + + i = 1 + while current_estimate.exists(i): + gtsam_plot.plot_pose3(0, current_estimate.atPose3(i), 10, + marginals.marginalCovariance(i)) + i += 1 + + axes.set_xlim3d(-30, 45) + axes.set_ylim3d(-30, 45) + axes.set_zlim3d(-30, 45) + plt.pause(1) + +def create_poses() -> List[gtsam.Pose3]: + """Creates ground truth poses of the robot.""" + P0 = np.array([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + P1 = np.array([[0, -1, 0, 15], + [1, 0, 0, 15], + [0, 0, 1, 20], + [0, 0, 0, 1]]) + P2 = np.array([[np.cos(np.pi/4), 0, np.sin(np.pi/4), 30], + [0, 1, 0, 30], + [-np.sin(np.pi/4), 0, np.cos(np.pi/4), 30], + [0, 0, 0, 1]]) + P3 = np.array([[0, 1, 0, 30], + [0, 0, -1, 0], + [-1, 0, 0, -15], + [0, 0, 0, 1]]) + P4 = np.array([[-1, 0, 0, 0], + [0, -1, 0, -10], + [0, 0, 1, -10], + [0, 0, 0, 1]]) + P5 = P0[:] + + return [gtsam.Pose3(P0), gtsam.Pose3(P1), gtsam.Pose3(P2), + gtsam.Pose3(P3), gtsam.Pose3(P4), gtsam.Pose3(P5)] + +def determine_loop_closure(odom_tf: gtsam.Pose3, current_estimate: gtsam.Values, + key: int, xyz_tol=0.6, rot_tol=17) -> int: + """Simple brute force approach which iterates through previous states + and checks for loop closure. + + Args: + odom_tf: The noisy odometry transformation measurement in the body frame. + current_estimate: The current estimates computed by iSAM2. + key: Key corresponding to the current state estimate of the robot. + xyz_tol: Optional argument for the translational tolerance, in meters. + rot_tol: Optional argument for the rotational tolerance, in degrees. + Returns: + k: The key of the state which is helping add the loop closure constraint. + If loop closure is not found, then None is returned. + """ + if current_estimate: + prev_est = current_estimate.atPose3(key+1) + curr_est = prev_est.compose(odom_tf) + for k in range(1, key+1): + pose = current_estimate.atPose3(k) + if (abs(pose.matrix()[:3,:3] - curr_est.matrix()[:3,:3]) <= rot_tol*np.pi/180).all() and \ + (abs(pose.matrix()[:3,3] - curr_est.matrix()[:3,3]) <= xyz_tol).all(): + return k + +def Pose3_ISAM2_example(): + """Perform 3D SLAM given ground truth poses as well as simple + loop closure detection.""" + plt.ion() + + # Declare the 3D translational standard deviations of the prior factor's Gaussian model, in meters. + prior_xyz_sigma = 0.3 + + # Declare the 3D rotational standard deviations of the prior factor's Gaussian model, in degrees. + prior_rpy_sigma = 5 + + # Declare the 3D translational standard deviations of the odometry factor's Gaussian model, in meters. + odometry_xyz_sigma = 0.2 + + # Declare the 3D rotational standard deviations of the odometry factor's Gaussian model, in degrees. + odometry_rpy_sigma = 5 + + # Although this example only uses linear measurements and Gaussian noise models, it is important + # to note that iSAM2 can be utilized to its full potential during nonlinear optimization. This example + # simply showcases how iSAM2 may be applied to a Pose2 SLAM problem. + PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([prior_rpy_sigma*np.pi/180, + prior_rpy_sigma*np.pi/180, + prior_rpy_sigma*np.pi/180, + prior_xyz_sigma, + prior_xyz_sigma, + prior_xyz_sigma])) + ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([odometry_rpy_sigma*np.pi/180, + odometry_rpy_sigma*np.pi/180, + odometry_rpy_sigma*np.pi/180, + odometry_xyz_sigma, + odometry_xyz_sigma, + odometry_xyz_sigma])) + + # Create a Nonlinear factor graph as well as the data structure to hold state estimates. + graph = gtsam.NonlinearFactorGraph() + initial_estimate = gtsam.Values() + + # Create iSAM2 parameters which can adjust the threshold necessary to force relinearization and how many + # update calls are required to perform the relinearization. + parameters = gtsam.ISAM2Params() + parameters.setRelinearizeThreshold(0.1) + parameters.relinearizeSkip = 1 + isam = gtsam.ISAM2(parameters) + + # Create the ground truth poses of the robot trajectory. + true_poses = create_poses() + + # Create the ground truth odometry transformations, xyz translations, and roll-pitch-yaw rotations + # between each robot pose in the trajectory. + odometry_tf = [true_poses[i-1].transformPoseTo(true_poses[i]) for i in range(1, len(true_poses))] + odometry_xyz = [(odometry_tf[i].x(), odometry_tf[i].y(), odometry_tf[i].z()) for i in range(len(odometry_tf))] + odometry_rpy = [odometry_tf[i].rotation().rpy() for i in range(len(odometry_tf))] + + # Corrupt xyz translations and roll-pitch-yaw rotations with gaussian noise to create noisy odometry measurements. + noisy_measurements = [np.random.multivariate_normal(np.hstack((odometry_rpy[i],odometry_xyz[i])), \ + ODOMETRY_NOISE.covariance()) for i in range(len(odometry_tf))] + + # Add the prior factor to the factor graph, and poorly initialize the prior pose to demonstrate + # iSAM2 incremental optimization. + graph.push_back(gtsam.PriorFactorPose3(1, true_poses[0], PRIOR_NOISE)) + initial_estimate.insert(1, true_poses[0].compose(gtsam.Pose3( + gtsam.Rot3.Rodrigues(-0.1, 0.2, 0.25), gtsam.Point3(0.05, -0.10, 0.20)))) + + # Initialize the current estimate which is used during the incremental inference loop. + current_estimate = initial_estimate + for i in range(len(odometry_tf)): + + # Obtain the noisy translation and rotation that is received by the robot and corrupted by gaussian noise. + noisy_odometry = noisy_measurements[i] + + # Compute the noisy odometry transformation according to the xyz translation and roll-pitch-yaw rotation. + noisy_tf = gtsam.Pose3(gtsam.Rot3.RzRyRx(noisy_odometry[:3]), noisy_odometry[3:6].reshape(-1,1)) + + # Determine if there is loop closure based on the odometry measurement and the previous estimate of the state. + loop = determine_loop_closure(noisy_tf, current_estimate, i, xyz_tol=18, rot_tol=30) + + # Add a binary factor in between two existing states if loop closure is detected. + # Otherwise, add a binary factor between a newly observed state and the previous state. + if loop: + graph.push_back(gtsam.BetweenFactorPose3(i + 1, loop, noisy_tf, ODOMETRY_NOISE)) + else: + graph.push_back(gtsam.BetweenFactorPose3(i + 1, i + 2, noisy_tf, ODOMETRY_NOISE)) + + # Compute and insert the initialization estimate for the current pose using a noisy odometry measurement. + noisy_estimate = current_estimate.atPose3(i + 1).compose(noisy_tf) + initial_estimate.insert(i + 2, noisy_estimate) + + # Perform incremental update to iSAM2's internal Bayes tree, optimizing only the affected variables. + isam.update(graph, initial_estimate) + current_estimate = isam.calculateEstimate() + + # Report all current state estimates from the iSAM2 optimization. + report_on_progress(graph, current_estimate, i) + initial_estimate.clear() + + # Print the final covariance matrix for each pose after completing inference. + marginals = gtsam.Marginals(graph, current_estimate) + i = 1 + while current_estimate.exists(i): + print(f"X{i} covariance:\n{marginals.marginalCovariance(i)}\n") + i += 1 + + plt.ioff() + plt.show() + +if __name__ == '__main__': + Pose3_ISAM2_example() diff --git a/python/gtsam/examples/Pose3SLAMExample_g2o.py b/python/gtsam/examples/Pose3SLAMExample_g2o.py index 501a75dc1..dcdfc34a3 100644 --- a/python/gtsam/examples/Pose3SLAMExample_g2o.py +++ b/python/gtsam/examples/Pose3SLAMExample_g2o.py @@ -8,13 +8,14 @@ # pylint: disable=invalid-name, E1101 from __future__ import print_function + import argparse -import numpy as np -import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D import gtsam +import matplotlib.pyplot as plt +import numpy as np from gtsam.utils import plot +from mpl_toolkits.mplot3d import Axes3D def vector6(x, y, z, a, b, c): @@ -22,50 +23,62 @@ def vector6(x, y, z, a, b, c): return np.array([x, y, z, a, b, c], dtype=float) -parser = argparse.ArgumentParser( - description="A 3D Pose SLAM example that reads input from g2o, and " - "initializes Pose3") -parser.add_argument('-i', '--input', help='input file g2o format') -parser.add_argument('-o', '--output', - help="the path to the output file with optimized graph") -parser.add_argument("-p", "--plot", action="store_true", - help="Flag to plot results") -args = parser.parse_args() +def main(): + """Main runner.""" -g2oFile = gtsam.findExampleDataFile("pose3example.txt") if args.input is None \ - else args.input + parser = argparse.ArgumentParser( + description="A 3D Pose SLAM example that reads input from g2o, and " + "initializes Pose3") + parser.add_argument('-i', '--input', help='input file g2o format') + parser.add_argument( + '-o', + '--output', + help="the path to the output file with optimized graph") + parser.add_argument("-p", + "--plot", + action="store_true", + help="Flag to plot results") + args = parser.parse_args() -is3D = True -graph, initial = gtsam.readG2o(g2oFile, is3D) + g2oFile = gtsam.findExampleDataFile("pose3example.txt") if args.input is None \ + else args.input -# Add Prior on the first key -priorModel = gtsam.noiseModel.Diagonal.Variances(vector6(1e-6, 1e-6, 1e-6, - 1e-4, 1e-4, 1e-4)) + is3D = True + graph, initial = gtsam.readG2o(g2oFile, is3D) -print("Adding prior to g2o file ") -firstKey = initial.keys()[0] -graph.add(gtsam.PriorFactorPose3(firstKey, gtsam.Pose3(), priorModel)) + # Add Prior on the first key + priorModel = gtsam.noiseModel.Diagonal.Variances( + vector6(1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4)) -params = gtsam.GaussNewtonParams() -params.setVerbosity("Termination") # this will show info about stopping conds -optimizer = gtsam.GaussNewtonOptimizer(graph, initial, params) -result = optimizer.optimize() -print("Optimization complete") + print("Adding prior to g2o file ") + firstKey = initial.keys()[0] + graph.add(gtsam.PriorFactorPose3(firstKey, gtsam.Pose3(), priorModel)) -print("initial error = ", graph.error(initial)) -print("final error = ", graph.error(result)) + params = gtsam.GaussNewtonParams() + params.setVerbosity( + "Termination") # this will show info about stopping conds + optimizer = gtsam.GaussNewtonOptimizer(graph, initial, params) + result = optimizer.optimize() + print("Optimization complete") -if args.output is None: - print("Final Result:\n{}".format(result)) -else: - outputFile = args.output - print("Writing results to file: ", outputFile) - graphNoKernel, _ = gtsam.readG2o(g2oFile, is3D) - gtsam.writeG2o(graphNoKernel, result, outputFile) - print ("Done!") + print("initial error = ", graph.error(initial)) + print("final error = ", graph.error(result)) -if args.plot: - resultPoses = gtsam.utilities.allPose3s(result) - for i in range(resultPoses.size()): - plot.plot_pose3(1, resultPoses.atPose3(i)) - plt.show() + if args.output is None: + print("Final Result:\n{}".format(result)) + else: + outputFile = args.output + print("Writing results to file: ", outputFile) + graphNoKernel, _ = gtsam.readG2o(g2oFile, is3D) + gtsam.writeG2o(graphNoKernel, result, outputFile) + print("Done!") + + if args.plot: + resultPoses = gtsam.utilities.allPose3s(result) + for i in range(resultPoses.size()): + plot.plot_pose3(1, resultPoses.atPose3(i)) + plt.show() + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/Pose3SLAMExample_initializePose3Chordal.py b/python/gtsam/examples/Pose3SLAMExample_initializePose3Chordal.py index 2b2c5f991..a96da0774 100644 --- a/python/gtsam/examples/Pose3SLAMExample_initializePose3Chordal.py +++ b/python/gtsam/examples/Pose3SLAMExample_initializePose3Chordal.py @@ -13,23 +13,29 @@ Author: Luca Carlone, Frank Dellaert (python port) from __future__ import print_function +import gtsam import numpy as np -import gtsam -# Read graph from file -g2oFile = gtsam.findExampleDataFile("pose3example.txt") +def main(): + """Main runner.""" + # Read graph from file + g2oFile = gtsam.findExampleDataFile("pose3example.txt") -is3D = True -graph, initial = gtsam.readG2o(g2oFile, is3D) + is3D = True + graph, initial = gtsam.readG2o(g2oFile, is3D) -# Add prior on the first key. TODO: assumes first key ios z -priorModel = gtsam.noiseModel.Diagonal.Variances( - np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4])) -firstKey = initial.keys()[0] -graph.add(gtsam.PriorFactorPose3(0, gtsam.Pose3(), priorModel)) + # Add prior on the first key. TODO: assumes first key ios z + priorModel = gtsam.noiseModel.Diagonal.Variances( + np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4])) + firstKey = initial.keys()[0] + graph.add(gtsam.PriorFactorPose3(0, gtsam.Pose3(), priorModel)) -# Initializing Pose3 - chordal relaxation" -initialization = gtsam.InitializePose3.initialize(graph) + # Initializing Pose3 - chordal relaxation + initialization = gtsam.InitializePose3.initialize(graph) -print(initialization) + print(initialization) + + +if __name__ == "__main__": + main() diff --git a/python/gtsam/examples/PreintegrationExample.py b/python/gtsam/examples/PreintegrationExample.py index 458248c30..611c536c7 100644 --- a/python/gtsam/examples/PreintegrationExample.py +++ b/python/gtsam/examples/PreintegrationExample.py @@ -5,10 +5,14 @@ All Rights Reserved See LICENSE for the license information -A script validating the Preintegration of IMU measurements +A script validating the Preintegration of IMU measurements. + +Authors: Frank Dellaert, Varun Agrawal. """ -import math +# pylint: disable=invalid-name,unused-import,wrong-import-order + +from typing import Optional, Sequence import gtsam import matplotlib.pyplot as plt @@ -18,25 +22,28 @@ from mpl_toolkits.mplot3d import Axes3D IMU_FIG = 1 POSES_FIG = 2 +GRAVITY = 10 -class PreintegrationExample(object): - +class PreintegrationExample: + """Base class for all preintegration examples.""" @staticmethod - def defaultParams(g): + def defaultParams(g: float): """Create default parameters with Z *up* and realistic noise parameters""" params = gtsam.PreintegrationParams.MakeSharedU(g) - kGyroSigma = math.radians(0.5) / 60 # 0.5 degree ARW + kGyroSigma = np.radians(0.5) / 60 # 0.5 degree ARW kAccelSigma = 0.1 / 60 # 10 cm VRW - params.setGyroscopeCovariance( - kGyroSigma ** 2 * np.identity(3, float)) - params.setAccelerometerCovariance( - kAccelSigma ** 2 * np.identity(3, float)) - params.setIntegrationCovariance( - 0.0000001 ** 2 * np.identity(3, float)) + params.setGyroscopeCovariance(kGyroSigma**2 * np.identity(3, float)) + params.setAccelerometerCovariance(kAccelSigma**2 * + np.identity(3, float)) + params.setIntegrationCovariance(0.0000001**2 * np.identity(3, float)) return params - def __init__(self, twist=None, bias=None, dt=1e-2): + def __init__(self, + twist: Optional[np.ndarray] = None, + bias: Optional[gtsam.imuBias.ConstantBias] = None, + params: Optional[gtsam.PreintegrationParams] = None, + dt: float = 1e-2): """Initialize with given twist, a pair(angularVelocityVector, velocityVector).""" # setup interactive plotting @@ -48,7 +55,7 @@ class PreintegrationExample(object): else: # default = loop with forward velocity 2m/s, while pitching up # with angular velocity 30 degree/sec (negative in FLU) - W = np.array([0, -math.radians(30), 0]) + W = np.array([0, -np.radians(30), 0]) V = np.array([2, 0, 0]) self.scenario = gtsam.ConstantTwistScenario(W, V) @@ -58,9 +65,11 @@ class PreintegrationExample(object): self.labels = list('xyz') self.colors = list('rgb') - # Create runner - self.g = 10 # simple gravity constant - self.params = self.defaultParams(self.g) + if params: + self.params = params + else: + # Default params with simple gravity constant + self.params = self.defaultParams(g=GRAVITY) if bias is not None: self.actualBias = bias @@ -69,13 +78,22 @@ class PreintegrationExample(object): gyroBias = np.array([0, 0, 0]) self.actualBias = gtsam.imuBias.ConstantBias(accBias, gyroBias) - self.runner = gtsam.ScenarioRunner( - self.scenario, self.params, self.dt, self.actualBias) + # Create runner + self.runner = gtsam.ScenarioRunner(self.scenario, self.params, self.dt, + self.actualBias) fig, self.axes = plt.subplots(4, 3) fig.set_tight_layout(True) - def plotImu(self, t, measuredOmega, measuredAcc): + def plotImu(self, t: float, measuredOmega: Sequence, + measuredAcc: Sequence): + """ + Plot IMU measurements. + Args: + t: The time at which the measurement was recoreded. + measuredOmega: Measured angular velocity. + measuredAcc: Measured linear acceleration. + """ plt.figure(IMU_FIG) # plot angular velocity @@ -108,12 +126,21 @@ class PreintegrationExample(object): ax.scatter(t, measuredAcc[i], color=color, marker='.') ax.set_xlabel('specific force ' + label) - def plotGroundTruthPose(self, t, scale=0.3, time_interval=0.01): - # plot ground truth pose, as well as prediction from integrated IMU measurements + def plotGroundTruthPose(self, + t: float, + scale: float = 0.3, + time_interval: float = 0.01): + """ + Plot ground truth pose, as well as prediction from integrated IMU measurements. + Args: + t: Time at which the pose was obtained. + scale: The scaling factor for the pose axes. + time_interval: The time to wait before showing the plot. + """ actualPose = self.scenario.pose(t) plot_pose3(POSES_FIG, actualPose, scale) - t = actualPose.translation() - self.maxDim = max([max(np.abs(t)), self.maxDim]) + translation = actualPose.translation() + self.maxDim = max([max(np.abs(translation)), self.maxDim]) ax = plt.gca() ax.set_xlim3d(-self.maxDim, self.maxDim) ax.set_ylim3d(-self.maxDim, self.maxDim) @@ -121,8 +148,8 @@ class PreintegrationExample(object): plt.pause(time_interval) - def run(self, T=12): - # simulate the loop + def run(self, T: int = 12): + """Simulate the loop.""" for i, t in enumerate(np.arange(0, T, self.dt)): measuredOmega = self.runner.measuredAngularVelocity(t) measuredAcc = self.runner.measuredSpecificForce(t) diff --git a/python/gtsam/examples/RangeISAMExample_plaza2.ipynb b/python/gtsam/examples/RangeISAMExample_plaza2.ipynb new file mode 100644 index 000000000..f11636606 --- /dev/null +++ b/python/gtsam/examples/RangeISAMExample_plaza2.ipynb @@ -0,0 +1,9453 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,\n", + "Atlanta, Georgia 30332-0415\n", + "All Rights Reserved\n", + "\n", + "Authors: Frank Dellaert, et al. (see THANKS for the full author list)\n", + "\n", + "See LICENSE for the license information\n", + "\n", + "A 2D Range SLAM example, with iSAM and smart range factors\n", + "\n", + "Author: Frank Dellaert" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data is second UWB ranging dataset, B2 or \"plaza 2\", from\n", + "\n", + "> \"Navigating with Ranging Radios: Five Data Sets with Ground Truth\", by Joseph Djugash, Bradley Hamner, and Stephan Roth, available at https://www.ri.cmu.edu/pub_files/2009/9/Final_5datasetsRangingRadios.pdf\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "# pylint: disable=invalid-name, E1101\n", + "\n", + "from gtsam import Point2, Pose2\n", + "import plotly.express as px\n", + "import numpy as np\n", + "import gtsam\n", + "import math\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from gtsam.utils import plot\n", + "from numpy.random import default_rng\n", + "\n", + "rng = default_rng()\n", + "\n", + "NM = gtsam.noiseModel" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Read 4090 odometry entries.\n" + ] + } + ], + "source": [ + "# load the odometry\n", + "# DR: Odometry Input (delta distance traveled and delta heading change)\n", + "# Time (sec) Delta Distance Traveled (m) Delta Heading (rad)\n", + "odometry = {}\n", + "data_file = gtsam.findExampleDataFile(\"Plaza2_DR.txt\")\n", + "for row in np.loadtxt(data_file):\n", + " t, distance_traveled, delta_heading = row\n", + " odometry[t] = Pose2(distance_traveled, 0, delta_heading)\n", + "M = len(odometry)\n", + "print(f\"Read {M} odometry entries.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Read 1816 range triples.\n" + ] + } + ], + "source": [ + "# load the ranges from TD\n", + "# Time (sec) Sender / Antenna ID Receiver Node ID Range (m)\n", + "triples = []\n", + "data_file = gtsam.findExampleDataFile(\"Plaza2_TD.txt\")\n", + "for row in np.loadtxt(data_file):\n", + " t, sender, receiver, _range = row\n", + " triples.append((t, int(receiver), _range))\n", + "K = len(triples)\n", + "print(f\"Read {K} range triples.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "# parameters\n", + "minK = 150 # minimum number of range measurements to process initially\n", + "incK = 25 # minimum number of range measurements to process after\n", + "robust = True" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "# Set Noise parameters\n", + "priorSigmas = gtsam.Point3(1, 1, math.pi)\n", + "odoSigmas = gtsam.Point3(0.05, 0.01, 0.1)\n", + "sigmaR = 100 # range standard deviation\n", + "\n", + "priorNoise = NM.Diagonal.Sigmas(priorSigmas) # prior\n", + "looseNoise = NM.Isotropic.Sigma(2, 1000) # loose LM prior\n", + "odoNoise = NM.Diagonal.Sigmas(odoSigmas) # odometry\n", + "gaussian = NM.Isotropic.Sigma(1, sigmaR) # non-robust\n", + "tukey = NM.Robust.Create(NM.mEstimator.Tukey.Create(15), gaussian) # robust\n", + "rangeNoise = tukey if robust else gaussian" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ": cliques: 0, variables: 0\n", + "\n" + ] + } + ], + "source": [ + "# Initialize iSAM\n", + "isam = gtsam.ISAM2()\n", + "print(isam)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NonlinearFactorGraph: size: 1\n", + "\n", + "Factor 0: PriorFactor on 0\n", + " prior mean: (-34.208649, 45.300764, 1.12050365)\n", + " noise model: diagonal sigmas[1; 1; 3.14159265];\n", + "\n", + " Values with 1 values:\n", + "Value 0: (gtsam::Pose2)\n", + "(-34.208649, 45.300764, 1.12050365)\n", + "\n", + "\n" + ] + } + ], + "source": [ + "# Add prior on first pose\n", + "pose0 = Pose2(-34.2086489999201, 45.3007639991120, math.pi - 2.021089)\n", + "newFactors = gtsam.NonlinearFactorGraph()\n", + "newFactors.addPriorPose2(0, pose0, priorNoise)\n", + "initial = gtsam.Values()\n", + "initial.insert(0, pose0)\n", + "print(newFactors, initial)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Adding landmark L1\n", + "Adding landmark L6\n", + "Adding landmark L0\n", + "Adding landmark L5\n", + "Initializing at time 151\n" + ] + } + ], + "source": [ + "# set some loop variables\n", + "i = 1 # step counter\n", + "k = 0 # range measurement counter\n", + "initialized = False\n", + "lastPose = pose0\n", + "countK = 0\n", + "\n", + "initializedLandmarks = set()\n", + "\n", + "# Loop over odometry\n", + "for t, relative_pose in odometry.items():\n", + " # add odometry factor\n", + " newFactors.add(gtsam.BetweenFactorPose2(i - 1, i, relative_pose,\n", + " odoNoise))\n", + "\n", + " # predict pose and add as initial estimate\n", + " predictedPose = lastPose.compose(relative_pose)\n", + " lastPose = predictedPose\n", + " initial.insert(i, predictedPose)\n", + "\n", + " # Check if there are range factors to be added\n", + " while (k < K) and (triples[k][0] <= t):\n", + " j = triples[k][1]\n", + " landmark_key = gtsam.symbol('L', j)\n", + " _range = triples[k][2]\n", + " newFactors.add(gtsam.RangeFactor2D(\n", + " i, landmark_key, _range, rangeNoise))\n", + " if landmark_key not in initializedLandmarks:\n", + " p = rng.normal(loc=0, scale=100, size=(2,))\n", + " initial.insert(landmark_key, p)\n", + " print(f\"Adding landmark L{j}\")\n", + " initializedLandmarks.add(landmark_key)\n", + " # We also add a very loose prior on the landmark in case there is only\n", + " # one sighting, which cannot fully determine the landmark.\n", + " newFactors.add(gtsam.PriorFactorPoint2(\n", + " landmark_key, Point2(0, 0), looseNoise))\n", + " k = k + 1\n", + " countK = countK + 1\n", + "\n", + " # Check whether to update iSAM 2\n", + " if (k > minK) and (countK > incK):\n", + " if not initialized: # Do a full optimize for first minK ranges\n", + " print(f\"Initializing at time {k}\")\n", + " batchOptimizer = gtsam.LevenbergMarquardtOptimizer(\n", + " newFactors, initial)\n", + " initial = batchOptimizer.optimize()\n", + " initialized = True\n", + "\n", + " isam.update(newFactors, initial)\n", + " result = isam.calculateEstimate()\n", + " lastPose = result.atPose2(i)\n", + " newFactors = gtsam.NonlinearFactorGraph()\n", + " initial = gtsam.Values()\n", + " countK = 0\n", + "\n", + " i += 1\n", + "\n", + "finalResult = isam.calculateEstimate()" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5476377146882523136: [-35.97329685 26.31658086]\n", + "5476377146882523137: [-75.1003452 21.01144091]\n", + "5476377146882523141: [ -1.03876425 -12.13811931]\n", + "5476377146882523142: [-36.08926944 72.3500464 ]\n" + ] + } + ], + "source": [ + "# Print optimized landmarks:\n", + "for j in [0,1,5,6]:\n", + " landmark_key = gtsam.symbol('L', j)\n", + " p = finalResult.atPoint2(landmark_key)\n", + " print(f\"{landmark_key}: {p}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(4090, 2)\n" + ] + } + ], + "source": [ + "# plot positions\n", + "poses = gtsam.utilities.allPose2s(finalResult)\n", + "landmarks = gtsam.utilities.extractPoint2(finalResult)\n", + "positions = np.array([poses.atPose2(key).translation()\n", + " for key in poses.keys()])\n", + "print(positions.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "x=%{x}
y=%{y}", + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "showlegend": false, + "type": "scattergl", + "x": [ + -34.20850096446654, + -34.208238954743685, + -34.20792902100424, + -34.20760757184568, + -34.20738361846996, + -34.20720176000561, + -34.20683416597183, + -34.20650648822083, + -34.20608935588558, + -34.20572966544402, + -34.20539047307245, + -34.204978189129015, + -34.204550257854706, + -34.20419937583317, + -34.203851648593165, + -34.20343169052313, + -34.203042274235045, + -34.20264363032312, + -34.2023726490794, + -34.202098140666784, + -34.20181134675162, + -34.20148933862536, + -34.20121738216338, + -34.2009080994932, + -34.20063643983228, + -34.20034481742374, + -34.2000696237222, + -34.19983311428898, + -34.19961392234484, + -34.19937365514169, + -34.199189402854074, + -34.19900317603461, + -34.19880114105563, + -34.19861066740314, + -34.19837645150039, + -34.19816029290118, + -34.19797297627397, + -34.19781332693034, + -34.197621612689275, + -34.19741383124952, + -34.19727147453903, + -34.19707929682523, + -34.196825570993276, + -34.19668981102975, + -34.1965642221961, + -34.196420248329694, + -34.196235418026966, + -34.196100527762006, + -34.19595528348216, + -34.19577717294135, + -34.19561281660328, + -34.19541564722681, + -34.19522091075678, + -34.19506867319604, + -34.19491322906737, + -34.19475617348167, + -34.19456738946074, + -34.19441257759611, + -34.194305608523415, + -34.19413890108838, + -34.19402563257533, + -34.193871851019416, + -34.19370503405217, + -34.19358027267171, + -34.19344531751981, + -34.19329006230464, + -34.19314564970173, + -34.19281770202517, + -34.19237269478954, + -34.19196721038553, + -34.191578273729284, + -34.19110684869285, + -34.19059806524432, + -34.19014443601172, + -34.18971520292186, + -34.18923602362348, + -34.18869913427268, + -34.18833003245038, + -34.18794166622675, + -34.18751483624891, + -34.187190230565314, + -34.18680050834423, + -34.18632317786081, + -34.185942639398604, + -34.18557941319778, + -34.185093513453985, + -34.18464292833133, + -34.184145833297634, + -34.1835690413308, + -34.183009218187, + -34.18245641784577, + -34.18194165756511, + -34.181411245583405, + -34.180842191387, + -34.18031070625432, + -34.17977878608106, + -34.179274141905964, + -34.178791235572014, + -34.17844353743938, + -34.178133542359454, + -34.17772043770016, + -34.17732671873731, + -34.17705071666225, + -34.17674574279609, + -34.17648907354066, + -34.17609695070887, + -34.175650899786554, + -34.17532133596385, + -34.1749278740639, + -34.1745363562516, + -34.17415109707245, + -34.17368635309034, + -34.173255122177345, + -34.172899756560206, + -34.17251040591449, + -34.172086874520765, + -34.171630360950736, + -34.17117119532719, + -34.17084558026226, + -34.170524084280146, + -34.170110063374906, + -34.16967258385827, + -34.16930042973883, + -34.168970613859635, + -34.16856304460919, + -34.16811576085387, + -34.16768676123481, + -34.16736099295385, + -34.167124898097356, + -34.16672691923807, + -34.16629949645637, + -34.166020544017336, + -34.16567348824885, + -34.165206629486555, + -34.1649102522274, + -34.16442554900806, + -34.16411656055334, + -34.163835743037595, + -34.163667960830786, + -34.163287829216934, + -34.16288151302674, + -34.16269395430048, + -34.16238991455418, + -34.162026429660955, + -34.161650280510415, + -34.16120406524903, + -34.16081468768969, + -34.16051969428689, + -34.16009094929095, + -34.15946914274596, + -34.159176862117704, + -34.15888727213249, + -34.15848397635331, + -34.15805064953907, + -34.15768222447644, + -34.15741075690409, + -34.157199236984766, + -34.15702885116569, + -34.156646530479655, + -34.15631433417576, + -34.15610539694631, + -34.155873450868484, + -34.155492955086224, + -34.15500939848627, + -34.15434250697164, + -34.15403249529436, + -34.15373559707832, + -34.153195077612814, + -34.15237223406272, + -34.151197823002306, + -34.150389061340434, + -34.14972355532398, + -34.14876182425223, + -34.147753205875084, + -34.1468315985221, + -34.14592332052173, + -34.14499865748368, + -34.14413081753902, + -34.14326697595306, + -34.14235363131009, + -34.141504566972635, + -34.140670341648466, + -34.13977560684852, + -34.138843663846956, + -34.13797352118324, + -34.13704209938618, + -34.136104337300594, + -34.13535680475344, + -34.13475717277114, + -34.13405664376714, + -34.13324167350189, + -34.13239793097516, + -34.13162108183035, + -34.13104483686513, + -34.130370325655754, + -34.129575170192076, + -34.12883978107497, + -34.12818004157837, + -34.12769342542458, + -34.12707652242479, + -34.12625640229815, + -34.12573308632649, + -34.1251498038307, + -34.124445847032156, + -34.12353836877944, + -34.122053316887815, + -34.1211257367895, + -34.119724360218534, + -34.11843319269593, + -34.11718493708246, + -34.11338507210205, + -34.10590214862871, + -34.09282073690137, + -34.07236785491298, + -34.04366888084793, + -34.007605138373535, + -33.96775324384806, + -33.92893317560639, + -33.89642176487315, + -33.87043613439064, + -33.85153191238244, + -33.83740516873628, + -33.82726615508978, + -33.820250147386204, + -33.81406643483115, + -33.806553260040644, + -33.79553734291343, + -33.77741136696219, + -33.751353316306194, + -33.71687278795068, + -33.673404994787944, + -33.62118686393082, + -33.56068922257415, + -33.49143699335243, + -33.412834311016205, + -33.3237099791763, + -33.22310971008737, + -33.11154866410422, + -32.990842000614734, + -32.86212511539496, + -32.72309606849177, + -32.57062017013406, + -32.40851697353658, + -32.23720250997644, + -32.05784746076087, + -31.866728622723116, + -31.666281580501536, + -31.45888722393481, + -31.236188791114166, + -31.006924763292396, + -30.77239936272471, + -30.52223313580673, + -30.2651984576044, + -30.000474598875655, + -29.731386707719196, + -29.452770053026384, + -29.16946897839721, + -28.89232263981253, + -28.60861583098035, + -28.320690784708873, + -28.03548898720037, + -27.750887893729637, + -27.462116131387482, + -27.179759975432518, + -26.903824475123297, + -26.631328055718043, + -26.36043103504407, + -26.096833531141332, + -25.850422454635652, + -25.602392293361373, + -25.35759661816489, + -25.14796974324249, + -24.936472171543887, + -24.729593124054873, + -24.537083022661065, + -24.353155730789634, + -24.178930832936064, + -24.00913393888203, + -23.839729121837294, + -23.687506229039517, + -23.537568073364582, + -23.385293532067568, + -23.247396754879265, + -23.113088745898345, + -22.97788261894402, + -22.85033600215608, + -22.7302817180564, + -22.614875683411867, + -22.501606043047857, + -22.39296302095069, + -22.291124620939094, + -22.1931172182229, + -22.10259642363014, + -22.01442660431957, + -21.930835891539587, + -21.854838952711773, + -21.785225767526434, + -21.71575024868301, + -21.648654064018306, + -21.584573001491112, + -21.520141220887233, + -21.45177011272198, + -21.384492813439486, + -21.318955993291997, + -21.25214802678836, + -21.183657070739866, + -21.119719849680337, + -21.054127260365053, + -20.983247893375697, + -20.91436578568897, + -20.84846061043535, + -20.77885808389792, + -20.70787520393427, + -20.64043121965617, + -20.567976117179327, + -20.49802638852945, + -20.425461212722524, + -20.35479298424054, + -20.282674033670123, + -20.20903252554625, + -20.137397231124215, + -20.064022256051153, + -19.989043931523284, + -19.90910114815038, + -19.83232039990529, + -19.753511678281324, + -19.669105733913366, + -19.583314519263745, + -19.494616842451236, + -19.40443270134451, + -19.310538831842955, + -19.213843720805905, + -19.115886189028437, + -19.01472079114958, + -18.911519476095137, + -18.806724464509383, + -18.699383897195215, + -18.587828190276884, + -18.475139327437294, + -18.358508133023715, + -18.23250049997659, + -18.107209041615686, + -17.97898633278654, + -17.839883042643052, + -17.69648363507199, + -17.551859605023964, + -17.39543380713658, + -17.234970818650286, + -17.070046627768086, + -16.896655854204827, + -16.72189078956723, + -16.53903862823957, + -16.345578366064313, + -16.15040970857468, + -15.949004036301039, + -15.73409882318099, + -15.520776286461096, + -15.301925922368234, + -15.070678004620648, + -14.843307150651425, + -14.613679045242476, + -14.375157473101241, + -14.139389072791404, + -13.904988476433443, + -13.661150612368795, + -13.41707736541694, + -13.180820011095541, + -12.940189575595827, + -12.700155724224805, + -12.46287048059889, + -12.225978074095101, + -11.983676868668432, + -11.741119724603255, + -11.498800839246972, + -11.257118257374717, + -11.01223198786857, + -10.771316621036691, + -10.531816609803048, + -10.289242645492118, + -10.051094210070056, + -9.813321328273139, + -9.572194435045677, + -9.335290205071642, + -9.103435364388192, + -8.870725952526021, + -8.643668752575262, + -8.425262228892867, + -8.180860199005918, + -7.957839554013466, + -7.742057345061757, + -7.528013744629356, + -7.314109479110836, + -7.103808525252691, + -6.898801427447551, + -6.693570821226059, + -6.493186831966468, + -6.296214106306464, + -6.099657869405394, + -5.908827113614065, + -5.715615709860129, + -5.521984170694302, + -5.332461478571686, + -5.143684708007091, + -4.953569959626599, + -4.772259464875152, + -4.599443535946358, + -4.422658612841595, + -4.25775388602834, + -4.1065418176812125, + -3.954217993435107, + -3.8120387751928058, + -3.684418011623918, + -3.560233442270346, + -3.4467609954925646, + -3.347455125665888, + -3.256983334164835, + -3.1818452917066, + -3.119373397337299, + -3.0677831078079105, + -3.031257649766328, + -3.0084601096112937, + -2.999507935999349, + -3.0054698298286207, + -3.0281236342209943, + -3.0700352925929155, + -3.128581122683889, + -3.204051529859299, + -3.29715812617134, + -3.410492761151179, + -3.548730766585928, + -3.7148238763535497, + -3.892215883500358, + -4.08410985314791, + -4.303713350962656, + -4.53199807872445, + -4.769603912150116, + -5.039499581466397, + -5.317875891463045, + -5.594221969919161, + -5.890632592115414, + -6.198945169797842, + -6.511674328245961, + -6.833766739908534, + -7.159226307604196, + -7.485880834688631, + -7.8213576057842005, + -8.155947721189817, + -8.490831353136372, + -8.82915319130515, + -9.167136649612592, + -9.50351549191835, + -9.844389215804224, + -10.179525657340298, + -10.515272701348366, + -10.853513119276752, + -11.186384006918765, + -11.51885458278919, + -11.853887739779578, + -12.186218771001966, + -12.518838054018211, + -12.85379646413921, + -13.192524168500752, + -13.531297475935212, + -13.87278381731817, + -14.216843064780582, + -14.565137202540603, + -14.916964860695987, + -15.276638133253954, + -15.649521404743567, + -16.023504645043484, + -16.399218438899965, + -16.787539834221676, + -17.17487956619213, + -17.55780193233755, + -17.953415497509326, + -18.36193162941352, + -18.753242349393854, + -19.1547288146283, + -19.57488363759125, + -19.97196405956268, + -20.382959986893194, + -20.808590477926927, + -21.214446333859488, + -21.63333896306491, + -22.063312427305817, + -22.469330277045458, + -22.883562574578935, + -23.30562305116361, + -23.69693990260586, + -24.089484082222096, + -24.508476133226807, + -24.91173666527541, + -25.300314654764954, + -25.714811859707304, + -26.12548731270093, + -26.513524431697025, + -26.921808472447154, + -27.335390292031736, + -27.730896380438267, + -28.135096223390125, + -28.54305955816263, + -28.936817600015, + -29.33587430829889, + -29.735739337632634, + -30.131784852729698, + -30.531510034799926, + -30.931020949816958, + -31.328154826694753, + -31.73203776645058, + -32.134486553878176, + -32.54050706752538, + -32.949192934150325, + -33.36421956031278, + -33.7768751678339, + -34.18940158319885, + -34.6021492023042, + -35.01047415651771, + -35.41171781082104, + -35.813119692813984, + -36.215050467039944, + -36.621893345222226, + -37.03717132871312, + -37.46002493725777, + -37.883741357496945, + -38.30719995932858, + -38.72938025987588, + -39.14924357296815, + -39.57802315670335, + -40.01043348825386, + -40.439590791865776, + -40.85750222093139, + -41.27182723776138, + -41.68831190734649, + -42.104015239877675, + -42.52170843545669, + -42.93599499431075, + -43.354357210777145, + -43.7750624616118, + -44.180848445444646, + -44.58812097321755, + -45.007713562651716, + -45.417312323038956, + -45.82621709698976, + -46.24179036897837, + -46.659494896046056, + -47.0695288423755, + -47.476973756636134, + -47.88445074356961, + -48.28357380372633, + -48.67844143188854, + -49.08561638451694, + -49.48706722067271, + -49.877482019089115, + -50.279156256808044, + -50.676273600636236, + -51.053999874666324, + -51.447405905939775, + -51.835994120984935, + -52.21242997655724, + -52.60506604706158, + -52.998686597981106, + -53.37501814955399, + -53.768167517412685, + -54.159168318936786, + -54.53551172277161, + -54.92985669052983, + -55.31894439666703, + -55.69133528047126, + -56.07482392120389, + -56.444863159140716, + -56.81320710027343, + -57.19053508956895, + -57.56215251701041, + -57.93099347502464, + -58.30172428484685, + -58.664305554440965, + -59.025359706567336, + -59.39829137353854, + -59.7582141542276, + -60.124530748974514, + -60.4908578766057, + -60.83505696904816, + -61.18130885226007, + -61.52499966132512, + -61.83836490363596, + -62.158196325612444, + -62.47435398172692, + -62.76534189619135, + -63.05109916028611, + -63.32856523564792, + -63.59049106119492, + -63.83443997198932, + -64.07088373841837, + -64.29533435273933, + -64.50526421490805, + -64.7052822956055, + -64.88974984865094, + -65.06757453538047, + -65.23988864735162, + -65.39869442920023, + -65.54776860045122, + -65.68685907210161, + -65.81838416625914, + -65.9438257267132, + -66.0586877840251, + -66.16983067614979, + -66.2802816919717, + -66.3794524945481, + -66.47116560935949, + -66.5580783084223, + -66.63414777564289, + -66.7028493866808, + -66.76637121729858, + -66.81648907329419, + -66.85396394868584, + -66.88622644304414, + -66.9069131925393, + -66.91060786059475, + -66.89911095628779, + -66.87390242235816, + -66.83644811340965, + -66.7854338571656, + -66.71980373895558, + -66.6393075249817, + -66.54347632615395, + -66.43556800084674, + -66.31751313422161, + -66.1879628899378, + -66.0492213591998, + -65.90249473180852, + -65.74234707643778, + -65.56934163451142, + -65.38876145051714, + -65.20278639422463, + -65.00492563650491, + -64.80032711231964, + -64.60508760719165, + -64.39640405531888, + -64.17347324859307, + -63.96347278695096, + -63.74783369003469, + -63.50827319197837, + -63.27817201058187, + -63.049125731267196, + -62.79906825605081, + -62.55571846530885, + -62.314009308642575, + -62.06057816611769, + -61.812002866040984, + -61.568008672865844, + -61.31977697955609, + -61.07544160626518, + -60.83338909845368, + -60.58843697818208, + -60.342020309956226, + -60.0982834257727, + -59.854435946413936, + -59.614697844411104, + -59.37731347930396, + -59.1442368134974, + -58.907917683733785, + -58.67663687502402, + -58.44415651110442, + -58.2105060936788, + -57.98171333830625, + -57.75986034582937, + -57.53099620873651, + -57.31742354180103, + -57.11094708789625, + -56.89404473721648, + -56.686883256420124, + -56.47789349938912, + -56.25906193982285, + -56.04874031747729, + -55.838189439977356, + -55.60030507722978, + -55.389217653260076, + -55.18088735036473, + -54.964419297405115, + -54.74251225436768, + -54.526641540244455, + -54.309039392756404, + -54.08466206621436, + -53.864323502833976, + -53.64917961946586, + -53.42760336382845, + -53.20497730544819, + -52.98581775877011, + -52.76052182351032, + -52.53402751945837, + -52.3150187111637, + -52.09446139112003, + -51.875301855406015, + -51.66458491630855, + -51.44880184949696, + -51.23583887481144, + -51.024251639441495, + -50.8071041651678, + -50.584758876264836, + -50.35673439549038, + -50.1297573599387, + -49.89796426630089, + -49.6561733825751, + -49.408897775010146, + -49.158956206097, + -48.907045966962926, + -48.65259274481218, + -48.39660932186194, + -48.14165349692995, + -47.89462459547674, + -47.649591173534326, + -47.40018989835032, + -47.153616860328476, + -46.910162188884115, + -46.66125059392396, + -46.41004047357298, + -46.160765219588185, + -45.91311081595627, + -45.66671801274094, + -45.40747481662267, + -45.15426593717196, + -44.909388930303805, + -44.647962917352636, + -44.38667975642351, + -44.146999413931816, + -43.889502030617585, + -43.61894147724354, + -43.365040674727965, + -43.10397341963339, + -42.82173856681402, + -42.54942670163961, + -42.29427549203657, + -42.01245170382104, + -41.72086995614185, + -41.44495855163145, + -41.14649300925449, + -40.83513244019241, + -40.53309408854726, + -40.217082860406435, + -39.899885571661386, + -39.58322762333115, + -39.25927891824996, + -38.929237385790294, + -38.60501927350662, + -38.26979751771763, + -37.9385310519133, + -37.61768400267112, + -37.277203470340076, + -36.93210187510192, + -36.5989670409393, + -36.2605671033598, + -35.917965476399594, + -35.58097928977165, + -35.25319314601881, + -34.92046457703637, + -34.60122646328586, + -34.29598638239527, + -33.98750709553549, + -33.68152469734872, + -33.39455183864643, + -33.105969452117364, + -32.81664183529315, + -32.54062425940163, + -32.26072655522608, + -31.99050060497156, + -31.72877191156155, + -31.468227199983055, + -31.21843403205665, + -30.977134869050094, + -30.744551200538588, + -30.520450877430214, + -30.306724802221535, + -30.101948932135542, + -29.90258558513699, + -29.711945872526478, + -29.52797590846506, + -29.350878727858255, + -29.18247502170537, + -29.01909000846517, + -28.861102358688207, + -28.708252894134002, + -28.557369426591453, + -28.409309883669305, + -28.268708922244162, + -28.126963558263956, + -27.984308901106722, + -27.85100925439488, + -27.71529717169798, + -27.576840023148723, + -27.44251871424342, + -27.313256895515355, + -27.180729339362657, + -27.04782848835487, + -26.92002153360609, + -26.790315116923516, + -26.660100898820513, + -26.53355776255497, + -26.405116036858658, + -26.276057964785224, + -26.148177020639906, + -26.02033937766196, + -25.88848134437824, + -25.755681789181292, + -25.624872901363535, + -25.490242505278097, + -25.355470932101024, + -25.22236384459584, + -25.088800442352564, + -24.95535322982862, + -24.82351898930916, + -24.689119954604983, + -24.558511466905593, + -24.425997464165782, + -24.290544088167824, + -24.157214348218936, + -24.021861518407203, + -23.885871920684064, + -23.7523139250111, + -23.619343365392247, + -23.486278389618853, + -23.35571588003516, + -23.227448808905987, + -23.09791383972981, + -22.973946373338865, + -22.853558225518288, + -22.732734754033817, + -22.615414309116364, + -22.50516554436355, + -22.397252798556945, + -22.29470452337946, + -22.194411644474002, + -22.095695812618867, + -22.005828463686132, + -21.918981677629155, + -21.832796787388066, + -21.748707256505124, + -21.6700995782513, + -21.588476863832106, + -21.502785620624614, + -21.42035927917382, + -21.33895745261171, + -21.25746798767443, + -21.173432789420303, + -21.093717056254025, + -21.01421380814389, + -20.92910815212903, + -20.845431538646466, + -20.763570632347665, + -20.683981750056073, + -20.605257257783588, + -20.527823423418333, + -20.450680999068584, + -20.37330676600147, + -20.297762528648224, + -20.22249251447037, + -20.149339684754093, + -20.07675386335212, + -20.001966877314885, + -19.930592511200402, + -19.85838426092376, + -19.784511919787423, + -19.71074907849777, + -19.641326428023383, + -19.57203711207824, + -19.499837444978393, + -19.419880274945527, + -19.3467494212837, + -19.270347377065946, + -19.18931053505767, + -19.105591213435552, + -19.01875416637844, + -18.92788485237338, + -18.8336186850511, + -18.73710595657461, + -18.634141549797043, + -18.525922295281873, + -18.413488281084675, + -18.296599974487968, + -18.171392414075658, + -18.040355925748102, + -17.910027118057833, + -17.77338992847017, + -17.629885914394755, + -17.487061658089928, + -17.339625102593942, + -17.18249906920041, + -17.027878626052093, + -16.873170781463866, + -16.706255793779288, + -16.540318732606554, + -16.376606673490276, + -16.2001022273604, + -16.023159396344724, + -15.850571060076042, + -15.664271270944365, + -15.475896287314427, + -15.298426651813514, + -15.091561042812065, + -14.892259901334407, + -14.705024033893125, + -14.512483072769651, + -14.305026564216309, + -14.10519411144605, + -13.903524572486049, + -13.692464657487488, + -13.483742809265246, + -13.274246955139253, + -13.059759920135956, + -12.84455403674611, + -12.627799045360609, + -12.409685667361273, + -12.1924413750488, + -11.970868912787585, + -11.7464236450359, + -11.522442826727415, + -11.295977029960582, + -11.067532313046188, + -10.839099112981014, + -10.61717452483189, + -10.388719073544767, + -10.16158941914945, + -9.937843429918663, + -9.708727087611281, + -9.480656932105369, + -9.253151528476407, + -9.026050056669934, + -8.799426062753268, + -8.577421694627647, + -8.356533192411291, + -8.134179831465978, + -7.912744714117008, + -7.694252414226817, + -7.476803140464561, + -7.259360915045362, + -7.045289914935967, + -6.835186557510802, + -6.6243444886204585, + -6.419611856274912, + -6.2210353157285665, + -6.02486892025979, + -5.83073140440032, + -5.640612261893859, + -5.4562844787822256, + -5.269901658356259, + -5.084882915776098, + -4.898135437540016, + -4.714547937466565, + -4.533587236559559, + -4.3288310694700245, + -4.143397569531199, + -3.967343203614429, + -3.7935401456925155, + -3.619734248375605, + -3.450697890446953, + -3.2925389530276097, + -3.1389235386498022, + -2.990376432907575, + -2.8535680191276174, + -2.7222266067770535, + -2.6005019413857395, + -2.497352417612733, + -2.4000035272737996, + -2.3089878530450045, + -2.2337449494523565, + -2.1701904299675587, + -2.1159681678033944, + -2.074625610913162, + -2.045475535555629, + -2.029161475575345, + -2.026358986910275, + -2.0371416488528715, + -2.0618610307693253, + -2.100938228557439, + -2.155907372425467, + -2.2255863659585926, + -2.3095800842277896, + -2.407399203609409, + -2.5225881586497594, + -2.6630509837274308, + -2.8185007212722084, + -2.986035983338265, + -3.183263172760066, + -3.397533865505796, + -3.619823420106108, + -3.8637291704136136, + -4.124021877019453, + -4.390613092861507, + -4.6773400070370545, + -4.975224485938064, + -5.27345518278105, + -5.587907304664719, + -5.911992859510572, + -6.230762785194832, + -6.55513967648047, + -6.886529345227692, + -7.206733306233928, + -7.5352635724096775, + -7.869206688399662, + -8.189496714942116, + -8.513422442809402, + -8.83988899464593, + -9.152968752423135, + -9.466355174354861, + -9.780120593229137, + -10.087283830882374, + -10.391203247701847, + -10.695466047400185, + -10.997952737530179, + -11.2968763513574, + -11.59695340044311, + -11.89762892723237, + -12.199047458639765, + -12.503552294087825, + -12.810250490063101, + -13.123043937553804, + -13.436433210200024, + -13.751874414867586, + -14.071036938215897, + -14.390312532211357, + -14.712179025001948, + -15.03737136375139, + -15.36352784876661, + -15.695121715030757, + -16.026625347948894, + -16.355412133457943, + -16.69674625962069, + -17.033640657897166, + -17.369267607448105, + -17.704054002405943, + -18.054108663868952, + -18.40259955279605, + -18.736412744667824, + -19.091415333251987, + -19.449470775356403, + -19.79053783368905, + -20.142884001997782, + -20.511361012524375, + -20.864014383527227, + -21.220138142465128, + -21.592449615714653, + -21.95105644301703, + -22.306083533920376, + -22.67429333385156, + -23.037392076997108, + -23.393195909901355, + -23.75368144445266, + -24.12053574396884, + -24.482355018737177, + -24.84277917128603, + -25.20546346341832, + -25.57296965523862, + -25.933788732782507, + -26.295052429160425, + -26.66664627437255, + -27.035480756971413, + -27.399127734923916, + -27.769060923278527, + -28.146037446226607, + -28.51169084023298, + -28.87468841939414, + -29.24624414963434, + -29.604713050167096, + -29.962627439354968, + -30.32445985346866, + -30.681273837093627, + -31.03950346961145, + -31.39679867260935, + -31.748219274989207, + -32.10410935222879, + -32.46541736589728, + -32.821335797797595, + -33.1796506882653, + -33.54110010763142, + -33.89853684918133, + -34.25586492273972, + -34.61612883746897, + -34.97040605630931, + -35.31793899655061, + -35.672835745619764, + -36.01812177167872, + -36.36670621473768, + -36.72506259706506, + -37.08436397400378, + -37.44385573591024, + -37.816975636283104, + -38.18916070593347, + -38.55455601119412, + -38.92460038130979, + -39.299503653654085, + -39.66411577470547, + -40.03225720112146, + -40.40549389662069, + -40.76736290191831, + -41.12269769746697, + -41.482852300709595, + -41.84225629509178, + -42.1938791629395, + -42.55609226584854, + -42.91934939165613, + -43.27528231397707, + -43.636636751629936, + -43.99890114014338, + -44.34786787571027, + -44.701227001123364, + -45.05899890206312, + -45.412497155596654, + -45.761181757254725, + -46.11438109125427, + -46.468353171673265, + -46.820220999222904, + -47.16840113637365, + -47.51786808464334, + -47.86888628608643, + -48.21737330898203, + -48.56310893062025, + -48.91521831222494, + -49.272464591852795, + -49.62173762377014, + -49.96970628222811, + -50.32443776675892, + -50.67498731428199, + -51.012205085006684, + -51.35865074123352, + -51.706623319588275, + -52.043596672557385, + -52.38645600411932, + -52.73627141639067, + -53.07957442147244, + -53.41971349747408, + -53.77095151892777, + -54.114957890414765, + -54.45019496610639, + -54.79856140679166, + -55.146660878574004, + -55.48081222853083, + -55.81990923461163, + -56.16331143241947, + -56.494210845184845, + -56.82777955894273, + -57.166836628774895, + -57.49731944828853, + -57.82638527590838, + -58.15808213867223, + -58.48476491883403, + -58.803442229655566, + -59.122058537933945, + -59.444892148310636, + -59.762422052013186, + -60.08008914024777, + -60.40578938497902, + -60.72188398200776, + -61.02658827196787, + -61.33811740634257, + -61.64817223693242, + -61.94016895889685, + -62.23021658249741, + -62.52097652742526, + -62.79630336970065, + -63.06227947815229, + -63.33040216818478, + -63.57961685109977, + -63.824155271857705, + -64.06078522539734, + -64.27601863921969, + -64.48611394439426, + -64.68291275551678, + -64.8688910556516, + -65.05883481493532, + -65.2295824985437, + -65.39107094542109, + -65.55888416912086, + -65.71953239347093, + -65.87235787735521, + -66.02122733852457, + -66.16097797595631, + -66.3001078504453, + -66.43607181753445, + -66.5649978184758, + -66.68447963809102, + -66.80496996462325, + -66.91749658058505, + -67.01499682792995, + -67.10451054468302, + -67.18398164698797, + -67.24855957610923, + -67.30529088716048, + -67.3489588700816, + -67.37678516450451, + -67.39878145401124, + -67.41350821822829, + -67.41462716887222, + -67.4014549685066, + -67.37738092142811, + -67.34327272135867, + -67.29916608466613, + -67.24443754844559, + -67.17819032906456, + -67.09504772459013, + -67.00028290392237, + -66.89428747725327, + -66.77724842495371, + -66.64867135313715, + -66.50926243576065, + -66.36805672868711, + -66.21853389972787, + -66.06400771932115, + -65.90164156157225, + -65.734632794016, + -65.56657721606501, + -65.39037351954278, + -65.21468008734657, + -65.03688644861485, + -64.85251657624295, + -64.66359622923198, + -64.47857172903966, + -64.28603113159828, + -64.08911541569569, + -63.89806627292407, + -63.702733993316485, + -63.50088321116002, + -63.29755695886102, + -63.09161931112981, + -62.87983162884179, + -62.6671544312839, + -62.451920772708206, + -62.2337393737835, + -62.0166704822151, + -61.792339473587255, + -61.56352323841773, + -61.33568055930336, + -61.1016579322237, + -60.862786281664306, + -60.62418752885767, + -60.383963344211935, + -60.141195225475855, + -59.901914270296935, + -59.660478242705985, + -59.42293859390448, + -59.18513332557597, + -58.941345177919416, + -58.7002959173493, + -58.46433951249894, + -58.22189301364764, + -57.986011878624765, + -57.75913529554494, + -57.52218633547478, + -57.29423931936483, + -57.0698705213643, + -56.837254235681065, + -56.611813590205564, + -56.388279414302474, + -56.15604413982848, + -55.93087161005591, + -55.71582916372175, + -55.48567875695876, + -55.254269686957365, + -55.03334941884798, + -54.80055829621406, + -54.56311277320943, + -54.33762906335555, + -54.10600016746531, + -53.869014087005894, + -53.63889286127345, + -53.40369522408548, + -53.16219284132628, + -52.931527826394884, + -52.703625976211065, + -52.47509795814298, + -52.25330743116507, + -52.03453983917092, + -51.82054456648979, + -51.602559678881605, + -51.384751277985615, + -51.162011385075836, + -50.93741702373425, + -50.71615399273452, + -50.48993478386764, + -50.25605295810436, + -50.021506529335404, + -49.788406338011605, + -49.54634219897052, + -49.30882398784405, + -49.078007785005084, + -48.843338943819056, + -48.61757268965913, + -48.39566728328335, + -48.160147679826615, + -47.91391012701571, + -47.681482808142384, + -47.445264365310344, + -47.19521380793525, + -46.94829331667921, + -46.705173200194736, + -46.45590249097712, + -46.19414798768018, + -45.938591625021985, + -45.69135464833386, + -45.436556691036294, + -45.1700352242094, + -44.9193464817555, + -44.66370427821693, + -44.38462162809643, + -44.10928850903636, + -43.84475633329836, + -43.56101912259667, + -43.26916780925776, + -43.00215538749537, + -42.730103248848714, + -42.43255454426534, + -42.14271369274558, + -41.85100933746453, + -41.53016148815435, + -41.210831215664115, + -40.88815509054743, + -40.5510344005366, + -40.21464673634862, + -39.873850051081675, + -39.521198215048955, + -39.17320252466675, + -38.81390435094139, + -38.45129600828928, + -38.0975338708504, + -37.73035353362482, + -37.35950643773611, + -36.99679419638933, + -36.632887522917756, + -36.26406605379358, + -35.91209689382184, + -35.565609465375076, + -35.21445744958684, + -34.89036309066244, + -34.57407088684302, + -34.24588188342661, + -33.94459584976613, + -33.647373976645035, + -33.34684672071522, + -33.0639932380401, + -32.78837738725645, + -32.5060827451208, + -32.233616178662146, + -31.974568655034858, + -31.715857675406586, + -31.461831254485197, + -31.22039968696241, + -30.987933228162344, + -30.760172487308616, + -30.53458248163382, + -30.314936479045173, + -30.104151305456135, + -29.899127412438165, + -29.697353012702013, + -29.5040090923272, + -29.31373973405169, + -29.123571697693556, + -28.941167377881264, + -28.76558256105084, + -28.587589482188562, + -28.420649669863987, + -28.264423190323043, + -28.103857859423396, + -27.95094691330261, + -27.807944638035227, + -27.660012091095293, + -27.515775067245535, + -27.37740325508359, + -27.236327458306935, + -27.09665204951298, + -26.959008262330904, + -26.818546998844113, + -26.67694543511828, + -26.53794803076415, + -26.39265680223, + -26.244917938631552, + -26.099260537815184, + -25.947248918197815, + -25.795964651986313, + -25.645215501942417, + -25.491078797437968, + -25.335977085982055, + -25.181077948821404, + -25.029432783995343, + -24.87271441809661, + -24.71553012384757, + -24.560659198297937, + -24.40528157733817, + -24.25222422438953, + -24.099913260358807, + -23.95199826779356, + -23.80705387831147, + -23.66389650129534, + -23.50860153926942, + -23.377205840182665, + -23.247375051793995, + -23.11994141914694, + -23.000168906362543, + -22.885095003156678, + -22.77561164838907, + -22.671021165335954, + -22.570048134501, + -22.475925129539053, + -22.3871268371147, + -22.297129381391294, + -22.208194495110913, + -22.124045343718752, + -22.032814062121165, + -21.938671705458578, + -21.84945388620275, + -21.75913901266592, + -21.66291869108665, + -21.571842556129557, + -21.47889433693878, + -21.377522783740005, + -21.279676525653205, + -21.187195966730886, + -21.092740410870768, + -21.00013433181636, + -20.910009128995, + -20.81803034627194, + -20.725254572870437, + -20.635908500377614, + -20.549458520919444, + -20.462560882591013, + -20.375027708946895, + -20.29179943034236, + -20.207735550713927, + -20.123140092350056, + -20.0422695909282, + -19.963284652474147, + -19.882417677939628, + -19.80006098860956, + -19.7164917979833, + -19.629668419272896, + -19.537748434389247, + -19.44329916286591, + -19.343888268147293, + -19.239331412049115, + -19.132778184945618, + -19.019851930988047, + -18.899693634244414, + -18.7740222361764, + -18.64098764588919, + -18.498694309792743, + -18.35070991187078, + -18.201133407197233, + -18.042508094909714, + -17.880072443923964, + -17.714131263705024, + -17.534272578123293, + -17.352049855298723, + -17.166474852006907, + -16.971031804739926, + -16.77285335055818, + -16.57108523614885, + -16.362789926429176, + -16.1568502058274, + -15.945924725414416, + -15.727011384518159, + -15.52087926321072, + -15.308608777976925, + -15.081023749659833, + -14.867190488523567, + -14.65261469516981, + -14.422873334341448, + -14.19775737997882, + -13.97538784029993, + -13.74029848216648, + -13.514067712925359, + -13.305662879590912, + -13.090133036188979, + -12.8749084736165, + -12.66460687077586, + -12.447283591667928, + -12.226123652789182, + -12.008635692779118, + -11.786103363661667, + -11.563554636491423, + -11.336508136336477, + -11.105139193387451, + -10.87310897973837, + -10.641936490991498, + -10.41504853696415, + -10.187361287970107, + -9.959962900483445, + -9.734518728442552, + -9.512464440594053, + -9.289723557559512, + -9.06962063046247, + -8.850349641018138, + -8.629493815200412, + -8.408026328287127, + -8.190105732462012, + -7.977462748886694, + -7.764403192292031, + -7.555317194741051, + -7.353277302203586, + -7.154073892218664, + -6.959369532699178, + -6.772135414139355, + -6.588362759891173, + -6.407471870227311, + -6.230809235740635, + -6.0514239953809925, + -5.873849213501359, + -5.704583332793421, + -5.517992193649563, + -5.351842156381033, + -5.195978546597032, + -5.0444182955658725, + -4.89452980585661, + -4.75265711854504, + -4.616714693316472, + -4.478987911620376, + -4.347586387492915, + -4.221261991556717, + -4.094105073197255, + -3.973620987203808, + -3.8582581195654524, + -3.743854038062321, + -3.6406250204203503, + -3.543159820531951, + -3.451697622539284, + -3.373083583306613, + -3.302958184860286, + -3.238454729469617, + -3.1840381544859695, + -3.138262606993281, + -3.101446047905101, + -3.075092558166955, + -3.057787655905119, + -3.051154313913769, + -3.057541909161938, + -3.0777058046686556, + -3.11226177300754, + -3.1633028148244224, + -3.2269528478740774, + -3.3055484962963204, + -3.411798257575102, + -3.5327781800075804, + -3.6598205132349197, + -3.8042761382970274, + -3.9637921100409432, + -4.1343610580237, + -4.321488710941954, + -4.5254183007263, + -4.743492198481461, + -4.976546490154955, + -5.217377413034586, + -5.463944656267498, + -5.725466210792774, + -5.9942265008430455, + -6.266880248401289, + -6.547174852306351, + -6.830449296321783, + -7.140692472535568, + -7.4244805791327595, + -7.712915267119155, + -7.99947758519533, + -8.287869958446791, + -8.580775449727524, + -8.870859696915527, + -9.158148441959966, + -9.44699351667771, + -9.732155623590316, + -10.014872063053648, + -10.29775979865664, + -10.5749857424261, + -10.852780150019186, + -11.13579070510385, + -11.414735024671282, + -11.691983571243979, + -12.001988427141596, + -12.284375625399804, + -12.56870016424334, + -12.853058122056988, + -13.137900260390332, + -13.42979842600842, + -13.719020086320706, + -14.009281632255217, + -14.301011028445044, + -14.592116426583539, + -14.890971598247424, + -15.184707720545834, + -15.481093409437367, + -15.785987358772202, + -16.092112434469538, + -16.39191902946486, + -16.69850670541692, + -17.0161434213426, + -17.32450689834855, + -17.639360481835368, + -17.952300054487637, + -18.27492426202884, + -18.60534648107577, + -18.923198927697737, + -19.25420879463818, + -19.595659059684134, + -19.922765883100173, + -20.253464392487235, + -20.601756971249607, + -20.942552896658825, + -21.278023387366545, + -21.630659974387093, + -21.981249616001126, + -22.32017323749989, + -22.668377046161826, + -23.05463993470257, + -23.399669839059996, + -23.744378497142918, + -24.093191328550297, + -24.44436936299327, + -24.795515704871704, + -25.14414383087788, + -25.496241874938153, + -25.849278367345228, + -26.20207562252542, + -26.55782194511243, + -26.9148884183778, + -27.273738204083404, + -27.62996812180062, + -27.988628487833694, + -28.343905979595174, + -28.69105165873182, + -29.037582666102935, + -29.38635666555043, + -29.723594788931088, + -30.056689528810995, + -30.398340778637635, + -30.732889562974144, + -31.064289829857046, + -31.399693377691765, + -31.73114349509177, + -32.06073470265107, + -32.39085993270465, + -32.726477933935236, + -33.06476388178211, + -33.39372387206702, + -33.72672082897999, + -34.065249854126584, + -34.39760966485348, + -34.73058887095751, + -35.06296182746248, + -35.38255055373931, + -35.70814463524432, + -36.039025349592656, + -36.35972932359205, + -36.68348309060314, + -37.015753113562184, + -37.34542599091974, + -37.67911904288027, + -38.020693570149206, + -38.36066205243361, + -38.70234003793076, + -39.0451492945628, + -39.39470491993342, + -39.74042583403011, + -40.08281748005475, + -40.43186199005066, + -40.78057224431363, + -41.12203865389018, + -41.46267864179745, + -41.808416009280755, + -42.1509228039912, + -42.49332936056103, + -42.8435430598219, + -43.18692325830212, + -43.52734250106671, + -43.8776940098036, + -44.21698436032345, + -44.556835639299806, + -44.89983918658318, + -45.242047277374795, + -45.5826423524521, + -45.92626112623457, + -46.26649535067602, + -46.60896010384153, + -46.949825362626704, + -47.28504508563118, + -47.62134759487119, + -47.95862599059289, + -48.29311570298762, + -48.62734834939975, + -48.96608905881075, + -49.30823556860031, + -49.64709563386687, + -49.98457430475821, + -50.32742155888007, + -50.66680723085256, + -50.992571879080906, + -51.32806674987615, + -51.6678732874057, + -51.99491914646086, + -52.33010756345978, + -52.67357247262765, + -53.010980957968805, + -53.346069446378266, + -53.693569866543896, + -54.037595073463436, + -54.373086426955794, + -54.72237993359764, + -55.07397645239609, + -55.412614130024785, + -55.75677926690112, + -56.106190350043306, + -56.43909434686085, + -56.77630047366981, + -57.11950997288579, + -57.455225540470664, + -57.78897197103224, + -58.126435303607316, + -58.45929014889902, + -58.78439546632375, + -59.11397571401777, + -59.45000916190445, + -59.77746170812227, + -60.108321017355244, + -60.440039735420875, + -60.751897208358145, + -61.05539281208393, + -61.35354119299772, + -61.643599581934104, + -61.912007900310435, + -62.18193003148963, + -62.44967261244594, + -62.69734117106019, + -62.94625825298737, + -63.1885117669288, + -63.44629136317529, + -63.71321495812481, + -63.96855143572826, + -64.21750078899412, + -64.45998084573628, + -64.69048974510206, + -64.9097276652017, + -65.13231715848787, + -65.33119631774655, + -65.52088910011457, + -65.70062746247886, + -65.8565819388565, + -66.00708425595732, + -66.15325504641662, + -66.28385190357783, + -66.40717356075399, + -66.52631527291177, + -66.62958093381904, + -66.72924851871147, + -66.82707187290991, + -66.91117047812757, + -66.98971235880302, + -67.06456025223856, + -67.1300129744069, + -67.1889735503678, + -67.23783511245693, + -67.27488175592373, + -67.30723066911459, + -67.32866615617212, + -67.33249111606224, + -67.32121765826327, + -67.29769196841443, + -67.26372753216401, + -67.21770276411456, + -67.15775947255702, + -67.08362594093559, + -66.99374450078747, + -66.88971886776457, + -66.7747525958143, + -66.64499872374087, + -66.5047908956285, + -66.35542672806882, + -66.19544669942078, + -66.02707019249198, + -65.8507498384762, + -65.66671210846613, + -65.4831662845677, + -65.29471575998292, + -65.09597657589894, + -64.90250498287553, + -64.70670321752343, + -64.49651569008526, + -64.28647349591253, + -64.0832888223937, + -63.8636405810657, + -63.6391231646599, + -63.4189590613997, + -63.19149285289527, + -62.95389538462015, + -62.71839446630718, + -62.478440369530034, + -62.23445317242666, + -61.98963337497661, + -61.740036658029396, + -61.49076695590796, + -61.241328126844905, + -60.98683623445296, + -60.73310879462256, + -60.477894447425896, + -60.220686607194, + -59.96636730543988, + -59.714213309905475, + -59.4590148414393, + -59.20951404769694, + -58.96006911126184, + -58.70508896088304, + -58.45649206354154, + -58.20666486943769, + -57.951979833842046, + -57.70589026616218, + -57.464689786537754, + -57.21669082657974, + -56.980934676128975, + -56.745921029818305, + -56.50802576386077, + -56.276750829476114, + -56.04317179109562, + -55.80867142570943, + -55.57667262122025, + -55.341456448501326, + -55.10434011192988, + -54.877842723637585, + -54.643498335668696, + -54.403127624811034, + -54.168984336643426, + -53.93372568497106, + -53.69061024371152, + -53.45504493834885, + -53.21811788454279, + -52.972977071567044, + -52.72832673016044, + -52.47843122474741, + -52.22517971900285, + -51.97870037637053, + -51.73492669184909, + -51.495159189060026, + -51.25994896995887, + -51.02679480745267, + -50.7949423064563, + -50.56273688923523, + -50.322646298805935, + -50.0518369270235, + -49.80940060456466, + -49.56274912715791, + -49.30633051674411, + -49.048153248661066, + -48.788574214159524, + -48.528002933079165, + -48.26586732939847, + -48.0063101472052, + -47.76030630870287, + -47.52441189578046, + -47.291128619127235, + -47.05883153043223, + -46.82799772148867, + -46.59482025492095, + -46.365585376495105, + -46.11109902775143, + -45.87727287335853, + -45.642914140177105, + -45.40265378491914, + -45.152736261013686, + -44.897224738899496, + -44.65296719519171, + -44.40939803456398, + -44.15264580641331, + -43.90177598054978, + -43.663467940761436, + -43.409574537718164, + -43.149188443721904, + -42.91191714832933, + -42.672110917469624, + -42.423003772869414, + -42.17677983247791, + -41.929123357104196, + -41.67955502117077, + -41.413873554753664, + -41.14334513864083, + -40.86956510701395, + -40.5741653030368, + -40.26044947653966, + -39.94043634888155, + -39.610520471411476, + -39.277888904030334, + -38.93790784879036, + -38.59248930617415, + -38.2485112170478, + -37.89317134392239, + -37.53036248814005, + -37.174427806825406, + -36.85794224453516, + -36.53845093440649, + -36.21348196170747, + -35.89285402641411, + -35.582396884716175, + -35.260874017508556, + -34.95264533947338, + -34.657933152357934, + -34.364936308225865, + -34.06839631466887, + -33.75798183397996, + -33.453822376036925, + -33.15327496730229, + -32.85245477930143, + -32.55264967900641, + -32.25956613718065, + -31.968360034488466, + -31.678558658588116, + -31.395253608918217, + -31.12487086303531, + -30.86073486042292, + -30.596911207466043, + -30.35094700172187, + -30.107916916846243, + -29.86662969238843, + -29.64140287514268, + -29.42465730060835, + -29.20873141950385, + -29.003208174876242, + -28.801047166518092, + -28.598559421673098, + -28.408279489559504, + -28.226271545182716, + -28.04055186945441, + -27.87096283596155, + -27.71148765551883, + -27.546167521352587, + -27.39230417727304, + -27.24877580876583, + -27.09783215893978, + -26.954937921799782, + -26.818949965758552, + -26.681676601821323, + -26.547398310076627, + -26.416396387906435, + -26.285743836643963, + -26.154399525411, + -26.028216440474537, + -25.89718916840326, + -25.76383098390717, + -25.63069841007004, + -25.496061982526225, + -25.359107458747367, + -25.224247495385317, + -25.089574969004996, + -24.95285873066305, + -24.819494799318345, + -24.684718566617278, + -24.554116616503265, + -24.419019357854157, + -24.284172149282224, + -24.148707793931205, + -24.01186473981563, + -23.876342746496064, + -23.738868073052906, + -23.601694397583806, + -23.46532239616867, + -23.32692643656262, + -23.18872086142345, + -23.05105840492789, + -22.91885948279839, + -22.785663621276885, + -22.654800919142524, + -22.532057954698, + -22.411929968003665, + -22.29671530099123, + -22.186545449912806, + -22.07959033186948, + -21.979140273166, + -21.88559955764864, + -21.793155405157744, + -21.701638186393176, + -21.617775345367917, + -21.53450567338066, + -21.44536929480654, + -21.36241397007175, + -21.284528674198466, + -21.205495079814423, + -21.127556831766018, + -21.054384220246313, + -20.97955361232512, + -20.90073577389424, + -20.82357507949722, + -20.74731866628342, + -20.6711798956272, + -20.593662158870323, + -20.515250902662043, + -20.435299164524103, + -20.35293541021248, + -20.268813027999187, + -20.185173336563242, + -20.10105356746033, + -20.0118211480889, + -19.92464043858409, + -19.836792297259407, + -19.747523545570328, + -19.6571131930019, + -19.570821403242952, + -19.486117255630514, + -19.40006239938822, + -19.31576568805221, + -19.230069952903875, + -19.143307897505597, + -19.054961879789328, + -18.965652637571274, + -18.873872155913173, + -18.778858428906904, + -18.68578786711409, + -18.58738025704817, + -18.483825961339633, + -18.381112187414274, + -18.2740657788208, + -18.162325042151902, + -18.04407412307408, + -17.919854667959843, + -17.792509316116345, + -17.66184175675464, + -17.523915424889008, + -17.380889757343244, + -17.23131569457251, + -17.07543853247107, + -16.912602931612362, + -16.742816293028714, + -16.57246049417901, + -16.3943147157906, + -16.212966686173175, + -16.032686508658802, + -15.847141521277802, + -15.651425656594203, + -15.464992511939615, + -15.278464908435609, + -15.072448379880425, + -14.8761217062947, + -14.685520947858995, + -14.480019650318107, + -14.27821657976232, + -14.081162336099894, + -13.878548719850825, + -13.676114825254635, + -13.479426329362635, + -13.281653292041767, + -13.081932356207679, + -12.883715506200103, + -12.679829250555638, + -12.474074208665568, + -12.264738759020902, + -12.052432689037936, + -11.836503537782825, + -11.621061684740479, + -11.401713923807238, + -11.180083391260311, + -10.958317772071153, + -10.731046098698549, + -10.505790343847037, + -10.278488349239959, + -10.047943367741272, + -9.817561506108019, + -9.587572898564792, + -9.35766954806178, + -9.126937303434898, + -8.900387525863927, + -8.67631977764867, + -8.448014145550228, + -8.220877478765015, + -7.998791010097954, + -7.773894684448212, + -7.541617244165808, + -7.320168571875958, + -7.098763267260314, + -6.8729716888569605, + -6.652353600687759, + -6.43479038495738, + -6.217797107537136, + -5.998216552350932, + -5.776678779729985, + -5.560223461091686, + -5.345274024557623, + -5.127561761243892, + -4.917491843645876, + -4.715461570545477, + -4.505065395882604, + -4.307361113708217, + -4.124523779496649, + -3.9355579397707596, + -3.7491507672574134, + -3.5787963356111887, + -3.406589393198464, + -3.24175832452202, + -3.086619820093648, + -2.930549832146983, + -2.795652688582039, + -2.6676514831264004, + -2.5401125480067064, + -2.4299029546326705, + -2.333505659562151, + -2.2450534950716903, + -2.17204735922884, + -2.1128176477362017, + -2.064530519347138, + -2.030370074393183, + -2.010393138338829, + -2.004138709737753, + -2.012870078493015, + -2.0370841526207615, + -2.0763602872947198, + -2.131686324194107, + -2.2055013778985, + -2.2978840990695226, + -2.4069163471706303, + -2.5364879960175424, + -2.68534378683897, + -2.848188644199925, + -3.037611048174019, + -3.2441099854616016, + -3.4647746970836804, + -3.703799504275772, + -3.9553835254678953, + -4.218520760246446, + -4.495366437253349, + -4.785639833248068, + -5.08484574406934, + -5.386986180057311, + -5.702574899155793, + -6.024555013090089, + -6.348233572480257, + -6.676596928656058, + -7.002112491822304, + -7.323535097965916, + -7.64907641442559, + -7.974162122243784, + -8.29329386687016, + -8.614724092092834, + -8.967351104418675, + -9.281004572916306, + -9.595351486147798, + -9.906920463817347, + -10.215609104682294, + -10.522452460283825, + -10.822021762894998, + -11.12167338367002, + -11.419731886306183, + -11.71402587006081, + -12.009303977815186, + -12.3039125702886, + -12.59794267786584, + -12.890491044518715, + -13.183914504029614, + -13.481320566748892, + -13.77774003043691, + -14.076319692672689, + -15.135851789820482, + -15.448746547073267, + -15.7701082496401, + -16.09652425161251, + -16.419857209306667, + -16.74923535317632, + -17.087754028882582, + -17.418509385474984, + -17.749829044650806, + -18.081905985797718, + -18.43035359220856, + -18.774654053643356, + -19.111139682495992, + -19.46983396195551, + -19.826603384066924, + -20.169733909783766, + -20.52679978792883, + -20.892862525450653, + -21.24183538689616, + -21.599351761140987, + -21.969785438463287, + -22.3253797601872, + -22.684291974552025, + -23.052360362425834, + -23.41196124967644, + -23.770153047626177, + -24.133781319508433, + -24.50390460074309, + -24.872779437632673, + -25.240713528172346, + -25.613006170135684, + -25.986125133400538, + -26.356572373107657, + -26.732719884490372, + -27.110898722633916, + -27.488921639221125, + -27.867802714125826, + -28.250522161210075, + -28.627673559074477, + -28.999630763154563, + -29.374260600789675, + -29.747133847477823, + -30.108627139695123, + -30.4738748947489, + -30.845345249687618, + -31.20496502543251, + -31.56838718343814, + -31.934373093137385, + -32.29625825293812, + -32.655795528669486, + -33.023786562502075, + -33.394066451795176, + -33.75437144700971, + -34.12279959224543, + -34.49565203273659, + -34.86116745868966, + -35.22769623168896, + -35.58748552490108, + -35.93626152603457, + -36.29294327802148, + -36.64582261485386, + -36.995657073893845, + -37.35707992209258, + -37.719757995762386, + -38.08279784436257, + -38.45545006209987, + -38.82825938742799, + -39.19826192510871, + -39.57148727374065, + -39.94837305387175, + -40.3131466176835, + -40.68148435156973, + -41.053580633180246, + -41.41471059380327, + -41.771110999942785, + -42.13160233797795, + -42.48929319127417, + -42.83258211303702, + -43.17156392946248, + -43.50777876759926, + -43.83621764122727, + -44.17080271291086, + -44.52739005309362, + -44.84879118581973, + -45.172689735288515, + -45.49298321643188, + -45.81118501611773, + -46.13863714923011, + -46.46773403246215, + -46.795166887960754, + -47.1257062847864, + -47.44813465160126, + -47.76764993598219, + -48.086771459399884, + -48.40222941032584, + -48.71907808757432, + -49.03399471655375, + -49.34354732591616, + -49.64813474894097, + -49.94981490289839, + -50.251862128363726, + -50.55493943284664, + -50.85796949254924, + -51.15180674861884, + -51.44529451950592, + -51.74712457823629, + -52.04292433028553, + -52.33742608542614, + -52.64444297727034, + -52.95310408136931, + -53.25735769920424, + -53.559259930307086, + -53.868186045975776, + -54.176431125540695, + -54.478829869137414, + -54.78500535792089, + -55.095673466036885, + -55.40471352965616, + -55.714435944882986, + -56.02605195190667, + -56.33613734445318, + -56.64287796102448, + -56.94404246659396, + -57.2454612627552, + -57.547812439847576, + -57.84562457342142, + -58.13879443261982, + -58.433761921080084, + -58.72899131468643, + -59.01665140201153, + -59.30114472694361, + -59.590252232730855, + -59.87847927387634, + -60.162801188955875, + -60.44947893580497, + -60.738067632631044, + -61.01861847533337, + -61.29090609295296, + -61.55798266103622, + -61.823011438212895, + -62.08601868209725, + -62.332524119480766, + -62.585320398106404, + -62.835515300168986, + -63.07422177960659, + -63.31370164822563, + -63.54605428117169, + -63.7728317320232, + -63.998742709320645, + -64.2148619436213, + -64.4281100223949, + -64.63494551952286, + -64.8312054221813, + -65.02270635297941, + -65.21286313376582, + -65.38030932523728, + -65.55025723287066, + -65.7186392970733, + -65.87307562484358, + -66.02658267148043, + -66.16997682946553, + -66.30036537860236, + -66.42860486157353, + -66.54586257112908, + -66.65080224457779, + -66.75524537867345, + -66.84978197627228, + -66.93204206142951, + -67.00986209234982, + -67.07818786683039, + -67.13766003323333, + -67.18855777233131, + -67.22686182700441, + -67.2561763081435, + -67.28037100105091, + -67.29232540801087, + -67.28902093918421, + -67.27309618532365, + -67.24619829938165, + -67.20977023756896, + -67.16182994888243, + -67.10232748696933, + -67.0303655676146, + -66.94613023023517, + -66.85302319135522, + -66.75040284348367, + -66.63711735946971, + -66.514289986404, + -66.38405686478657, + -66.24356107069106, + -66.09319796269615, + -65.93253183883319, + -65.76682167290186, + -65.60006941227343, + -65.42529806307405, + -65.24472705747665, + -65.06968890862828, + -64.88883274173362, + -64.69673994010479, + -64.50796366224124, + -64.32312159258277, + -64.12314018644014, + -63.92081827337502, + -63.72362476919724, + -63.518495215160286, + -63.30621814697917, + -63.09775156198335, + -62.88512388334975, + -62.66755489876957, + -62.45221622175098, + -62.23244384909481, + -62.01294574095953, + -61.797165848834396, + -61.57844859793764, + -61.35676750292627, + -61.13764904108875, + -60.913958948083966, + -60.68835770839192, + -60.4639032078944, + -60.240140380911946, + -60.01428124793145, + -59.79349400741151, + -59.568453565516144, + -59.34337280034526, + -59.12099930516383, + -58.89234417074142, + -58.66546206150973, + -58.443926654764425, + -58.22197563076419, + -57.99691914914282, + -57.7858573920524, + -57.56798673677686, + -57.347636867085846, + -57.13459913937889, + -56.912562430064085, + -56.687429448579756, + -56.468789942761816, + -56.245386985762075, + -56.018754601727956, + -55.799065418774035, + -55.5855183332723, + -55.36229348495816, + -55.14053947443615, + -54.9222949185408, + -54.70282645857301, + -54.480597882969064, + -54.26499998226399, + -54.05146625314503, + -53.833881587790664, + -53.61799894792032, + -53.40159086033317, + -53.180210132692544, + -52.962174855186966, + -52.747858101984505, + -52.5351623575907, + -52.328901860940086, + -52.12870463886185, + -51.924782676423575, + -51.72756685644262, + -51.53106038072766, + -51.3286602838826, + -51.123117717595456, + -50.91323763770998, + -50.70543784235384, + -50.49438155851113, + -50.27510822028462, + -50.050909480063616, + -49.823669528226226, + -49.59422691319602, + -49.36440612388891, + -49.134583903761175, + -48.90276059387876, + -48.6776976551374, + -48.45875150493637, + -48.237037244812875, + -48.01299505200677, + -47.78735522324146, + -47.56730557931867, + -47.34259502952261, + -47.11120550652267, + -46.88007594379417, + -46.64671633762666, + -46.416150078136745, + -46.172300421953395, + -45.92810876478263, + -45.68903856733507, + -45.446385626242474, + -45.18985324628705, + -44.94372691322733, + -44.70192897308499, + -44.4388573852421, + -44.17362366263393, + -43.92040723303416, + -43.65516753542635, + -43.37830386078054, + -43.12025807725942, + -42.86759601270207, + -42.58731737506445, + -42.30107905713094, + -42.022423625323306, + -41.719980538226665, + -41.41304072763449, + -41.11286088896048, + -40.795923494360935, + -40.48217486048098, + -40.16624475184209, + -39.83954074672723, + -39.51433138503104, + -39.19059499041351, + -38.855692630079304, + -38.52375773208069, + -38.19461421504617, + -37.84480949933679, + -37.49163018486492, + -37.144637803730156, + -36.7925621307796, + -36.43533341187442, + -36.086511974681386, + -35.74411094330811, + -35.3974239198621, + -35.070615321269564, + -34.75612404371996, + -34.433828009503436, + -34.119296782661245, + -33.82917333821225, + -33.53119170654072, + -33.23896565819366, + -32.96298573650396, + -32.68880755570047, + -32.411180307255314, + -32.14730907317692, + -31.89148736574123, + -31.63406399894263, + -31.382816104159236, + -31.14260470511524, + -30.908738528997826, + -30.679175267360588, + -30.453972229323984, + -30.239954435773935, + -30.034953451987597, + -29.83544699147934, + -29.63870010592736, + -29.45051392019687, + -29.26484764906907, + -29.080832426982415, + -28.904547605644584, + -28.733812127457593, + -28.560948097803262, + -28.398201418249936, + -28.244219460857213, + -28.085503298428566, + -27.93239840355711, + -27.788020578850716, + -27.6382945306535, + -27.489944406321268, + -27.347831741300396, + -27.204421570626206, + -27.05296883779879, + -26.898753759572028, + -26.743990246723495, + -26.58693748758553, + -26.434903402143977, + -26.27864859898281, + -26.118987136725572, + -25.960688848867633, + -25.80263384082788, + -25.641587531545945, + -25.490634425886167, + -25.345428899364343, + -25.198233210221513, + -25.054022176188564, + -24.90948308823688, + -24.768150125688518, + -24.622367149289303, + -24.47454313616094, + -24.32689585944346, + -24.177551516639706, + -24.02368924149357, + -23.864973243607917, + -23.70908548011217, + -23.55665080355543, + -23.40380073678331, + -23.255790164156334, + -23.112005237974362, + -22.974921510501254, + -22.836658156324113, + -22.701933896202043, + -22.57860038025221, + -22.460292385077597, + -22.34604813811691, + -22.235431708965255, + -22.12778442554132, + -22.02883321507727, + -21.934526113250715, + -21.841440296149116, + -21.750985679990453, + -21.667079661963893, + -21.5867976110394, + -21.508110597429436, + -21.437055558110558, + -21.36902699900956, + -21.298175186937353, + -21.230834669316433, + -21.167569360273735, + -21.099309576083893, + -21.030590119263348, + -20.963200368387593, + -20.890956437634305, + -20.8155983053334, + -20.738588165632944, + -20.660873903780818, + -20.58267998099796, + -20.502745508009546, + -20.421955384431214, + -20.34330220580916, + -20.265145814853938, + -20.18118020561016, + -20.102040809730145, + -20.026286139676, + -19.94960758496676, + -19.87154028788423, + -19.796544460529947, + -19.722975346652813, + -19.646816421842942, + -19.568828146982252, + -19.485389574699816, + -19.396748886114946, + -19.300280064223347, + -19.196228108034806, + -19.088681697785233, + -18.973792366598712, + -18.855078107823672, + -18.73464300770831, + -18.60742735561582, + -18.47426458226979, + -18.3356953708922, + -18.192468238461107, + -18.05014620355314, + -17.900311344850667, + -17.748259627428386, + -17.59985239845478, + -17.44749018124793, + -17.289831861687734, + -17.13476711030833, + -16.978987429061053, + -16.815182750249186, + -16.649125894399738, + -16.484893876493732, + -16.31141786509426, + -16.13937011537734, + -15.966794840397359, + -15.78756663352844, + -15.60156817702558, + -15.42305278717738, + -15.242249598749698, + -15.043921348952347, + -14.860129157670553, + -14.670611824815383, + -14.465757156182125, + -14.264695572567078, + -14.063136254561105, + -13.855595937981844, + -13.646791011708268, + -13.442235373770762, + -13.233586255091215, + -13.022735948499761, + -12.812340887866805, + -12.599936269082722, + -12.3875157182114, + -12.174185056595656, + -11.961061223735184, + -11.743969719285408, + -11.530653431910961, + -11.314437774434387, + -11.10220201621013, + -10.891674418989599, + -10.678543557240541, + -10.469547248740419, + -10.2624515237825, + -10.051378919587437, + -9.841779070441515, + -9.634490779696469, + -9.427093585360735, + -9.218670150847851, + -9.015585082397052, + -8.813825543884338, + -8.607017514119331, + -8.401703878797997, + -8.20298074403046, + -8.004702741756736, + -7.801652168700154, + -7.610108952407171, + -7.42142288803849, + -7.2264885692473735, + -7.039059526496087, + -6.8538488835666485, + -6.667048806111803, + -6.480069200058203, + -6.290853721787352, + -6.098503309303771, + -5.9105502171822195, + -5.7192761035844, + -5.529391050315129, + -5.346962338968414, + -5.156469172134299, + -4.973915853484037, + -4.797752869141554, + -4.6190817736512875, + -4.447224680955508, + -4.283957503658762, + -4.125194422673401, + -3.970335520706842, + -3.823632292155835, + -3.6807675216413864, + -3.547546651900776, + -3.423187895815954, + -3.3027916263417385, + -3.192411461611369, + -3.091623219459845, + -2.9974353690905984, + -2.9142283907573603, + -2.842398311146486, + -2.78200160182335, + -2.7332929856987773, + -2.697585945913979, + -2.6758406562339077, + -2.6692647170003116, + -2.6794013408390134, + -2.7068345307342736, + -2.750243888382191, + -2.8140204735141774, + -2.9019989163318414, + -3.0102149280354915, + -3.136163434944975, + -3.292116164523131, + -3.4732407107632817, + -3.6673367794577665, + -3.8855819480982174, + -4.121239444135406, + -4.367886537153666, + -4.6310590045092255, + -4.909107602324732, + -5.198241707007494, + -5.493338337837971, + -5.801065701760351, + -6.11715304648553, + -6.436004558689778, + -6.761900661809447, + -7.08484882685639, + -7.406039134764869, + -7.729457535360628, + -8.053267053009991, + -8.37159367129153, + -8.68936576889038, + -9.004949568297512, + -9.31682047245802, + -9.62696774417688, + -9.93245417582924, + -10.235487151275668, + -10.536757563008043, + -10.831021603735033, + -11.126865921930344, + -11.422013064881448, + -11.711720823201341, + -12.003050423904435, + -12.296803887576305, + -12.589869543335947, + -12.882519007259162, + -13.206824304862469, + -13.506153041046147, + -13.804482252710493, + -14.104526524025465, + -14.407232207829418, + -14.713380061745976, + -15.025173389468257, + -15.339392391657894, + -15.657636340223394, + -15.985835400864142, + -16.315854623638366, + -16.645196662538208, + -16.98721157759249, + -17.32813326129069, + -17.670578183215763, + -18.012771106910797, + -18.372457112967986, + -18.728236276014083, + -19.07590688798593, + -19.448794833792853, + -19.819856423742173, + -20.17754088957126, + -20.55025241238078, + -20.93166781440896, + -21.29683990132369, + -21.674990368027238, + -22.06002927686083, + -22.43131262677279, + -22.81070485571307, + -23.193418361000578, + -23.565429945496376, + -23.940763282180978, + -24.326808498268797, + -24.70813635926141, + -25.085573837306793, + -25.471415865390853, + -25.85905754259594, + -26.24067967699594, + -26.62835577756132, + -27.023276222597836, + -27.41673543567052, + -27.810942640593055, + -28.21092489010343, + -28.603232087112556, + -28.99032235059943, + -29.38285130277187, + -29.7640527462422, + -30.138842298441205, + -30.524940716961034, + -30.9031609071591, + -31.275535933580933, + -31.653674064251156, + -32.031825364270404, + -32.40078807806816, + -32.77565139090284, + -33.15572092407154, + -33.524616131218615, + -33.89788555388855, + -34.31205031475888, + -34.67917302969318, + -35.047490717617805, + -35.40843109635998, + -35.765511114477306, + -36.12346065001754, + -36.47838098507164, + -36.83849735418127, + -37.20493953280114, + -37.566328907076304, + -37.93701170042838, + -38.31494730759715, + -38.6832600789177, + -39.04813864890323, + -39.42307988909404, + -39.788370226073305, + -40.14692481413972, + -40.514965166964416, + -40.876722858299054, + -41.22342213010617, + -41.574186828760354, + -41.93368443234105, + -42.28015563616246, + -42.63314166434982, + -42.98871329167694, + -43.336115226771106, + -43.682257706749766, + -44.03144549426208, + -44.37217780132779, + -44.714603335792184, + -45.056444297411836, + -45.39785578436033, + -45.73739926610399, + -46.07722256707363, + -46.41392265976048, + -46.7536884124984, + -47.08837036597805, + -47.41816538407625, + -47.74972236062895, + -48.07810458489749, + -48.406582657969295, + -48.735465316936306, + -49.06695767124381, + -49.39895970424358, + -49.73018564900906, + -50.06126296328641, + -50.39551242883402, + -50.72834905185715, + -51.05150893735426, + -51.38114781399141, + -51.717259812594975, + -52.07301299189745, + -52.39875053442496, + -52.733819938449834, + -53.06355532999838, + -53.38589277702482, + -53.716245957876666, + -54.04869301317649, + -54.37178924855617, + -54.69642329062652, + -55.02574949580002, + -55.34834129208912, + -55.669145147422846, + -55.98960028704136, + -56.3076972647037, + -56.621269058949764, + -56.93013001812104, + -57.24022509237409, + -57.55092377293988, + -57.85341381392597, + -58.15499335547893, + -58.49317394099913, + -58.795631088408726, + -59.091024596135846, + -59.39217311527495, + -59.69419299884416, + -59.99299964095694, + -60.29196990099787, + -60.59386658038928, + -60.88848406793987, + -61.17521208756635, + -61.4546765742016, + -61.735882797629706, + -62.00895634729002, + -62.26878311208972, + -62.535928099095486, + -62.79467599850539, + -63.04577366907901, + -63.29258615431078, + -63.532465026755965, + -63.76603160470231, + -64.01143698884582, + -64.2470503431708, + -64.47776937000921, + -64.69985146036979, + -64.91011890076263, + -65.11988179054424, + -65.3067541269102, + -65.49349294740588, + -65.67787352151993, + -65.84831801820904, + -66.00543237013562, + -66.15149592500164, + -66.28718826595971, + -66.42162106870808, + -66.5449400884666, + -66.65912764232849, + -66.77431064903143, + -66.87844686027412, + -66.97074860087946, + -67.0581201025096, + -67.13291211280782, + -67.19976263984901, + -67.2577201728544, + -67.30180824980346, + -67.33745312389016, + -67.36778895840271, + -67.38511135378526, + -67.38654074476754, + -67.37602965781693, + -67.35486159496755, + -67.32253305632219, + -67.27822350802018, + -67.2214423833344, + -67.15138090897804, + -67.06713408138613, + -66.97115496015253, + -66.86398241166371, + -66.74260535983153, + -66.6124286904359, + -66.47603631568566, + -66.3223538863303, + -66.16347486178844, + -66.00119848845848, + -65.83522823405262, + -65.67165984607001, + -65.50577920561427, + -65.32982397618471, + -65.15726721485673, + -64.9839874745163, + -64.80122020292747, + -64.61961921010575, + -64.43955916572844, + -64.25791070889484, + -64.06429147947685, + -63.86949330469262, + -63.67349152816083, + -63.47230699031968, + -63.26594689905641, + -63.0585791504768, + -62.84804847784432, + -62.63558106177187, + -62.42359538938827, + -62.20750638047554, + -61.99064999614591, + -61.77627137593541, + -61.5591030115546, + -61.3393009777562, + -61.119650829716875, + -60.895726935657365, + -60.668349275660866, + -60.440341380750475, + -60.21114195590213, + -59.9799352263799, + -59.72910276464591, + -59.498188999257664, + -59.26975687989393, + -59.04510407064621, + -58.818231079100684, + -58.5988116024957, + -58.38759521545623, + -58.175723036610236, + -57.967848630035796, + -57.77784702640796, + -57.582916018358446, + -57.38763405398029, + -57.2043625573086, + -57.01231722463997, + -56.81442331127769, + -56.62234054540864, + -56.42458477237226, + -56.224255895006415, + -56.02328811809464, + -55.81992462558698, + -55.62008498502751, + -55.41666243639262, + -55.203566020502805, + -54.99347717623862, + -54.7827257372063, + -54.563907961629525, + -54.34541200182019, + -54.12995911009073, + -53.90694176591345, + -53.68048133160258, + -53.45608410949461, + -53.22444569531539, + -52.988828306748424, + -52.759969627072124, + -52.5307230438456, + -52.30249765484058, + -52.08318313925874, + -51.8617495682587, + -51.64291895230069, + -51.42674271864286, + -51.205374836573014, + -50.978822865355845, + -50.74893558123603, + -50.520451718492126, + -50.286235222182874, + -50.04404205777941, + -49.79841125314252, + -49.55133029598883, + -49.301142538041006, + -49.05171706351374, + -48.80262570546493, + -48.55795012537514, + -48.320503554169306, + -48.08122853844462, + -47.83923244013082, + -47.597854027344184, + -47.36296859969274, + -47.119848164271716, + -46.87179152215019, + -46.62784398461014, + -46.3865185398944, + -46.13758662508226, + -45.882506913206186, + -45.638350768579734, + -45.40202584931485, + -45.15286767081181, + -44.910602981557574, + -44.68683528605533, + -44.44814233486934, + -44.197869411745884, + -43.958137197320205, + -43.7175201928318, + -43.4659316411832, + -43.20799095845028, + -42.958669988965056, + -42.714975562290384, + -42.45226502704055, + -42.18397429266377, + -41.92212531907494, + -41.64374813716246, + -41.35245730080967, + -41.06432149035486, + -40.761467928484095, + -40.454594322204976, + -40.150273450651085, + -39.8354963456903, + -39.51103483054273, + -39.18926132755797, + -38.859115427458796, + -38.52251769893648, + -38.19902540483133, + -37.86218669225267, + -37.50934723842243, + -37.16581799944193, + -36.82223660091939, + -36.47303888228664, + -36.123074121166894, + -35.782204882836574, + -35.442409260937794, + -35.102056320517846, + -34.78629987780564, + -34.471883439756084, + -34.14746347121026, + -33.84605388489217, + -33.553708101905876, + -33.250563454522336, + -32.96460424785478, + -32.68641759692767, + -32.40952811630079, + -32.13647907841946, + -31.875496718347694, + -31.62040906978592, + -31.3692731568516, + -31.128653885588243, + -30.895484754384583, + -30.673021061602277, + -30.4564499076067, + -30.243102858337505, + -30.037110492090004, + -29.840072342414665, + -29.646596192747868, + -29.45520323757677, + -29.27322605585274, + -29.092874082076392, + -28.911547352606625, + -28.736870940234013, + -28.569567648391725, + -28.404831342714697, + -28.246515067547016, + -28.097330488214077, + -27.946981831102516, + -27.797469042688693, + -27.65824620050871, + -27.518667348820248, + -27.377270208336014, + -27.243466225191476, + -27.108572634055996, + -26.970365830634698, + -26.836377997977227, + -26.70039720608668, + -26.56266711343642, + -26.426854755029865, + -26.290415675286066, + -26.149065128213227, + -26.008000462419282, + -25.868400980166317, + -25.72419946234665, + -25.583747645422825, + -25.44457752090553, + -25.304002733636057, + -25.16402169134971, + -25.024850126056847, + -24.886710853992444, + -24.75077367734998, + -24.612215773194514, + -24.475219884022657, + -24.33902673184721, + -24.20212450543729, + -24.068034805090296, + -23.934360297636253, + -23.80302587282976, + -23.671934351942323, + -23.541636228121696, + -23.41150044334196, + -23.28447644210979, + -23.148561389650993, + -23.025401165750644, + -22.90954852775124, + -22.796729788349747, + -22.684810104200672, + -22.57605822622979, + -22.46792378791558, + -22.361120401782223, + -22.261551301205262, + -22.16353455580876, + -22.064981855155263, + -21.966975526131932, + -21.87150168248146, + -21.772723029819876, + -21.66942948268381, + -21.57520608187572, + -21.484674383312882, + -21.39262706670191, + -21.30338594612345, + -21.220403163730513, + -21.136668215983228, + -21.050642316748583, + -20.971266422092764, + -20.893704719430318, + -20.819095103855986, + -20.744356394558064, + -20.668110426183997, + -20.593575790907362, + -20.51913769651664, + -20.44398529835246, + -20.37019633369734, + -20.298325499929867, + -20.222271214700026, + -20.14286692746316, + -20.068865753641326, + -19.994722344735, + -19.921408639803634, + -19.848279216226192, + -19.77733712660915, + -19.708843692405654, + -19.639876490727172, + -19.571363877498232, + -19.500457183976046, + -19.42755711057401, + -19.353821194607914, + -19.277222947707447, + -19.198696655644017, + -19.117528511903302, + -19.032573119760347, + -18.948416948005327, + -18.8617110322628, + -18.76787339255267, + -18.671976260282626, + -18.572269788069963, + -18.465027304774082, + -18.35302653550974, + -18.23322606697562, + -18.109198924002346, + -17.97901877700008, + -17.8424357393355, + -17.699650572565574, + -17.546991433374593, + -17.388572184205284, + -17.22651462273087, + -17.05426537516221, + -16.875136298756033, + -16.696253347316734, + -16.505823524083787, + -16.312044657667926, + -16.121715044231273, + -15.918228961012252, + -15.713662817862469, + -15.520349065238337, + -15.314117388062634, + -15.101981873349454, + -14.901015296769765, + -14.692722931406239, + -14.47507862921057, + -14.264177112491339, + -14.051834892143809, + -13.828488818514472, + -13.610171729805234, + -13.392998744052473, + -13.14431326671235, + -12.919627335562577, + -12.694532034725631, + -12.46382295349791, + -12.231529775858943, + -12.002729134096949, + -11.770840032936897, + -11.545372404882938, + -11.319961015720121, + -11.094155532004333, + -10.870343496667056, + -10.649476851254931, + -10.432298095911957, + -10.216307836805138, + -10.001517123035553, + -9.790703986991717, + -9.580788044014426, + -9.370481683762973, + -9.16321502127577, + -8.961535107073983, + -8.759009779363042, + -8.557318526251073, + -8.355886338897776, + -8.153905396308813, + -7.951682623975034, + -7.75027452716869, + -7.5512766894352366, + -7.350702686380873, + -7.154769967425929, + -6.960011268123255, + -6.766274518296466, + -6.572764187927147, + -6.37801869391982, + -6.186990444418874, + -5.9952547142450285, + -5.8024834919889425, + -5.610113250581571, + -5.421421402540372, + -5.230705688924124, + -5.046683022305792, + -4.872716593779623, + -4.6951301980288855, + -4.511179217587073, + -4.332158443315197, + -4.158427781620778, + -3.9781560199225248, + -3.804624532867186, + -3.6385278102201566, + -3.4657404385840342, + -3.2985543996196376, + -3.143760131127154, + -2.990897550279068, + -2.853284988059845, + -2.728124024093954, + -2.6083346915128622, + -2.4995781005502993, + -2.401952119700064, + -2.3136789494691223, + -2.238971619951746, + -2.178121549782583, + -2.1308086498294294, + -2.0989279496834654, + -2.0811384019076833, + -2.0791462567281584, + -2.0941115217407553, + -2.127357150261988, + -2.1812564930478997, + -2.255843138179999, + -2.351285321542093, + -2.470949885461389, + -2.619418061092849, + -2.7877105289187534, + -2.975052188900128, + -3.186661968551388, + -3.4111859653665855, + -3.653318850390739, + -3.9126615935363556, + -4.1800908551126845, + -4.467339948394169, + -4.76213772483465, + -5.057079131423051, + -5.366431454037695, + -5.686629453715802, + -6.000219054441521, + -6.315818757227649, + -6.636045568944125, + -6.948817367054708, + -7.261130271977808, + -7.577241231467479, + -7.888166201756551, + -8.197871795054352, + -8.50808823068634, + -8.812478521149504, + -9.115105767042023, + -9.414291744215312, + -9.711444317474582, + -10.007502270019703, + -10.299487170473892, + -10.592538344186792, + -10.891153498312034, + -11.187520981503875, + -11.482022983094124, + -11.780385466250532, + -12.081554951147046, + -12.381652289882577, + -12.682140856885521, + -12.990862659199513, + -13.299742955942214, + -13.611617440190054, + -13.934305403208066, + -14.256927111462618, + -14.586113170885733, + -14.917746982119734, + -15.251391235544043, + -15.594129593187766, + -15.940412518175888, + -16.28431813443251, + -16.642108913478406, + -16.995968507434565, + -17.341447380945205, + -17.68624308935871, + -18.046967535864507, + -18.40739677468867, + -18.75423808809474, + -19.122883537464627, + -19.494562086443743, + -19.850467709603112, + -20.217375693444236, + -20.600471912205226, + -20.964774255638137, + -21.334574125398436, + -21.71635453927342, + -22.08321685900946, + -22.451794632413858, + -22.830192417191082, + -23.197398613331025, + -23.56588574348082, + -23.9411828469809, + -24.317713222784892, + -24.691862653126957, + -25.067342381426663, + -25.446518281020253, + -25.823575223588996, + -26.2028494895185, + -26.583118605135837, + -26.966954059517665, + -27.35077087822334, + -27.734676910582905, + -28.118658584974153, + -28.49849501805635, + -28.87466346669121, + -29.25502327108787, + -29.62282677381422, + -29.98810194848637, + -30.36409863788679, + -30.729381056802502, + -31.093416129346256, + -31.462290431553313, + -31.82712487808434, + -32.183114856408764, + -32.54546634304612, + -32.91461519520696, + -33.27211085075718, + -33.634882201375724, + -34.00379296069003, + -34.365538004602506, + -34.729305012827695, + -35.09053582815073, + -35.442051401313115, + -35.80147397332604, + -36.156478428285936, + -36.509943348927166, + -36.87540585234159, + -37.23901072209325, + -37.60450716369108, + -37.98091776311527, + -38.35826093083761, + -38.72586756756725, + -39.099598938468134, + -39.47442917873935, + -39.83691641541903, + -40.204012329161955, + -40.574518240633275, + -40.93284919279329, + -41.286376866942405, + -41.64670181348839, + -42.002212553600565, + -42.3507670717366, + -42.710882038521, + -43.06959132588439, + -43.421435885960044, + -43.7807719502002, + -44.13042410569125, + -44.47974523087207, + -44.83083555170508, + -45.17975029971333, + -45.526137563452274, + -45.875589050399924, + -46.22206441573103, + -46.569064381392046, + -46.91316488006124, + -47.252871500956985, + -47.59460254898223, + -47.93472843239765, + -48.27245989973975, + -48.61316785031234, + -48.95780558244229, + -49.30319506797552, + -49.64442423993319, + -49.98381841805986, + -50.32604717784858, + -50.66631640577856, + -50.99451219306471, + -51.32925243718483, + -51.66956924331703, + -51.998736515246826, + -52.3287767677094, + -52.667858253196044, + -53.01000201520464, + -53.36753357541763, + -53.73099880724471, + -54.095482201916454, + -54.451681350020955, + -54.80973824207325, + -55.17013669730103, + -55.524106384949896, + -55.8756427780986, + -56.22684818609178, + -56.568173790368725, + -56.89164961503797, + -57.210054830820965, + -57.527142887254534, + -57.84681566068695, + -58.1570298021242, + -58.463839291949, + -58.77464051965263, + -59.080840617557335, + -59.379611364919995, + -59.676143211914265, + -59.95963755435715, + -60.239955016541984, + -60.52007216423825, + -60.80366715994708, + -61.08194029433051, + -61.35281209135889, + -61.61752644560808, + -61.88090988763418, + -62.17169681900389, + -62.41810160592986, + -62.67171699605905, + -62.92700286480446, + -63.171093429782275, + -63.40996922133053, + -63.648671884466815, + -63.881023424434204, + -64.11161156677625, + -64.33274081599845, + -64.5484139784681, + -64.76538019545904, + -64.9762134188021, + -65.20740470348781, + -65.40373475442904, + -65.59756447844596, + -65.80873944884215, + -65.99163835726283, + -66.1642621214959, + -66.32658432151234, + -66.48021806007075, + -66.62097019240412, + -66.74939135202578, + -66.87171558534783, + -66.97876032382108, + -67.07211215897057, + -67.15447761613395, + -67.22104224477765, + -67.2741529478326, + -67.31308135995369, + -67.33853409042516, + -67.35370455384614, + -67.35460539923463, + -67.33965814033566, + -67.30902899364848, + -67.26404469862514, + -67.2078001782688, + -67.13824374251357, + -67.05358555831275, + -66.95710528627973, + -66.84693482330073, + -66.72309386440044, + -66.58992031612112, + -66.44468496860021, + -66.29607605628406, + -66.14247628542849, + -65.98439382705837, + -65.82077959233854, + -65.65819730422088, + -65.49709534143815, + -65.32840960359943, + -65.16012022721834, + -64.99283673717963, + -64.81902387014065, + -64.64236745794976, + -64.46675564704715, + -64.28902454181636, + -64.10627233332251, + -63.9245189968719, + -63.739895374679506, + -63.552789948909144, + -63.36255950167904, + -63.171640173338844, + -62.977289859578626, + -62.78089205260735, + -62.58559801107137, + -62.385988271432005, + -62.18664524096537, + -61.98937501139463, + -61.78934089875508, + -61.585390276013854, + -61.38256301127122, + -61.174410387424665, + -60.961679105118876, + -60.74845623489055, + -60.533814350995335, + -60.31690206841151, + -60.10096530121553, + -59.882987962738184, + -59.66207909539635, + -59.44231655656435, + -59.21722167022417, + -58.98802951173427, + -58.76205771207255, + -58.53723530877152, + -58.30593089017669, + -58.08199102346705, + -57.86537305766054, + -57.63728994130255, + -57.417093154127556, + -57.19783982644442, + -56.96809559375092, + -56.745401588101466, + -56.527400212898556, + -56.302759666354554, + -56.08121448281064, + -55.86963629564599, + -55.65688287615789, + -55.43401671720818, + -55.21534072806675, + -54.99965622634113, + -54.77522794277379, + -54.550565532425885, + -54.33252560499146, + -54.10737652938803, + -53.87845931182021, + -53.65398134783925, + -53.423887931321886, + -53.18849279457918, + -52.96138067614139, + -52.73633495991684, + -52.511767140085105, + -52.293945992345314, + -52.07486610261543, + -51.86308025501084, + -51.64710141416969, + -51.4285046739303, + -51.207936041427956, + -50.98112666724028, + -50.75759524732726, + -50.530524647357275, + -50.29611717891968, + -50.05853979565974, + -49.82148790953997, + -49.579964240667685, + -49.337494641666325, + -49.099491975439534, + -48.86009979050132, + -48.63134517445355, + -48.40446603377749, + -48.15076643023613, + -47.91352750034972, + -47.68602777604388, + -47.4606470698996, + -47.22363976674354, + -46.98960292127097, + -46.75719495268253, + -46.526941675163066, + -46.28666461660954, + -46.049829381604496, + -45.818999186174935, + -45.590109752221245, + -45.34974987102058, + -45.11066667581404, + -44.882333624532365, + -44.64177533187349, + -44.38690967518651, + -44.13745327703394, + -43.89167800014418, + -43.6346765855797, + -43.37106746872596, + -43.120170567487335, + -42.87514780108498, + -42.61189937175488, + -42.3444472783157, + -42.08199659401624, + -41.80345028441472, + -41.514048526503196, + -41.22651686968583, + -40.92555981169007, + -40.620679684726795, + -40.31489772743544, + -39.99940914928018, + -39.672597049306184, + -39.34823267667472, + -39.01389573865975, + -38.67111447699712, + -38.3417180772891, + -37.99869991142855, + -37.64066093269598, + -37.29183094018727, + -36.944293079592875, + -36.59102497377393, + -36.240332036105386, + -35.903320575532284, + -35.56515687007894, + -35.23029923900151, + -34.92269420854079, + -34.614763213636685, + -34.29892708616655, + -34.01210940746176, + -33.72535284716105, + -33.43346597443473, + -33.16069412386568, + -32.893568944678464, + -32.618449071051465, + -32.351954603625344, + -32.09740388948943, + -31.843371052113444, + -31.591763948974286, + -31.34816043213504, + -31.111338409549965, + -30.884495494391107, + -30.660592507642257, + -30.43986832898685, + -30.232311921187634, + -30.034832536263558, + -29.836882422200684, + -29.647805933322292, + -29.464161681522267, + -29.278335121264682, + -29.09703501070668, + -28.92426903504328, + -28.752792917301697, + -28.57969357253216, + -28.419568014755612, + -28.264885833516104, + -28.105049263851885, + -27.953223004070953, + -27.809238189289864, + -27.658637193310675, + -27.51092928949594, + -27.368130468550394, + -27.22228630232354, + -27.078306792113878, + -26.93666850605959, + -26.793280175586176, + -26.647893986261852, + -26.506868769852076, + -26.362206492738075, + -26.213789442975205, + -26.065900168254448, + -25.917723340433536, + -25.76644333607956, + -25.617643956734327, + -25.469910724251086, + -25.32087796709313, + -25.175205130333048, + -25.02923169482956, + -24.88858201198964, + -24.744508251223348, + -24.599890378823794, + -24.45614000552284, + -24.31174841617279, + -24.169613655949668, + -24.025554899223497, + -23.88392619934849, + -23.744495502265423, + -23.603328354286806, + -23.46321450382037, + -23.326411948354405, + -23.195045748240727, + -23.06177641883814, + -22.93121302256855, + -22.80615728391779, + -22.68289856818865, + -22.56415621191781, + -22.447976935762057, + -22.333911497446024, + -22.22793228306158, + -22.124055858807413, + -22.01852245969138, + -21.91624251358265, + -21.819172107053053, + -21.71383450795761, + -21.60779321331262, + -21.51057160389064, + -21.413535363175278, + -21.313846004326333, + -21.221748340988928, + -21.132012972374937, + -21.034140729436047, + -20.940020427879332, + -20.84880348241341, + -20.757562254644483, + -20.666031727970818, + -20.574703883397184, + -20.482314164119668, + -20.38841150430447, + -20.29335029544132, + -20.200865709932508, + -20.11118814062886, + -20.017291769726018, + -19.925399536618944, + -19.836184497930848, + -19.74708590869075, + -19.655944044723206, + -19.569553194373647, + -19.48507098397482, + -19.398711486551836, + -19.311303325113908, + -19.218866146453074, + -19.122777395645734, + -19.023391537244862, + -18.921415244909824, + -18.815847394177677, + -18.706022277471703, + -18.59663693758805, + -18.481433923823037, + -18.360150135271645, + -18.23514641089734, + -18.11650148987277, + -17.998273267605136, + -17.873623826276784, + -17.744163833825844, + -17.611702852220468, + -17.475956962954065, + -17.333844386239388, + -17.185815383871923, + -17.036862941336874, + -16.88543948550195, + -16.71508023090852, + -16.539657270708158, + -16.361528991185704, + -16.175309519380576, + -15.992975878334086, + -15.809859644432564, + -15.613581896340465, + -15.425261356390541, + -15.24047317195969, + -15.035840578274309, + -14.84650269721174, + -14.670204811131125, + -14.475004061424876, + -14.277627277337423, + -14.085006727846713, + -13.883696292373584, + -13.6760943845444, + -13.475113325042148, + -13.273863511387828, + -13.06761417066516, + -12.85571265687709, + -12.642648543091907, + -12.426046847370069, + -12.206646246010985, + -11.989051906594568, + -11.767208718892018, + -11.548567712995542, + -11.327059003885726, + -11.108289901515315, + -10.89483712861106, + -10.651596952608092, + -10.402437268413932, + -10.154654641580422, + -9.903982397230568, + -9.655577505284743, + -9.410451765981142, + -9.167145777093761, + -8.922410420456462, + -8.677950678765317, + -8.442878178667304, + -8.224549091238233, + -8.012244108645312, + -7.803377836943842, + -7.5992916658262235, + -7.3905951917552155, + -7.180661574635478, + -6.978917859540822, + -6.775344940328747, + -6.5707139938575985, + -6.369800614404133, + -6.16294179079084, + -5.955294821820569, + -5.7494593187793015, + -5.53920160205076, + -5.327561320334729, + -5.127665088135819, + -4.923495928887167, + -4.716923102821366, + -4.522760306960668, + -4.324679740492083, + -4.1349298904023835, + -3.9574413344576356, + -3.779892597270801, + -3.603220899922884, + -3.435553303622606, + -3.272291094171636, + -3.1088952667081533, + -2.9554104392099503, + -2.8056806730677, + -2.661758505747535, + -2.528416841861101, + -2.397186683423948, + -2.27517888973646, + -2.164584053072101, + -2.0615658461716873, + -1.9662971884350573, + -1.880455250454103, + -1.8062125159846574, + -1.7416060408752518, + -1.6883575246470086, + -1.6464221118954492, + -1.6161019349254333, + -1.598952334573434, + -1.5962239149252784, + -1.6088778091433145, + -1.6365442606316238, + -1.6817466445111175, + -1.7471709239700712, + -1.8304379054912938, + -1.9293634624295373, + -2.0557765917490776, + -2.2106226245532574, + -2.379448232045034, + -2.5642431789587876, + -2.770485432445319, + -2.984658334705341, + -3.2125196137460006, + -3.4560509928231054, + -3.7074353443265613, + -3.9715840351858427, + -4.261318533939599, + -4.524814887744411, + -4.801974158007338, + -5.086049453114505, + -5.3717500378122764, + -5.661918145237242, + -5.951433251046758, + -6.239212786650746, + -6.528322701353272, + -6.816435192112626, + -7.093688572207026, + -7.380497893714439, + -7.669322202948817, + -7.957540775263811, + -8.240558247756539, + -8.526147085033289, + -8.8094926928902, + -9.088670509354857, + -9.370651659488038, + -9.64811978897924, + -9.92364255285066, + -10.204372674559558, + -10.482359883396267, + -10.759157124840334, + -11.041214801482244, + -11.323874997705309, + -11.607346828275382, + -11.892572715090681, + -12.179160979136796, + -12.47143522716537, + -12.764556620481914, + -13.057292124639014, + -13.353928227169805, + -13.654242722158928, + -13.957140782146217, + -14.264127794442901, + -14.573264478042214, + -14.887181306730454, + -15.20907860518733, + -15.526499329387255, + -15.849211824472095, + -16.1837218211788, + -16.51242832777449, + -16.844908402576852, + -17.17604128776223, + -17.52185730380651, + -17.86363848644997, + -18.195146483620313, + -18.546750146028806, + -18.89652427529972, + -19.233422785718236, + -19.58001510150497, + -19.93962503028559, + -20.288186067872886, + -20.63609329753014, + -20.995772633526173, + -21.348988875525382, + -21.69659190603601, + -22.05360912650614, + -22.407707415262625, + -22.75424121571816, + -23.108273307268206, + -23.456800500960327, + -23.81209383041351, + -24.16824531704238, + -24.52007929669434, + -24.875864462154343, + -25.234450167280023, + -25.59249566257504, + -25.94840914936964, + -26.312619561636854, + -26.67225056181053, + -27.03042670043228, + -27.393002228059295, + -27.750145402605973, + -28.056119222177905, + -28.406556474352666, + -28.755716255364483, + -29.09454159159732, + -29.43946894773987, + -29.7860694420215, + -30.12272349580964, + -30.47029680231665, + -30.813465404750218, + -31.14857999508093, + -31.489246573379937, + -31.83343464508342, + -32.180106428740125, + -32.5181274336786, + -32.859070374984235, + -33.2098595497299, + -33.553603501315536, + -33.89825888738154, + -34.237887134155855, + -34.56665357821295, + -34.899538291821145, + -35.22762919113327, + -35.55176383432408, + -35.87623956047461, + -36.201292388560894, + -36.5207841866333, + -36.84531761093447, + -37.165586545365954, + -37.475087034752455, + -37.781226759500434, + -38.07977000686565, + -38.36833329366603, + -38.64527235782367, + -38.95212376368915, + -39.22191254840864, + -39.476727235435135, + -39.72943703874776, + -39.978926081087, + -40.21587167988429, + -40.44927422439422, + -40.67779494125739, + -40.89732908766, + -41.10698439454845, + -41.31169356236648, + -41.50934421279334, + -41.70070147597102, + -41.885089193801406, + -42.0623696183549, + -42.23242790307994, + -42.3975268493152, + -42.554139770735866, + -42.706005443434805, + -42.85119877794946, + -42.99326441183154, + -43.13043244029614, + -43.26536152526621, + -43.396386244919455, + -43.525692046003925, + -43.645579472730375, + -43.76058152027081, + -43.87314104004291, + -43.97967154488688, + -44.08086289472793, + -44.14137311119168, + -44.23687926633128, + -44.32673352167267, + -44.411268496572475, + -44.488690502888126, + -44.563273244780454, + -44.63163749042488, + -44.69691641427846, + -44.76166908041577, + -44.82416156640577, + -44.88329338084764, + -44.938867626121535, + -44.991684912561816, + -45.04115679053738, + -45.084374913991404, + -45.12328908023327, + -45.158615138261865, + -45.18908164182476, + -45.21715773717338, + -45.241843955700894, + -45.26225978074167, + -45.2812397669999, + -45.29912954405049, + -45.31495342648793, + -45.32909914608487, + -45.34129563765416, + -45.35123159378307, + -45.36002550833295, + -45.367429663092, + -45.373450790889294, + -45.37892491568039, + -45.38348451296948, + -45.38709576122565, + -45.38987259486149, + -45.39201223749755, + -45.39378837231216, + -45.39522130842233, + -45.39638158814507, + -45.39745662734854, + -45.39852915472637, + -45.39959623986342, + -45.40060689908244, + -45.4009172382536, + -45.40117706279354, + -45.40140460923213, + -45.40163051313254, + -45.4018254182305, + -45.40199439764787, + -45.40215732635067, + -45.40230463500697, + -45.402431118078326, + -45.40253451824544, + -45.40347533642653, + -45.4045009196911, + -45.40552077161054, + -45.406543967452464, + -45.40756378700323, + -45.40413971615596, + -45.405139804092926, + -45.40612640666854, + -45.4071109824521, + -45.40809951205476, + -45.40833313628159, + -45.40847880006112, + -45.40862585447327, + -45.40876694322939, + -45.40890264220816, + -45.40903780334378, + -45.409165907549266, + -45.40928989672995, + -45.40941046317816, + -45.40952603744997, + -45.409900894210814, + -45.41028605882007, + -45.41065411389173, + -45.41100188576985, + -45.41133135166681, + -45.4116413057063, + -45.411932882404905, + -45.41220387265481, + -45.41245398582146, + -45.41268703353797, + -45.412762329630134, + -45.41281787658448, + -45.412867679042265, + -45.41291217928194, + -45.41295127054923, + -45.41298450778652, + -45.41301281485027, + -45.4130352475178, + -45.41305221925071, + -45.41306374612866, + -45.41309799695799, + -45.413099614007116, + -45.41306145129848, + -45.41298698123988, + -45.41287818974583, + -45.41273854380772, + -45.4125637350658, + -45.41235486001969, + -45.412116954526766, + -45.411841286248574, + -45.41177303123894, + -45.411730072958946, + -45.4116850162911, + -45.411637649842746, + -45.41158623264187, + -45.4115310064707, + -45.41147140268202, + -45.411400723873406, + -45.41133185972873, + -45.41126062736721, + -45.41124040327362, + -45.41122029569904, + -45.411202513724646, + -45.41118365673517, + -45.41116019455542, + -45.411136594624296, + -45.411114116089095, + -45.41108971761326 + ], + "xaxis": "x", + "y": [ + 45.30065635922624, + 45.30124205437229, + 45.30193400755794, + 45.30265046529222, + 45.303148779253284, + 45.30355241309474, + 45.30436763201242, + 45.30509261653337, + 45.30601404203569, + 45.306807104326325, + 45.30755378323268, + 45.3084595882965, + 45.30939829386062, + 45.310166602457215, + 45.31092652284556, + 45.31184263146779, + 45.31269093838768, + 45.3135576272191, + 45.31414556023766, + 45.31474025871411, + 45.315360652914144, + 45.31605557841787, + 45.31664194012994, + 45.31730697535071, + 45.317890613009425, + 45.31851538132313, + 45.319104667384806, + 45.319609469309746, + 45.32007727141195, + 45.320588366136754, + 45.32098028619176, + 45.32137490542775, + 45.3218027568607, + 45.322205865337594, + 45.322700240070716, + 45.323155095223285, + 45.32354950746685, + 45.32388409129931, + 45.32428643617517, + 45.324720620555965, + 45.325017845529025, + 45.32541910711946, + 45.32594672200731, + 45.32622930094886, + 45.32648902428704, + 45.32678770395347, + 45.3271690999394, + 45.327448012220245, + 45.32774646010955, + 45.32811341332325, + 45.32845006239287, + 45.328854892378736, + 45.32925256389109, + 45.32956418507035, + 45.3298872080439, + 45.330206948208186, + 45.33058368137483, + 45.330898532790556, + 45.33111427854165, + 45.33145221325836, + 45.331679937477944, + 45.331990641078995, + 45.33232540056891, + 45.33257673594378, + 45.33284641602727, + 45.33315718987583, + 45.33344664049978, + 45.3341012148372, + 45.33498707332209, + 45.33579373852934, + 45.33656718067833, + 45.337500908131304, + 45.33850912961951, + 45.339404503757684, + 45.34025240597607, + 45.34119504699709, + 45.34225198723166, + 45.342975385452, + 45.34373765087828, + 45.34457141712823, + 45.34520671973865, + 45.3459652914809, + 45.3468944913872, + 45.34763544710913, + 45.34834020280647, + 45.34927992179219, + 45.350152851503346, + 45.35111122525774, + 45.35222320834729, + 45.353302302900545, + 45.35436280634842, + 45.355351858982594, + 45.35636600921168, + 45.357455718978656, + 45.35846858023884, + 45.359484076225094, + 45.36044248198346, + 45.36136159825707, + 45.36201903786664, + 45.362607731819615, + 45.36338685277333, + 45.3641318443685, + 45.36464971564687, + 45.365223329150176, + 45.365707271298, + 45.36644051487314, + 45.36727776004473, + 45.36789140847082, + 45.36862781694657, + 45.36935732697808, + 45.37007171229736, + 45.37093491702046, + 45.37173675951087, + 45.37239239138229, + 45.37311424419294, + 45.373893642577876, + 45.3747371400346, + 45.37557952956787, + 45.37617997045989, + 45.37676707820621, + 45.37752752882974, + 45.378324959843155, + 45.379006673860545, + 45.379605142930025, + 45.38034676039454, + 45.381162078651684, + 45.38194037878425, + 45.382527586665006, + 45.38295425868896, + 45.383675463602145, + 45.384443316479214, + 45.38494781921849, + 45.385568689166135, + 45.38640871629172, + 45.38693699116032, + 45.38780653396654, + 45.38835553836785, + 45.38885891735504, + 45.389154530847094, + 45.3898332667208, + 45.39055161538907, + 45.39088370020398, + 45.39142454130535, + 45.39206677437021, + 45.39272727527996, + 45.39351633320761, + 45.39419795561232, + 45.39471870690721, + 45.395467328939716, + 45.3965562787684, + 45.397069915703106, + 45.397574952630535, + 45.398274735105915, + 45.39902863728136, + 45.39967158075588, + 45.40013911515351, + 45.40050842848147, + 45.400799460712484, + 45.40146284138777, + 45.402032107932655, + 45.402394761819686, + 45.40278992019668, + 45.403446170751124, + 45.4042712768712, + 45.40541550236088, + 45.40594147911467, + 45.40645103491125, + 45.40736818440101, + 45.40877140195812, + 45.41076173057873, + 45.41213695009949, + 45.413260271751, + 45.41488591444327, + 45.41659224102789, + 45.41814194246695, + 45.41967431722273, + 45.42122421702902, + 45.42268026571465, + 45.42412747733741, + 45.4256593487682, + 45.427073589205904, + 45.42846906335461, + 45.42995491671941, + 45.431508868967576, + 45.43294949070756, + 45.43449795571096, + 45.436045896077744, + 45.43728591240688, + 45.43830163804861, + 45.439456093556636, + 45.44079735860992, + 45.44215224015999, + 45.44343120932357, + 45.444375003035816, + 45.445473560628116, + 45.44677225297314, + 45.44797609993086, + 45.44904543456109, + 45.44984120341096, + 45.45083794348416, + 45.452172561035646, + 45.45301431712992, + 45.45396179300497, + 45.45509309570998, + 45.45656097111738, + 45.458946085778955, + 45.460441648941014, + 45.462685237721104, + 45.46475459325952, + 45.466758441210935, + 45.472838358761976, + 45.484788098251514, + 45.50565315221926, + 45.53818408883597, + 45.5837903752552, + 45.64111097475223, + 45.704187093594705, + 45.76522474138874, + 45.815936438604616, + 45.856251405721544, + 45.88558165611295, + 45.907505296726434, + 45.923163721268345, + 45.93400337920274, + 45.94351507725153, + 45.955115010582595, + 45.972050141859455, + 45.999918103421045, + 46.03983971828734, + 46.09248763925718, + 46.158414662017144, + 46.23703736619839, + 46.327072657177446, + 46.42859592850699, + 46.541639647834934, + 46.667231187151486, + 46.8061285382561, + 46.957309432679715, + 47.118302808644636, + 47.28683927792054, + 47.463314218233265, + 47.64834611054478, + 47.83566798625701, + 48.023172248939865, + 48.20956769266305, + 48.39662784137522, + 48.57955770996967, + 48.75555971609832, + 48.92948050858245, + 49.093625856117846, + 49.24649494120659, + 49.393815701230196, + 49.52991360445883, + 49.65438721604678, + 49.76523014549145, + 49.8657542055838, + 49.95354106003255, + 50.02452950595315, + 50.081488363030104, + 50.125347509952476, + 50.153806861592756, + 50.16461030378974, + 50.15624106610237, + 50.12911427774848, + 50.083836488523914, + 50.02017330276769, + 49.93702726391105, + 49.836609476753964, + 49.723397604679526, + 49.58920810115112, + 49.43606536365915, + 49.28353242203603, + 49.10416673639415, + 48.90626942969818, + 48.70003420605666, + 48.478655587866776, + 48.25004021935163, + 48.00576350908627, + 47.74332031709656, + 47.48982967303377, + 47.22313318201167, + 46.93662462568459, + 46.659443183046186, + 46.37493857465736, + 46.07180596975934, + 45.76914529901862, + 45.464300795957406, + 45.15069647264574, + 44.82778404522427, + 44.504214552495256, + 44.175672589939694, + 43.84039272948602, + 43.50903893858218, + 43.174720354105986, + 42.826041969874936, + 42.48349316746876, + 42.139016908716485, + 41.786461125674016, + 41.41556209947847, + 41.064403391404404, + 40.70602954504817, + 40.32807491275439, + 39.959394441926364, + 39.59687493350447, + 39.2248773385048, + 38.84348071036869, + 38.48260607577791, + 38.11372693814414, + 37.73017677472638, + 37.359222962492275, + 36.99840874913466, + 36.624498682198684, + 36.256436845624044, + 35.89994905495483, + 35.532620226739965, + 35.167989300368006, + 34.80099922747861, + 34.44037918294277, + 34.07418353586902, + 33.70399761498599, + 33.34426143382196, + 32.98334220062625, + 32.62887100293793, + 32.2600001708878, + 31.91429267155726, + 31.567862840036128, + 31.20477116469836, + 30.852135211188124, + 30.5041019515858, + 30.15437616855528, + 29.795493418235086, + 29.439773034543396, + 29.092867442925655, + 28.74074003409194, + 28.389689200180523, + 28.04108618783154, + 27.694376883639453, + 27.343390182357584, + 26.99807895619347, + 26.654587155780035, + 26.298581014537678, + 25.96005462606094, + 25.62522181335514, + 25.276650295123694, + 24.933256141845575, + 24.600249322179046, + 24.25324649044085, + 23.912317605463883, + 23.578091177360236, + 23.238691252751593, + 22.913128686706337, + 22.587558377602974, + 22.260670840822804, + 21.945386829709058, + 21.632131528260782, + 21.309847674618073, + 21.00083876508177, + 20.692869915956067, + 20.374059350389338, + 20.069267127553385, + 19.76556702024636, + 19.45555588051149, + 19.152908458260175, + 18.853670810727607, + 18.54332865279199, + 18.233151029185514, + 17.931059701218214, + 17.621971149590077, + 17.314731946458686, + 17.01364025257034, + 16.714380610703095, + 16.411528787587354, + 16.109089073258097, + 15.805912897581694, + 15.502517862735598, + 15.194229158304234, + 14.887743611100595, + 14.580535534255668, + 14.267081270570092, + 13.957935779247167, + 13.645837748807345, + 13.329297699110638, + 13.017084780374136, + 12.708165383851261, + 12.391605224475741, + 12.076588495210615, + 11.771657252356125, + 11.428856918233718, + 11.115085510666857, + 10.811547455719008, + 10.508092171074775, + 10.202100623559728, + 9.902107827661291, + 9.606825690276896, + 9.310224501023288, + 9.016231326587475, + 8.724466119535888, + 8.429646961681588, + 8.141380509861708, + 7.847680658910126, + 7.548387398439552, + 7.2557844820569155, + 6.958131601985776, + 6.648654056656934, + 6.342844172003127, + 6.0429672685303055, + 5.724886961082017, + 5.414892843547536, + 5.1155752002257335, + 4.793823666094012, + 4.475778975369742, + 4.167811483842318, + 3.8436931817113065, + 3.5206956344703415, + 3.205702010388127, + 2.88152243930825, + 2.570958053977248, + 2.260990016202482, + 1.9440542354900103, + 1.6457803016981958, + 1.344874972001644, + 1.0359357387089771, + 0.7431647760945245, + 0.4470713207390033, + 0.14182208713334643, + -0.14927943330410953, + -0.44086000184198426, + -0.7319241037042714, + -1.01638476911375, + -1.293997171603349, + -1.5693275790496173, + -1.815731952007402, + -2.0444292003357356, + -2.274251979854682, + -2.481389733636101, + -2.6646039940176514, + -2.838972470329237, + -2.9884123218117042, + -3.11023864851491, + -3.21472992012814, + -3.2974032758040073, + -3.355915431645531, + -3.3913835382922035, + -3.4048123854802816, + -3.39803858859178, + -3.3710280857549937, + -3.325108369200027, + -3.2611074439519006, + -3.1802804961620454, + -3.084287327127412, + -2.974114920253723, + -2.8483406607594324, + -2.7114471000777796, + -2.5612456990820673, + -2.3972975979451885, + -2.2240683506681793, + -2.0399820767784145, + -1.8454583819492636, + -1.6462925311363519, + -1.4426765317417052, + -1.2355122652243018, + -1.0266833604456562, + -0.8203287147004672, + -0.6151670102511687, + -0.41186524360498994, + -0.21142850065476018, + -0.014788827337959793, + 0.17912239878214273, + 0.3695248181338274, + 0.5497528631733597, + 0.7202484735149064, + 0.8870487712695617, + 1.044087887633294, + 1.1923079643531156, + 1.3375689256931582, + 1.478910789980186, + 1.6021343359196878, + 1.7198753508456497, + 1.8343477615511954, + 1.935784009373079, + 2.034653048186079, + 2.1293593031519724, + 2.209958860284805, + 2.285664476839112, + 2.3577403536442203, + 2.422468884831358, + 2.4857085427293923, + 2.545101229466147, + 2.5939810802864383, + 2.63730522737304, + 2.6793949775599177, + 2.716254932991565, + 2.750329503332266, + 2.7847351056235126, + 2.8149824253505944, + 2.840858525848436, + 2.866832856374559, + 2.89238014379388, + 2.9184729701399044, + 2.9480580670122456, + 2.979556576918602, + 3.0146754054240814, + 3.054826423352543, + 3.096675035659461, + 3.139990556703625, + 3.184667815612314, + 3.2292327942258465, + 3.2743714193132347, + 3.32226300056172, + 3.371700814595118, + 3.424566973691934, + 3.4830876993069837, + 3.5485035527291973, + 3.619961418131344, + 3.6975441393227633, + 3.781726261926542, + 3.8735972866285193, + 3.9715187448803877, + 4.077521685436444, + 4.190621417091774, + 4.311281690122221, + 4.437588575041153, + 4.569883254801034, + 4.7074473641466374, + 4.8516013531248126, + 5.002459568735386, + 5.157973202353776, + 5.32030188422588, + 5.489787275327103, + 5.664176150901903, + 5.841915989421802, + 6.025743569696181, + 6.215099677021708, + 6.4056394329494495, + 6.599747195124951, + 6.796871578543679, + 7.000876578665662, + 7.210515601989129, + 7.414008863247476, + 7.619599845586016, + 7.833515395655514, + 8.04457660123133, + 8.25985916135352, + 8.481300158080627, + 8.703533225016418, + 8.922489518320095, + 9.142777567531969, + 9.366623691436404, + 9.591344715497799, + 9.819215722954493, + 10.057646146450304, + 10.29427445882235, + 10.526748962602603, + 10.7706706591195, + 11.015257917073688, + 11.253385326295108, + 11.506395584240968, + 11.758353643179158, + 12.005630872525707, + 12.267258033357656, + 12.530909801550804, + 12.786219345300823, + 13.057092854792527, + 13.32675583986419, + 13.587365823600374, + 13.862357481099695, + 14.134623496051223, + 14.394828503880833, + 14.66369232146545, + 14.921908316794884, + 15.18178455458598, + 15.44913444616254, + 15.710475691890215, + 15.966310648688792, + 16.220032134403954, + 16.464998920790475, + 16.71058892492495, + 16.964207762201045, + 17.209136022293695, + 17.461327162857966, + 17.716838938047527, + 17.960616386513582, + 18.210827985277778, + 18.46779453122991, + 18.71105572938212, + 18.97297690737294, + 19.241911854540295, + 19.500014477003784, + 19.767568083880867, + 20.039002069566862, + 20.306012026346778, + 20.564985494761704, + 20.83235000158009, + 21.09764463708101, + 21.35937942682989, + 21.6232834573419, + 21.876158423324846, + 22.128987388840248, + 22.384981603051383, + 22.63557889065227, + 22.886990330602675, + 23.1340420780325, + 23.381549899588997, + 23.63432836002035, + 23.884484257210527, + 24.141551206805165, + 24.410597473604646, + 24.668300827657163, + 24.941057687774595, + 25.235389680352345, + 25.52246640034228, + 25.820733265212404, + 26.146261319287532, + 26.459876076891092, + 26.758775995034405, + 27.083479341448516, + 27.416728140297526, + 27.737446902315366, + 28.074198688632993, + 28.424149427074962, + 28.76398193325348, + 29.105992458154173, + 29.455341264256788, + 29.804127060503248, + 30.15726674063583, + 30.5131652108559, + 30.873861402715676, + 31.24506793191564, + 31.61839907168172, + 31.987082870538966, + 32.36529463382383, + 32.74973439340146, + 33.129260166990576, + 33.507090698883566, + 33.89352452630285, + 34.274803994394475, + 34.62557077132737, + 34.97840526938305, + 35.3422523640149, + 35.67228967036853, + 36.00004151381821, + 36.351047157615085, + 36.67870026120099, + 36.99436114487999, + 37.3308200786596, + 37.65797535807802, + 37.98028943410984, + 38.319243502578786, + 38.65300220649891, + 38.980404389671115, + 39.31352333739406, + 39.6443216231381, + 39.9735661614373, + 40.30608505693576, + 40.64112651950941, + 40.97431678441473, + 41.3094050207714, + 41.639694499983264, + 41.970172151397996, + 42.298154479387556, + 42.63252777091632, + 42.961975135376704, + 43.29487399499332, + 43.63078203673385, + 43.960324113806855, + 44.280587363603054, + 44.61022584117488, + 44.91814656157118, + 45.21551535146278, + 45.52393224649588, + 45.816086558073636, + 46.105626265729946, + 46.40708217324153, + 46.695976393416366, + 46.980681696970024, + 47.302051147569, + 47.59138340375597, + 47.8763230782806, + 48.16904590905884, + 48.46962976416081, + 48.76150376781474, + 49.051398520258836, + 49.347994812593775, + 49.6389114935602, + 49.923916276300005, + 50.21596991028357, + 50.50779320964663, + 50.790358821381034, + 51.07948178153422, + 51.37078619788843, + 51.65129645766725, + 51.93142224600338, + 52.209249332542704, + 52.474303990615155, + 52.74191204645305, + 53.004554558054494, + 53.26258271656714, + 53.52397070358046, + 53.786443744982456, + 54.05305086160046, + 54.316936240388834, + 54.58382060666164, + 54.85971421690486, + 55.139833110999646, + 55.422915124787856, + 55.7068940366129, + 55.9926321361938, + 56.28063917988319, + 56.56640100698341, + 56.845524203622524, + 57.1221086914137, + 57.40251863786673, + 57.68043863229484, + 57.95409464108825, + 58.2316791416228, + 58.51290285072393, + 58.794513578866905, + 59.07724275831911, + 59.358543609297875, + 59.65469311588764, + 59.94482815433133, + 60.22598664665075, + 60.52269524291083, + 60.819740480228326, + 61.09034466522357, + 61.37502407232589, + 61.66738107404477, + 61.92984680065179, + 62.18397791345223, + 62.44359295123307, + 62.680022794538566, + 62.88649249020206, + 63.105048476809586, + 63.317588877884944, + 63.502937534715095, + 63.686564391128805, + 63.86294009574444, + 64.0186306609059, + 64.16673301835688, + 64.30334395151725, + 64.42663562399538, + 64.53651397526073, + 64.63161466900597, + 64.70816300520418, + 64.77281153676934, + 64.82311159398937, + 64.86002769221417, + 64.8892987824187, + 64.9063208822544, + 64.90763321466169, + 64.89550866877087, + 64.8720369381132, + 64.83709216836066, + 64.78896813125499, + 64.72646694884126, + 64.65180600879873, + 64.56676316067332, + 64.46485226581336, + 64.34873269231545, + 64.22535653966474, + 64.0865173517071, + 63.93315970541152, + 63.77116117350705, + 63.59061669908269, + 63.3980218639613, + 63.19137370472949, + 62.96576109058118, + 62.728687979886935, + 62.477231154300924, + 62.21413222611236, + 61.94292872182762, + 61.66377917267032, + 61.377245128534945, + 61.08228224831478, + 60.786882322751936, + 60.48741490343282, + 60.18441305967586, + 59.883376067533106, + 59.58182549475365, + 59.27927534524969, + 58.971913215094816, + 58.656278852646786, + 58.33845528784237, + 58.030759945024286, + 57.71592718520604, + 57.39280002548415, + 57.08655392095955, + 56.77130031276584, + 56.4490628529945, + 56.13098410273622, + 55.822499964849065, + 55.50558258918016, + 55.19053868704912, + 54.88597372525349, + 54.570850756057624, + 54.253975835514126, + 53.94573842453985, + 53.63051043994362, + 53.31091129134796, + 52.993804550829, + 52.67858715138811, + 52.35427618518613, + 52.029122339850936, + 51.709388406735044, + 51.37997236972773, + 51.05126997517338, + 50.72869087137583, + 50.40673533921882, + 50.087407082488724, + 49.76918964500379, + 49.44359393141434, + 49.12774868708828, + 48.80936249396381, + 48.482471009638466, + 48.16154030687558, + 47.83975326351295, + 47.51791137988673, + 47.198017846154094, + 46.87762889692828, + 46.55362901235234, + 46.23150805396997, + 45.91080859581661, + 45.58201851787454, + 45.258302837766585, + 44.93505012540116, + 44.605107474019384, + 44.277477323798685, + 43.955016245727585, + 43.62820678836799, + 43.30942836303323, + 42.98600470233816, + 42.65115161151966, + 42.32784807714479, + 42.001137866568904, + 41.66248222985032, + 41.315973962492706, + 40.9869940930252, + 40.64246110903902, + 40.284759235795846, + 39.93720250367378, + 39.58972856230055, + 39.2391737060455, + 38.881433406578324, + 38.53986266002814, + 38.196382868875865, + 37.844457631283646, + 37.50221303275838, + 37.15913681583543, + 36.81304313884099, + 36.465684266328594, + 36.12490782074833, + 35.78431103636534, + 35.43850879148098, + 35.09626933188624, + 34.74763222137042, + 34.40528131864121, + 34.060431499684526, + 33.709413402553835, + 33.374583192523616, + 33.03659172933373, + 32.69472681307677, + 32.352871024716386, + 32.02382317548636, + 31.692640813105342, + 31.351825875548375, + 30.981931236957166, + 30.649926729282427, + 30.308778463062776, + 29.958820918627342, + 29.612163221650928, + 29.27205607700253, + 28.928528663361543, + 28.587289296537417, + 28.252425888312267, + 27.913438454958953, + 27.575693933844544, + 27.243387472339872, + 26.920450503404965, + 26.593324956162824, + 26.268851354846873, + 25.961662584533766, + 25.648983227043615, + 25.330841804496647, + 25.024094124192995, + 24.716810376168564, + 24.397226452860163, + 24.09106675332628, + 23.792899362983363, + 23.47963434262606, + 23.178586035797643, + 22.89036554768616, + 22.588507931956737, + 22.29472900096752, + 22.016067916560125, + 21.72151298888659, + 21.424651910482837, + 21.14884867245253, + 20.831723694209558, + 20.52882144534976, + 20.251926737154022, + 19.969856166318113, + 19.674857734782908, + 19.394901622127772, + 19.115589157764912, + 18.825826774014903, + 18.540882879612287, + 18.25670131857264, + 17.966870319203498, + 17.679928870075006, + 17.395216912925186, + 17.111623677846932, + 16.83327633895371, + 16.553289938794, + 16.272858561236113, + 15.99475508423096, + 15.714673362984552, + 15.434100350148046, + 15.153104332282986, + 14.878737149287456, + 14.59569361776559, + 14.314400851082752, + 14.037187911041322, + 13.752669178415237, + 13.469130969104476, + 13.186252551954205, + 12.899769263558902, + 12.608741436899221, + 12.320483607054424, + 12.033333952864483, + 11.743212867818992, + 11.452400836491277, + 11.163832061493704, + 10.87428313586358, + 10.581298211797025, + 10.289585233925095, + 10.003460194212307, + 9.714023874585337, + 9.431185913029365, + 9.152679304249583, + 8.876103763020335, + 8.60106861338303, + 8.330234351226155, + 8.062399392062407, + 7.789153086387231, + 7.518843322043353, + 7.244486191065995, + 6.972379311833277, + 6.702715121419891, + 6.397657078386688, + 6.119546934832735, + 5.850730903770537, + 5.5778769180651375, + 5.295731491782003, + 5.010585368516014, + 4.732149618759102, + 4.446223579405364, + 4.154052964445426, + 3.8707771410600462, + 3.578668000483556, + 3.283248346031772, + 3.009831509551147, + 2.7294772185752003, + 2.439359255974464, + 2.165069821066807, + 1.8922510554894079, + 1.611875081193677, + 1.3398181623517504, + 1.06716635467668, + 0.7930661827227454, + 0.5213396865892294, + 0.24900753508883527, + -0.02927115132790787, + -0.30631210960349814, + -0.5903008760638585, + -0.8740797373657807, + -1.1557809442866598, + -1.43457206073578, + -1.7109215467472685, + -1.9875241679544764, + -2.243214593134427, + -2.4798193150215364, + -2.716157479213643, + -2.9366919052011538, + -3.1346932650920354, + -3.323224561917978, + -3.4968741879849956, + -3.6479032785146766, + -3.782963066035635, + -3.8967752603163106, + -3.986063585407675, + -4.05570221046406, + -4.102947137383738, + -4.125668714420832, + -4.126023698323803, + -4.106316079237101, + -4.069879750356919, + -4.013507029812954, + -3.9388505583439106, + -3.8521203595204794, + -3.748246727545021, + -3.628567392124721, + -3.4997206517553705, + -3.356917373552429, + -3.2018029807575537, + -3.037824888284141, + -2.8652333256444793, + -2.684198629368167, + -2.4974028485933752, + -2.306800582453958, + -2.1105985560625293, + -1.9120259660815933, + -1.7137098506786699, + -1.5177394157354076, + -1.3267271944641734, + -1.138814197084552, + -0.9579866704342876, + -0.781498247925053, + -0.6072680113377534, + -0.4379243557395881, + -0.2726371571455989, + -0.11020970900653274, + 0.04568686663126684, + 0.19655028292906002, + 0.33990140254323953, + 0.4746665422254469, + 0.6091185402850376, + 0.7351145318647105, + 0.8559473234029513, + 0.9709084656969998, + 1.0877879580709697, + 1.1983330275124655, + 1.2949339891721148, + 1.3926104580928997, + 1.4861922949559723, + 1.5718667115192695, + 1.6556812775728242, + 1.737533576764661, + 1.8103433203707366, + 1.8768386752259385, + 1.94169337644323, + 2.0008644539328135, + 2.058011016547057, + 2.113376084090168, + 2.161917395027421, + 2.206327915063, + 2.2470241984415624, + 2.2844376943959053, + 2.3167971720485934, + 2.3463759940634703, + 2.3732528454295974, + 2.3977015645263875, + 2.419424716496846, + 2.4391257210056154, + 2.4581661721343564, + 2.4760834025730842, + 2.4948209717097534, + 2.5152157292469712, + 2.537726956003219, + 2.563045175832419, + 2.5923259488744463, + 2.6235803001400018, + 2.6547104375531503, + 2.6882468672845015, + 2.722868273296302, + 2.7577078965335047, + 2.7947002260250353, + 2.833303583593914, + 2.8722764445081204, + 2.913744128947523, + 2.9582002371524707, + 3.0043149595698164, + 3.054572368166239, + 3.109635545671697, + 3.169127382480864, + 3.233982551457134, + 3.305067027519865, + 3.381649352350055, + 3.462579502841436, + 3.548881480285042, + 3.6387510749962053, + 3.73422127659939, + 3.837246845655587, + 3.9438950955057437, + 4.055148513043715, + 4.175796563345039, + 4.299857145929124, + 4.427920028760938, + 4.564200864570259, + 4.706662589097586, + 4.845795057310785, + 4.990574450932911, + 5.141452335010133, + 5.2924638623179305, + 5.445462298161452, + 5.6047012143928665, + 5.7656655200045765, + 5.923161432890435, + 6.086715257279508, + 6.253030729176701, + 6.418953847168114, + 6.591517594838, + 6.767784638605135, + 6.939741504314393, + 7.117000458113039, + 7.300423866865699, + 7.485421635925301, + 7.672668670726102, + 7.867973639907908, + 8.064928215884052, + 8.261454641495268, + 8.456323783215645, + 8.653707009385208, + 8.853109363506482, + 9.052095269379818, + 9.251010569352694, + 9.455883497431737, + 9.664548108832998, + 9.86867958138228, + 10.072192968002337, + 10.283140091676069, + 10.493257173499961, + 10.697821767737025, + 10.911771274136825, + 11.127694008643468, + 11.338431291874018, + 11.555933018342088, + 11.77915157388375, + 11.999906705493894, + 12.22073278204977, + 12.450875494394165, + 12.677462465857127, + 12.89944884663247, + 13.131417460797836, + 13.365342006427221, + 13.59036767840702, + 13.819426994262399, + 14.052295347324945, + 14.27679261328338, + 14.505431181488275, + 14.74058925982937, + 14.96915917570018, + 15.197299561430436, + 15.424908144460323, + 15.64782793185962, + 15.86418333806608, + 16.081449408336358, + 16.302276766260317, + 16.520593847042477, + 16.73952000038133, + 16.966763792036904, + 17.187508805665516, + 17.402097786691098, + 17.62437166898507, + 17.849784911959865, + 18.067067204532794, + 18.288807239623967, + 18.519673597989403, + 18.741453074728906, + 18.960201003842627, + 19.187528970108, + 19.40603607314136, + 19.6295802469014, + 19.851471769792692, + 20.063892902360426, + 20.28284044547063, + 20.49506265498528, + 20.70484641453179, + 20.931845452406186, + 21.14380278638029, + 21.351556856093342, + 21.573710934054002, + 21.797372150281266, + 22.025143555702744, + 22.26162635728783, + 22.49555917009494, + 22.74108251568682, + 22.997379529708848, + 23.259893410872117, + 23.521739121800756, + 23.803526124778454, + 24.085909301716974, + 24.364358564513505, + 24.66952019053571, + 24.984296990340198, + 25.288991394864112, + 25.621334529400418, + 25.94923672774287, + 26.231807525533302, + 26.528803991800288, + 26.83558524180938, + 27.13251200071202, + 27.43363846792482, + 27.745063153098563, + 28.050495350322393, + 28.352985451389543, + 28.657802793457517, + 28.965318681392027, + 29.290414036766883, + 29.613305450592385, + 29.935893118228776, + 30.261392835932636, + 30.590872510173433, + 30.923368052022465, + 31.244988057728378, + 31.575803204180385, + 31.911151803658623, + 32.25018089339783, + 32.59480910699614, + 32.937627402566825, + 33.28794588439994, + 33.6279347381299, + 33.96190557936996, + 34.294610113578194, + 34.627339545147244, + 34.95052864881934, + 35.281680386612635, + 35.6130857622301, + 35.9290746696189, + 36.24852163663968, + 36.57422177102859, + 36.899631083515544, + 37.22588313597998, + 37.559215504866344, + 37.892777877735455, + 38.22551890017168, + 38.560538554247486, + 38.89158251079862, + 39.23302529975385, + 39.57820606213428, + 39.91884227485884, + 40.26306141062812, + 40.60937815640735, + 40.951666628629305, + 41.29318404025987, + 41.63657129801098, + 41.975595254807295, + 42.31734029198733, + 42.65353646987068, + 42.98831713780912, + 43.32888139048449, + 43.6655586896088, + 43.99368605046921, + 44.33142138232008, + 44.66215604052806, + 44.98490126596221, + 45.320020446725124, + 45.6423555683974, + 45.957346640807316, + 46.28348996674704, + 46.60040390395959, + 46.911337690664304, + 47.23368236810229, + 47.548562973696995, + 47.84594364853935, + 48.161893501649466, + 48.47966621412509, + 48.78104052066766, + 49.09550960412552, + 49.41365828497713, + 49.71392945487082, + 50.01885034497675, + 50.32869305739655, + 50.62525095914379, + 50.92564264333343, + 51.23246453113018, + 51.524620784434596, + 51.81150141540853, + 52.10062067544182, + 52.380974291115784, + 52.65774195484237, + 52.92921084593508, + 53.204554760149584, + 53.4783903496776, + 53.75658240531161, + 54.0366105377843, + 54.31404765242369, + 54.59787519797507, + 54.88992129520901, + 55.18350011196275, + 55.47300473071834, + 55.77115199960749, + 56.06407678115986, + 56.34957599191736, + 56.640500017847735, + 56.918919577569525, + 57.18651191609925, + 57.464796277014536, + 57.7489816803978, + 58.01228352171607, + 58.27448447590476, + 58.549261764877976, + 58.81966354190635, + 59.0810190597338, + 59.34550804391031, + 59.621904510986425, + 59.89243981498452, + 60.15054633062333, + 60.4146359330284, + 60.690927131716464, + 60.95117336000782, + 61.211926574230255, + 61.49175187232582, + 61.75862622972725, + 62.002538063312464, + 62.25130083747442, + 62.494978305394376, + 62.70413113581055, + 62.90684831747295, + 63.11819867225374, + 63.30876360557971, + 63.48255241562824, + 63.65652732608937, + 63.81188357849256, + 63.952223161903035, + 64.08384835040273, + 64.20228694366901, + 64.3058778678585, + 64.39431052265087, + 64.46181476053984, + 64.5138221164496, + 64.55037676437557, + 64.57202236022114, + 64.58122002040913, + 64.57357281798858, + 64.54846525516317, + 64.5066277721625, + 64.44855207527613, + 64.37828463181846, + 64.29050019600527, + 64.1837097623907, + 64.06820377617622, + 63.9359705388783, + 63.78154042851915, + 63.62139498865959, + 63.443128478639025, + 63.246074115359114, + 63.044716788422015, + 62.830994017778295, + 62.59475675268386, + 62.35145991094609, + 62.102827126482275, + 61.836469170239184, + 61.56052117465298, + 61.28350782294512, + 61.00211248774538, + 60.71319822773284, + 60.415230565970255, + 60.11068377896734, + 59.80297097306492, + 59.487662636397125, + 59.166261821624495, + 58.84489002582075, + 58.51493489560721, + 58.17376645681918, + 57.83577562934087, + 57.50159619448542, + 57.150206968316674, + 56.8098775057813, + 56.480478255854734, + 56.1337826236115, + 55.79212702638441, + 55.46529845716434, + 55.122389420158356, + 54.78620751035339, + 54.45562989734567, + 54.112611391115095, + 53.77183859158204, + 53.4346064857942, + 53.08624326957375, + 52.73705271285156, + 52.396702556113986, + 52.04201990757297, + 51.68831138692674, + 51.341335793576874, + 50.98392454616007, + 50.63082960851566, + 50.2845751874734, + 49.935157892616516, + 49.585502316420516, + 49.23450291787667, + 48.89043106822508, + 48.53718118140634, + 48.183106174437874, + 47.83290180765036, + 47.48342297947092, + 47.13267354616005, + 46.779444550330155, + 46.4290543350682, + 46.079316639733236, + 45.72484143705796, + 45.33074088809445, + 44.97905968459101, + 44.61934566425871, + 44.25404515186802, + 43.89440567514032, + 43.52880276659065, + 43.16432513341936, + 42.80109993962998, + 42.426443112120424, + 42.05458400887959, + 41.68869886475714, + 41.30406811282386, + 40.91732412143384, + 40.55184686255979, + 40.162958205166085, + 39.77007580965446, + 39.39551996242697, + 39.011968802918155, + 38.61155593306669, + 38.237991349744036, + 37.862498253284045, + 37.47087662406409, + 37.09438850324539, + 36.725218789343096, + 36.34145903615912, + 35.962788171007944, + 35.59742720821883, + 35.22314046889542, + 34.845470615037904, + 34.47077348983268, + 34.10078165992983, + 33.720052942159576, + 33.34048087728195, + 32.970353952554824, + 32.59422500184866, + 32.21704072487308, + 31.8485029905479, + 31.48400049576467, + 31.112191483180407, + 30.739825730803798, + 30.37312850259972, + 29.99985756267379, + 29.617963798496188, + 29.2467291159945, + 28.875955840739014, + 28.502309183512967, + 28.135940856742728, + 27.766807861272547, + 27.39749678995351, + 27.034658109848234, + 26.677225788696784, + 26.318530075613232, + 25.967085718463615, + 25.627262362385075, + 25.277091515812, + 24.93020802031985, + 24.58994780320747, + 24.23465416028736, + 23.88850758412448, + 23.548554735994852, + 23.202723948728426, + 22.86842250365915, + 22.538816273231063, + 22.205466008741364, + 21.883740759053744, + 21.559288904454927, + 21.224337741051304, + 20.91460446251487, + 20.59968776191551, + 20.264677799873517, + 19.954420542202634, + 19.642669745042046, + 19.314741441083697, + 18.998021408460353, + 18.688882635627362, + 18.36340240744392, + 18.052109997714656, + 17.76556782272875, + 17.47181725701148, + 17.182570274630805, + 16.904251916910184, + 16.62114939379875, + 16.33708929371887, + 16.059215464919923, + 15.775224779471078, + 15.492091854500226, + 15.202113729578942, + 14.903877747934894, + 14.602521465534105, + 14.301146836855489, + 14.004371042662179, + 13.704510927008922, + 13.404502075527272, + 13.106773714107563, + 12.810173788665038, + 12.50792605386516, + 12.204309055755756, + 11.901595980951306, + 11.597353718742553, + 11.29308072858805, + 10.994020537350567, + 10.699498462628563, + 10.400942490118092, + 10.107273838156436, + 9.819291714596416, + 9.53246740519458, + 9.246960777908608, + 8.967798298246185, + 8.687201947034684, + 8.406735056830456, + 8.129034611020845, + 7.84208767826444, + 7.552771676692099, + 7.270788312748914, + 6.948883786893604, + 6.647360760685175, + 6.352150079582912, + 6.054043535841086, + 5.747904281550427, + 5.450203041801112, + 5.153571962327716, + 4.845539280114005, + 4.547248750200525, + 4.254694092048906, + 3.950410623554936, + 3.653454874272269, + 3.361836168080324, + 3.0649956197929056, + 2.784864028950497, + 2.5030941040493673, + 2.2184336888929423, + 1.9506677645006707, + 1.6831380259049515, + 1.4099914836804952, + 1.148007225055149, + 0.8825854868148396, + 0.6091438528935407, + 0.3329631741793019, + 0.047078976871867295, + -0.2352447315031835, + -0.508756199989569, + -0.7853828673457621, + -1.0603026857941362, + -1.338679442750711, + -1.6007652894089202, + -1.847553644318416, + -2.1040329713990307, + -2.3418535145622625, + -2.5538629993803976, + -2.765133911770426, + -2.9701720576168875, + -3.1595598230605817, + -3.3382909620172394, + -3.504147407326207, + -3.6539679529776756, + -3.78939863520616, + -3.907272219931894, + -4.0066095925570115, + -4.091204898540712, + -4.158262119894641, + -4.206623367713327, + -4.236058617552424, + -4.246644353929049, + -4.239034035612159, + -4.213506574398528, + -4.171128279359379, + -4.113710327681287, + -4.0419241443501, + -3.9559111690968516, + -3.857513719772755, + -3.748090036784706, + -3.626855195420487, + -3.496670562598065, + -3.3563265848874084, + -3.205151737412957, + -3.0488316854022117, + -2.8854476964654956, + -2.7128724275487377, + -2.5379062116092443, + -2.3614629685122734, + -2.163055172621855, + -1.982818367233269, + -1.8038731289925154, + -1.6272554952898974, + -1.4532546008280691, + -1.278990090742027, + -1.1098506009116396, + -0.942136755407772, + -0.776576723408273, + -0.6158370508830519, + -0.4544020997355928, + -0.29942973559885255, + -0.1454188060245573, + 0.007936003205972264, + 0.15565305258712397, + 0.29457625258933406, + 0.4314142580302064, + 0.5683430981854892, + 0.694846679137476, + 0.8188379360171041, + 0.9358809833941538, + 1.0527320108236529, + 1.1671660272613753, + 1.2671443883084748, + 1.3640722172091344, + 1.4577010292601715, + 1.542415850180607, + 1.6223748629558545, + 1.699971255890312, + 1.7685763797422365, + 1.8275838981024881, + 1.8824069707097382, + 1.9315755100308414, + 1.9761177832561962, + 2.0184224135107582, + 2.058772295444945, + 2.090048449146628, + 2.118414284995719, + 2.1433218020241127, + 2.164499847176934, + 2.183624398856786, + 2.201745563914696, + 2.2191163862191376, + 2.2356112568211453, + 2.252205190590705, + 2.2696734386851425, + 2.2861436871219065, + 2.3028388082520315, + 2.320156656784082, + 2.3385900642083643, + 2.3592713493285324, + 2.381709319584404, + 2.4070750850405918, + 2.433536714213401, + 2.460323180625303, + 2.488334737445989, + 2.517470701544321, + 2.546629802554478, + 2.5761035809681236, + 2.605801430812296, + 2.6346378618305946, + 2.664230743227441, + 2.694549804929435, + 2.7258159022823096, + 2.759238603216797, + 2.7944993759776464, + 2.8341113639860507, + 2.878222183433039, + 2.9265506539557307, + 2.9817570389312267, + 3.0424355076743477, + 3.10706981078607, + 3.176268493362981, + 3.2518366092586892, + 3.3306555083175096, + 3.4167003836949985, + 3.5103863409690583, + 3.608075276617331, + 3.712465197951945, + 3.824986779880386, + 3.9414953622228603, + 4.063337721981724, + 4.189228453085594, + 4.32305494845731, + 4.4601335086936915, + 4.599612768708412, + 4.745724565432593, + 4.894244097721296, + 5.041885910037444, + 5.191764910987589, + 5.346624945272, + 5.502136993699961, + 5.66031023985374, + 5.827040420228186, + 5.993915138638752, + 6.163019081791028, + 6.340004190690086, + 6.51480791576448, + 6.693052797442919, + 6.876344788889039, + 7.061744999846965, + 7.249530507137674, + 7.442401885917925, + 7.6363173798228825, + 7.833347531175167, + 8.031939902077687, + 8.231630218605986, + 8.437204769860044, + 8.646940742133532, + 8.857529835015656, + 9.07108788216632, + 9.290443494699435, + 9.513642844314555, + 9.734015257909563, + 9.952667334600143, + 10.176895422656376, + 10.399408406976416, + 10.612948927801359, + 10.835322750369697, + 11.060327826084288, + 11.278301473239495, + 11.503489875629574, + 11.73523294520646, + 11.963468785348505, + 12.191789007678494, + 12.429478131732653, + 12.665443031503617, + 12.89579158070827, + 13.136067873538982, + 13.377967392787282, + 13.609060380777938, + 13.842656718036658, + 14.079282585095763, + 14.303025116185163, + 14.531193772188637, + 14.76343373257601, + 14.988364717265448, + 15.210401108907472, + 15.432194824916941, + 15.648898221378381, + 15.859666504257781, + 16.07159243459752, + 16.28587704705328, + 16.492741049637687, + 16.70271477593439, + 16.916959488029352, + 17.120388676395603, + 17.321872086130636, + 17.52598437352893, + 17.73069132213282, + 17.923541381424986, + 18.12415877679045, + 18.33055247009981, + 18.527339344950683, + 18.728619393999544, + 18.93229651117788, + 19.159708602880198, + 19.4082784942946, + 19.657013882891285, + 19.913688153069966, + 20.1810237350667, + 20.44925984654765, + 20.71943223986166, + 21.013255340780542, + 21.296880996082397, + 21.587008717105277, + 21.87669543522121, + 22.14614390357882, + 22.429699716773424, + 22.72867919352653, + 23.015124808272496, + 23.309689131058693, + 23.62072737599991, + 23.917317997509155, + 24.22728923897416, + 24.549557830741826, + 24.859284779918198, + 25.198737878170633, + 25.55124425892161, + 25.891361522185825, + 26.253819551000504, + 26.613487891441043, + 26.953390314574257, + 27.310638571378494, + 27.675098879004985, + 28.02257609685884, + 28.372024437456844, + 28.722782278393833, + 29.056779778170263, + 29.391639482158382, + 29.733978619258064, + 30.06954199813284, + 30.405578557387145, + 30.74290460478035, + 31.076644811849285, + 31.415921869142643, + 31.75168566214963, + 32.07880615187477, + 32.40706885545298, + 32.73851460197841, + 33.0676807176834, + 33.39832381666462, + 33.722583165076074, + 34.047658623891216, + 34.379238402257826, + 34.69422166225884, + 35.000045231143986, + 35.31662744158378, + 35.626168566857515, + 35.92233434856951, + 36.23308218134385, + 36.54449892409094, + 36.84456016976044, + 37.1509961306791, + 37.46708218572251, + 37.778666332331575, + 38.09528208590343, + 38.41604525917596, + 38.73530274120566, + 39.055625712890404, + 39.37255657030364, + 39.68694055810615, + 40.00704721386262, + 40.32652264170847, + 40.648259894550165, + 40.97597427139211, + 41.301638338918146, + 41.62584438215976, + 41.956297836934525, + 42.284245628285845, + 42.613262425546324, + 42.95240079702139, + 43.28646993462193, + 43.625021677518454, + 43.97328896243987, + 44.312702543278725, + 44.64813124991473, + 44.99439254626033, + 45.3277421467613, + 45.66063501705099, + 45.99941329622676, + 46.3335422645042, + 46.67401168335572, + 47.01662165443693, + 47.35177127067828, + 47.68824468473037, + 48.02732167165485, + 48.34825278519832, + 48.67520754566773, + 49.009570294261664, + 49.330733779477306, + 49.64803478144664, + 49.97272288173262, + 50.28626181294306, + 50.59636292421281, + 50.91164452188793, + 51.217231473830275, + 51.523743506303816, + 51.83079606383554, + 52.12537571031702, + 52.41455454927051, + 52.69755428166409, + 52.97307848949289, + 53.244803055412994, + 53.51315808323412, + 53.77889165725008, + 54.0491643707451, + 54.35163494380637, + 54.62137644655208, + 54.893989465334805, + 55.17536679309078, + 55.45710389409889, + 55.74015307895498, + 56.0220228048467, + 56.303164609917786, + 56.58000819492296, + 56.841236795067466, + 57.09245391592853, + 57.339838256342624, + 57.58516029644259, + 57.82775634694091, + 58.070779490401414, + 58.308651526869305, + 58.57167097058137, + 58.81437714424254, + 59.060321005464154, + 59.31221678831944, + 59.57476904619701, + 59.84574422526881, + 60.10556921586334, + 60.36630394592955, + 60.64091672994618, + 60.911254785589264, + 61.16707885205474, + 61.434349921802685, + 61.70112616186897, + 61.93388386434675, + 62.15830832844092, + 62.38125936047842, + 62.590587097629445, + 62.787725268168, + 62.97481626703805, + 63.16308213907139, + 63.34120860191085, + 63.50523215150407, + 63.665414522459955, + 63.81702987323637, + 63.95334470163948, + 64.07447428964709, + 64.17894658610709, + 64.26765538233062, + 64.33715010727487, + 64.38424431844649, + 64.41170181037019, + 64.42031773573312, + 64.40931520031016, + 64.38339297026431, + 64.34336667198134, + 64.28737658398609, + 64.21855403422586, + 64.14121152797942, + 64.04995761227521, + 63.94846791151717, + 63.838660338765, + 63.71870140813576, + 63.58564694001701, + 63.43286502924833, + 63.26756956130647, + 63.089021454444065, + 62.89696735199743, + 62.69109830514015, + 62.475849872656646, + 62.24723760717158, + 62.00403832883274, + 61.74744926263506, + 61.48441603100128, + 61.21147438509104, + 60.92554547200103, + 60.64556168698664, + 60.35556131366098, + 60.05528774603434, + 59.761106324338954, + 59.45953567556409, + 59.14674387373742, + 58.83877987674616, + 58.52537351931939, + 58.19696747819619, + 57.87673206832023, + 57.55904407515568, + 57.219522834317935, + 56.89825560150076, + 56.583692262369006, + 56.247382950278954, + 55.925194275975464, + 55.61661738404746, + 55.287665670641275, + 54.97358775518657, + 54.667044557843184, + 54.34926065592826, + 54.033403777056186, + 53.72032698465734, + 53.39942408763453, + 53.07425797607542, + 52.76026576865347, + 52.43499851719745, + 52.1044872768535, + 51.77589236535425, + 51.44234818902297, + 51.10348480799015, + 50.77061054705971, + 50.43943035416799, + 50.105447993078876, + 49.775680140952204, + 49.437693937288984, + 49.109475016229915, + 48.774063273912645, + 48.44174250821862, + 48.11031287371997, + 47.780591614667, + 47.455178389408125, + 47.12345403343129, + 46.79250829325339, + 46.463096924711756, + 46.1302104268118, + 45.79628426756038, + 45.46087308387347, + 45.12707494978503, + 44.78291080964345, + 44.435197543313826, + 44.095601761189705, + 43.74644247020221, + 43.398053080714256, + 43.05685402972333, + 42.707021234983905, + 42.35524805506865, + 42.009044387462, + 41.65243165458508, + 41.276103231129696, + 40.9187023488142, + 40.5627451827555, + 40.18094692538712, + 39.81602462825786, + 39.4544890754482, + 39.075386646568006, + 38.70226880539241, + 38.3415785656529, + 37.97676501433701, + 37.60772571361131, + 37.250223049724525, + 36.88744324348057, + 36.52416454726091, + 36.1573556895985, + 35.79788148487995, + 35.437265081393555, + 35.07271719105067, + 34.70233914650486, + 34.334966911426164, + 33.96279699298524, + 33.57711933124139, + 33.20363930786742, + 32.83230262301181, + 32.460658625022646, + 32.08325053537434, + 31.71958169915117, + 31.357203914839488, + 30.990154777772926, + 30.631982239623195, + 30.275221666848807, + 29.915767642342555, + 29.55390361714718, + 29.197115634246998, + 28.838005033417023, + 28.47869388507307, + 28.13208967675225, + 27.77559158249143, + 27.421637809516504, + 27.090678679335795, + 26.75922066017794, + 26.43484794550527, + 26.11289662642745, + 25.79380377776938, + 25.47813120634998, + 25.164360541699196, + 24.847550507204833, + 24.534393432863162, + 24.220401011247223, + 23.90615233246471, + 23.589665253530022, + 23.273895528473595, + 22.971481724084825, + 22.665888937391326, + 22.3620898886782, + 22.069296232048632, + 21.77440371495339, + 21.466856629222264, + 21.17859792836629, + 20.8934541543908, + 20.58147048470993, + 20.285748877176808, + 20.001366350635404, + 19.69549193285993, + 19.395772270006105, + 19.102822298479825, + 18.79866540471908, + 18.49086458184024, + 18.18718747554513, + 17.87562827174696, + 17.561654080135373, + 17.25428736220401, + 16.943385909834873, + 16.635600645459125, + 16.328601752378894, + 16.022285516326736, + 15.713355480669735, + 15.406459366001444, + 15.098179998114757, + 14.788261722135921, + 14.480811970868182, + 14.168821267560622, + 13.861490918040086, + 13.552445228312457, + 13.241443743478781, + 12.931410145368236, + 12.621438580723158, + 12.308900177709276, + 11.991856791088074, + 11.677526845036505, + 11.368176690497656, + 11.05392183631032, + 10.741464037040421, + 10.439526008169324, + 10.136818386414463, + 9.82423644278, + 9.531400298275525, + 9.238686371846137, + 8.941403335004782, + 8.65284802362914, + 8.370242659390138, + 8.090755133450198, + 7.810471819959358, + 7.5290850443663935, + 7.255905312909552, + 6.981271089469676, + 6.696914938425842, + 6.4181863050356345, + 6.146481325295194, + 5.85483052076996, + 5.570950955358325, + 5.297348884398934, + 5.001466191033476, + 4.702468465076711, + 4.418485385786011, + 4.116906583262891, + 3.8174808038538623, + 3.5235553830862942, + 3.2118632499177506, + 2.9281086942201098, + 2.6440876629142287, + 2.344325939053991, + 2.0635779600180695, + 1.787671719069037, + 1.4993269583573747, + 1.2207628985657388, + 0.9454794304059446, + 0.6640958804600767, + 0.3877002089322421, + 0.11448444930974705, + -0.16578620518746093, + -0.44750527714441424, + -0.7258070032432261, + -1.0046909442100316, + -1.2832724040354047, + -1.5631608335006302, + -1.839274903874311, + -2.10599621587619, + -2.3694675277233017, + -2.62343893987978, + -2.858403956492617, + -3.090102958119805, + -3.3069641079040974, + -3.507229477288608, + -3.6938382521769215, + -3.8629814663577817, + -4.01420132205265, + -4.1466584356736265, + -4.259877706509005, + -4.352527189824512, + -4.422964230685253, + -4.473565063110901, + -4.502183007140632, + -4.507021577830109, + -4.489283337148372, + -4.45174440211488, + -4.397050804131656, + -4.323598131913655, + -4.231769913016596, + -4.124768055855302, + -4.000782801956151, + -3.848806574422038, + -3.7008431642453847, + -3.542712038947353, + -3.378535522961384, + -3.2077932292354094, + -3.030690460174938, + -2.8519001098794727, + -2.666681213026716, + -2.477801972577572, + -2.288094039386811, + -2.0954870850756433, + -1.902136836077668, + -1.7092534902537027, + -1.5183759989546781, + -1.32893877896922, + -1.1404759348918994, + -0.9554250517824684, + -0.770610521851054, + -0.1285582442863698, + 0.043080907309530216, + 0.211280767432397, + 0.37327933021129, + 0.5247876104488909, + 0.6714241109417203, + 0.8148578400793581, + 0.9457752361654689, + 1.0699833089646602, + 1.1874050240947238, + 1.3062525644859002, + 1.4148382096503505, + 1.5101383386416976, + 1.6047030472306796, + 1.6933293187395781, + 1.7742607545942723, + 1.8534355266406306, + 1.9285556499646628, + 1.9949586429074053, + 2.057049811045084, + 2.117137136923471, + 2.1715376262881505, + 2.224645268626372, + 2.2751400325408344, + 2.3194737710752404, + 2.3609646532785145, + 2.3993493477777537, + 2.434847778227501, + 2.4667596720468703, + 2.496246389384398, + 2.524205932018945, + 2.5508374102446174, + 2.5768897875550767, + 2.603334615586056, + 2.627617746830503, + 2.651458920424354, + 2.675565875821827, + 2.699632969712007, + 2.724885295219408, + 2.7514870422184745, + 2.7805366985599824, + 2.8100202298576566, + 2.839351985580052, + 2.870126952080361, + 2.902052472474704, + 2.934071944572666, + 2.9667649677538774, + 2.999161281092309, + 3.033247502509151, + 3.069458790047973, + 3.109101654141103, + 3.1521991926926627, + 3.1984060810843857, + 3.2513735416748735, + 3.3096335977130296, + 3.3721912004533228, + 3.442704234125669, + 3.518800508446662, + 3.598955450020016, + 3.6852691404136575, + 3.776690115407541, + 3.873897199025853, + 3.9809590777696595, + 4.093369901196933, + 4.21073543806722, + 4.337228720660737, + 4.467472306282675, + 4.600399174940791, + 4.7378920080515625, + 4.8818043434231715, + 5.025019052801778, + 5.174247070572004, + 5.328205564256003, + 5.479368840755798, + 5.631256669496028, + 5.788513887931863, + 5.948897262006386, + 6.10670218711256, + 6.26844904483016, + 6.432903803230794, + 6.596779259376989, + 6.7669267608152985, + 6.951513019733041, + 7.121582061617247, + 7.296303722578637, + 7.471446604902217, + 7.648390856797305, + 7.8329936775282905, + 8.020794966136831, + 8.209632808845917, + 8.401417620202345, + 8.590646064052626, + 8.783114896662422, + 8.978829742472467, + 9.175300105333713, + 9.375093784722608, + 9.576483388082313, + 9.776822687641996, + 9.974478826265862, + 10.16913283903571, + 10.363268687008691, + 10.55903746225986, + 10.75563130245319, + 10.945306668273254, + 11.135763190050596, + 11.332323488480347, + 11.525328914372237, + 11.720578016653537, + 11.926693929229051, + 12.135399727870412, + 12.342714717150432, + 12.55003234180756, + 12.761801265083607, + 12.97256146778967, + 13.179128127067901, + 13.387694609538476, + 13.597709724968675, + 13.80361797596977, + 14.006064805723163, + 14.207421030004635, + 14.406475207674958, + 14.601841720278095, + 14.793749819797004, + 14.987798478702754, + 15.181954233018319, + 15.372252885129878, + 15.558553037905114, + 15.745354404116084, + 15.931025015334148, + 16.112133567599226, + 16.29169380601665, + 16.474231645900765, + 16.65673077203514, + 16.838598653819968, + 17.024090442972696, + 17.215763953323627, + 17.40575160280916, + 17.593940114722514, + 17.78483760605761, + 17.981528288357495, + 18.182422671778454, + 18.374776814112163, + 18.581111160182836, + 18.79362259697671, + 19.002776113219582, + 19.21626970605701, + 19.43112870827955, + 19.65102080898132, + 19.878853568831616, + 20.10312546864239, + 20.33456798064165, + 20.571455677987018, + 20.805307917268557, + 21.045403500637494, + 21.298626572310877, + 21.53535576431192, + 21.787166760774234, + 22.04704548994729, + 22.30130520146322, + 22.574917940685502, + 22.846884374467052, + 23.11244824903883, + 23.394736111510714, + 23.676559415391722, + 23.953561991223374, + 24.250244848173462, + 24.540664243181794, + 24.833706554567964, + 25.156073221083844, + 25.473212534749933, + 25.796484527816077, + 26.13727266230354, + 26.460167230968434, + 26.77512580126815, + 27.108708116478457, + 27.438704853826536, + 27.759939243778156, + 28.097005722458324, + 28.43352407824126, + 28.759042039978237, + 29.089389232226157, + 29.420449816116722, + 29.74983119488025, + 30.08212078675594, + 30.411269707900978, + 30.742944489098743, + 31.079775794747977, + 31.415306515442328, + 31.741820418039623, + 32.07149075857679, + 32.40361086995983, + 32.73287821237294, + 33.05788248022835, + 33.377945876076446, + 33.703491287061006, + 34.02881086933761, + 34.336350875136155, + 34.63920540415571, + 34.95101710678149, + 35.252506576967846, + 35.54453913345566, + 35.85123718330825, + 36.1556024863718, + 36.44855841864565, + 36.750834738479114, + 37.06178859755466, + 37.36669465341993, + 37.67828627342124, + 37.99722984607634, + 38.3128101331543, + 38.63145517761055, + 38.94999809789752, + 39.262558010339916, + 39.579606480905696, + 39.900656383320424, + 40.21583267579129, + 40.53472068673256, + 40.85465562257616, + 41.171871203786324, + 41.486037357245635, + 41.80341437120019, + 42.115184424916826, + 42.432267817800195, + 42.74869006613391, + 43.06151791935443, + 43.3820080676286, + 43.70140529186662, + 44.013298163509596, + 44.328170401408954, + 44.64674426514315, + 44.947785521041396, + 45.25663632901949, + 45.56604597424447, + 45.86151513692612, + 46.1652166857148, + 46.47234999090252, + 46.76795480259499, + 47.06580954138513, + 47.369754222452364, + 47.66790276213002, + 47.955036340077136, + 48.25304046096536, + 48.55028755018086, + 48.84124349388114, + 49.13120357369446, + 49.425155662794786, + 49.71207103834986, + 49.997415742459815, + 50.28675152243265, + 50.57106337226466, + 50.851292024285996, + 51.13648119935449, + 51.417560945870974, + 51.69166214349666, + 51.96406627724454, + 52.22926991177055, + 52.48582986037931, + 52.745453668301394, + 52.99701452760204, + 53.24552141987308, + 53.4978257580014, + 53.7494697455108, + 54.00373136573797, + 54.255545635500454, + 54.50935013800976, + 54.770800669626155, + 55.036105044860285, + 55.30578103475355, + 55.57712629270071, + 55.84841517886736, + 56.121305819378854, + 56.396127182219146, + 56.66480154268678, + 56.92590628816359, + 57.18698185333048, + 57.44827129411211, + 57.71011908320546, + 57.963999775378376, + 58.218160972638955, + 58.47851432665823, + 58.73886674078835, + 59.0018961803807, + 59.25945908560965, + 59.52850899544941, + 59.7983985531829, + 60.06082253745904, + 60.325324743604384, + 60.603310869586814, + 60.87044228018748, + 61.12857127677023, + 61.40259509282569, + 61.67108245307185, + 61.91466944388023, + 62.15701613438367, + 62.39918205440036, + 62.61468586808511, + 62.814890921539984, + 63.02869703724511, + 63.23373607245903, + 63.41827901302412, + 63.60299309694622, + 63.775328972893604, + 63.92889207470187, + 64.07696865776536, + 64.21234997683516, + 64.33590342676878, + 64.44841181662521, + 64.54345291530073, + 64.61930888948838, + 64.6804588488602, + 64.72498316710659, + 64.75503192542301, + 64.77423687634494, + 64.77803140213041, + 64.76324698744891, + 64.7316474879003, + 64.68467452846008, + 64.6238135747295, + 64.54563229844479, + 64.44741475264918, + 64.3356181562903, + 64.20980251223435, + 64.05914493782261, + 63.89264825264366, + 63.72007165595962, + 63.522447573320285, + 63.31200360543133, + 63.09432787491209, + 62.85875514113911, + 62.60149057913376, + 62.33823598972425, + 62.06428886562031, + 61.77172285159672, + 61.47355105704766, + 61.17496468507821, + 60.87045058716934, + 60.55938185071553, + 60.24397081513541, + 59.93095284858972, + 59.61728220176726, + 59.299333939218656, + 58.97707926350325, + 58.65792783521924, + 58.33271154277494, + 58.0012539725055, + 57.67455351559774, + 57.35069668912867, + 57.0110033390398, + 56.682707863948245, + 56.36397335195614, + 56.03032003075439, + 55.70119001727611, + 55.386535174913206, + 55.05762540355066, + 54.73099797658368, + 54.413186029836865, + 54.08440679283176, + 53.73343743190345, + 53.37357244061748, + 53.004494568369005, + 52.62954228046007, + 52.26446093215443, + 51.8915976868798, + 51.510535483981364, + 51.13531867866375, + 50.75821421464184, + 50.37417727794913, + 50.014487343204124, + 49.669224905922185, + 49.321863973983824, + 48.97802674707963, + 48.62923138335952, + 48.28707355315305, + 47.93763151574009, + 47.5872907197428, + 47.24000687906527, + 46.893839862634835, + 46.53587658051037, + 46.163590570402484, + 45.79241563967316, + 45.42317540460844, + 45.04551180986331, + 44.66978917424429, + 44.29262343794187, + 43.91757027134529, + 43.531587538619334, + 43.14696794743754, + 42.78206108151733, + 42.41824049184275, + 42.060184243998044, + 41.70452003079957, + 41.33805206358612, + 40.9766109893554, + 40.61983341924642, + 40.24887535369847, + 39.87223229958563, + 39.519149902919985, + 39.18455634113386, + 38.860337834576384, + 38.55863130316872, + 38.25385652507213, + 37.934972214089946, + 37.63300200201469, + 37.3396025668475, + 37.03070272139799, + 36.7277037174502, + 36.43124540575485, + 36.10315588004389, + 35.76046113590517, + 35.41316425324658, + 35.07222237259091, + 34.731873722071725, + 34.387032236592475, + 34.036139478930146, + 33.69462576680907, + 33.35530418497045, + 32.99712842178764, + 32.65920215189999, + 32.333351840655396, + 32.0051779050895, + 31.672771298385804, + 31.345114919508905, + 31.023647925277686, + 30.69561782251355, + 30.372857957424703, + 30.04455318675462, + 29.716530650772686, + 29.376775009459283, + 29.032727437247416, + 28.694625437234627, + 28.349447040615075, + 28.011036457098466, + 27.678311777443966, + 27.340469359712845, + 27.009506139351178, + 26.681481223625706, + 26.35123343731874, + 26.030749965565633, + 25.703695028905464, + 25.3729809368927, + 25.05138638577504, + 24.72510824817708, + 24.39159632548782, + 24.06553639184264, + 23.740206121656943, + 23.40739769660998, + 23.081261657381123, + 22.763948496106984, + 22.434754638579886, + 22.11865931239052, + 21.807961662977746, + 21.4905386192929, + 21.164729629764402, + 20.862205917135583, + 20.557598524940502, + 20.22856773223279, + 19.92575997577919, + 19.6153881903489, + 19.286986859541866, + 18.97106000671142, + 18.65889334937452, + 18.33971068441129, + 18.018534079578565, + 17.705214573794848, + 17.385735385576314, + 17.068365578110665, + 16.757617040072198, + 16.4496460625785, + 16.145614271502176, + 15.843183457518938, + 15.539724258363853, + 15.23081033491441, + 14.928168044593754, + 14.619350352711118, + 14.312415204065331, + 14.004900876356503, + 13.691505397129548, + 13.382761226417156, + 13.072808436633192, + 12.75617255762379, + 12.440687864840086, + 12.126135942931526, + 11.807694599494935, + 11.481563043217621, + 11.160286642289627, + 10.843404375716958, + 10.519724274272413, + 10.197474213686554, + 9.885910478792159, + 9.576727165359841, + 9.255507127598605, + 8.954497568968769, + 8.656825556729501, + 8.35151396705426, + 8.061474395638424, + 7.7724433786617295, + 7.479411741392958, + 7.188095097181444, + 6.89573277042006, + 6.601306178114268, + 6.311113766360433, + 6.012875205586886, + 5.711826949516217, + 5.418867168237042, + 5.11094740753219, + 4.812014104697031, + 4.518557033966874, + 4.21565271665293, + 3.9175057084956775, + 3.6254461948470915, + 3.331545966540914, + 3.0376112132400044, + 2.7496280170382965, + 2.458511058356632, + 2.1776750685582287, + 1.9023324270350852, + 1.6219873136478071, + 1.3515619052253622, + 1.085310220601021, + 0.8132238274266218, + 0.542831482048225, + 0.27315570468110895, + 0.0028421577680153853, + -0.27049276703125885, + -0.5463268248329006, + -0.8262004551646298, + -1.107879395286346, + -1.3916557107149121, + -1.6774387069881023, + -1.9574536540348448, + -2.2414835550731893, + -2.526726733690335, + -2.796098289993876, + -3.0485159716884143, + -3.3022453823800215, + -3.545904518964392, + -3.7680426080445755, + -3.9822654167228353, + -4.182143455408361, + -4.362342772072077, + -4.525006759792942, + -4.667669520894381, + -4.786895394067779, + -4.882482773350822, + -4.9567380382589485, + -5.00760632093183, + -5.033904877646841, + -5.035950736923687, + -5.015924590660966, + -4.977009753510907, + -4.918364667515823, + -4.840335445178749, + -4.7456233965048025, + -4.633561589034956, + -4.505079616704708, + -4.36172030627776, + -4.204324694798168, + -4.036262663020632, + -3.8583288185661546, + -3.6709480583204877, + -3.480061577759506, + -3.282265203079415, + -3.079398118769766, + -2.87609039202851, + -2.6689610642012136, + -2.4603357019972623, + -2.2549151712934425, + -2.0555541820113747, + -1.8414780652124227, + -1.651346634081261, + -1.4682549265943972, + -1.2889642340245302, + -1.113857587858442, + -0.9438890907116492, + -0.7775191634099369, + -0.6157830205941762, + -0.4605198406816952, + -0.30980998000948873, + -0.1681397575337732, + -0.03536969681095837, + 0.0949241979294292, + 0.21648296179149223, + 0.33224419728985344, + 0.44184639524177516, + 0.5530670602380464, + 0.6552431131693358, + 0.7450458870719152, + 0.8349279146221444, + 0.9197564848703558, + 0.9981466064915008, + 1.0757093210853481, + 1.1498858268519423, + 1.2157172484173244, + 1.277918760588237, + 1.3369809429183377, + 1.3913095375182345, + 1.4445619806903216, + 1.4930199416312082, + 1.5353990216939157, + 1.57393914736842, + 1.6094500826339255, + 1.640371474278643, + 1.6681913617370199, + 1.6941189600356357, + 1.7183310280475588, + 1.7415256307133822, + 1.7652994469640837, + 1.7876405491265939, + 1.8093238050473628, + 1.8308054348159348, + 1.8520438795362315, + 1.8748257790396303, + 1.8996520904616516, + 1.9262057756298006, + 1.952449525353578, + 1.979242520411555, + 2.008259301776948, + 2.0384574799230704, + 2.069874562285853, + 2.1033989223984486, + 2.139669610737211, + 2.178847733484213, + 2.2224087485069375, + 2.2704188602062985, + 2.321978765945034, + 2.3805472836395496, + 2.451333063077684, + 2.5204186862095352, + 2.596175004887099, + 2.6769282243746972, + 2.7612674531780805, + 2.850935968758286, + 2.944825825949513, + 3.0466485157760412, + 3.1548488331499427, + 3.2655988413770833, + 3.3841870985236198, + 3.5110067392985895, + 3.6403081450036745, + 3.772111482310876, + 3.91300797460865, + 4.053585243976242, + 4.196203948618943, + 4.346386942022684, + 4.4964075473362435, + 4.6438185147866395, + 4.797364815007405, + 4.957656078701385, + 5.113551060641826, + 5.277450170411978, + 5.447738805582302, + 5.61868728094935, + 5.792274115754509, + 5.970106560387088, + 6.1457914048665065, + 6.3240470581112325, + 6.5038232997804295, + 6.684811613747742, + 6.867542189719006, + 7.0527788657473085, + 7.23832605270834, + 7.4271045363039025, + 7.61488157945303, + 7.803947766682915, + 7.997218280776169, + 8.191109069136724, + 8.386710094937602, + 8.584029781015781, + 8.784563043070957, + 8.98616378095811, + 9.186300524490916, + 9.384675712138709, + 9.58578425900947, + 9.785127603053741, + 9.977131089883029, + 10.17423205860926, + 10.375199026198787, + 10.589069677872354, + 10.787897957613902, + 10.9944989318322, + 11.199313719525836, + 11.40220889785868, + 11.611422327244709, + 11.822203056455006, + 12.028078159659541, + 12.236382887514369, + 12.448917898577703, + 12.656943496096835, + 12.86293268568178, + 13.068496729897587, + 13.273911901456566, + 13.476843235127044, + 13.677405638669933, + 13.880626236854065, + 14.083914035531546, + 14.281469296468464, + 14.477584788810997, + 14.696795871184278, + 14.891196227875946, + 15.082311130722983, + 15.276633135204207, + 15.472227614634006, + 15.667924286693639, + 15.866672791542968, + 16.0723923531363, + 16.277184547378898, + 16.48010689593563, + 16.683861003762978, + 16.89619857368472, + 17.108556511388578, + 17.315968162417718, + 17.539772148344706, + 17.765361978006677, + 17.990712281871108, + 18.21992972308881, + 18.451895868100554, + 18.68826151201862, + 18.9436482275444, + 19.19867073009613, + 19.463662386025156, + 19.730008903195234, + 19.9952377679208, + 20.276325961206325, + 20.541583971021918, + 20.819617418066056, + 21.104824436489302, + 21.38463648974693, + 21.66152259805987, + 21.932805429587546, + 22.199802235368654, + 22.482130779534508, + 22.761252586310977, + 23.03887936669711, + 23.334516634642657, + 23.61941361492147, + 23.911167285421456, + 24.228285841106565, + 24.53476624296535, + 24.855426896910537, + 25.18760203848001, + 25.49634142742295, + 25.802876757165475, + 26.12767599834626, + 26.448340980539754, + 26.758791512708367, + 27.084390931040822, + 27.40786261649449, + 27.731069968144634, + 28.05683599811312, + 28.3814308398038, + 28.702588021869904, + 29.024447450977327, + 29.342072286186745, + 29.65736677249564, + 29.980166530694827, + 30.29936164656833, + 30.609351396683735, + 30.942218991646982, + 31.274985758940357, + 31.59873538561554, + 31.917679388804146, + 32.22637499903647, + 32.53261266046398, + 32.84957140299829, + 33.1537604937998, + 33.45175328611491, + 33.75321001672992, + 34.04505669586066, + 34.3315081975387, + 34.61734665505737, + 34.91373649998367, + 35.20612154712153, + 35.4972385862957, + 35.79638421038561, + 36.10143582856517, + 36.40616431502931, + 36.7166564837261, + 37.03092016042375, + 37.34524010017537, + 37.66244537940462, + 37.98099835555076, + 38.29525213913023, + 38.613917564197976, + 38.93655431141074, + 39.25684937883945, + 39.58090054663728, + 39.908278426885154, + 40.2359741012518, + 40.56312289221807, + 40.89323347893464, + 41.252771217355594, + 41.583297146888995, + 41.90962386932896, + 42.22992667364179, + 42.55172726340942, + 42.86374484428655, + 43.163666724218984, + 43.4657626112107, + 43.76206757207292, + 44.03554399500987, + 44.3155751391994, + 44.59400240126925, + 44.85395319498643, + 45.12284975288024, + 45.39909856508403, + 45.66769236792282, + 45.94006227356499, + 46.21462418755764, + 46.4915005790355, + 46.77438105175335, + 47.04990057016058, + 47.32688896385801, + 47.6175106516778, + 47.90280342864727, + 48.18506541471182, + 48.47461299275151, + 48.761709218167425, + 49.043855530041554, + 49.333408220082184, + 49.624954836179626, + 49.907927570110175, + 50.196870979637715, + 50.49018285919311, + 50.77470018970976, + 51.05728309383378, + 51.33949904509417, + 51.61068049658557, + 51.88334215765688, + 52.153806801541364, + 52.42080888823905, + 52.692224989140016, + 52.96637536516427, + 53.24359290174978, + 53.51893086188829, + 53.799865792365004, + 54.088940690466885, + 54.38145622363204, + 54.67559382248205, + 54.972990652781654, + 55.26971689965987, + 55.567887137864005, + 55.861405261790225, + 56.146119876549264, + 56.43012315788755, + 56.71446933117696, + 56.996859523816624, + 57.26842312183835, + 57.544300551709576, + 57.825632761545144, + 58.10239128015138, + 58.375800451310475, + 58.65452129529378, + 58.93909611047069, + 59.21085667673081, + 59.47113500778839, + 59.740987264165724, + 60.00288489650001, + 60.2411062117452, + 60.489616959428595, + 60.74440561630401, + 60.97948612363302, + 61.20377689502732, + 61.42669067933879, + 61.644217562591095, + 61.84154689449571, + 62.02494818540991, + 62.216494455226645, + 62.40301496814696, + 62.5729062602609, + 62.739733953662665, + 62.90127769126194, + 63.0479087832785, + 63.188862066127065, + 63.32009926727753, + 63.43892000765799, + 63.548062032563706, + 63.64495326883894, + 63.723300638418195, + 63.78756769276339, + 63.83887868470011, + 63.87464820900632, + 63.90015377300029, + 63.91353245152039, + 63.91048662090791, + 63.891427781412744, + 63.85912787946155, + 63.813384867752916, + 63.75416329934468, + 63.67813296736693, + 63.584576370320185, + 63.482783152394894, + 63.36265994441006, + 63.22093511317302, + 63.07142173192214, + 62.90890336399922, + 62.72142432038193, + 62.52867142173112, + 62.323333738872115, + 62.09941765674898, + 61.85980415567179, + 61.61067513099217, + 61.347178835717614, + 61.06974311268001, + 60.78741520466235, + 60.497830705641256, + 60.20279873773858, + 59.89873348432303, + 59.585043621096226, + 59.266936513389, + 58.948489803259854, + 58.621791337080595, + 58.28971966682808, + 57.962389710684064, + 57.62597298102287, + 57.27909752193891, + 56.936167214092045, + 56.599916639903, + 56.26025818730859, + 55.923789485388916, + 55.601350547949295, + 55.27176884837666, + 54.93680527794924, + 54.62045518974739, + 54.296942341618085, + 53.96884906228317, + 53.65506611243988, + 53.33251378359047, + 53.00009929624337, + 52.676468408131186, + 52.34445179465412, + 52.00681219218575, + 51.67495333296684, + 51.34402544088789, + 51.001131750417386, + 50.66352177516216, + 50.328657755750555, + 49.984471850724496, + 49.649447251285785, + 49.319524842901195, + 48.98736591985296, + 48.6577174415375, + 48.32583472324404, + 47.993326969072214, + 47.66618385550341, + 47.33024882085921, + 46.996015374056576, + 46.6625747941048, + 46.3288997901906, + 45.99449743748959, + 45.65820973178802, + 45.32367401442467, + 44.98701469094136, + 44.65087242754459, + 44.3128161997339, + 43.97490572725108, + 43.60540030037656, + 43.2693660002701, + 42.9533266225585, + 42.63806078501144, + 42.32052615358233, + 42.0095836811218, + 41.693147034220765, + 41.37030589200015, + 41.056625342652836, + 40.74007253894454, + 40.40955698480483, + 40.07238619335842, + 39.73983985942109, + 39.39260096861668, + 39.029058991457305, + 38.68573293410243, + 38.33982956594653, + 37.9793949940433, + 37.63018870182943, + 37.29310655753665, + 36.95160633019749, + 36.60880867832948, + 36.293045043930555, + 35.97320318330131, + 35.65843250820487, + 35.34143487093785, + 35.023269264007865, + 34.71491328954702, + 34.404730125534805, + 34.08880522075607, + 33.77287508607854, + 33.464978895564755, + 33.13532789499079, + 32.79521185152687, + 32.470538905825705, + 32.142421903556844, + 31.816023894313595, + 31.489715051373466, + 31.162633798062817, + 30.844728987655056, + 30.520512393006193, + 30.200186095790748, + 29.874214475069444, + 29.546783390008734, + 29.22024519972762, + 28.88660565295937, + 28.556615940743562, + 28.225122145310184, + 27.887892089126876, + 27.55833359518771, + 27.225281584961564, + 26.880269079369786, + 26.546483679129313, + 26.211752975406316, + 25.872264985783026, + 25.54079640197197, + 25.206617009471472, + 24.87929742152391, + 24.54970022173698, + 24.21830077512884, + 23.89078380024956, + 23.55858934148418, + 23.229282320037715, + 22.90869003444346, + 22.582503646843886, + 22.259829266745648, + 21.951962768662916, + 21.636761446826245, + 21.325235204922894, + 21.027892349667027, + 20.71510396184291, + 20.400388442243187, + 20.10406009212683, + 19.788973839438082, + 19.465898738205222, + 19.16521682464404, + 18.853977589852196, + 18.533844973163838, + 18.225560236998408, + 17.916390066821002, + 17.593149552008764, + 17.279735159101676, + 16.96738557844563, + 16.61353796299852, + 16.300786009462797, + 15.993030516912416, + 15.681375942368238, + 15.370879265711897, + 15.064397930064088, + 14.749228817483711, + 14.439039644775507, + 14.124585889320837, + 13.804131998714254, + 13.481147050734426, + 13.157733911860888, + 12.834151570500785, + 12.505922526550986, + 12.175900497336459, + 11.84838349465872, + 11.519249127080254, + 11.185757644129255, + 10.852782421353982, + 10.525000382206137, + 10.192836056464515, + 9.866128307061684, + 9.542051026352523, + 9.217309348473334, + 8.895321172734855, + 8.578402333638502, + 8.26170217041076, + 7.939481330769564, + 7.627377744402807, + 7.3177884251394145, + 7.010961355341202, + 6.70668843031132, + 6.403785089924659, + 6.10947597925163, + 5.814662071230183, + 5.519360755593465, + 5.225755566271759, + 4.941538545824498, + 4.654884637718361, + 4.38037694029395, + 4.120549845484478, + 3.8539538181153965, + 3.579046402777778, + 3.313308735996936, + 3.0528396543980385, + 2.7771352753989516, + 2.50866923525166, + 2.2483543666163492, + 1.9711689430605315, + 1.6920598186618878, + 1.418285367364076, + 1.1352485535177619, + 0.8652150394627393, + 0.5996267110920918, + 0.3251281950151414, + 0.053165143872573595, + -0.2163908265782503, + -0.4892134221579513, + -0.754967675341181, + -1.016081762235151, + -1.2804261208534857, + -1.5432177157655607, + -1.8140142664177707, + -2.0899060194822106, + -2.36329393983016, + -2.6378118023299026, + -2.915525029592267, + -3.1890168242133385, + -3.454539702908877, + -3.7163576609103215, + -3.975675678123687, + -4.2169925406172375, + -4.442689768247701, + -4.658572553725139, + -4.852573105338675, + -5.030485206627612, + -5.190462235832259, + -5.3259049753047725, + -5.443858062965672, + -5.538424480407576, + -5.609667089114421, + -5.66232695645199, + -5.694878036607654, + -5.704945724777047, + -5.694117560591922, + -5.664957132654981, + -5.619711295630675, + -5.556593805077401, + -5.475222126440451, + -5.379112331432588, + -5.267802947825157, + -5.141050971444225, + -5.0015617749878665, + -4.8494632509374025, + -4.687618016712599, + -4.517439685474579, + -4.338229332376084, + -4.153840410571848, + -3.9630722745112554, + -3.762739198445909, + -3.559114993832612, + -3.354593878788111, + -3.1470325361190508, + -2.9382350311469287, + -2.7327155931539, + -2.5306624485787, + -2.328507542270429, + -2.132728645912017, + -1.9398885568742124, + -1.7454978122007605, + -1.5582439754376392, + -1.3742513527750697, + -1.1958249393208156, + -1.0243365165770353, + -0.8596568349141067, + -0.7045010362854772, + -0.5612043749577209, + -0.4213395870413378, + -0.29265799824540595, + -0.1747140709160461, + -0.06427383575028402, + 0.04523888440500598, + 0.1470176048847739, + 0.2335592457807787, + 0.31806725376161676, + 0.39725495085779616, + 0.4688920850625567, + 0.5378993094566618, + 0.6042473556930033, + 0.6626326850149867, + 0.7166694726914491, + 0.7682639322389797, + 0.8145747696512449, + 0.8593962881678019, + 0.902068632543713, + 0.9396647149657125, + 0.9750302682532217, + 1.0078196208749777, + 1.0370968429022625, + 1.0632475635657515, + 1.0885355014874818, + 1.1129560030203436, + 1.1368010358379101, + 1.1616439975988324, + 1.185384436834196, + 1.2084431698441171, + 1.231475823346499, + 1.254762415104005, + 1.279003680873743, + 1.304490645361735, + 1.3321634619381912, + 1.3605681174269533, + 1.3892327278655954, + 1.4186765175872023, + 1.449830915578756, + 1.481366538781976, + 1.5139104799467373, + 1.5476033220878807, + 1.583263999062211, + 1.6211933233438354, + 1.6631383793633245, + 1.7085112718607305, + 1.7566245083624843, + 1.8106609482744596, + 1.8701071709030277, + 1.9325622125965751, + 2.000600405276433, + 2.073073719730457, + 2.1487545459062662, + 2.229233110974539, + 2.314248791055779, + 2.404061883807194, + 2.503093932077756, + 2.605972973079022, + 2.713395029001061, + 2.8296484655673715, + 2.9507293024997834, + 3.073381872513424, + 3.202339794245352, + 3.337032920952903, + 3.4721955709684758, + 3.6149854181753436, + 3.7636729912526707, + 3.910178150042658, + 4.058923070848974, + 4.21475164808327, + 4.3719460968550985, + 4.529731913673303, + 4.699013206401141, + 4.872421310561653, + 5.046968437777668, + 5.229006277495889, + 5.409705435527416, + 5.593652208359495, + 5.782655045033461, + 5.973170236368014, + 6.165235802080729, + 6.362082629896817, + 6.559839729215012, + 6.758378728278356, + 6.955605431440671, + 7.152684060701755, + 7.354195314330586, + 7.556177350072287, + 7.757156526317015, + 7.961187881364418, + 8.169003680971045, + 8.377459379121646, + 8.581252383706312, + 8.781756769247409, + 8.98448890738033, + 9.185323285345376, + 9.377647316277997, + 9.574631679098811, + 9.77453789199406, + 9.968845107737579, + 10.165648840966393, + 10.369614482872768, + 10.576969644898908, + 10.796436682491988, + 11.021275163606756, + 11.246864619549918, + 11.468880277362299, + 11.693590573508914, + 11.92143585618974, + 12.146513612658033, + 12.371844212598857, + 12.598693790863255, + 12.822096235578423, + 13.035439877280268, + 13.247246640349253, + 13.460973546419812, + 13.67818491273366, + 13.889488252801932, + 14.09869730671219, + 14.310534066599072, + 14.517203507884803, + 14.719377674133092, + 14.91959255573849, + 15.11129102371526, + 15.302603365919937, + 15.497272166423574, + 15.698383337483369, + 15.900379833012433, + 16.10047374125098, + 16.301025074430772, + 16.509151973958257, + 16.746694732657037, + 16.953910521678782, + 17.177209192109522, + 17.41215293423783, + 17.64433768572708, + 17.877133390770577, + 18.119552844386924, + 18.364348477256137, + 18.615503268685643, + 18.86396729684635, + 19.119776393134917, + 19.387481743891833, + 19.65799693723779, + 19.97138913535241, + 20.252628616256146, + 20.53972116370472, + 20.864599743802877, + 21.167491705195776, + 21.479459035401458, + 21.794086367754364, + 22.11399182898066, + 22.431454120196182, + 22.750859270112684, + 23.08280638027464, + 23.40683337366142, + 23.752317076511204, + 24.111574288181092, + 24.463440793015387, + 24.83238042227069, + 25.195968103920862, + 25.539312503153802, + 25.893310860662872, + 26.24369925261043, + 26.582267823747525, + 26.934328990924996, + 27.28588487239649, + 27.621442569502996, + 27.960117425471747, + 28.303818723395178, + 28.635159948539933, + 28.96563663478973, + 29.302293916387427, + 29.63502560679554, + 29.971363294207848, + 30.29806216363941, + 30.627519894388662, + 30.96068694316108, + 31.294987541140635, + 31.623092608346326, + 31.944797115786923, + 32.27409768959656, + 32.592069090265866, + 32.89812328844543, + 33.2033668790907, + 33.50553426591979, + 33.80315270375641, + 34.10199116335442, + 34.40374116129653, + 34.70037769128913, + 34.99930901872234, + 35.3027502455714, + 35.61019483415942, + 35.91778870110349, + 36.23177032401961, + 36.549318572671865, + 36.86447146895361, + 37.18297074853924, + 37.50051483488108, + 37.812978772462074, + 38.12853715955711, + 38.44735728290711, + 38.761536901895255, + 39.07845222453753, + 39.398809727184464, + 39.71473270566849, + 40.029760352806896, + 40.34504792610765, + 40.656417875372846, + 40.968938220175374, + 41.28462245556623, + 41.595517406598056, + 41.9094211795519, + 42.22432579850013, + 42.533473369050185, + 42.8382291810717, + 43.149608904374205, + 43.451222239978314, + 43.7461698096567, + 44.055228739426425, + 44.3520907788295, + 44.64503373770526, + 44.950728556177175, + 45.24935317712527, + 45.54098436312996, + 45.842943558877685, + 46.14755303458895, + 46.443326146505704, + 46.73854919902926, + 47.04786138192352, + 47.35185725399669, + 47.6496011629088, + 47.95643597017896, + 48.26021279372815, + 48.5526815219461, + 48.85167784575071, + 49.15353998314599, + 49.44431352175967, + 49.73911939311548, + 50.03882848645932, + 50.32720195041783, + 50.61033855379421, + 50.89366193587728, + 51.167884941565404, + 51.442476792158224, + 51.70740478669174, + 51.97455183959459, + 52.2424372398762, + 52.50955159749591, + 52.7824441181921, + 53.051639281404725, + 53.324425481435284, + 53.60486635010587, + 53.88963555239475, + 54.17262292819129, + 54.4593666694105, + 54.746556348026466, + 55.0305015067513, + 55.31720677899201, + 55.59453917403792, + 55.86299504245453, + 56.15888068132426, + 56.43126106967637, + 56.689559026047526, + 56.94065230347727, + 57.20232586978722, + 57.46099179998578, + 57.716605912095446, + 57.9667704481623, + 58.22562514258538, + 58.48279049995408, + 58.73193463906139, + 58.97679471021057, + 59.231726255130766, + 59.48618960316225, + 59.72657914780032, + 59.97475629090222, + 60.23213442959812, + 60.47564590161082, + 60.7049426999666, + 60.93445010429704, + 61.16012701900658, + 61.3628146677447, + 61.550008156074576, + 61.74411336584359, + 61.93120147403796, + 62.101802479904414, + 62.270331330755454, + 62.43413388237453, + 62.58517546562002, + 62.73057604318588, + 62.8658714391403, + 62.989543574071305, + 63.103099285653244, + 63.20404152139367, + 63.28437587397921, + 63.34853626378916, + 63.39811140601858, + 63.43054360930116, + 63.45097372267109, + 63.4567870942156, + 63.4447036226012, + 63.415166511477636, + 63.37023820702246, + 63.31125332955198, + 63.23868736126066, + 63.14754852230307, + 63.03907964589805, + 62.92303318853251, + 62.78694141392481, + 62.63013965343733, + 62.47033102323899, + 62.291344581506586, + 62.09359031103882, + 61.89338210735949, + 61.68071930303039, + 61.44532632829423, + 61.203398502780615, + 60.956911902839714, + 60.69511980115445, + 60.42411300080536, + 60.151264974959936, + 59.874515231224194, + 59.597389428236895, + 59.31315620970049, + 59.021405078511236, + 58.73348260878832, + 58.44388700203228, + 58.14274849214298, + 57.845826988627806, + 57.547992367501905, + 57.23870777395376, + 56.928117561617285, + 56.62382853069168, + 56.31482853049869, + 55.990849826911465, + 55.683937631249485, + 55.379659002614254, + 55.05940881942585, + 54.7478383274686, + 54.446228891852904, + 54.12854095685096, + 53.81675817839018, + 53.51092358754183, + 53.191921350580614, + 52.874242507329065, + 52.560490915928355, + 52.23665547370583, + 51.9077213338404, + 51.586335427714964, + 51.25869299864215, + 50.92217315579611, + 50.58992709820675, + 50.25454632426013, + 49.91197629794347, + 49.57380078223236, + 49.237340611949016, + 48.898109266487566, + 48.561150029770815, + 48.217675361077035, + 47.8814988613294, + 47.538440549060454, + 47.193064510530945, + 46.84970730673195, + 46.50781183219243, + 46.16979702769043, + 45.82343644077698, + 45.479765835147404, + 45.13888378364644, + 44.79218792981562, + 44.44490713561993, + 44.0988326524609, + 43.75639413298596, + 43.405006335112816, + 43.055527911170955, + 42.7086224383385, + 42.35636752516571, + 42.01094060050921, + 41.66224869226204, + 41.30453120325186, + 40.955322427633995, + 40.606122509777826, + 40.23671831132266, + 39.87349021139226, + 39.5275216361299, + 39.15532054602195, + 38.7817917199122, + 38.42861344698394, + 38.061284715268485, + 37.68275186248695, + 37.327194044766365, + 36.97288601560147, + 36.59687620499849, + 36.23856852027886, + 35.88109543832795, + 35.51698266869542, + 35.151454114699064, + 34.794316267355306, + 34.437446570826715, + 34.0771206780121, + 33.71134847630277, + 33.35259399364076, + 33.003896861735065, + 32.640723827922415, + 32.28507059277639, + 31.936035330965904, + 31.586702433075484, + 31.22925195912165, + 30.88262143045321, + 30.53913232170126, + 30.188714527696792, + 29.840928928588454, + 29.48492688340245, + 29.125366374489975, + 28.76159707627873, + 28.402096387015956, + 28.04029167738964, + 27.677618347007837, + 27.324717705275482, + 26.96516969116016, + 26.60742238152953, + 26.25813582849179, + 25.940644512843257, + 25.64412781515667, + 25.348022830332283, + 25.054248705830318, + 24.762940172729472, + 24.474128236011488, + 24.183142438830497, + 23.889583459808744, + 23.60377062481296, + 23.31867446763228, + 23.00573403717028, + 22.69543862677699, + 22.38831817031561, + 22.07397291674327, + 21.77221620303394, + 21.474625655377526, + 21.159855111507262, + 20.860472983457548, + 20.57213799635636, + 20.259065357779058, + 19.974570675468115, + 19.717712944410955, + 19.440432089622206, + 19.16813345630301, + 18.908419499224287, + 18.64000437675221, + 18.36595897269595, + 18.101705674415747, + 17.836092414791384, + 17.56766658637108, + 17.29630341598981, + 17.025696047063235, + 16.75357804943685, + 16.481298915640618, + 16.21239540766006, + 15.937971120749795, + 15.666811668372599, + 15.391689270983614, + 15.117745913680936, + 14.848566220763917, + 14.540019487081059, + 14.222796660348196, + 13.903347266645358, + 13.579127085268404, + 13.257103352453145, + 12.937332081227593, + 12.61581250883567, + 12.287693670984922, + 11.954463566316953, + 11.631394681600423, + 11.329342568479117, + 11.031985211164129, + 10.737523628601888, + 10.45177322208383, + 10.159690898535434, + 9.86528611566276, + 9.585463943766001, + 9.302510829422065, + 9.018771681351822, + 8.740721821328417, + 8.451412847625747, + 8.16055976991773, + 7.940739885595652, + 7.649689547274853, + 7.355041201225335, + 7.073086663607159, + 6.781403498982336, + 6.480975198189406, + 6.193000995912802, + 5.894118683471748, + 5.604056865569893, + 5.33068637775697, + 5.0487330235721695, + 4.759880363605886, + 4.478298131655137, + 4.196421392600951, + 3.905844195969257, + 3.622675096119216, + 3.3335103833622286, + 3.0471742784361973, + 2.7719845417630884, + 2.487867692040898, + 2.2085807352402806, + 1.9386381623244573, + 1.6696997892504515, + 1.4008923722223003, + 1.1325338854488094, + 0.8698931640746329, + 0.6058550815292704, + 0.34178591209782944, + 0.07667844538865048, + -0.19026593451683016, + -0.4554406244379945, + -0.7270459916318173, + -0.9993729438317401, + -1.2657127999327225, + -1.5359056209694093, + -1.80910351197689, + -2.0707859508561195, + -2.3171929634210486, + -2.57303096028214, + -2.8298091729269363, + -3.06656586192069, + -3.2899355731611752, + -3.5058629323482395, + -3.7002218758350507, + -3.8799659060810683, + -4.044154401469258, + -4.188005818183804, + -4.3152184885170035, + -4.431354366367503, + -4.514703010052084, + -4.582047960812328, + -4.631528872229713, + -4.660910020821575, + -4.6708403048030105, + -4.662449926632847, + -4.637595003710751, + -4.596236011176684, + -4.537947385026952, + -4.376888542703274, + -4.288209726394857, + -4.18429492589649, + -4.066784639944364, + -3.9387904539825316, + -3.799571967547997, + -3.653433657150335, + -3.501583167393676, + -3.3401169800524984, + -3.1744481619690124, + -3.0052437357619475, + -2.829155463434466, + -2.6527959456576737, + -2.475772365932245, + -2.2955140655935677, + -2.1163822744969565, + -1.9392996466493302, + -1.7651734640595305, + -1.5942498473022266, + -1.4247547430679854, + -1.2604125493983682, + -1.1006254540448448, + -0.9428929999493842, + -0.7885739354253732, + -0.6394382491121076, + -0.49422949338813715, + -0.3547025436789347, + -0.2220377492865131, + -0.09497282596831103, + 0.021735167796699595, + 0.13285858280024343, + 0.24087676995385654, + 0.33923224114094186, + 0.4329200110885978, + 0.5201430763632404, + 0.6076261201085276, + 0.6879757431943925, + 0.7578119823986187, + 0.8269555418893362, + 0.8913029716133205, + 0.9503621390414663, + 1.0076633251910367, + 1.0619069384761641, + 1.1108600056309395, + 1.154451832954652, + 1.1949563456065757, + 1.2306989835976558, + 1.2636346948814599, + 1.2938187560504752, + 1.3188102651418132, + 1.3400979295708186, + 1.3587029727938362, + 1.3736961944225057, + 1.3874170183239944, + 1.4002813212691323, + 1.413864282540611, + 1.428639923213225, + 1.4439159929580396, + 1.4588185530447295, + 1.4716824922168787, + 1.4835117408193126, + 1.4944732530514964, + 1.505037309620955, + 1.5159365970329328, + 1.5279182777271618, + 1.5710227590004497, + 1.585487163567568, + 1.601561593277515, + 1.6191296376214352, + 1.6392118118171357, + 1.661332817227168, + 1.683818426776412, + 1.7097134479140736, + 1.7383053758480917, + 1.7712132265080818, + 1.8082864177801734, + 1.8495958259515872, + 1.8971048618132835, + 1.9495414443591164, + 2.0089977824310385, + 2.0774488692880864, + 2.152719910738682, + 2.238964543324992, + 2.334237456808289, + 2.4368837415056133, + 2.548951046986737, + 2.6712010864801186, + 2.804199387424994, + 2.9489350548037434, + 3.104840352919012, + 3.2696446307951232, + 3.449537370291477, + 3.6422762528119206, + 3.844964818735719, + 4.064028170637739, + 4.296194383137763, + 4.536128453129303, + 4.780515740575605, + 5.068595331826337, + 5.339755000188887, + 5.614200312412775, + 5.903984607704475, + 6.204051350614986, + 6.502448630503951, + 6.810922050815234, + 7.129056537553518, + 7.449288520433644, + 7.773967266604292, + 8.107895570113833, + 8.446125235123965, + 8.788662654274459, + 9.13733022137939, + 9.492647484169803, + 9.849955649461563, + 10.21254506524584, + 10.577595066343726, + 10.945159109465557, + 11.317011962849532, + 11.695576288011447, + 12.073581835360525, + 12.45651816101892, + 12.83768461737093, + 13.225855182239112, + 13.6045467550301, + 13.982960369263015, + 14.37046998740073, + 14.753653591504062, + 15.128301092639678, + 15.48333670659692, + 15.866854354862518, + 16.245350633101292, + 16.624730674923363, + 17.001686838803085, + 17.37919471421535, + 17.740406988155204, + 18.101336461510847, + 18.468215679311406, + 18.82581382576176, + 19.18161067675971, + 19.535407199454607, + 19.884345087594586, + 20.228653998324123, + 20.548386356183464, + 20.865869997294162, + 21.16885830062872, + 21.444246808429096, + 21.723230991549734, + 21.981530352651177, + 22.224328239948395, + 22.46410538987838, + 22.682047779183893, + 22.87564032840609, + 23.07241333225752, + 23.249333466160344, + 23.402534058412485, + 23.554669446076673, + 23.689115788959498, + 23.803609375044548, + 23.914176643879845, + 24.00819998804275, + 24.090429572476623, + 24.15856535870222, + 24.21198547003588, + 24.258824715173773, + 24.299506484657044, + 24.333927203988285, + 24.36638965023104, + 24.3987771413772, + 24.4317546647794, + 24.463807479798007, + 24.47374998174773, + 24.482187505217095, + 24.489667268472083, + 24.497212999751582, + 24.50380722017657, + 24.50955393584218, + 24.515152386071996, + 24.520318624688684, + 24.5248045628062, + 24.52850286369933, + 24.562741464181546, + 24.600868556033113, + 24.639572183517366, + 24.678999202340567, + 24.71910268613251, + 24.752167541806834, + 24.79364139014063, + 24.835708363767267, + 24.87896621837055, + 24.9232279506598, + 24.934048566691608, + 24.94107921159258, + 24.948355187402903, + 24.955586820345676, + 24.962806447321853, + 24.97021256670934, + 24.977547770122076, + 24.98494706269543, + 24.99242952739627, + 24.99994437339902, + 25.02540271299319, + 25.052916655074828, + 25.08056116530959, + 25.108168293242098, + 25.135808603887792, + 25.163485883998344, + 25.19115746111873, + 25.218876122759994, + 25.24661949756898, + 25.274470489049747, + 25.284313636520654, + 25.292359845297213, + 25.30049146549067, + 25.308518479646892, + 25.316455103602966, + 25.32458102322264, + 25.332738211605577, + 25.340816196491332, + 25.349086424508247, + 25.357295176651835, + 25.403074324577638, + 25.45374928722011, + 25.504408168173335, + 25.55487276861665, + 25.605032335771533, + 25.655136174273608, + 25.705378138595986, + 25.755388975120873, + 25.80516652331133, + 25.855482764168823, + 25.866115316851094, + 25.872209004406198, + 25.87815187371181, + 25.88382182578971, + 25.889485629453922, + 25.895124265033186, + 25.900828708697823, + 25.907149898069797, + 25.912957769240567, + 25.91864769467925, + 25.920183502332854, + 25.921636104878875, + 25.922860010393627, + 25.92409529747196, + 25.925571771507055, + 25.92699075098118, + 25.928285237895963, + 25.929639765365053 + ], + "yaxis": "y" + }, + { + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + -35.97329684506414, + -75.10034519689322, + -1.038764248227812, + -36.0892694400111 + ], + "y": [ + 26.316580859038513, + 21.011440906439965, + -12.138119313965529, + 72.35004639855165 + ] + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "margin": { + "b": 0, + "l": 0, + "r": 0, + "t": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "scaleanchor": "x", + "scaleratio": 1, + "title": { + "text": "y" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = px.scatter(x=positions[:,0],y=positions[:,1])\n", + "fig.add_scatter(x=landmarks[:,0], y=landmarks[:,1], mode=\"markers\", showlegend= False)\n", + "fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))\n", + "fig.update_yaxes(scaleanchor = \"x\", scaleratio = 1)\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-35.97329685, 26.31658086],\n", + " [-75.1003452 , 21.01144091],\n", + " [ -1.03876425, -12.13811931],\n", + " [-36.08926944, 72.3500464 ]])" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "landmarks" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "341996cd3f3db7b5e0d1eaea072c5502d80452314e72e6b77c40445f6e9ba101" + }, + "kernelspec": { + "display_name": "Python 3.8.12 ('nbdev')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/examples/SFMExample.py b/python/gtsam/examples/SFMExample.py index f0c4c82ba..87bb3cb87 100644 --- a/python/gtsam/examples/SFMExample.py +++ b/python/gtsam/examples/SFMExample.py @@ -8,7 +8,6 @@ See LICENSE for the license information A structure-from-motion problem on a simulated dataset """ -from __future__ import print_function import gtsam import matplotlib.pyplot as plt @@ -89,7 +88,7 @@ def main(): point_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.1) factor = PriorFactorPoint3(L(0), points[0], point_noise) graph.push_back(factor) - graph.print_('Factor Graph:\n') + graph.print('Factor Graph:\n') # Create the data structure to hold the initial estimate to the solution # Intentionally initialize the variables off from the ground truth @@ -100,7 +99,7 @@ def main(): for j, point in enumerate(points): transformed_point = point + 0.1*np.random.randn(3) initial_estimate.insert(L(j), transformed_point) - initial_estimate.print_('Initial Estimates:\n') + initial_estimate.print('Initial Estimates:\n') # Optimize the graph and print results params = gtsam.DoglegParams() @@ -108,7 +107,7 @@ def main(): optimizer = DoglegOptimizer(graph, initial_estimate, params) print('Optimizing:') result = optimizer.optimize() - result.print_('Final results:\n') + result.print('Final results:\n') print('initial error = {}'.format(graph.error(initial_estimate))) print('final error = {}'.format(graph.error(result))) diff --git a/python/gtsam/examples/SFMExample_bal.py b/python/gtsam/examples/SFMExample_bal.py index dfe8b523c..3d71590a9 100644 --- a/python/gtsam/examples/SFMExample_bal.py +++ b/python/gtsam/examples/SFMExample_bal.py @@ -7,86 +7,89 @@ See LICENSE for the license information Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file - Author: Frank Dellaert (Python: Akshay Krishnan, John Lambert) + Author: Frank Dellaert (Python: Akshay Krishnan, John Lambert, Varun Agrawal) """ import argparse import logging import sys -import matplotlib.pyplot as plt -import numpy as np - import gtsam -from gtsam import ( - GeneralSFMFactorCal3Bundler, - PinholeCameraCal3Bundler, - PriorFactorPinholeCameraCal3Bundler, - readBal, - symbol_shorthand -) +from gtsam import (GeneralSFMFactorCal3Bundler, SfmData, + PriorFactorPinholeCameraCal3Bundler, PriorFactorPoint3) +from gtsam.symbol_shorthand import P # type: ignore +from gtsam.utils import plot # type: ignore +from matplotlib import pyplot as plt -C = symbol_shorthand.C -P = symbol_shorthand.P +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +DEFAULT_BAL_DATASET = "dubrovnik-3-7-pre" -logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +def plot_scene(scene_data: SfmData, result: gtsam.Values) -> None: + """Plot the SFM results.""" + plot_vals = gtsam.Values() + for i in range(scene_data.numberCameras()): + plot_vals.insert(i, result.atPinholeCameraCal3Bundler(i).pose()) + for j in range(scene_data.numberTracks()): + plot_vals.insert(P(j), result.atPoint3(P(j))) -def run(args): + plot.plot_3d_points(0, plot_vals, linespec="g.") + plot.plot_trajectory(0, plot_vals, title="SFM results") + + plt.show() + + +def run(args: argparse.Namespace) -> None: """ Run LM optimization with BAL input data and report resulting error """ - input_file = gtsam.findExampleDataFile(args.input_file) + input_file = args.input_file # Load the SfM data from file - scene_data = readBal(input_file) - logging.info(f"read {scene_data.number_tracks()} tracks on {scene_data.number_cameras()} cameras\n") + scene_data = SfmData.FromBalFile(input_file) + logging.info("read %d tracks on %d cameras\n", scene_data.numberTracks(), + scene_data.numberCameras()) # Create a factor graph graph = gtsam.NonlinearFactorGraph() # We share *one* noiseModel between all projection factors - noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0) # one pixel in u and v + noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0) # one pixel in u and v # Add measurements to the factor graph - j = 0 - for t_idx in range(scene_data.number_tracks()): - track = scene_data.track(t_idx) # SfmTrack + for j in range(scene_data.numberTracks()): + track = scene_data.track(j) # SfmTrack # retrieve the SfmMeasurement objects - for m_idx in range(track.number_measurements()): + for m_idx in range(track.numberMeasurements()): # i represents the camera index, and uv is the 2d measurement i, uv = track.measurement(m_idx) # note use of shorthand symbols C and P - graph.add(GeneralSFMFactorCal3Bundler(uv, noise, C(i), P(j))) - j += 1 + graph.add(GeneralSFMFactorCal3Bundler(uv, noise, i, P(j))) # Add a prior on pose x1. This indirectly specifies where the origin is. graph.push_back( - gtsam.PriorFactorPinholeCameraCal3Bundler( - C(0), scene_data.camera(0), gtsam.noiseModel.Isotropic.Sigma(9, 0.1) - ) - ) + PriorFactorPinholeCameraCal3Bundler( + 0, scene_data.camera(0), + gtsam.noiseModel.Isotropic.Sigma(9, 0.1))) # Also add a prior on the position of the first landmark to fix the scale graph.push_back( - gtsam.PriorFactorPoint3( - P(0), scene_data.track(0).point3(), gtsam.noiseModel.Isotropic.Sigma(3, 0.1) - ) - ) + PriorFactorPoint3(P(0), + scene_data.track(0).point3(), + gtsam.noiseModel.Isotropic.Sigma(3, 0.1))) # Create initial estimate initial = gtsam.Values() - + i = 0 # add each PinholeCameraCal3Bundler - for cam_idx in range(scene_data.number_cameras()): - camera = scene_data.camera(cam_idx) - initial.insert(C(i), camera) + for i in range(scene_data.numberCameras()): + camera = scene_data.camera(i) + initial.insert(i, camera) i += 1 - j = 0 # add each SfmTrack - for t_idx in range(scene_data.number_tracks()): - track = scene_data.track(t_idx) + for j in range(scene_data.numberTracks()): + track = scene_data.track(j) initial.insert(P(j), track.point3()) - j += 1 # Optimize the graph and print results try: @@ -94,25 +97,31 @@ def run(args): params.setVerbosityLM("ERROR") lm = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) result = lm.optimize() - except Exception as e: + except RuntimeError: logging.exception("LM Optimization failed") return + # Error drops from ~2764.22 to ~0.046 - logging.info(f"final error: {graph.error(result)}") + logging.info("initial error: %f", graph.error(initial)) + logging.info("final error: %f", graph.error(result)) + + plot_scene(scene_data, result) + + +def main() -> None: + """Main runner.""" + parser = argparse.ArgumentParser() + parser.add_argument('-i', + '--input_file', + type=str, + default=gtsam.findExampleDataFile(DEFAULT_BAL_DATASET), + help="""Read SFM data from the specified BAL file. + The data format is described here: https://grail.cs.washington.edu/projects/bal/. + BAL files contain (nrPoses, nrPoints, nrObservations), followed by (i,j,u,v) tuples, + then (wx,wy,wz,tx,ty,tz,f,k1,k1) as Bundler camera calibrations w/ Rodrigues vector + and (x,y,z) 3d point initializations.""") + run(parser.parse_args()) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - '-i', - '--input_file', - type=str, - default="dubrovnik-3-7-pre", - help='Read SFM data from the specified BAL file' - 'The data format is described here: https://grail.cs.washington.edu/projects/bal/.' - 'BAL files contain (nrPoses, nrPoints, nrObservations), followed by (i,j,u,v) tuples, ' - 'then (wx,wy,wz,tx,ty,tz,f,k1,k1) as Bundler camera calibrations w/ Rodrigues vector' - 'and (x,y,z) 3d point initializations.' - ) - run(parser.parse_args()) - + main() diff --git a/python/gtsam/examples/SimpleRotation.py b/python/gtsam/examples/SimpleRotation.py index 0fef261f8..3d5fd9e45 100644 --- a/python/gtsam/examples/SimpleRotation.py +++ b/python/gtsam/examples/SimpleRotation.py @@ -31,7 +31,7 @@ def main(): - A measurement model with the correct dimensionality for the factor """ prior = gtsam.Rot2.fromAngle(np.deg2rad(30)) - prior.print_('goal angle') + prior.print('goal angle') model = gtsam.noiseModel.Isotropic.Sigma(dim=1, sigma=np.deg2rad(1)) key = X(1) factor = gtsam.PriorFactorRot2(key, prior, model) @@ -48,7 +48,7 @@ def main(): """ graph = gtsam.NonlinearFactorGraph() graph.push_back(factor) - graph.print_('full graph') + graph.print('full graph') """ Step 3: Create an initial estimate @@ -65,7 +65,7 @@ def main(): """ initial = gtsam.Values() initial.insert(key, gtsam.Rot2.fromAngle(np.deg2rad(20))) - initial.print_('initial estimate') + initial.print('initial estimate') """ Step 4: Optimize @@ -77,7 +77,7 @@ def main(): with the final state of the optimization. """ result = gtsam.LevenbergMarquardtOptimizer(graph, initial).optimize() - result.print_('final result') + result.print('final result') if __name__ == '__main__': diff --git a/python/gtsam/examples/VisualISAM2Example.py b/python/gtsam/examples/VisualISAM2Example.py index bacf510ec..4b480fab7 100644 --- a/python/gtsam/examples/VisualISAM2Example.py +++ b/python/gtsam/examples/VisualISAM2Example.py @@ -81,7 +81,7 @@ def visual_ISAM2_example(): # will approach the batch result. parameters = gtsam.ISAM2Params() parameters.setRelinearizeThreshold(0.01) - parameters.setRelinearizeSkip(1) + parameters.relinearizeSkip = 1 isam = gtsam.ISAM2(parameters) # Create a Factor Graph and Values to hold the new data diff --git a/python/gtsam/examples/VisualISAMExample.py b/python/gtsam/examples/VisualISAMExample.py index f99d3f3e6..9691b3c46 100644 --- a/python/gtsam/examples/VisualISAMExample.py +++ b/python/gtsam/examples/VisualISAMExample.py @@ -10,8 +10,6 @@ A visualSLAM example for the structure-from-motion problem on a simulated datase This version uses iSAM to solve the problem incrementally """ -from __future__ import print_function - import numpy as np import gtsam from gtsam.examples import SFMdata @@ -94,7 +92,7 @@ def main(): current_estimate = isam.estimate() print('*' * 50) print('Frame {}:'.format(i)) - current_estimate.print_('Current estimate: ') + current_estimate.print('Current estimate: ') # Clear the factor graph and values for the next iteration graph.resize(0) diff --git a/python/gtsam/notebooks/DiscreteBayesTree.ipynb b/python/gtsam/notebooks/DiscreteBayesTree.ipynb new file mode 100644 index 000000000..066c31d6a --- /dev/null +++ b/python/gtsam/notebooks/DiscreteBayesTree.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Discrete Bayes Tree\n", + "\n", + "An example of building a Bayes net, then eliminating it into a Bayes tree. Mirrors the code in `testDiscreteBayesTree.cpp` .\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesTree, DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " #TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n8\n\n8\n\n\n\n0\n\n0\n\n\n\n8->0\n\n\n\n\n\n1\n\n1\n\n\n\n8->1\n\n\n\n\n\n12\n\n12\n\n\n\n12->8\n\n\n\n\n\n12->0\n\n\n\n\n\n12->1\n\n\n\n\n\n9\n\n9\n\n\n\n12->9\n\n\n\n\n\n2\n\n2\n\n\n\n12->2\n\n\n\n\n\n3\n\n3\n\n\n\n12->3\n\n\n\n\n\n9->2\n\n\n\n\n\n9->3\n\n\n\n\n\n10\n\n10\n\n\n\n4\n\n4\n\n\n\n10->4\n\n\n\n\n\n5\n\n5\n\n\n\n10->5\n\n\n\n\n\n13\n\n13\n\n\n\n13->10\n\n\n\n\n\n13->4\n\n\n\n\n\n13->5\n\n\n\n\n\n11\n\n11\n\n\n\n13->11\n\n\n\n\n\n6\n\n6\n\n\n\n13->6\n\n\n\n\n\n7\n\n7\n\n\n\n13->7\n\n\n\n\n\n11->6\n\n\n\n\n\n11->7\n\n\n\n\n\n14\n\n14\n\n\n\n14->8\n\n\n\n\n\n14->12\n\n\n\n\n\n14->9\n\n\n\n\n\n14->10\n\n\n\n\n\n14->13\n\n\n\n\n\n14->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c615b0>" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define DiscreteKey pairs.\n", + "keys = [(j, 2) for j in range(15)]\n", + "\n", + "# Create thin-tree Bayesnet.\n", + "bayesNet = DiscreteBayesNet()\n", + "\n", + "\n", + "bayesNet.add(keys[0], P(keys[8], keys[12]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[1], P(keys[8], keys[12]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[2], P(keys[9], keys[12]), \"1/4 8/2 2/3 4/1\")\n", + "bayesNet.add(keys[3], P(keys[9], keys[12]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[4], P(keys[10], keys[13]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[5], P(keys[10], keys[13]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[6], P(keys[11], keys[13]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[7], P(keys[11], keys[13]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[8], P(keys[12], keys[14]), \"T 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[9], P(keys[12], keys[14]), \"4/1 2/3 F 1/4\")\n", + "bayesNet.add(keys[10], P(keys[13], keys[14]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[11], P(keys[13], keys[14]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[12], P(keys[14]), \"3/1 3/1\")\n", + "bayesNet.add(keys[13], P(keys[14]), \"1/3 3/1\")\n", + "\n", + "bayesNet.add(keys[14], P(), \"1/3\")\n", + "\n", + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 0}\n", + "DiscreteValues{0: 0, 1: 1, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n", + "DiscreteValues{0: 1, 1: 0, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1, 9: 0, 10: 1, 11: 1, 12: 0, 13: 1, 14: 0}\n", + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 1}\n", + "DiscreteValues{0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n" + ] + } + ], + "source": [ + "# Sample Bayes net (needs conditionals added in elimination order!)\n", + "for i in range(5):\n", + " print(bayesNet.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar0\n\n0\n\n\n\nfactor0\n\n\n\n\nvar0--factor0\n\n\n\n\nvar1\n\n1\n\n\n\nfactor1\n\n\n\n\nvar1--factor1\n\n\n\n\nvar2\n\n2\n\n\n\nfactor2\n\n\n\n\nvar2--factor2\n\n\n\n\nvar3\n\n3\n\n\n\nfactor3\n\n\n\n\nvar3--factor3\n\n\n\n\nvar4\n\n4\n\n\n\nfactor4\n\n\n\n\nvar4--factor4\n\n\n\n\nvar5\n\n5\n\n\n\nfactor5\n\n\n\n\nvar5--factor5\n\n\n\n\nvar6\n\n6\n\n\n\nfactor6\n\n\n\n\nvar6--factor6\n\n\n\n\nvar7\n\n7\n\n\n\nfactor7\n\n\n\n\nvar7--factor7\n\n\n\n\nvar8\n\n8\n\n\n\nvar8--factor0\n\n\n\n\nvar8--factor1\n\n\n\n\nfactor8\n\n\n\n\nvar8--factor8\n\n\n\n\nvar9\n\n9\n\n\n\nvar9--factor2\n\n\n\n\nvar9--factor3\n\n\n\n\nfactor9\n\n\n\n\nvar9--factor9\n\n\n\n\nvar10\n\n10\n\n\n\nvar10--factor4\n\n\n\n\nvar10--factor5\n\n\n\n\nfactor10\n\n\n\n\nvar10--factor10\n\n\n\n\nvar11\n\n11\n\n\n\nvar11--factor6\n\n\n\n\nvar11--factor7\n\n\n\n\nfactor11\n\n\n\n\nvar11--factor11\n\n\n\n\nvar12\n\n12\n\n\n\nvar14\n\n14\n\n\n\nvar12--var14\n\n\n\n\nvar12--factor0\n\n\n\n\nvar12--factor1\n\n\n\n\nvar12--factor2\n\n\n\n\nvar12--factor3\n\n\n\n\nvar12--factor8\n\n\n\n\nvar12--factor9\n\n\n\n\nvar13\n\n13\n\n\n\nvar13--var14\n\n\n\n\nvar13--factor4\n\n\n\n\nvar13--factor5\n\n\n\n\nvar13--factor6\n\n\n\n\nvar13--factor7\n\n\n\n\nvar13--factor10\n\n\n\n\nvar13--factor11\n\n\n\n\nvar14--factor8\n\n\n\n\nvar14--factor9\n\n\n\n\nvar14--factor10\n\n\n\n\nvar14--factor11\n\n\n\n\nfactor14\n\n\n\n\nvar14--factor14\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61f10>" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\n8,12,14\n\n\n\n1\n\n0 : 8,12\n\n\n\n0->1\n\n\n\n\n\n2\n\n1 : 8,12\n\n\n\n0->2\n\n\n\n\n\n3\n\n9 : 12,14\n\n\n\n0->3\n\n\n\n\n\n6\n\n10,13 : 14\n\n\n\n0->6\n\n\n\n\n\n4\n\n2 : 9,12\n\n\n\n3->4\n\n\n\n\n\n5\n\n3 : 9,12\n\n\n\n3->5\n\n\n\n\n\n7\n\n4 : 10,13\n\n\n\n6->7\n\n\n\n\n\n8\n\n5 : 10,13\n\n\n\n6->8\n\n\n\n\n\n9\n\n11 : 13,14\n\n\n\n6->9\n\n\n\n\n\n10\n\n6 : 11,13\n\n\n\n9->10\n\n\n\n\n\n11\n\n7 : 11,13\n\n\n\n9->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61b50>" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "for j in range(15): ordering.push_back(j)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "show(bayesTree)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/notebooks/DiscreteSwitching.ipynb b/python/gtsam/notebooks/DiscreteSwitching.ipynb new file mode 100644 index 000000000..6872e78c8 --- /dev/null +++ b/python/gtsam/notebooks/DiscreteSwitching.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A Discrete Switching System\n", + "\n", + "A la MHS, but all discrete.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n", + "from gtsam.symbol_shorthand import M\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " # TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nrStates = 3\n", + "K = 5\n", + "\n", + "bayesNet = DiscreteBayesNet()\n", + "for k in range(1, K):\n", + " key = S(k), nrStates\n", + " key_plus = S(k+1), nrStates\n", + " mode = M(k), 2\n", + " bayesNet.add(key_plus, P(mode, key), \"9/1/0 1/8/1 0/1/9 1/9/0 0/1/9 9/0/1\")\n", + "\n", + "bayesNet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "# First eliminate \"continuous\" states in time order\n", + "for k in range(1, K+1):\n", + " ordering.push_back(S(k))\n", + "for k in range(1, K):\n", + " ordering.push_back(M(k))\n", + "print(ordering)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "bayesTree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesTree)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h new file mode 100644 index 000000000..608508c32 --- /dev/null +++ b/python/gtsam/preamble/discrete.h @@ -0,0 +1,16 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include + +PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/preamble/geometry.h b/python/gtsam/preamble/geometry.h index 35fe2a577..bd0441d06 100644 --- a/python/gtsam/preamble/geometry.h +++ b/python/gtsam/preamble/geometry.h @@ -23,8 +23,8 @@ PYBIND11_MAKE_OPAQUE( std::vector>); PYBIND11_MAKE_OPAQUE(gtsam::Point2Pairs); PYBIND11_MAKE_OPAQUE(gtsam::Point3Pairs); +PYBIND11_MAKE_OPAQUE(gtsam::Pose2Pairs); PYBIND11_MAKE_OPAQUE(gtsam::Pose3Pairs); PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE( - gtsam::CameraSet>); +PYBIND11_MAKE_OPAQUE(gtsam::CameraSet>); PYBIND11_MAKE_OPAQUE(gtsam::CameraSet>); diff --git a/python/gtsam/preamble/inference.h b/python/gtsam/preamble/inference.h new file mode 100644 index 000000000..4106c794a --- /dev/null +++ b/python/gtsam/preamble/inference.h @@ -0,0 +1,14 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include \ No newline at end of file diff --git a/python/gtsam/preamble/slam.h b/python/gtsam/preamble/slam.h index 34dbb4b7a..f7bf5863c 100644 --- a/python/gtsam/preamble/slam.h +++ b/python/gtsam/preamble/slam.h @@ -15,3 +15,4 @@ PYBIND11_MAKE_OPAQUE( std::vector > >); PYBIND11_MAKE_OPAQUE( std::vector > >); +PYBIND11_MAKE_OPAQUE(gtsam::Rot3Vector); diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h new file mode 100644 index 000000000..458a2ea4c --- /dev/null +++ b/python/gtsam/specializations/discrete.h @@ -0,0 +1,17 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + +// Seems this is not a good idea with inherited stl +//py::bind_vector>(m_, "DiscreteKeys"); + +py::bind_map(m_, "DiscreteValues"); diff --git a/python/gtsam/specializations/geometry.h b/python/gtsam/specializations/geometry.h index a492ce8eb..5a0ea35ef 100644 --- a/python/gtsam/specializations/geometry.h +++ b/python/gtsam/specializations/geometry.h @@ -16,6 +16,7 @@ py::bind_vector< m_, "Point2Vector"); py::bind_vector>(m_, "Point2Pairs"); py::bind_vector>(m_, "Point3Pairs"); +py::bind_vector>(m_, "Pose2Pairs"); py::bind_vector>(m_, "Pose3Pairs"); py::bind_vector>(m_, "Pose3Vector"); py::bind_vector>>( diff --git a/python/gtsam/specializations/inference.h b/python/gtsam/specializations/inference.h new file mode 100644 index 000000000..22fe3beff --- /dev/null +++ b/python/gtsam/specializations/inference.h @@ -0,0 +1,13 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + diff --git a/python/gtsam/specializations/slam.h b/python/gtsam/specializations/slam.h index 198485a72..6a439c370 100644 --- a/python/gtsam/specializations/slam.h +++ b/python/gtsam/specializations/slam.h @@ -12,8 +12,9 @@ */ py::bind_vector< - std::vector > > >( + std::vector>>>( m_, "BetweenFactorPose3s"); py::bind_vector< - std::vector > > >( + std::vector>>>( m_, "BetweenFactorPose2s"); +py::bind_vector(m_, "Rot3Vector"); diff --git a/python/gtsam/tests/testEssentialMatrixConstraint.py b/python/gtsam/tests/testEssentialMatrixConstraint.py new file mode 100644 index 000000000..8439ad2e9 --- /dev/null +++ b/python/gtsam/tests/testEssentialMatrixConstraint.py @@ -0,0 +1,47 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +visual_isam unit tests. +Author: Frank Dellaert & Pablo Alcantarilla +""" + +import unittest + +import gtsam +import numpy as np +from gtsam import (EssentialMatrix, EssentialMatrixConstraint, Point3, Pose3, + Rot3, Unit3, symbol) +from gtsam.utils.test_case import GtsamTestCase + + +class TestVisualISAMExample(GtsamTestCase): + def test_VisualISAMExample(self): + + # Create a factor + poseKey1 = symbol('x', 1) + poseKey2 = symbol('x', 2) + trueRotation = Rot3.RzRyRx(0.15, 0.15, -0.20) + trueTranslation = Point3(+0.5, -1.0, +1.0) + trueDirection = Unit3(trueTranslation) + E = EssentialMatrix(trueRotation, trueDirection) + model = gtsam.noiseModel.Isotropic.Sigma(5, 0.25) + factor = EssentialMatrixConstraint(poseKey1, poseKey2, E, model) + + # Create a linearization point at the zero-error point + pose1 = Pose3(Rot3.RzRyRx(0.00, -0.15, 0.30), Point3(-4.0, 7.0, -10.0)) + pose2 = Pose3( + Rot3.RzRyRx(0.179693265735950, 0.002945368776519, + 0.102274823253840), + Point3(-3.37493895, 6.14660244, -8.93650986)) + + expected = np.zeros((5, 1)) + actual = factor.evaluateError(pose1, pose2) + self.gtsamAssertEquals(actual, expected, 1e-8) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_Cal3Fisheye.py b/python/gtsam/tests/test_Cal3Fisheye.py index 298c6e57b..e54afc757 100644 --- a/python/gtsam/tests/test_Cal3Fisheye.py +++ b/python/gtsam/tests/test_Cal3Fisheye.py @@ -17,6 +17,15 @@ import gtsam from gtsam.utils.test_case import GtsamTestCase from gtsam.symbol_shorthand import K, L, P + +def ulp(ftype=np.float64): + """ + Unit in the last place of floating point datatypes + """ + f = np.finfo(ftype) + return f.tiny / ftype(1 << f.nmant) + + class TestCal3Fisheye(GtsamTestCase): @classmethod @@ -105,6 +114,71 @@ class TestCal3Fisheye(GtsamTestCase): score = graph.error(state) self.assertAlmostEqual(score, 0) + def test_jacobian_on_axis(self): + """Check of jacobian at optical axis""" + obj_point_on_axis = np.array([0, 0, 1]) + img_point = np.array([0, 0]) + f, z, H = self.evaluate_jacobian(obj_point_on_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + def test_jacobian_convergence(self): + """Test stability of jacobian close to optical axis""" + t = ulp(np.float64) + obj_point_close_to_axis = np.array([t, 0, 1]) + img_point = np.array([np.sqrt(t), 0]) + f, z, H = self.evaluate_jacobian(obj_point_close_to_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + # With a height of sqrt(ulp), this may cause an overflow + t = ulp(np.float64) + obj_point_close_to_axis = np.array([np.sqrt(t), 0, 1]) + img_point = np.array([np.sqrt(t), 0]) + f, z, H = self.evaluate_jacobian(obj_point_close_to_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + def test_scaling_factor(self): + """Check convergence of atan2(r, z)/r ~ 1/z for small r""" + r = ulp(np.float64) + s = np.arctan(r) / r + self.assertEqual(s, 1.0) + z = 1 + s = self.scaling_factor(r, z) + self.assertEqual(s, 1.0/z) + z = 2 + s = self.scaling_factor(r, z) + self.assertEqual(s, 1.0/z) + s = self.scaling_factor(2*r, z) + self.assertEqual(s, 1.0/z) + + @staticmethod + def scaling_factor(r, z): + """Projection factor theta/r for equidistant fisheye lens model""" + return np.arctan2(r, z) / r if r/z != 0 else 1.0/z + + @staticmethod + def evaluate_jacobian(obj_point, img_point): + """Evaluate jacobian at given object point""" + pose = gtsam.Pose3() + camera = gtsam.Cal3Fisheye() + state = gtsam.Values() + camera_key, pose_key, landmark_key = K(0), P(0), L(0) + state.insert_point3(landmark_key, obj_point) + state.insert_pose3(pose_key, pose) + g = gtsam.NonlinearFactorGraph() + noise_model = gtsam.noiseModel.Unit.Create(2) + factor = gtsam.GenericProjectionFactorCal3Fisheye(img_point, noise_model, pose_key, landmark_key, camera) + g.add(factor) + f = g.error(state) + gaussian_factor_graph = g.linearize(state) + H, z = gaussian_factor_graph.jacobian() + return f, z, H + @unittest.skip("triangulatePoint3 currently seems to require perspective projections.") def test_triangulation_skipped(self): """Estimate spatial point from image measurements""" diff --git a/python/gtsam/tests/test_Cal3Unified.py b/python/gtsam/tests/test_Cal3Unified.py index dab1ae446..630109d66 100644 --- a/python/gtsam/tests/test_Cal3Unified.py +++ b/python/gtsam/tests/test_Cal3Unified.py @@ -117,6 +117,39 @@ class TestCal3Unified(GtsamTestCase): score = graph.error(state) self.assertAlmostEqual(score, 0) + def test_jacobian(self): + """Evaluate jacobian at optical axis""" + obj_point_on_axis = np.array([0, 0, 1]) + img_point = np.array([0.0, 0.0]) + pose = gtsam.Pose3() + camera = gtsam.Cal3Unified() + state = gtsam.Values() + camera_key, pose_key, landmark_key = K(0), P(0), L(0) + state.insert_cal3unified(camera_key, camera) + state.insert_point3(landmark_key, obj_point_on_axis) + state.insert_pose3(pose_key, pose) + g = gtsam.NonlinearFactorGraph() + noise_model = gtsam.noiseModel.Unit.Create(2) + factor = gtsam.GeneralSFMFactor2Cal3Unified(img_point, noise_model, pose_key, landmark_key, camera_key) + g.add(factor) + f = g.error(state) + gaussian_factor_graph = g.linearize(state) + H, z = gaussian_factor_graph.jacobian() + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 4*np.eye(2)) + + Dcal = np.zeros((2, 10), order='F') + Dp = np.zeros((2, 2), order='F') + camera.calibrate(img_point, Dcal, Dp) + + self.gtsamAssertEquals(Dcal, np.array( + [[ 0., 0., 0., -1., 0., 0., 0., 0., 0., 0.], + [ 0., 0., 0., 0., -1., 0., 0., 0., 0., 0.]])) + self.gtsamAssertEquals(Dp, np.array( + [[ 1., -0.], + [-0., 1.]])) + @unittest.skip("triangulatePoint3 currently seems to require perspective projections.") def test_triangulation(self): """Estimate spatial point from image measurements""" diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py new file mode 100644 index 000000000..0499e7215 --- /dev/null +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -0,0 +1,98 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for DecisionTreeFactors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering +from gtsam.utils.test_case import GtsamTestCase + + +class TestDecisionTreeFactor(GtsamTestCase): + """Tests for DecisionTreeFactors.""" + + def setUp(self): + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") + + def test_enumerate(self): + actual = self.factor.enumerate() + _, values = zip(*actual) + self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscreteDistribution, i.e., Bayes Law! + prior = DiscreteDistribution(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + expected = \ + "|A|B|value|\n" \ + "|:-:|:-:|:-:|\n" \ + "|0|0|1|\n" \ + "|0|1|2|\n" \ + "|1|0|3|\n" \ + "|1|1|4|\n" \ + "|2|0|5|\n" \ + "|2|1|6|\n" + + def formatter(x: int): + return "A" if x == 12 else "B" + + actual = self.factor._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 000000000..74191dcc7 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,166 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest +import textwrap + +import gtsam +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + +# Some keys: +Asia = (0, 2) +Smoking = (4, 2) +Tuberculosis = (3, 2) +LungCancer = (6, 2) + +Bronchitis = (7, 2) +Either = (5, 2) +XRay = (2, 2) +Dyspnea = (1, 2) + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + asia = DiscreteBayesNet() + asia.add(Asia, "99/1") + asia.add(Smoking, "50/50") + + asia.add(Tuberculosis, [Asia], "99/1 95/5") + asia.add(LungCancer, [Smoking], "99/1 90/10") + asia.add(Bronchitis, [Smoking], "70/30 40/60") + + asia.add(Either, [Tuberculosis, LungCancer], "F T T T") + + asia.add(XRay, [Either], "95/5 2/98") + asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscreteDistribution(Bronchitis, "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = fg.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + actualMPE2 = fg.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + chordal2 = fg.eliminateSequential(ordering) + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + def test_fragment(self): + """Test sampling and optimizing for Asia fragment.""" + + # Create a reverse-topologically sorted fragment: + fragment = DiscreteBayesNet() + fragment.add(Either, [Tuberculosis, LungCancer], "F T T T") + fragment.add(Tuberculosis, [Asia], "99/1 95/5") + fragment.add(LungCancer, [Smoking], "99/1 90/10") + + # Create assignment with missing values: + given = DiscreteValues() + for key in [Asia, Smoking]: + given[key[0]] = 0 + + # Now sample from fragment: + actual = fragment.sample(given) + self.assertEqual(len(actual), 5) + + def test_dot(self): + """Check that dot works with position hints.""" + fragment = DiscreteBayesNet() + fragment.add(Either, [Tuberculosis, LungCancer], "F T T T") + MyAsia = gtsam.symbol('a', 0), 2 # use a symbol! + fragment.add(Tuberculosis, [MyAsia], "99/1 95/5") + fragment.add(LungCancer, [Smoking], "99/1 90/10") + + # Make sure we can *update* position hints + writer = gtsam.DotWriter() + ph: dict = writer.positionHints + ph.update({'a': 2}) # hint at symbol position + writer.positionHints = ph + + # Check the output of dot + actual = fragment.dot(writer=writer) + expected_result = """\ + digraph { + size="5,5"; + + var3[label="3"]; + var4[label="4"]; + var5[label="5"]; + var6[label="6"]; + vara0[label="a0", pos="0,2!"]; + + var4->var6 + vara0->var3 + var3->var5 + var6->var5 + }""" + self.assertEqual(actual, textwrap.dedent(expected_result)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesTree.dot b/python/gtsam/tests/test_DiscreteBayesTree.dot new file mode 100644 index 000000000..d7cf7d9bc --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.dot @@ -0,0 +1,25 @@ +digraph G{ +0[label="8,12,14"]; +0->1 +1[label="0 : 8,12"]; +0->2 +2[label="1 : 8,12"]; +0->3 +3[label="9 : 12,14"]; +3->4 +4[label="2 : 9,12"]; +3->5 +5[label="3 : 9,12"]; +0->6 +6[label="10,13 : 14"]; +6->7 +7[label="4 : 10,13"]; +6->8 +8[label="5 : 10,13"]; +6->9 +9[label="11 : 13,14"]; +9->10 +10[label="6 : 11,13"]; +9->11 +11[label="7 : 11,13"]; +} \ No newline at end of file diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py new file mode 100644 index 000000000..b1ed4fe69 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -0,0 +1,79 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes trees. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, + DiscreteConditional, DiscreteFactorGraph, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_elimination(self): + """Test Multifrontal elimination.""" + + # Define DiscreteKey pairs. + keys = [(j, 2) for j in range(15)] + + # Create thin-tree Bayesnet. + bayesNet = DiscreteBayesNet() + + bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") + bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") + bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[12], [keys[14]], "3/1 3/1") + bayesNet.add(keys[13], [keys[14]], "1/3 3/1") + + bayesNet.add(keys[14], "1/3") + + # Create a factor graph out of the Bayes net. + factorGraph = DiscreteFactorGraph(bayesNet) + + # Create a BayesTree out of the factor graph. + ordering = Ordering() + for j in range(15): + ordering.push_back(j) + bayesTree = factorGraph.eliminateMultifrontal(ordering) + + # Uncomment these for visualization: + # print(bayesTree) + # for key in range(15): + # bayesTree[key].printSignature() + # bayesTree.saveGraph("test_DiscreteBayesTree.dot") + + self.assertFalse(bayesTree.empty()) + self.assertEqual(12, bayesTree.size()) + + # The root is P( 8 12 14), we can retrieve it by key: + root = bayesTree[8] + self.assertIsInstance(root, DiscreteBayesTreeClique) + self.assertTrue(root.isRoot()) + self.assertIsInstance(root.conditional(), DiscreteConditional) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py new file mode 100644 index 000000000..241a5f0be --- /dev/null +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -0,0 +1,124 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Conditionals. +Author: Varun Agrawal +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + + +class TestDiscreteConditional(GtsamTestCase): + """Tests for Discrete Conditionals.""" + + def test_single_value_versions(self): + X = (0, 2) + Y = (1, 3) + conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") + + actual0 = conditional.likelihood(0) + expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") + self.gtsamAssertEquals(actual0, expected0, 1e-9) + + actual1 = conditional.likelihood(1) + expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") + self.gtsamAssertEquals(actual1, expected1, 1e-9) + + actual = conditional.sample(2) + self.assertIsInstance(actual, int) + + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A, "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + A = (2, 2) + B = (1, 2) + C = (0, 3) + parents = DiscreteKeys() + parents.push_back(B) + parents.push_back(C) + + conditional = DiscreteConditional(A, parents, + "0/1 1/3 1/1 3/1 0/1 1/0") + expected = " *P(A|B,C):*\n\n" \ + "|*B*|*C*|0|1|\n" \ + "|:-:|:-:|:-:|:-:|\n" \ + "|0|0|0|1|\n" \ + "|0|1|0.25|0.75|\n" \ + "|0|2|0.5|0.5|\n" \ + "|1|0|0.75|0.25|\n" \ + "|1|1|0|1|\n" \ + "|1|2|1|0|\n" + + def formatter(x: int): + names = ["C", "B", "A"] + return names[x] + + actual = conditional._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteDistribution.py b/python/gtsam/tests/test_DiscreteDistribution.py new file mode 100644 index 000000000..3986bf2df --- /dev/null +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -0,0 +1,69 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Priors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +import numpy as np +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution +from gtsam.utils.test_case import GtsamTestCase + +X = 0, 2 + + +class TestDiscreteDistribution(GtsamTestCase): + """Tests for Discrete Priors.""" + + def test_constructor(self): + """Test various constructors.""" + keys = DiscreteKeys() + keys.push_back(X) + f = DecisionTreeFactor(keys, "0.4 0.6") + expected = DiscreteDistribution(f) + + actual = DiscreteDistribution(X, "2/3") + self.gtsamAssertEquals(actual, expected) + + actual2 = DiscreteDistribution(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) + + def test_operator(self): + prior = DiscreteDistribution(X, "2/3") + self.assertAlmostEqual(prior(0), 0.4) + self.assertAlmostEqual(prior(1), 0.6) + + def test_pmf(self): + prior = DiscreteDistribution(X, "2/3") + expected = np.array([0.4, 0.6]) + np.testing.assert_allclose(expected, prior.pmf()) + + def test_sample(self): + prior = DiscreteDistribution(X, "2/3") + actual = prior.sample() + self.assertIsInstance(actual, int) + + def test_markdown(self): + """Test the _repr_markdown_ method.""" + + prior = DiscreteDistribution(X, "2/3") + expected = " *P(0):*\n\n" \ + "|0|value|\n" \ + "|:-:|:-:|\n" \ + "|0|0.4|\n" \ + "|1|0.6|\n" \ + + actual = prior._repr_markdown_() + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 000000000..ef85fc753 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,160 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering +from gtsam.utils.test_case import GtsamTestCase + +OrderingType = Ordering.OrderingType + + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + """Test constructing and evaluating a discrete factor graph.""" + + # Three keys + P1 = (0, 2) + P2 = (1, 2) + P3 = (2, 3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() + + # Add two unary factors (priors) + graph.add(P1, [0.9, 0.3]) + graph.add(P2, "0.9 0.6") + + # Add a binary factor + graph.add([P1, P2], "4 1 10 4") + + # Instantiate Values + assignment = DiscreteValues() + assignment[0] = 1 + assignment[1] = 1 + + # Check if graph evaluation works ( 0.3*0.6*4 ) + self.assertAlmostEqual(.72, graph(assignment)) + + # Create a new test with third node and adding unary and ternary factor + graph.add(P3, "0.9 0.2 0.5") + keys = DiscreteKeys() + keys.push_back(P1) + keys.push_back(P2) + keys.push_back(P3) + graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") + + # Below assignment selects the 8th index in the ternary factor table + assignment[0] = 1 + assignment[1] = 0 + assignment[2] = 1 + + # Check if graph evaluation works (0.3*0.9*1*0.2*8) + self.assertAlmostEqual(4.32, graph(assignment)) + + # Below assignment selects the 3rd index in the ternary factor table + assignment[0] = 0 + assignment[1] = 1 + assignment[2] = 0 + + # Check if graph evaluation works (0.9*0.6*1*0.9*4) + self.assertAlmostEqual(1.944, graph(assignment)) + + # Check if graph product works + product = graph.product() + self.assertAlmostEqual(1.944, product(assignment)) + + def test_optimize(self): + """Test constructing and optizing a discrete factor graph.""" + + # Three keys + C = (0, 2) + B = (1, 2) + A = (2, 2) + + # A simple factor graph (A)-fAC-(C)-fBC-(B) + # with smoothness priors + graph = DiscreteFactorGraph() + graph.add([A, C], "3 1 1 3") + graph.add([C, B], "3 1 1 3") + + # Test optimization + expectedValues = DiscreteValues() + expectedValues[0] = 0 + expectedValues[1] = 0 + expectedValues[2] = 0 + actualValues = graph.optimize() + self.assertEqual(list(actualValues.items()), + list(expectedValues.items())) + + def test_MPE(self): + """Test maximum probable explanation (MPE): same as optimize.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use maxProduct + dag = graph.maxProduct(OrderingType.COLAMD) + actualMPE = dag.argmax() + self.assertEqual(list(actualMPE.items()), + list(mpe.items())) + + # All in one + actualMPE2 = graph.optimize() + self.assertEqual(list(actualMPE2.items()), + list(mpe.items())) + + def test_sumProduct(self): + """Test sumProduct.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use default sumProduct + bayesNet = graph.sumProduct() + mpeProbability = bayesNet(mpe) + self.assertAlmostEqual(mpeProbability, 0.36) # regression + + # Use sumProduct + for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, + OrderingType.CUSTOM]: + bayesNet = graph.sumProduct(ordering_type) + self.assertEqual(bayesNet(mpe), mpeProbability) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py new file mode 100644 index 000000000..e4d396cfe --- /dev/null +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -0,0 +1,53 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Gaussian Bayes Nets. +Author: Frank Dellaert +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam import GaussianBayesNet, GaussianConditional +from gtsam.utils.test_case import GtsamTestCase + +# some keys +_x_ = 11 +_y_ = 22 +_z_ = 33 + + +def smallBayesNet(): + """Create a small Bayes Net for testing""" + bayesNet = GaussianBayesNet() + I_1x1 = np.eye(1, dtype=float) + bayesNet.push_back(GaussianConditional( + _x_, [9.0], I_1x1, _y_, I_1x1)) + bayesNet.push_back(GaussianConditional(_y_, [5.0], I_1x1)) + return bayesNet + + +class TestGaussianBayesNet(GtsamTestCase): + """Tests for Gaussian Bayes nets.""" + + def test_matrix(self): + """Test matrix method""" + R, d = smallBayesNet().matrix() # get matrix and RHS + R1 = np.array([ + [1.0, 1.0], + [0.0, 1.0]]) + d1 = np.array([9.0, 5.0]) + np.testing.assert_equal(R, R1) + np.testing.assert_equal(d, d1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/gtsam/tests/test_GaussianFactorGraph.py b/python/gtsam/tests/test_GaussianFactorGraph.py index a29b0f263..09ac4c564 100644 --- a/python/gtsam/tests/test_GaussianFactorGraph.py +++ b/python/gtsam/tests/test_GaussianFactorGraph.py @@ -23,23 +23,23 @@ from gtsam.utils.test_case import GtsamTestCase def create_graph(): """Create a basic linear factor graph for testing""" graph = gtsam.GaussianFactorGraph() - + x0 = X(0) x1 = X(1) x2 = X(2) - + BETWEEN_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.ones(1)) PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.ones(1)) graph.add(x1, np.eye(1), x0, -np.eye(1), np.ones(1), BETWEEN_NOISE) - graph.add(x2, np.eye(1), x1, -np.eye(1), 2*np.ones(1), BETWEEN_NOISE) + graph.add(x2, np.eye(1), x1, -np.eye(1), 2 * np.ones(1), BETWEEN_NOISE) graph.add(x0, np.eye(1), np.zeros(1), PRIOR_NOISE) return graph, (x0, x1, x2) + class TestGaussianFactorGraph(GtsamTestCase): """Tests for Gaussian Factor Graphs.""" - def test_fg(self): """Test solving a linear factor graph""" graph, X = create_graph() @@ -71,7 +71,7 @@ class TestGaussianFactorGraph(GtsamTestCase): self.assertAlmostEqual(EXPECTEDM[0], m[0], delta=1e-8) self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8) self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8) - + def test_linearMarginalization(self): """Marginalize a linear factor graph""" graph, X = create_graph() @@ -88,5 +88,23 @@ class TestGaussianFactorGraph(GtsamTestCase): self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8) self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8) + def test_ordering(self): + """Test ordering""" + gfg, keys = create_graph() + ordering = gtsam.Ordering() + for key in keys[::-1]: + ordering.push_back(key) + + bn = gfg.eliminateSequential(ordering) + self.assertEqual(bn.size(), 3) + + keyVector = gtsam.KeyVector() + keyVector.append(keys[2]) + #TODO(Varun) Below code causes segfault in Debug config + # ordering = gtsam.Ordering.ColamdConstrainedLastGaussianFactorGraph(gfg, keyVector) + # bn = gfg.eliminateSequential(ordering) + # self.assertEqual(bn.size(), 3) + + if __name__ == '__main__': unittest.main() diff --git a/python/gtsam/tests/test_GraphvizFormatting.py b/python/gtsam/tests/test_GraphvizFormatting.py new file mode 100644 index 000000000..ecdc23b45 --- /dev/null +++ b/python/gtsam/tests/test_GraphvizFormatting.py @@ -0,0 +1,135 @@ +""" +See LICENSE for the license information + +Unit tests for Graphviz formatting of NonlinearFactorGraph. +Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059) +""" + +# pylint: disable=no-member, invalid-name + +import unittest +import textwrap + +import numpy as np + +import gtsam +from gtsam.utils.test_case import GtsamTestCase + + +class TestGraphvizFormatting(GtsamTestCase): + """Tests for saving NonlinearFactorGraph to GraphViz format.""" + + def setUp(self): + self.graph = gtsam.NonlinearFactorGraph() + + odometry = gtsam.Pose2(2.0, 0.0, 0.0) + odometryNoise = gtsam.noiseModel.Diagonal.Sigmas( + np.array([0.2, 0.2, 0.1])) + self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise)) + self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise)) + + self.values = gtsam.Values() + self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.)) + self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.)) + self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.)) + + def test_default(self): + """Test with default GraphvizFormatting""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + self.assertEqual(self.graph.dot(self.values), + textwrap.dedent(expected_result)) + + def test_swapped_axes(self): + """Test with user-defined GraphvizFormatting swapping x and y""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="2,0!"]; + var2[label="2", pos="4,0!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X + graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y + self.assertEqual(self.graph.dot(self.values, + writer=graphviz_formatting), + textwrap.dedent(expected_result)) + + def test_factor_points(self): + """Test with user-defined GraphvizFormatting without factor points""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + var0--var1; + var1--var2; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.plotFactorPoints = False + + self.assertEqual(self.graph.dot(self.values, + writer=graphviz_formatting), + textwrap.dedent(expected_result)) + + def test_width_height(self): + """Test with user-defined GraphvizFormatting for width and height""" + expected_result = """\ + graph { + size="20,10"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.figureWidthInches = 20 + graphviz_formatting.figureHeightInches = 10 + + self.assertEqual(self.graph.dot(self.values, + writer=graphviz_formatting), + textwrap.dedent(expected_result)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_KarcherMeanFactor.py b/python/gtsam/tests/test_KarcherMeanFactor.py index a315a506c..f4ec64283 100644 --- a/python/gtsam/tests/test_KarcherMeanFactor.py +++ b/python/gtsam/tests/test_KarcherMeanFactor.py @@ -15,27 +15,15 @@ import unittest import gtsam import numpy as np +from gtsam import Rot3 from gtsam.utils.test_case import GtsamTestCase KEY = 0 MODEL = gtsam.noiseModel.Unit.Create(3) -def find_Karcher_mean_Rot3(rotations): - """Find the Karcher mean of given values.""" - # Cost function C(R) = \sum PriorFactor(R_i)::error(R) - # No closed form solution. - graph = gtsam.NonlinearFactorGraph() - for R in rotations: - graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) - initial = gtsam.Values() - initial.insert(KEY, gtsam.Rot3()) - result = gtsam.GaussNewtonOptimizer(graph, initial).optimize() - return result.atRot3(KEY) - - # Rot3 version -R = gtsam.Rot3.Expmap(np.array([0.1, 0, 0])) +R = Rot3.Expmap(np.array([0.1, 0, 0])) class TestKarcherMean(GtsamTestCase): @@ -43,11 +31,23 @@ class TestKarcherMean(GtsamTestCase): def test_find(self): # Check that optimizing for Karcher mean (which minimizes Between distance) # gets correct result. - rotations = {R, R.inverse()} - expected = gtsam.Rot3() - actual = find_Karcher_mean_Rot3(rotations) + rotations = gtsam.Rot3Vector([R, R.inverse()]) + expected = Rot3() + actual = gtsam.FindKarcherMean(rotations) self.gtsamAssertEquals(expected, actual) + def test_find_karcher_mean_identity(self): + """Averaging 3 identity rotations should yield the identity.""" + a1Rb1 = Rot3() + a2Rb2 = Rot3() + a3Rb3 = Rot3() + + aRb_list = gtsam.Rot3Vector([a1Rb1, a2Rb2, a3Rb3]) + aRb_expected = Rot3() + + aRb = gtsam.FindKarcherMean(aRb_list) + self.gtsamAssertEquals(aRb, aRb_expected) + def test_factor(self): """Check that the InnerConstraint factor leaves the mean unchanged.""" # Make a graph with two variables, one between, and one InnerConstraint @@ -66,11 +66,11 @@ class TestKarcherMean(GtsamTestCase): initial = gtsam.Values() initial.insert(1, R.inverse()) initial.insert(2, R) - expected = find_Karcher_mean_Rot3([R, R.inverse()]) + expected = Rot3() result = gtsam.GaussNewtonOptimizer(graph, initial).optimize() - actual = find_Karcher_mean_Rot3( - [result.atRot3(1), result.atRot3(2)]) + actual = gtsam.FindKarcherMean( + gtsam.Rot3Vector([result.atRot3(1), result.atRot3(2)])) self.gtsamAssertEquals(expected, actual) self.gtsamAssertEquals( R12, result.atRot3(1).between(result.atRot3(2))) diff --git a/python/gtsam/tests/test_PinholeCamera.py b/python/gtsam/tests/test_PinholeCamera.py new file mode 100644 index 000000000..392d48d3f --- /dev/null +++ b/python/gtsam/tests/test_PinholeCamera.py @@ -0,0 +1,46 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +PinholeCamera unit tests. +Author: Fan Jiang +""" +import unittest +from math import pi + +import numpy as np + +import gtsam +from gtsam.utils.test_case import GtsamTestCase + + +class TestPinholeCamera(GtsamTestCase): + """ + Tests if we can correctly get the camera Jacobians in Python + """ + def test_jacobian(self): + cam1 = gtsam.PinholeCameraCal3Bundler() + + # order is important because Eigen is column major! + Dpose = np.zeros((2, 6), order='F') + Dpoint = np.zeros((2, 3), order='F') + Dcal = np.zeros((2, 3), order='F') + cam1.project(np.array([1, 1, 1]), Dpose, Dpoint, Dcal) + + self.gtsamAssertEquals(Dpoint, np.array([[1, 0, -1], [0, 1, -1]])) + + self.gtsamAssertEquals( + Dpose, + np.array([ + [1., -2., 1., -1., 0., 1.], # + [2., -1., -1., 0., -1., 1.] + ])) + + self.gtsamAssertEquals(Dcal, np.array([[1., 2., 4.], [1., 2., 4.]])) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_Pose2.py b/python/gtsam/tests/test_Pose2.py index e5ffbad7d..d3a51d638 100644 --- a/python/gtsam/tests/test_Pose2.py +++ b/python/gtsam/tests/test_Pose2.py @@ -8,11 +8,11 @@ See LICENSE for the license information Pose2 unit tests. Author: Frank Dellaert & Duy Nguyen Ta & John Lambert """ +import math import unittest -import numpy as np - import gtsam +import numpy as np from gtsam import Point2, Point2Pairs, Pose2 from gtsam.utils.test_case import GtsamTestCase @@ -26,6 +26,34 @@ class TestPose2(GtsamTestCase): actual = Pose2.adjoint_(xi, xi) np.testing.assert_array_equal(actual, expected) + def test_transformTo(self): + """Test transformTo method.""" + pose = Pose2(2, 4, -math.pi/2) + actual = pose.transformTo(Point2(3, 2)) + expected = Point2(2, 1) + self.gtsamAssertEquals(actual, expected, 1e-6) + + # multi-point version + points = np.stack([Point2(3, 2), Point2(3, 2)]).T + actual_array = pose.transformTo(points) + self.assertEqual(actual_array.shape, (2, 2)) + expected_array = np.stack([expected, expected]).T + np.testing.assert_allclose(actual_array, expected_array, atol=1e-6) + + def test_transformFrom(self): + """Test transformFrom method.""" + pose = Pose2(2, 4, -math.pi/2) + actual = pose.transformFrom(Point2(2, 1)) + expected = Point2(3, 2) + self.gtsamAssertEquals(actual, expected, 1e-6) + + # multi-point version + points = np.stack([Point2(2, 1), Point2(2, 1)]).T + actual_array = pose.transformFrom(points) + self.assertEqual(actual_array.shape, (2, 2)) + expected_array = np.stack([expected, expected]).T + np.testing.assert_allclose(actual_array, expected_array, atol=1e-6) + def test_align(self) -> None: """Ensure estimation of the Pose2 element to align two 2d point clouds succeeds. @@ -42,27 +70,36 @@ class TestPose2(GtsamTestCase): O---O """ pts_a = [ - Point2(3, 1), - Point2(1, 1), - Point2(1, 3), - Point2(3, 3), - ] - pts_b = [ Point2(1, -3), Point2(1, -5), Point2(-1, -5), Point2(-1, -3), ] + pts_b = [ + Point2(3, 1), + Point2(1, 1), + Point2(1, 3), + Point2(3, 3), + ] # fmt: on ab_pairs = Point2Pairs(list(zip(pts_a, pts_b))) - bTa = gtsam.align(ab_pairs) - aTb = bTa.inverse() - assert aTb is not None + aTb = Pose2.Align(ab_pairs) + self.assertIsNotNone(aTb) for pt_a, pt_b in zip(pts_a, pts_b): pt_a_ = aTb.transformFrom(pt_b) - assert np.allclose(pt_a, pt_a_) + np.testing.assert_allclose(pt_a, pt_a_) + + # Matrix version + A = np.array(pts_a).T + B = np.array(pts_b).T + aTb = Pose2.Align(A, B) + self.assertIsNotNone(aTb) + + for pt_a, pt_b in zip(pts_a, pts_b): + pt_a_ = aTb.transformFrom(pt_b) + np.testing.assert_allclose(pt_a, pt_a_) if __name__ == "__main__": diff --git a/python/gtsam/tests/test_Pose3.py b/python/gtsam/tests/test_Pose3.py index e07b904a9..cde71de53 100644 --- a/python/gtsam/tests/test_Pose3.py +++ b/python/gtsam/tests/test_Pose3.py @@ -15,7 +15,7 @@ import unittest import numpy as np import gtsam -from gtsam import Point3, Pose3, Rot3 +from gtsam import Point3, Pose3, Rot3, Point3Pairs from gtsam.utils.test_case import GtsamTestCase @@ -30,13 +30,34 @@ class TestPose3(GtsamTestCase): actual = T2.between(T3) self.gtsamAssertEquals(actual, expected, 1e-6) - def test_transform_to(self): + def test_transformTo(self): """Test transformTo method.""" - transform = Pose3(Rot3.Rodrigues(0, 0, -1.570796), Point3(2, 4, 0)) - actual = transform.transformTo(Point3(3, 2, 10)) + pose = Pose3(Rot3.Rodrigues(0, 0, -math.pi/2), Point3(2, 4, 0)) + actual = pose.transformTo(Point3(3, 2, 10)) expected = Point3(2, 1, 10) self.gtsamAssertEquals(actual, expected, 1e-6) + # multi-point version + points = np.stack([Point3(3, 2, 10), Point3(3, 2, 10)]).T + actual_array = pose.transformTo(points) + self.assertEqual(actual_array.shape, (3, 2)) + expected_array = np.stack([expected, expected]).T + np.testing.assert_allclose(actual_array, expected_array, atol=1e-6) + + def test_transformFrom(self): + """Test transformFrom method.""" + pose = Pose3(Rot3.Rodrigues(0, 0, -math.pi/2), Point3(2, 4, 0)) + actual = pose.transformFrom(Point3(2, 1, 10)) + expected = Point3(3, 2, 10) + self.gtsamAssertEquals(actual, expected, 1e-6) + + # multi-point version + points = np.stack([Point3(2, 1, 10), Point3(2, 1, 10)]).T + actual_array = pose.transformFrom(points) + self.assertEqual(actual_array.shape, (3, 2)) + expected_array = np.stack([expected, expected]).T + np.testing.assert_allclose(actual_array, expected_array, atol=1e-6) + def test_range(self): """Test range method.""" l1 = Point3(1, 0, 0) @@ -59,8 +80,16 @@ class TestPose3(GtsamTestCase): self.assertEqual(math.sqrt(2.0), x1.range(pose=xl2)) def test_adjoint(self): - """Test adjoint method.""" + """Test adjoint methods.""" + T = Pose3() xi = np.array([1, 2, 3, 4, 5, 6]) + # test calling functions + T.AdjointMap() + T.Adjoint(xi) + T.AdjointTranspose(xi) + Pose3.adjointMap(xi) + Pose3.adjoint(xi, xi) + # test correctness of adjoint(x, y) expected = np.dot(Pose3.adjointMap_(xi), xi) actual = Pose3.adjoint_(xi, xi) np.testing.assert_array_equal(actual, expected) @@ -73,6 +102,24 @@ class TestPose3(GtsamTestCase): actual.deserialize(serialized) self.gtsamAssertEquals(expected, actual, 1e-10) + def test_align_squares(self): + """Test if Align method can align 2 squares.""" + square = np.array([[0,0,0],[0,1,0],[1,1,0],[1,0,0]], float).T + sTt = Pose3(Rot3.Rodrigues(0, 0, -math.pi), Point3(2, 4, 0)) + transformed = sTt.transformTo(square) + + st_pairs = Point3Pairs() + for j in range(4): + st_pairs.append((square[:,j], transformed[:,j])) + + # Recover the transformation sTt + estimated_sTt = Pose3.Align(st_pairs) + self.gtsamAssertEquals(estimated_sTt, sTt, 1e-10) + + # Matrix version + estimated_sTt = Pose3.Align(square, transformed) + self.gtsamAssertEquals(estimated_sTt, sTt, 1e-10) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_Rot3.py b/python/gtsam/tests/test_Rot3.py new file mode 100644 index 000000000..a1ce01ba2 --- /dev/null +++ b/python/gtsam/tests/test_Rot3.py @@ -0,0 +1,2037 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved +See LICENSE for the license information +Rot3 unit tests. +Author: John Lambert +""" +# pylint: disable=no-name-in-module + +import unittest + +import numpy as np + +import gtsam +from gtsam import Rot3 +from gtsam.utils.test_case import GtsamTestCase + + +R1_R2_pairs = [ + ( + [ + [0.994283, -0.10356, 0.0260251], + [0.102811, 0.994289, 0.0286205], + [-0.0288404, -0.0257812, 0.999251], + ], + [ + [-0.994235, 0.0918291, -0.0553602], + [-0.0987317, -0.582632, 0.806718], + [0.0418251, 0.807532, 0.588339], + ], + ), + ( + [ + [0.999823, -0.000724729, 0.0187896], + [0.00220672, 0.996874, -0.0789728], + [-0.0186736, 0.0790003, 0.9967], + ], + [ + [-0.99946, -0.0155217, -0.0289749], + [-0.0306159, 0.760422, 0.648707], + [0.0119641, 0.649244, -0.760487], + ], + ), + ( + [ + [0.999976, 0.00455542, -0.00529608], + [-0.00579633, 0.964214, -0.265062], + [0.00389908, 0.265086, 0.964217], + ], + [ + [-0.999912, -0.0123323, -0.00489179], + [-0.00739095, 0.21159, 0.977331], + [-0.0110179, 0.977281, -0.211663], + ], + ), + ( + [ + [0.998801, 0.0449026, 0.019479], + [-0.0448727, 0.998991, -0.00197348], + [-0.0195479, 0.00109704, 0.999808], + ], + [ + [-0.999144, -0.0406154, -0.00800012], + [0.0406017, -0.999174, 0.00185875], + [-0.00806909, 0.00153352, 0.999966], + ], + ), + ( + [ + [0.587202, 0.00034062, -0.80944], + [0.394859, 0.872825, 0.286815], + [0.706597, -0.488034, 0.51239], + ], + [ + [-0.999565, -0.028095, -0.00905389], + [0.0192863, -0.853838, 0.520182], + [-0.0223455, 0.519782, 0.854007], + ], + ), + ( + [ + [0.998798, 0.0370584, 0.0320815], + [-0.0355966, 0.998353, -0.0449943], + [-0.033696, 0.0437982, 0.998472], + ], + [ + [-0.999942, -0.010745, -0.00132538], + [-0.000998705, -0.0304045, 0.999538], + [-0.0107807, 0.999481, 0.0303914], + ], + ), + ( + [ + [0.998755, 0.00708291, -0.0493744], + [-0.00742097, 0.99995, -0.00666709], + [0.0493247, 0.0070252, 0.998758], + ], + [ + [-0.998434, 0.0104672, 0.0549825], + [0.0115323, 0.999751, 0.0190859], + [-0.0547691, 0.01969, -0.998307], + ], + ), + ( + [ + [0.990471, 0.0997485, -0.0949595], + [-0.117924, 0.970427, -0.210631], + [0.0711411, 0.219822, 0.972943], + ], + [ + [-0.99192, -0.125627, 0.0177888], + [0.126478, -0.967866, 0.217348], + [-0.0100874, 0.217839, 0.975933], + ], + ), + ( + [ + [-0.780894, -0.578319, -0.236116], + [0.34478, -0.0838381, -0.934932], + [0.520894, -0.811491, 0.264862], + ], + [ + [-0.99345, 0.00261746, -0.114244], + [-0.112503, 0.152922, 0.981815], + [0.0200403, 0.988236, -0.151626], + ], + ), + ( + [ + [0.968425, 0.0466097, 0.244911], + [-0.218867, 0.629346, 0.745668], + [-0.119378, -0.775726, 0.619676], + ], + [ + [-0.971208, 0.00666431, -0.238143], + [0.0937886, 0.929584, -0.35648], + [0.218998, -0.368551, -0.903444], + ], + ), + ( + [ + [0.998512, 0.0449168, -0.0309146], + [-0.0344032, 0.958823, 0.281914], + [0.0423043, -0.280431, 0.958941], + ], + [ + [-0.999713, 0.00732431, 0.0228168], + [-0.00759688, 0.806166, -0.59164], + [-0.0227275, -0.591644, -0.805879], + ], + ), + ( + [ + [0.981814, 0.00930728, 0.189617], + [-0.0084101, 0.999949, -0.00553563], + [-0.189659, 0.00384026, 0.981843], + ], + [ + [-0.981359, 0.00722349, -0.192051], + [0.00148564, 0.999549, 0.0300036], + [0.192182, 0.0291591, -0.980927], + ], + ), + ( + [ + [0.972544, -0.215591, 0.0876242], + [0.220661, 0.973915, -0.0529018], + [-0.0739333, 0.0707846, 0.994748], + ], + [ + [-0.971294, 0.215675, -0.100371], + [-0.23035, -0.747337, 0.62324], + [0.0594069, 0.628469, 0.775564], + ], + ), + ( + [ + [0.989488, 0.0152447, 0.143808], + [-0.0160974, 0.999859, 0.00476753], + [-0.143715, -0.00703235, 0.989594], + ], + [ + [-0.988492, 0.0124362, -0.150766], + [0.00992423, 0.999799, 0.0174037], + [0.150952, 0.0157072, -0.988417], + ], + ), + ( + [ + [0.99026, 0.109934, -0.0854388], + [-0.0973012, 0.985345, 0.140096], + [0.099588, -0.130418, 0.986445], + ], + [ + [-0.994239, 0.0206112, 0.1052], + [0.0227944, 0.999548, 0.0195944], + [-0.104748, 0.0218794, -0.994259], + ], + ), + ( + [ + [0.988981, 0.132742, -0.0655406], + [-0.113134, 0.963226, 0.243712], + [0.0954813, -0.233612, 0.96763], + ], + [ + [-0.989473, -0.144453, 0.00888153], + [0.112318, -0.727754, 0.67658], + [-0.0912697, 0.670455, 0.736317], + ], + ), + ( + [ + [0.13315, -0.722685, 0.678231], + [0.255831, 0.686195, 0.680946], + [-0.957508, 0.0828446, 0.276252], + ], + [ + [-0.233019, 0.0127274, -0.97239], + [-0.0143824, 0.99976, 0.0165321], + [0.972367, 0.0178377, -0.23278], + ], + ), + ( + [ + [0.906305, -0.0179617, -0.422243], + [0.0246095, 0.999644, 0.0102984], + [0.421908, -0.0197247, 0.906424], + ], + [ + [-0.90393, 0.0136293, 0.427466], + [0.0169755, 0.999848, 0.0040176], + [-0.427346, 0.0108879, -0.904024], + ], + ), + ( + [ + [0.999808, 0.0177784, -0.00826505], + [-0.0177075, 0.999806, 0.00856939], + [0.0084158, -0.00842139, 0.999929], + ], + [ + [-0.999901, -0.0141114, 0.00072392], + [0.00130602, -0.0413336, 0.999145], + [-0.0140699, 0.999047, 0.0413473], + ], + ), + ( + [ + [0.985811, -0.161425, 0.0460375], + [0.154776, 0.980269, 0.12295], + [-0.0649764, -0.11408, 0.991344], + ], + [ + [-0.985689, 0.137931, -0.09692], + [-0.162627, -0.626622, 0.762168], + [0.0443951, 0.767022, 0.640085], + ], + ), + ( + [ + [0.956652, -0.0116044, 0.291001], + [0.05108, 0.990402, -0.128428], + [-0.286718, 0.137726, 0.948064], + ], + [ + [-0.956189, 0.00996594, -0.292585], + [-0.0397033, 0.985772, 0.16333], + [0.29005, 0.167791, -0.942189], + ], + ), + ( + [ + [0.783763, -0.0181248, -0.620796], + [-0.0386421, 0.996214, -0.0778717], + [0.619857, 0.0850218, 0.780095], + ], + [ + [-0.780275, 0.0093644, 0.625368], + [-0.0221791, 0.998845, -0.0426297], + [-0.625045, -0.0471329, -0.779165], + ], + ), + ( + [ + [0.890984, 0.0232464, -0.453439], + [-0.0221215, 0.999725, 0.00778511], + [0.453495, 0.00309433, 0.891253], + ], + [ + [-0.890178, 0.0290103, 0.45469], + [0.0337152, 0.999429, 0.0022403], + [-0.454366, 0.0173244, -0.890648], + ], + ), + ( + [ + [0.998177, -0.0119546, 0.0591504], + [0.00277494, 0.988238, 0.152901], + [-0.0602825, -0.152458, 0.98647], + ], + [ + [-0.997444, 0.00871865, -0.0709414], + [0.0197108, 0.987598, -0.155762], + [0.0687035, -0.156762, -0.985246], + ], + ), + ( + [ + [0.985214, 0.164929, 0.0463837], + [-0.166966, 0.984975, 0.0441225], + [-0.0384096, -0.0512146, 0.997949], + ], + [ + [-0.999472, -0.000819214, -0.0325087], + [-0.00344291, 0.99673, 0.0807324], + [0.0323362, 0.0808019, -0.996206], + ], + ), + ( + [ + [0.998499, 0.0465241, 0.0288955], + [-0.0454764, 0.99832, -0.0359142], + [-0.0305178, 0.0345463, 0.998937], + ], + [ + [-0.999441, 0.00412484, -0.0332105], + [0.00374685, 0.999928, 0.0114307], + [0.0332552, 0.0112999, -0.999384], + ], + ), + ( + [ + [0.10101, -0.943239, -0.316381], + [0.841334, -0.0887423, 0.533182], + [-0.530994, -0.320039, 0.784615], + ], + [ + [-0.725665, 0.0153749, -0.687878], + [-0.304813, 0.889109, 0.34143], + [0.616848, 0.457438, -0.640509], + ], + ), + ( + [ + [0.843428, 0.174952, 0.507958], + [0.0435637, 0.920106, -0.389239], + [-0.535473, 0.350423, 0.768422], + ], + [ + [-0.835464, 0.0040872, -0.549533], + [0.00227251, 0.999989, 0.00398241], + [0.549543, 0.00207845, -0.835464], + ], + ), + ( + [ + [0.999897, -0.0142888, -0.00160177], + [0.0141561, 0.997826, -0.064364], + [0.00251798, 0.0643346, 0.997925], + ], + [ + [-0.999956, 0.00898988, 0.00296485], + [0.00900757, 0.999941, 0.00601779], + [-0.00291058, 0.00604429, -0.999978], + ], + ), + ( + [ + [0.999477, -0.0204859, 0.0250096], + [0.0126204, 0.959462, 0.281557], + [-0.0297637, -0.281094, 0.959219], + ], + [ + [-0.999384, 0.0172602, -0.0305795], + [-0.0254425, 0.24428, 0.969371], + [0.0242012, 0.969551, -0.24369], + ], + ), + ( + [ + [0.984597, 0.173474, -0.0218106], + [-0.15145, 0.783891, -0.602145], + [-0.0873592, 0.596173, 0.798089], + ], + [ + [-0.998608, -0.0432858, 0.0301827], + [-0.00287128, 0.615692, 0.787983], + [-0.0526917, 0.786797, -0.61496], + ], + ), + ( + [ + [0.917099, -0.3072, 0.254083], + [0.303902, 0.951219, 0.0531566], + [-0.258018, 0.0284665, 0.965721], + ], + [ + [-0.931191, 0.347008, -0.111675], + [-0.352102, -0.77686, 0.522032], + [0.0943935, 0.52543, 0.845586], + ], + ), + ( + [ + [0.991706, 0.0721037, -0.106393], + [-0.0995017, 0.954693, -0.280464], + [0.0813505, 0.288725, 0.95395], + ], + [ + [-0.995306, 0.00106317, 0.0967833], + [0.0167662, 0.986717, 0.161583], + [-0.0953259, 0.162447, -0.982103], + ], + ), + ( + [ + [0.997078, 0.0478273, -0.0595641], + [-0.0348316, 0.978617, 0.202719], + [0.067986, -0.200052, 0.977424], + ], + [ + [-0.997925, -0.0439691, 0.0470461], + [0.0643829, -0.695474, 0.715663], + [0.00125305, 0.717207, 0.696861], + ], + ), + ( + [ + [0.972749, -0.0233882, -0.230677], + [0.0255773, 0.999652, 0.00650349], + [0.230445, -0.0122264, 0.973009], + ], + [ + [-0.973286, 0.0147558, 0.229126], + [0.0145644, 0.999891, -0.00252631], + [-0.229138, 0.000878362, -0.973394], + ], + ), + ( + [ + [0.999271, 0.00700481, 0.0375381], + [-0.0348202, 0.57069, 0.820427], + [-0.0156757, -0.821136, 0.570517], + ], + [ + [-0.999805, -0.0198049, 0.000539906], + [0.0179848, -0.89584, 0.444015], + [-0.00831113, 0.443938, 0.89602], + ], + ), + ( + [ + [0.975255, -0.0207895, 0.220104], + [0.0252764, 0.999526, -0.0175888], + [-0.219634, 0.022717, 0.975318], + ], + [ + [-0.975573, 0.0128154, -0.219304], + [0.0106554, 0.999882, 0.0110292], + [0.219419, 0.00842303, -0.975594], + ], + ), + ( + [ + [-0.433961, -0.658151, -0.615236], + [0.610442, -0.717039, 0.336476], + [-0.6626, -0.229548, 0.71293], + ], + [ + [-0.998516, -0.00675119, -0.054067], + [-0.00405539, 0.99875, -0.0498174], + [0.0543358, -0.0495237, -0.997296], + ], + ), + ( + [ + [0.942764, -0.0126807, -0.333221], + [-0.0017175, 0.999079, -0.042879], + [0.333458, 0.0409971, 0.941873], + ], + [ + [-0.942228, -0.0109444, 0.334798], + [0.0110573, 0.997905, 0.0637396], + [-0.334794, 0.0637589, -0.940133], + ], + ), + ( + [ + [0.962038, 0.0147987, 0.272515], + [-0.0185974, 0.999762, 0.0113615], + [-0.272283, -0.0159982, 0.962084], + ], + [ + [-0.959802, 0.0113708, -0.280451], + [0.00982126, 0.999928, 0.00692958], + [0.280509, 0.00389678, -0.959845], + ], + ), + ( + [ + [0.998414, 0.0139348, 0.0545528], + [-0.0226877, 0.986318, 0.163283], + [-0.0515311, -0.164262, 0.98507], + ], + [ + [-0.998641, -0.000695993, -0.0521343], + [0.0182534, 0.931965, -0.362087], + [0.0488394, -0.362547, -0.930686], + ], + ), + ( + [ + [0.999705, -0.0234518, -0.00633743], + [0.0235916, 0.999458, 0.0229643], + [0.00579544, -0.023107, 0.999716], + ], + [ + [-0.999901, 0.004436, 0.0133471], + [-0.00306106, 0.85758, -0.514342], + [-0.0137278, -0.514332, -0.857481], + ], + ), + ( + [ + [0.998617, -0.051704, 0.0094837], + [0.0484292, 0.975079, 0.216506], + [-0.0204416, -0.215748, 0.976235], + ], + [ + [-0.999959, -0.00295958, -0.00862907], + [-0.00313279, 0.999792, 0.0201361], + [0.00856768, 0.0201625, -0.999761], + ], + ), + ( + [ + [0.999121, 0.0345472, -0.023733], + [-0.0333175, 0.998174, 0.0503881], + [0.0254304, -0.0495531, 0.998448], + ], + [ + [-0.999272, -0.0337466, 0.0178065], + [0.0200629, -0.0677624, 0.9975], + [-0.0324556, 0.997131, 0.0683898], + ], + ), + ( + [ + [0.989017, 0.139841, -0.0478525], + [-0.131355, 0.683201, -0.718319], + [-0.0677572, 0.716715, 0.694067], + ], + [ + [-0.995236, 0.00457798, 0.097401], + [0.097488, 0.0258334, 0.994902], + [0.00203912, 0.999657, -0.0261574], + ], + ), + ( + [ + [0.961528, 0.249402, 0.11516], + [-0.204522, 0.9298, -0.306009], + [-0.183395, 0.270684, 0.945038], + ], + [ + [-0.999566, -0.0233216, 0.0180679], + [0.012372, 0.224583, 0.974377], + [-0.0267822, 0.974177, -0.224197], + ], + ), + ( + [ + [0.865581, 0.0252563, -0.500131], + [0.0302583, 0.994265, 0.102578], + [0.499853, -0.103923, 0.859853], + ], + [ + [-0.866693, 0.0042288, 0.498824], + [0.0129331, 0.999818, 0.0139949], + [-0.498674, 0.0185807, -0.866591], + ], + ), + ( + [ + [0.998999, -0.0213419, -0.0393009], + [-0.0007582, 0.870578, -0.492031], + [0.0447153, 0.491568, 0.86969], + ], + [ + [-0.999207, -0.0184688, 0.0353073], + [0.00153266, 0.867625, 0.497218], + [-0.0398164, 0.496876, -0.866908], + ], + ), + ( + [ + [0.96567, -0.00482973, 0.259728], + [0.00349956, 0.999978, 0.00558359], + [-0.259749, -0.00448297, 0.965666], + ], + [ + [-0.962691, -0.00113074, -0.270609], + [-5.93716e-05, 0.999992, -0.00396767], + [0.270612, -0.00380337, -0.962683], + ], + ), + ( + [ + [0.948799, 0.287027, -0.131894], + [-0.300257, 0.949181, -0.0943405], + [0.0981135, 0.129112, 0.986764], + ], + [ + [-0.993593, -0.0406684, 0.105449], + [-0.0506857, 0.994269, -0.0941326], + [-0.101017, -0.0988741, -0.98996], + ], + ), + ( + [ + [0.998935, 0.0451118, 0.0097202], + [-0.0418086, 0.973879, -0.223183], + [-0.0195345, 0.222539, 0.974728], + ], + [ + [-0.999821, 0.00821522, -0.0170658], + [0.00742187, 0.998912, 0.046048], + [0.0174255, 0.0459131, -0.998794], + ], + ), + ( + [ + [0.99577, 0.00458459, 0.0917705], + [-0.00244288, 0.999722, -0.0234365], + [-0.0918524, 0.0231131, 0.995504], + ], + [ + [-0.995956, 0.0137721, -0.0887945], + [0.0122932, 0.999777, 0.0171801], + [0.0890113, 0.0160191, -0.995903], + ], + ), + ( + [ + [0.97931, 0.0219899, 0.201169], + [-0.0159226, 0.99937, -0.0317288], + [-0.201739, 0.0278692, 0.979043], + ], + [ + [-0.980952, 0.00507266, -0.19419], + [0.00310821, 0.999941, 0.010419], + [0.194231, 0.00961706, -0.98091], + ], + ), + ( + [ + [0.999616, 0.00550326, -0.0271537], + [-0.0048286, 0.99968, 0.0248495], + [0.0272817, -0.0247088, 0.999322], + ], + [ + [-0.999689, -0.00054899, 0.0249588], + [-0.00125497, 0.999599, -0.0282774], + [-0.0249333, -0.0282998, -0.999289], + ], + ), + ( + [ + [0.998036, -0.00755259, -0.0621791], + [0.0417502, 0.820234, 0.570502], + [0.0466927, -0.571978, 0.818939], + ], + [ + [-0.999135, -0.0278203, 0.0309173], + [-0.00855238, 0.864892, 0.501886], + [-0.0407029, 0.501187, -0.864382], + ], + ), + ( + [ + [0.958227, 0.00271545, 0.285997], + [-0.00426128, 0.999979, 0.00478282], + [-0.285979, -0.00580174, 0.958218], + ], + [ + [-0.958726, 0.011053, -0.284121], + [0.0138068, 0.999875, -0.00769161], + [0.284001, -0.0112968, -0.958759], + ], + ), + ( + [ + [-0.804547, -0.48558, -0.341929], + [0.517913, -0.855425, -0.00382581], + [-0.290637, -0.180168, 0.939718], + ], + [ + [0.993776, -0.0469383, -0.101033], + [-0.110087, -0.274676, -0.955214], + [0.0170842, 0.96039, -0.278134], + ], + ), + ( + [ + [0.991875, -0.0022313, -0.127195], + [-0.00198041, 0.999454, -0.0329762], + [0.127199, 0.0329602, 0.991329], + ], + [ + [-0.992632, -0.0090772, 0.120844], + [-0.00870494, 0.999956, 0.00360636], + [-0.120871, 0.00252786, -0.992667], + ], + ), + ( + [ + [0.999305, -0.0252534, 0.0274367], + [0.026144, 0.999126, -0.0326002], + [-0.0265895, 0.0332948, 0.999092], + ], + [ + [-0.999314, -0.0038532, -0.0368519], + [-0.00441323, 0.999876, 0.0151263], + [0.036789, 0.0152787, -0.999207], + ], + ), + ( + [ + [0.999843, -0.00958823, 0.0148803], + [0.00982469, 0.999825, -0.0159002], + [-0.0147253, 0.0160439, 0.999763], + ], + [ + [-0.999973, 0.00673608, -0.00308692], + [-0.0067409, -0.999977, 0.00116827], + [-0.00307934, 0.00119013, 0.999995], + ], + ), + ( + [ + [0.981558, -0.00727741, 0.191028], + [-0.00866166, 0.996556, 0.0824708], + [-0.190971, -0.0826044, 0.978114], + ], + [ + [-0.980202, 0.0179519, -0.197188], + [0.00957606, 0.999014, 0.0433472], + [0.197772, 0.0406008, -0.979408], + ], + ), + ( + [ + [0.966044, 0.0143709, 0.257977], + [-0.0157938, 0.999869, 0.00344404], + [-0.257894, -0.00740153, 0.966145], + ], + [ + [-0.965532, 0.0100318, -0.260094], + [0.00950897, 0.999949, 0.00326797], + [0.260113, 0.000682242, -0.965579], + ], + ), + ( + [ + [0.999965, 0.00727991, -0.00412134], + [-0.00802642, 0.973769, -0.227397], + [0.00235781, 0.227422, 0.973794], + ], + [ + [-0.999877, 0.00698241, 0.0141441], + [0.0103867, 0.966295, 0.257228], + [-0.0118713, 0.257343, -0.966248], + ], + ), + ( + [ + [0.951385, -0.0297966, 0.306561], + [-0.0314555, 0.980706, 0.19294], + [-0.306395, -0.193204, 0.932092], + ], + [ + [-0.99981, 0.00389172, -0.0191159], + [-0.00386326, -0.999991, -0.00150593], + [-0.0191215, -0.00143146, 0.999816], + ], + ), + ( + [ + [0.986772, -0.120673, 0.10825], + [0.0543962, 0.875511, 0.480126], + [-0.152713, -0.467887, 0.870495], + ], + [ + [-0.991246, 0.125848, -0.0399414], + [-0.129021, -0.85897, 0.495507], + [0.0280503, 0.496321, 0.867686], + ], + ), + ( + [ + [-0.804799, -0.588418, 0.0778637], + [-0.514399, 0.756902, 0.403104], + [-0.296129, 0.284365, -0.911836], + ], + [ + [0.98676, -0.0939473, 0.132227], + [0.162179, 0.557277, -0.814336], + [0.0028177, 0.824995, 0.565135], + ], + ), + ( + [ + [0.878935, 0.115231, 0.462813], + [0.0845639, 0.917349, -0.388998], + [-0.469386, 0.381041, 0.796546], + ], + [ + [-0.869533, 0.00193279, -0.493873], + [-0.00419575, 0.999927, 0.0113007], + [0.493859, 0.0118986, -0.869462], + ], + ), + ( + [ + [0.951881, 0.20828, 0.224816], + [-0.305582, 0.700797, 0.644595], + [-0.023294, -0.682277, 0.730722], + ], + [ + [-0.999787, 0.0141074, -0.0151097], + [-0.000971554, 0.698061, 0.716038], + [0.0206489, 0.7159, -0.697898], + ], + ), + ( + [ + [0.999538, 0.0192173, 0.0235334], + [-0.0189064, 0.999732, -0.0133635], + [-0.0237839, 0.0129124, 0.999634], + ], + [ + [-0.999807, 0.00286378, -0.0194776], + [0.0026258, 0.999922, 0.0122308], + [0.0195111, 0.0121774, -0.999736], + ], + ), + ( + [ + [0.998468, 0.041362, -0.0367422], + [-0.0364453, 0.991404, 0.125658], + [0.0416238, -0.124127, 0.991393], + ], + [ + [-0.997665, -0.0658235, 0.0183602], + [0.0216855, -0.0501652, 0.998507], + [-0.064804, 0.99657, 0.0514739], + ], + ), + ( + [ + [0.995563, 0.0493669, 0.0801057], + [-0.0272233, 0.966027, -0.257002], + [-0.0900717, 0.253681, 0.963085], + ], + [ + [-0.999228, -0.034399, -0.0190572], + [0.0250208, -0.929986, 0.366743], + [-0.0303386, 0.365984, 0.930127], + ], + ), + ( + [ + [0.952898, 0.0122933, 0.303043], + [-0.00568444, 0.999727, -0.0226807], + [-0.303239, 0.0198898, 0.952707], + ], + [ + [-0.951155, 0.0127759, -0.308452], + [0.000612627, 0.999219, 0.0394978], + [0.308716, 0.0373795, -0.95042], + ], + ), + ( + [ + [0.923096, -0.000313887, 0.38457], + [0.00948258, 0.999714, -0.0219453], + [-0.384454, 0.0239044, 0.922835], + ], + [ + [-0.922662, -0.00403523, -0.385589], + [-0.0119834, 0.999762, 0.0182116], + [0.385424, 0.0214239, -0.922491], + ], + ), + ( + [ + [0.991575, 0.0945042, -0.0885834], + [-0.10112, 0.99216, -0.0734349], + [0.080949, 0.0817738, 0.993358], + ], + [ + [-0.990948, -0.127974, 0.0405639], + [0.096351, -0.467557, 0.878697], + [-0.0934839, 0.874651, 0.475655], + ], + ), + ( + [ + [0.997148, 0.010521, 0.0747407], + [-0.0079726, 0.999379, -0.034313], + [-0.0750553, 0.0336192, 0.996612], + ], + [ + [-0.996543, 0.00988805, -0.0825019], + [0.00939476, 0.999936, 0.0063645], + [0.0825595, 0.00556751, -0.996572], + ], + ), + ( + [ + [0.991261, 0.00474444, -0.131831], + [-0.00205841, 0.999788, 0.0205036], + [0.131901, -0.020053, 0.99106], + ], + [ + [-0.990924, 4.45275e-05, 0.134427], + [0.00614714, 0.998969, 0.0449827], + [-0.134286, 0.0454008, -0.989903], + ], + ), + ( + [ + [0.992266, -0.0947916, 0.0801474], + [0.100889, 0.992006, -0.0757987], + [-0.0723216, 0.0832984, 0.993897], + ], + [ + [-0.992701, 0.0817686, -0.0886652], + [-0.114283, -0.40263, 0.908203], + [0.0385633, 0.911704, 0.409035], + ], + ), + ( + [ + [0.99696, -0.00808565, -0.0774951], + [0.0585083, 0.734519, 0.676061], + [0.0514552, -0.67854, 0.732759], + ], + [ + [-0.9998, 0.0053398, -0.0193164], + [-0.0162677, -0.779206, 0.626556], + [-0.0117055, 0.626745, 0.779137], + ], + ), + ( + [ + [0.961501, 0.0133645, -0.274475], + [-0.016255, 0.999834, -0.00825889], + [0.274319, 0.0124025, 0.961559], + ], + [ + [-0.963687, 0.000179203, 0.267042], + [0.00670194, 0.999701, 0.023515], + [-0.266958, 0.0244509, -0.9634], + ], + ), + ( + [ + [0.99877, 0.0413462, -0.0273572], + [-0.0263673, 0.91029, 0.413131], + [0.0419844, -0.411902, 0.910261], + ], + [ + [-0.998035, -0.0613039, 0.0130407], + [-0.00146496, 0.230815, 0.972998], + [-0.0626594, 0.971065, -0.230452], + ], + ), + ( + [ + [0.999657, 0.0261608, 0.00141675], + [-0.0261957, 0.998937, 0.0379393], + [-0.000422719, -0.0379634, 0.999279], + ], + [ + [-0.998896, -0.0310033, -0.0353275], + [0.0315452, -0.999392, -0.0148857], + [-0.0348445, -0.0159846, 0.999265], + ], + ), + ( + [ + [0.77369, 0.0137861, 0.633415], + [-0.0186509, 0.999826, 0.00102049], + [-0.63329, -0.0126033, 0.773812], + ], + [ + [-0.773069, 0.0156632, -0.634129], + [0.00418312, 0.999799, 0.0195956], + [0.634308, 0.0124961, -0.772979], + ], + ), + ( + [ + [0.952827, -0.024521, -0.302522], + [-0.00541318, 0.9952, -0.0977158], + [0.303465, 0.0947439, 0.94812], + ], + [ + [-0.952266, -0.00806089, 0.305165], + [0.00351941, 0.999295, 0.037378], + [-0.305252, 0.0366678, -0.951567], + ], + ), + ( + [ + [-0.172189, 0.949971, 0.260587], + [-0.86961, -0.0223234, -0.493235], + [-0.462741, -0.311539, 0.829948], + ], + [ + [-0.672964, 0.0127645, -0.739567], + [0.00429523, 0.999902, 0.0133494], + [0.739664, 0.00580721, -0.672953], + ], + ), + ( + [ + [0.637899, -0.440017, 0.632036], + [-0.52883, 0.346333, 0.774849], + [-0.559842, -0.828516, -0.0117683], + ], + [ + [-0.0627307, -0.0314554, -0.997536], + [-0.733537, 0.679201, 0.0247117], + [0.67675, 0.733279, -0.0656804], + ], + ), + ( + [ + [0.998402, 0.00284932, -0.0564372], + [0.000393713, 0.998353, 0.0573683], + [0.0565077, -0.0572989, 0.996757], + ], + [ + [-0.997878, 0.000941416, 0.0651252], + [-2.16756e-05, 0.999891, -0.0147853], + [-0.065132, -0.0147552, -0.997768], + ], + ), + ( + [ + [0.9999, 0.0141438, -0.000431687], + [-0.0140882, 0.9979, 0.063225], + [0.00132502, -0.0632125, 0.997999], + ], + [ + [-0.999515, -0.0308197, -0.00482715], + [-0.00160551, -0.103741, 0.994605], + [-0.0311554, 0.994128, 0.10364], + ], + ), + ( + [ + [-0.201909, 0.0267804, 0.979038], + [-0.0159062, 0.999405, -0.0306179], + [-0.979275, -0.0217548, -0.201363], + ], + [ + [0.261235, 0.951613, -0.161839], + [0.0758567, 0.146901, 0.986239], + [0.962292, -0.269916, -0.03381], + ], + ), + ( + [ + [0.998335, -0.0191576, -0.0544038], + [0.0163271, 0.998513, -0.0520045], + [0.0553192, 0.0510297, 0.997164], + ], + [ + [-0.998811, -0.00846127, 0.0480344], + [-0.0051736, 0.997661, 0.0681593], + [-0.0484988, 0.0678295, -0.996519], + ], + ), + ( + [ + [0.999973, 0.00227282, -0.00699658], + [-0.00137504, 0.992062, 0.125744], + [0.00722684, -0.125731, 0.992038], + ], + [ + [-0.999995, -0.00337061, 4.25756e-05], + [-0.00333677, 0.991528, 0.129853], + [-0.00047993, 0.129852, -0.991534], + ], + ), + ( + [ + [0.998908, 0.0216581, -0.041392], + [-0.0327304, 0.956678, -0.289302], + [0.0333331, 0.290341, 0.956342], + ], + [ + [-0.998254, -0.0377592, 0.0454422], + [0.00744647, 0.682591, 0.730764], + [-0.0586112, 0.729825, -0.681118], + ], + ), + ( + [ + [0.999387, -0.0042571, -0.0347599], + [0.00485203, 0.999843, 0.017049], + [0.0346819, -0.0172072, 0.99925], + ], + [ + [-0.999976, 0.00260242, -0.00669664], + [-0.00250352, -0.999889, -0.0147361], + [-0.00673422, -0.0147175, 0.99987], + ], + ), + ( + [ + [0.906103, -0.398828, -0.141112], + [0.381512, 0.914475, -0.13485], + [0.182826, 0.0683519, 0.980766], + ], + [ + [-0.996568, -0.0321282, -0.0763021], + [-0.0823787, 0.476597, 0.875254], + [0.00824509, 0.878535, -0.477609], + ], + ), + ( + [ + [0.908356, 0.316033, -0.273884], + [-0.231421, -0.165634, -0.95865], + [-0.34833, 0.934178, -0.0773183], + ], + [ + [-0.999889, -0.0146322, -0.00295739], + [-0.0149238, 0.974974, 0.221815], + [-0.000362257, 0.221835, -0.975085], + ], + ), + ( + [ + [0.999507, -0.00834631, 0.0302637], + [0.00899248, 0.999733, -0.0212785], + [-0.030078, 0.0215401, 0.999315], + ], + [ + [-0.999538, 0.00785187, -0.0293621], + [0.00739788, 0.999852, 0.0155394], + [0.0294797, 0.0153149, -0.999448], + ], + ), + ( + [ + [0.999951, -0.00729441, -0.00672921], + [0.00313753, 0.87564, -0.482954], + [0.00941523, 0.48291, 0.87562], + ], + [ + [-0.999984, -0.005202, -0.00277372], + [0.00340465, -0.893745, 0.448565], + [-0.00481353, 0.448548, 0.893747], + ], + ), + ( + [ + [0.998028, -0.0569885, 0.0263322], + [0.0489091, 0.968801, 0.242967], + [-0.039357, -0.2412, 0.969677], + ], + [ + [-0.997066, 0.0422415, -0.0638525], + [-0.0760293, -0.448184, 0.890703], + [0.00900662, 0.892944, 0.45008], + ], + ), + ( + [ + [0.999745, 0.00860777, 0.0208747], + [-0.00827114, 0.999835, -0.0161595], + [-0.0210103, 0.0159827, 0.999651], + ], + [ + [-0.999576, 0.0148733, -0.0251161], + [0.0151027, 0.999846, -0.00898035], + [0.0249787, -0.00935575, -0.999646], + ], + ), + ( + [ + [0.91924, 0.0372116, -0.391934], + [-0.00675798, 0.996868, 0.0787959], + [0.393639, -0.0697837, 0.916613], + ], + [ + [-0.921919, 0.00882585, 0.387286], + [0.00588498, 0.999944, -0.00877866], + [-0.387342, -0.00581387, -0.921919], + ], + ), + ( + [ + [0.998324, -0.0029024, 0.0577924], + [0.00236766, 0.999954, 0.00931901], + [-0.0578167, -0.00916657, 0.998285], + ], + [ + [-0.99892, -0.0025688, -0.0464413], + [-0.00203721, 0.999932, -0.0114927], + [0.0464676, -0.0113855, -0.998857], + ], + ), + ( + [ + [0.993986, 0.0163462, -0.108279], + [-0.0612924, 0.902447, -0.426418], + [0.090746, 0.43049, 0.898022], + ], + [ + [-0.994519, -0.0767804, 0.0709843], + [0.0579273, 0.160607, 0.985318], + [-0.0870543, 0.984028, -0.15528], + ], + ), + ( + [ + [0.997351, 0.0715122, -0.0132892], + [-0.0707087, 0.996067, 0.0533919], + [0.0170551, -0.0523108, 0.998485], + ], + [ + [-0.997704, -0.066002, 0.015281], + [0.064101, -0.846657, 0.528267], + [-0.0219278, 0.528033, 0.848942], + ], + ), + ( + [ + [0.999839, 0.00714662, -0.0164633], + [-0.00859425, 0.99594, -0.0896085], + [0.0157561, 0.0897356, 0.995841], + ], + [ + [-0.999773, 0.0079918, 0.0197854], + [0.00864136, 0.999419, 0.0329623], + [-0.0195105, 0.0331255, -0.999262], + ], + ), + ( + [ + [-0.773738, 0.630074, 0.0658454], + [-0.622848, -0.737618, -0.260731], + [-0.115711, -0.242749, 0.963163], + ], + [ + [-0.740005, 0.000855199, -0.672604], + [-0.0106008, 0.99986, 0.0129348], + [0.672521, 0.0167018, -0.739892], + ], + ), + ( + [ + [0.969039, -0.00110643, -0.246907], + [-0.121454, 0.868509, -0.480564], + [0.214973, 0.495673, 0.841484], + ], + [ + [-0.981168, -0.150714, 0.120811], + [0.172426, -0.401504, 0.89948], + [-0.0870583, 0.903372, 0.419929], + ], + ), + ( + [ + [0.589015, 0.80692, 0.0440651], + [-0.806467, 0.583447, 0.0959135], + [0.0516848, -0.0920316, 0.994414], + ], + [ + [-0.99998, 0.00434293, -0.00486489], + [0.00437139, 0.999973, -0.00588975], + [0.00483918, -0.00591087, -0.999972], + ], + ), + ( + [ + [0.999972, 0.000781564, 0.00750023], + [-0.0031568, 0.946655, 0.322235], + [-0.00684828, -0.322249, 0.94663], + ], + [ + [-0.999817, -0.0178453, -0.00691725], + [-0.0189272, 0.975556, 0.218934], + [0.00284118, 0.219025, -0.975716], + ], + ), + ( + [ + [-0.969668, 0.219101, -0.108345], + [0.172364, 0.298654, -0.938667], + [-0.173305, -0.928871, -0.32736], + ], + [ + [-0.999917, 0.0111423, -0.00656864], + [-0.00977865, -0.318874, 0.947748], + [0.00846644, 0.947733, 0.318955], + ], + ), + ( + [ + [-0.808574, -0.185515, -0.558383], + [0.174641, -0.981898, 0.0733309], + [-0.561879, -0.038223, 0.826336], + ], + [ + [-0.873416, 0.0121808, -0.486824], + [-0.00495714, 0.999413, 0.0338998], + [0.486951, 0.032022, -0.872843], + ], + ), + ( + [ + [0.999295, 0.0295658, -0.0231234], + [-0.0251771, 0.984904, 0.17126], + [0.0278378, -0.170557, 0.984954], + ], + [ + [-0.998834, -0.040128, 0.026921], + [0.0327412, -0.152276, 0.987798], + [-0.0355388, 0.987524, 0.153411], + ], + ), + ( + [ + [0.996021, -0.0050677, -0.0889802], + [0.0042919, 0.999951, -0.00890794], + [0.089021, 0.0084906, 0.995994], + ], + [ + [-0.995726, -0.00858132, 0.0919686], + [-0.00615004, 0.999625, 0.0266854], + [-0.0921631, 0.0260058, -0.995405], + ], + ), + ( + [ + [0.563325, 0.812296, 0.151129], + [-0.316559, 0.381143, -0.868632], + [-0.763188, 0.441481, 0.471847], + ], + [ + [-0.980048, -0.0115108, -0.198437], + [-0.168991, 0.573853, 0.801335], + [0.104649, 0.818877, -0.564348], + ], + ), + ( + [ + [0.984844, -0.0288271, 0.17103], + [0.0260588, 0.999491, 0.0184094], + [-0.171474, -0.0136736, 0.985094], + ], + [ + [-0.984637, -0.00367691, -0.174577], + [-0.00649229, 0.999858, 0.0155587], + [0.174495, 0.0164532, -0.984521], + ], + ), + ( + [ + [0.99985, 0.000720773, -0.0172841], + [-0.00075051, 0.999998, -0.0017141], + [0.0172828, 0.00172682, 0.999849], + ], + [ + [-0.999926, -0.00413456, 0.0114842], + [-0.00368343, 0.999231, 0.0390359], + [-0.0116368, 0.0389908, -0.999172], + ], + ), + ( + [ + [0.997976, 0.0603523, -0.0200139], + [-0.0558618, 0.982551, 0.177404], + [0.0303714, -0.175927, 0.983935], + ], + [ + [-0.996867, -0.0790953, 0.00217996], + [0.0318842, -0.376338, 0.925935], + [-0.0724181, 0.923101, 0.37768], + ], + ), + ( + [ + [0.94678, -0.00538407, -0.321837], + [0.00249113, 0.999953, -0.0094], + [0.321872, 0.008098, 0.946749], + ], + [ + [-0.945694, 0.0255694, 0.324053], + [0.0240377, 0.999673, -0.00872898], + [-0.32417, -0.000465377, -0.945999], + ], + ), + ( + [ + [0.846059, 0.435245, -0.307807], + [0.318073, 0.0512036, 0.946682], + [0.4278, -0.898855, -0.0951187], + ], + [ + [-0.217213, -0.0389124, 0.975352], + [0.742195, 0.642416, 0.190918], + [-0.634011, 0.765368, -0.11066], + ], + ), + ( + [ + [0.914988, -0.0538229, -0.399875], + [-0.0459455, 0.970717, -0.23579], + [0.400857, 0.234117, 0.885722], + ], + [ + [-0.919706, 0.00194642, 0.392606], + [0.105539, 0.964406, 0.242451], + [-0.378159, 0.264418, -0.887176], + ], + ), + ( + [ + [0.970915, -0.183858, 0.153365], + [0.209801, 0.96196, -0.174974], + [-0.115361, 0.202061, 0.972555], + ], + [ + [-0.975509, 0.21077, -0.0629391], + [-0.218082, -0.964089, 0.151576], + [-0.0287314, 0.161588, 0.986441], + ], + ), + ( + [ + [0.99369, -0.00515149, -0.112044], + [0.00366664, 0.999903, -0.0134545], + [0.112102, 0.0129588, 0.993612], + ], + [ + [-0.99406, 0.00631892, 0.108668], + [0.00878985, 0.999713, 0.022273], + [-0.108496, 0.0230956, -0.99383], + ], + ), + ( + [ + [0.995917, 0.0137529, 0.089215], + [-0.0145079, 0.999864, 0.00781912], + [-0.0890954, -0.00908151, 0.995982], + ], + [ + [-0.996188, 0.012059, -0.0864113], + [0.0126654, 0.999899, -0.00647346], + [0.0863245, -0.00754306, -0.99624], + ], + ), + ( + [ + [0.84563, -0.0032436, -0.533759], + [0.0040093, 0.999992, 0.000275049], + [0.533754, -0.00237259, 0.845636], + ], + [ + [-0.849818, -0.00755214, 0.527023], + [-0.00734806, 0.99997, 0.00248074], + [-0.527026, -0.00176415, -0.849848], + ], + ), + ( + [ + [0.736067, -0.212675, -0.642631], + [-0.447028, 0.560168, -0.697408], + [0.508303, 0.800613, 0.31725], + ], + [ + [-0.684029, 0.0061039, 0.729431], + [0.0260275, 0.999532, 0.0160434], + [-0.728992, 0.0299595, -0.683868], + ], + ), + ( + [ + [0.993949, 0.00461705, -0.109742], + [-0.00653155, 0.999833, -0.0170925], + [0.109644, 0.0177058, 0.993813], + ], + [ + [-0.994446, 0.0218439, 0.102965], + [0.0227578, 0.999711, 0.00770966], + [-0.102767, 0.0100102, -0.994656], + ], + ), + ( + [ + [0.996005, -0.0103388, 0.0886959], + [-0.0291635, 0.901147, 0.432531], + [-0.0843999, -0.43339, 0.897246], + ], + [ + [-0.999947, 0.00833193, -0.00598923], + [-0.0101526, -0.887864, 0.459993], + [-0.00148526, 0.46003, 0.887902], + ], + ), + ( + [ + [0.981518, 0.0114609, 0.191025], + [-0.0104683, 0.999926, -0.00620422], + [-0.191082, 0.00408984, 0.981565], + ], + [ + [-0.979556, 0.000134379, -0.201176], + [-0.00817302, 0.999148, 0.0404628], + [0.20101, 0.0412799, -0.97872], + ], + ), + ( + [ + [0.997665, -0.0372296, -0.0572574], + [0.0379027, 0.999224, 0.0107148], + [0.0568141, -0.01286, 0.998302], + ], + [ + [-0.997794, 0.00389749, 0.0662921], + [0.00639122, 0.999278, 0.0374446], + [-0.0660983, 0.0377856, -0.997099], + ], + ), + ( + [ + [0.981618, -0.0105643, -0.190564], + [0.00329498, 0.999256, -0.0384229], + [0.190828, 0.0370887, 0.980923], + ], + [ + [-0.981673, -0.000810695, 0.190576], + [0.00398375, 0.999685, 0.0247729], + [-0.190536, 0.0250779, -0.981361], + ], + ), + ( + [ + [-0.544941, -0.812151, -0.208446], + [0.812337, -0.449791, -0.37121], + [0.207722, -0.371617, 0.90485], + ], + [ + [-0.121327, -0.000366672, -0.992614], + [-0.955208, 0.271977, 0.116655], + [0.269926, 0.962303, -0.0333484], + ], + ), + ( + [ + [0.637701, -0.219537, 0.738336], + [0.735715, 0.457522, -0.499397], + [-0.228168, 0.861671, 0.453279], + ], + [ + [-0.741797, 0.0196167, -0.670339], + [-0.00209087, 0.9995, 0.0315629], + [0.670623, 0.0248149, -0.741385], + ], + ), + ( + [ + [0.99813, -0.0590625, -0.0157485], + [0.0589086, 0.998213, -0.0100649], + [0.0163148, 0.00911833, 0.999825], + ], + [ + [-0.99893, 0.0258783, -0.0383385], + [-0.0440455, -0.279068, 0.959261], + [0.014125, 0.959924, 0.279908], + ], + ), + ( + [ + [0.999558, 0.0028395, -0.0296019], + [-0.00492321, 0.997496, -0.0705578], + [0.0293274, 0.0706723, 0.997068], + ], + [ + [-0.999532, -0.0305627, -0.00231546], + [0.00957406, -0.38309, 0.923664], + [-0.0291167, 0.923206, 0.383202], + ], + ), + ( + [ + [0.99814, -0.0528437, -0.0303853], + [0.0590889, 0.96123, 0.269341], + [0.0149743, -0.270636, 0.962565], + ], + [ + [-0.999464, 0.00781117, 0.0318024], + [-0.000588355, 0.966696, -0.255928], + [-0.0327423, -0.255809, -0.966173], + ], + ), + ( + [ + [-0.936685, 0.234194, 0.260336], + [-0.233325, -0.97178, 0.034698], + [0.261116, -0.0282419, 0.964894], + ], + [ + [0.999511, 0.00582072, 0.0307461], + [0.0289012, 0.204922, -0.978352], + [-0.0119956, 0.978762, 0.204654], + ], + ), + ( + [ + [0.973616, -0.019218, -0.227384], + [0.0030011, 0.99744, -0.0714512], + [0.228175, 0.0688836, 0.97118], + ], + [ + [-0.974738, 0.0190271, 0.222547], + [0.0222378, 0.999682, 0.0119297], + [-0.222249, 0.0165771, -0.97485], + ], + ), + ( + [ + [0.997273, 0.0453471, -0.0582173], + [-0.0234007, 0.942529, 0.333303], + [0.0699858, -0.331032, 0.941021], + ], + [ + [-0.996269, -0.0613496, 0.0607196], + [-0.0100285, 0.780948, 0.624516], + [-0.0857328, 0.621576, -0.77865], + ], + ), + ( + [ + [0.999511, 0.0274482, -0.0149865], + [-0.0305945, 0.957511, -0.286769], + [0.00647846, 0.287087, 0.957883], + ], + [ + [-0.999443, -0.0260559, 0.0209038], + [0.0148505, 0.213942, 0.976734], + [-0.0299225, 0.976499, -0.213437], + ], + ), + ( + [ + [0.621123, 0.722893, 0.302708], + [-0.48353, 0.657448, -0.577894], + [-0.61677, 0.212574, 0.757896], + ], + [ + [-0.996888, -0.0217614, -0.0757776], + [-0.0783897, 0.376159, 0.923234], + [0.00841386, 0.926299, -0.376694], + ], + ), + ( + [ + [0.974426, 0.0128384, -0.224341], + [-0.0123842, 0.999917, 0.00343166], + [0.224367, -0.00056561, 0.974505], + ], + [ + [-0.973234, -0.00506667, 0.229763], + [-0.000498848, 0.999801, 0.0199346], + [-0.229818, 0.0192865, -0.973043], + ], + ), + ( + [ + [0.994721, -0.0881097, 0.0526082], + [0.0972904, 0.972774, -0.210345], + [-0.0326424, 0.214353, 0.976211], + ], + [ + [-0.994309, 0.0920529, -0.0536268], + [-0.105538, -0.782431, 0.613729], + [0.0145358, 0.615896, 0.787694], + ], + ), + ( + [ + [0.998677, -0.0372894, 0.0354002], + [0.0242326, 0.948589, 0.315583], + [-0.0453481, -0.314308, 0.948237], + ], + [ + [-0.999066, -0.00910724, -0.0422707], + [-0.024629, 0.923353, 0.383161], + [0.0355411, 0.383844, -0.922715], + ], + ), + ( + [ + [0.931525, 0.00831028, 0.363583], + [0.0192806, 0.997204, -0.0721909], + [-0.363167, 0.0742577, 0.92876], + ], + [ + [-0.930052, -0.00174384, -0.367425], + [-0.0268673, 0.997634, 0.0632737], + [0.366445, 0.0687194, -0.927899], + ], + ), + ( + [ + [-0.50483, -0.819216, 0.272087], + [0.775688, -0.568816, -0.273414], + [0.378753, 0.0730272, 0.922612], + ], + [ + [-0.981596, 0.00031926, 0.190974], + [0.00652401, 0.999471, 0.0318616], + [-0.190863, 0.0325211, -0.981079], + ], + ), + ( + [ + [0.990518, -0.00195099, -0.137368], + [-0.00164696, 0.999659, -0.0260735], + [0.137372, 0.0260526, 0.990177], + ], + [ + [-0.991078, 0.00934835, 0.132961], + [0.0106057, 0.999905, 0.00875176], + [-0.132866, 0.0100839, -0.991083], + ], + ), + ( + [ + [0.935049, -0.353081, 0.0318997], + [0.257018, 0.737114, 0.624984], + [-0.244184, -0.576192, 0.779985], + ], + [ + [-0.977342, -0.00167896, -0.211667], + [-0.0448634, 0.978894, 0.199386], + [0.206864, 0.204364, -0.956789], + ], + ), + ( + [ + [0.998464, 0.0501172, 0.0236119], + [-0.0498618, 0.998692, -0.0112844], + [-0.0241466, 0.0100898, 0.999658], + ], + [ + [-0.999931, -0.0037971, -0.0112195], + [-0.00640916, 0.970027, 0.242913], + [0.00996085, 0.242968, -0.969984], + ], + ), + ( + [ + [0.999893, -0.0108217, 0.00984537], + [0.011201, 0.999164, -0.0393194], + [-0.00941164, 0.0394255, 0.999178], + ], + [ + [-0.999886, 0.00730461, -0.0133396], + [-0.0118202, -0.925163, 0.379391], + [-0.00956982, 0.379504, 0.925142], + ], + ), + ( + [ + [0.990922, -0.086745, 0.102709], + [0.0847349, 0.99612, 0.0237834], + [-0.104373, -0.0148644, 0.994427], + ], + [ + [-0.994922, -0.00197458, -0.10064], + [-0.00242513, 0.999988, 0.00435525], + [0.10063, 0.00457739, -0.994914], + ], + ), + ( + [ + [0.999856, 0.00210734, -0.0168511], + [-0.00557165, 0.978053, -0.20828], + [0.0160424, 0.208344, 0.977924], + ], + [ + [-0.999698, 0.0048691, 0.0241226], + [-0.00154306, 0.965899, -0.258915], + [-0.0245606, -0.258874, -0.9656], + ], + ), + ( + [ + [0.992858, -0.0249864, -0.116659], + [0.0419872, 0.988447, 0.145634], + [0.111673, -0.149492, 0.982436], + ], + [ + [-0.992324, 0.0357741, 0.118384], + [-0.0419528, 0.803113, -0.594348], + [-0.116338, -0.594752, -0.795447], + ], + ), + ( + [ + [0.986821, -0.00531913, 0.161729], + [0.00797365, 0.999844, -0.0157688], + [-0.16162, 0.0168505, 0.986709], + ], + [ + [-0.985867, 0.0119402, -0.167109], + [0.0141227, 0.99983, -0.0118784], + [0.166939, -0.0140704, -0.985868], + ], + ), + ( + [ + [0.999693, -0.0158939, -0.0190113], + [0.0103599, 0.96501, -0.262007], + [0.0225104, 0.261729, 0.964879], + ], + [ + [-0.999344, -0.0314781, -0.0180051], + [-0.0250895, 0.241673, 0.970034], + [-0.0261833, 0.969847, -0.242305], + ], + ), + ( + [ + [0.977445, 0.0293661, 0.209138], + [-0.0723687, 0.976903, 0.201057], + [-0.198403, -0.211657, 0.956994], + ], + [ + [-0.976437, 0.00895131, -0.215624], + [0.0552894, 0.976169, -0.20985], + [0.208607, -0.216827, -0.953663], + ], + ), + ( + [ + [0.994593, 0.0974797, -0.0358119], + [-0.0822288, 0.949838, 0.301737], + [0.0634288, -0.297161, 0.952718], + ], + [ + [-0.994192, -0.10746, -0.00604622], + [0.078812, -0.7651, 0.639071], + [-0.0733003, 0.634882, 0.769124], + ], + ), + ( + [ + [0.365674, 0.282077, -0.88697], + [-0.609177, 0.793033, 0.00105565], + [0.703694, 0.539936, 0.461826], + ], + [ + [-0.469534, 0.0109062, 0.882848], + [0.0060785, 0.99994, -0.00911984], + [-0.882894, 0.00108445, -0.469572], + ], + ), + ( + [ + [0.999956, 0.00903085, 0.0025358], + [-0.00862738, 0.991574, -0.129252], + [-0.00368169, 0.129224, 0.991609], + ], + [ + [-0.999976, 0.00322491, -0.00637541], + [0.00379751, 0.995755, -0.0919687], + [0.00605176, -0.0919907, -0.995743], + ], + ), + ( + [ + [0.999982, -0.00398882, -0.00441072], + [0.00411881, 0.999545, 0.0298655], + [0.00428959, -0.0298832, 0.999544], + ], + [ + [-0.999931, -0.00315547, -0.0114491], + [0.00300966, -0.999914, 0.0128304], + [-0.0114875, 0.012796, 0.999853], + ], + ), + ( + [ + [0.996613, 0.0781452, -0.0256245], + [-0.0610516, 0.91178, 0.406116], + [0.0550999, -0.403175, 0.913462], + ], + [ + [-0.996368, -0.084671, 0.00909851], + [0.0540149, -0.545774, 0.83619], + [-0.0658352, 0.833644, 0.548365], + ], + ), + ( + [ + [0.961059, 0.139318, 0.238654], + [-0.117488, 0.987672, -0.103448], + [-0.250124, 0.0713812, 0.965579], + ], + [ + [-0.973397, 0.00782581, -0.228998], + [-0.0621109, 0.952986, 0.296581], + [0.220553, 0.302913, -0.927147], + ], + ), + ( + [ + [0.156415, -0.982138, 0.104589], + [-0.568896, -0.176149, -0.803323], + [0.807398, 0.0661518, -0.586287], + ], + [ + [-0.992155, 0.0934304, -0.0830664], + [-0.121171, -0.555137, 0.822887], + [0.0307694, 0.826496, 0.562102], + ], + ), + ( + [ + [0.997973, 0.0130328, -0.0622976], + [-0.011111, 0.999455, 0.0310968], + [0.0626689, -0.0303416, 0.997573], + ], + [ + [-0.997391, -0.00094697, 0.0722014], + [-0.00271076, 0.9997, -0.024334], + [-0.0721567, -0.024466, -0.997094], + ], + ), + ( + [ + [0.913504, -0.0125928, -0.406634], + [-0.108363, 0.95588, -0.27304], + [0.392132, 0.293487, 0.871836], + ], + [ + [-0.909813, 0.0115348, 0.414861], + [0.128636, 0.958223, 0.255464], + [-0.394582, 0.28579, -0.873287], + ], + ), + ( + [ + [0.932595, -0.0693644, 0.354197], + [0.0984415, 0.993036, -0.0647231], + [-0.347241, 0.0952281, 0.932928], + ], + [ + [-0.930498, 0.00578599, -0.366252], + [-0.106202, 0.952666, 0.284867], + [0.350564, 0.303964, -0.885839], + ], + ), + ( + [ + [0.995668, -0.00475737, 0.0928567], + [0.00890154, 0.99898, -0.0442667], + [-0.0925514, 0.0449015, 0.994695], + ], + [ + [-0.996077, -0.0107986, -0.0878355], + [0.00749423, 0.978669, -0.205305], + [0.0881789, -0.205158, -0.974749], + ], + ), + ( + [ + [0.99948, 0.0321999, 0.00146151], + [-0.0321302, 0.998886, -0.0345513], + [-0.00257243, 0.0344864, 0.999402], + ], + [ + [-0.999953, 0.00726142, -0.0065326], + [0.00488529, 0.950962, 0.30927], + [0.00845801, 0.309223, -0.950953], + ], + ), +] + + +class TestRot3(GtsamTestCase): + """Test selected Rot3 methods.""" + + def test_axisangle(self) -> None: + """Test .axisAngle() method.""" + # fmt: off + R = np.array( + [ + [ -0.999957, 0.00922903, 0.00203116], + [ 0.00926964, 0.999739, 0.0208927], + [ -0.0018374, 0.0209105, -0.999781] + ]) + # fmt: on + + # get back angle in radians + _, actual_angle = Rot3(R).axisAngle() + expected_angle = 3.1396582 + np.testing.assert_almost_equal(actual_angle, expected_angle, 1e-7) + + def test_axis_angle_stress_test(self) -> None: + """Test that .axisAngle() yields angles less than 180 degrees for specific inputs.""" + for (R1, R2) in R1_R2_pairs: + R1 = Rot3(np.array(R1)) + R2 = Rot3(np.array(R2)) + + i1Ri2 = R1.between(R2) + + axis, angle = i1Ri2.axisAngle() + angle_deg = np.rad2deg(angle) + assert angle_deg < 180 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_SfmData.py b/python/gtsam/tests/test_SfmData.py index 9c965ddc5..18c9ef722 100644 --- a/python/gtsam/tests/test_SfmData.py +++ b/python/gtsam/tests/test_SfmData.py @@ -14,9 +14,9 @@ from __future__ import print_function import unittest -import numpy as np - import gtsam +import numpy as np +from gtsam import Point2, Point3, SfmData, SfmTrack from gtsam.utils.test_case import GtsamTestCase @@ -25,55 +25,97 @@ class TestSfmData(GtsamTestCase): def setUp(self): """Initialize SfmData and SfmTrack""" - self.data = gtsam.SfmData() + self.data = SfmData() # initialize SfmTrack with 3D point - self.tracks = gtsam.SfmTrack() + self.tracks = SfmTrack() def test_tracks(self): """Test functions in SfmTrack""" # measurement is of format (camera_idx, imgPoint) # create arbitrary camera indices for two cameras - i1, i2 = 4,5 + i1, i2 = 4, 5 + # create arbitrary image measurements for cameras i1 and i2 - uv_i1 = gtsam.Point2(12.6, 82) + uv_i1 = Point2(12.6, 82) + # translating point uv_i1 along X-axis - uv_i2 = gtsam.Point2(24.88, 82) + uv_i2 = Point2(24.88, 82) + # add measurements to the track - self.tracks.add_measurement(i1, uv_i1) - self.tracks.add_measurement(i2, uv_i2) + self.tracks.addMeasurement(i1, uv_i1) + self.tracks.addMeasurement(i2, uv_i2) + # Number of measurements in the track is 2 - self.assertEqual(self.tracks.number_measurements(), 2) + self.assertEqual(self.tracks.numberMeasurements(), 2) + # camera_idx in the first measurement of the track corresponds to i1 cam_idx, img_measurement = self.tracks.measurement(0) self.assertEqual(cam_idx, i1) np.testing.assert_array_almost_equal( - gtsam.Point3(0.,0.,0.), + Point3(0., 0., 0.), self.tracks.point3() ) - def test_data(self): """Test functions in SfmData""" # Create new track with 3 measurements - i1, i2, i3 = 3,5,6 - uv_i1 = gtsam.Point2(21.23, 45.64) + i1, i2, i3 = 3, 5, 6 + uv_i1 = Point2(21.23, 45.64) + # translating along X-axis - uv_i2 = gtsam.Point2(45.7, 45.64) - uv_i3 = gtsam.Point2(68.35, 45.64) + uv_i2 = Point2(45.7, 45.64) + uv_i3 = Point2(68.35, 45.64) + # add measurements and arbitrary point to the track measurements = [(i1, uv_i1), (i2, uv_i2), (i3, uv_i3)] - pt = gtsam.Point3(1.0, 6.0, 2.0) - track2 = gtsam.SfmTrack(pt) - track2.add_measurement(i1, uv_i1) - track2.add_measurement(i2, uv_i2) - track2.add_measurement(i3, uv_i3) - self.data.add_track(self.tracks) - self.data.add_track(track2) + pt = Point3(1.0, 6.0, 2.0) + track2 = SfmTrack(pt) + track2.addMeasurement(i1, uv_i1) + track2.addMeasurement(i2, uv_i2) + track2.addMeasurement(i3, uv_i3) + self.data.addTrack(self.tracks) + self.data.addTrack(track2) + # Number of tracks in SfmData is 2 - self.assertEqual(self.data.number_tracks(), 2) + self.assertEqual(self.data.numberTracks(), 2) + # camera idx of first measurement of second track corresponds to i1 cam_idx, img_measurement = self.data.track(1).measurement(0) self.assertEqual(cam_idx, i1) + def test_Balbianello(self): + """ Check that we can successfully read a bundler file and create a + factor graph from it + """ + # The structure where we will save the SfM data + filename = gtsam.findExampleDataFile("Balbianello.out") + sfm_data = SfmData.FromBundlerFile(filename) + + # Check number of things + self.assertEqual(5, sfm_data.numberCameras()) + self.assertEqual(544, sfm_data.numberTracks()) + track0 = sfm_data.track(0) + self.assertEqual(3, track0.numberMeasurements()) + + # Check projection of a given point + self.assertEqual(0, track0.measurement(0)[0]) + camera0 = sfm_data.camera(0) + expected = camera0.project(track0.point3()) + actual = track0.measurement(0)[1] + self.gtsamAssertEquals(expected, actual, 1) + + # We share *one* noiseModel between all projection factors + model = gtsam.noiseModel.Isotropic.Sigma( + 2, 1.0) # one pixel in u and v + + # Convert to NonlinearFactorGraph + graph = sfm_data.sfmFactorGraph(model) + self.assertEqual(1419, graph.size()) # regression + + # Get initial estimate + values = gtsam.initialCamerasAndPointsEstimate(sfm_data) + self.assertEqual(549, values.size()) # regression + + if __name__ == '__main__': unittest.main() diff --git a/python/gtsam/tests/test_Sim2.py b/python/gtsam/tests/test_Sim2.py new file mode 100644 index 000000000..ea809b965 --- /dev/null +++ b/python/gtsam/tests/test_Sim2.py @@ -0,0 +1,194 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Sim3 unit tests. +Author: John Lambert +""" +# pylint: disable=no-name-in-module +import unittest + +import numpy as np +from gtsam import Pose2, Pose2Pairs, Rot2, Similarity2 +from gtsam.utils.test_case import GtsamTestCase + + +class TestSim2(GtsamTestCase): + """Test selected Sim2 methods.""" + + def test_align_poses_along_straight_line(self) -> None: + """Test Align Pose2Pairs method. + + Scenario: + 3 object poses + same scale (no gauge ambiguity) + world frame has poses rotated about 180 degrees. + world and egovehicle frame translated by 15 meters w.r.t. each other + """ + R180 = Rot2.fromDegrees(180) + + # Create source poses (three objects o1, o2, o3 living in the egovehicle "e" frame) + # Suppose they are 3d cuboids detected by an onboard sensor in the egovehicle frame + eTo0 = Pose2(Rot2(), np.array([5, 0])) + eTo1 = Pose2(Rot2(), np.array([10, 0])) + eTo2 = Pose2(Rot2(), np.array([15, 0])) + + eToi_list = [eTo0, eTo1, eTo2] + + # Create destination poses + # (same three objects, but instead living in the world "w" frame) + wTo0 = Pose2(R180, np.array([-10, 0])) + wTo1 = Pose2(R180, np.array([-5, 0])) + wTo2 = Pose2(R180, np.array([0, 0])) + + wToi_list = [wTo0, wTo1, wTo2] + + we_pairs = Pose2Pairs(list(zip(wToi_list, eToi_list))) + + # Recover the transformation wSe (i.e. world_S_egovehicle) + wSe = Similarity2.Align(we_pairs) + + for wToi, eToi in zip(wToi_list, eToi_list): + self.gtsamAssertEquals(wToi, wSe.transformFrom(eToi)) + + def test_align_poses_along_straight_line_gauge(self): + """Test if Align Pose3Pairs method can account for gauge ambiguity. + + Scenario: + 3 object poses + with gauge ambiguity (2x scale) + world frame has poses rotated by 90 degrees. + world and egovehicle frame translated by 11 meters w.r.t. each other + """ + R90 = Rot2.fromDegrees(90) + + # Create source poses (three objects o1, o2, o3 living in the egovehicle "e" frame) + # Suppose they are 3d cuboids detected by an onboard sensor in the egovehicle frame + eTo0 = Pose2(Rot2(), np.array([1, 0])) + eTo1 = Pose2(Rot2(), np.array([2, 0])) + eTo2 = Pose2(Rot2(), np.array([4, 0])) + + eToi_list = [eTo0, eTo1, eTo2] + + # Create destination poses + # (same three objects, but instead living in the world/city "w" frame) + wTo0 = Pose2(R90, np.array([0, 12])) + wTo1 = Pose2(R90, np.array([0, 14])) + wTo2 = Pose2(R90, np.array([0, 18])) + + wToi_list = [wTo0, wTo1, wTo2] + + we_pairs = Pose2Pairs(list(zip(wToi_list, eToi_list))) + + # Recover the transformation wSe (i.e. world_S_egovehicle) + wSe = Similarity2.Align(we_pairs) + + for wToi, eToi in zip(wToi_list, eToi_list): + self.gtsamAssertEquals(wToi, wSe.transformFrom(eToi)) + + def test_align_poses_scaled_squares(self): + """Test if Align Pose2Pairs method can account for gauge ambiguity. + + Make sure a big and small square can be aligned. + The u's represent a big square (10x10), and v's represents a small square (4x4). + + Scenario: + 4 object poses + with gauge ambiguity (2.5x scale) + """ + # 0, 90, 180, and 270 degrees yaw + R0 = Rot2.fromDegrees(0) + R90 = Rot2.fromDegrees(90) + R180 = Rot2.fromDegrees(180) + R270 = Rot2.fromDegrees(270) + + aTi0 = Pose2(R0, np.array([2, 3])) + aTi1 = Pose2(R90, np.array([12, 3])) + aTi2 = Pose2(R180, np.array([12, 13])) + aTi3 = Pose2(R270, np.array([2, 13])) + + aTi_list = [aTi0, aTi1, aTi2, aTi3] + + bTi0 = Pose2(R0, np.array([4, 3])) + bTi1 = Pose2(R90, np.array([8, 3])) + bTi2 = Pose2(R180, np.array([8, 7])) + bTi3 = Pose2(R270, np.array([4, 7])) + + bTi_list = [bTi0, bTi1, bTi2, bTi3] + + ab_pairs = Pose2Pairs(list(zip(aTi_list, bTi_list))) + + # Recover the transformation wSe (i.e. world_S_egovehicle) + aSb = Similarity2.Align(ab_pairs) + + for aTi, bTi in zip(aTi_list, bTi_list): + self.gtsamAssertEquals(aTi, aSb.transformFrom(bTi)) + + def test_constructor(self) -> None: + """Sim(2) to perform p_b = bSa * p_a""" + bRa = Rot2() + bta = np.array([1, 2]) + bsa = 3.0 + bSa = Similarity2(R=bRa, t=bta, s=bsa) + self.assertIsInstance(bSa, Similarity2) + np.testing.assert_allclose(bSa.rotation().matrix(), bRa.matrix()) + np.testing.assert_allclose(bSa.translation(), bta) + np.testing.assert_allclose(bSa.scale(), bsa) + + def test_is_eq(self) -> None: + """Ensure object equality works properly (are equal).""" + bSa = Similarity2(R=Rot2(), t=np.array([1, 2]), s=3.0) + bSa_ = Similarity2(R=Rot2(), t=np.array([1.0, 2.0]), s=3) + self.gtsamAssertEquals(bSa, bSa_) + + def test_not_eq_translation(self) -> None: + """Ensure object equality works properly (not equal translation).""" + bSa = Similarity2(R=Rot2(), t=np.array([2, 1]), s=3.0) + bSa_ = Similarity2(R=Rot2(), t=np.array([1.0, 2.0]), s=3) + self.assertNotEqual(bSa, bSa_) + + def test_not_eq_rotation(self) -> None: + """Ensure object equality works properly (not equal rotation).""" + bSa = Similarity2(R=Rot2(), t=np.array([2, 1]), s=3.0) + bSa_ = Similarity2(R=Rot2.fromDegrees(180), t=np.array([2.0, 1.0]), s=3) + self.assertNotEqual(bSa, bSa_) + + def test_not_eq_scale(self) -> None: + """Ensure object equality works properly (not equal scale).""" + bSa = Similarity2(R=Rot2(), t=np.array([2, 1]), s=3.0) + bSa_ = Similarity2(R=Rot2(), t=np.array([2.0, 1.0]), s=1.0) + self.assertNotEqual(bSa, bSa_) + + def test_rotation(self) -> None: + """Ensure rotation component is returned properly.""" + R = Rot2.fromDegrees(90) + t = np.array([1, 2]) + bSa = Similarity2(R=R, t=t, s=3.0) + + # evaluates to [[0, -1], [1, 0]] + expected_R = Rot2.fromDegrees(90) + np.testing.assert_allclose(expected_R.matrix(), bSa.rotation().matrix()) + + def test_translation(self) -> None: + """Ensure translation component is returned properly.""" + R = Rot2.fromDegrees(90) + t = np.array([1, 2]) + bSa = Similarity2(R=R, t=t, s=3.0) + + expected_t = np.array([1, 2]) + np.testing.assert_allclose(expected_t, bSa.translation()) + + def test_scale(self) -> None: + """Ensure the scale factor is returned properly.""" + bRa = Rot2() + bta = np.array([1, 2]) + bsa = 3.0 + bSa = Similarity2(R=bRa, t=bta, s=bsa) + self.assertEqual(bSa.scale(), 3.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_Sim3.py b/python/gtsam/tests/test_Sim3.py index 001321e2c..c00a36435 100644 --- a/python/gtsam/tests/test_Sim3.py +++ b/python/gtsam/tests/test_Sim3.py @@ -10,6 +10,7 @@ Author: John Lambert """ # pylint: disable=no-name-in-module import unittest +from typing import List, Optional import numpy as np @@ -129,6 +130,587 @@ class TestSim3(GtsamTestCase): for aTi, bTi in zip(aTi_list, bTi_list): self.gtsamAssertEquals(aTi, aSb.transformFrom(bTi)) + def test_align_via_Sim3_to_poses_skydio32(self) -> None: + """Ensure scale estimate of Sim(3) object is non-negative. + + Comes from real data (from Skydio-32 Crane Mast sequence with a SIFT front-end). + """ + poses_gt = [ + Pose3( + Rot3( + [ + [0.696305769, -0.0106830792, -0.717665705], + [0.00546412488, 0.999939148, -0.00958346857], + [0.717724415, 0.00275160848, 0.696321772], + ] + ), + Point3(5.83077801, -0.94815149, 0.397751679), + ), + Pose3( + Rot3( + [ + [0.692272397, -0.00529704529, -0.721616549], + [0.00634689669, 0.999979075, -0.00125157022], + [0.721608079, -0.0037136016, 0.692291531], + ] + ), + Point3(5.03853323, -0.97547405, -0.348177392), + ), + Pose3( + Rot3( + [ + [0.945991981, -0.00633548292, -0.324128225], + [0.00450436485, 0.999969379, -0.00639931046], + [0.324158843, 0.00459370582, 0.945991552], + ] + ), + Point3(4.13186176, -0.956364218, -0.796029527), + ), + Pose3( + Rot3( + [ + [0.999553623, -0.00346470207, -0.0296740626], + [0.00346104216, 0.999993995, -0.00017469881], + [0.0296744897, 7.19175654e-05, 0.999559612], + ] + ), + Point3(3.1113898, -0.928583423, -0.90539337), + ), + Pose3( + Rot3( + [ + [0.967850252, -0.00144846042, 0.251522892], + [0.000254511591, 0.999988546, 0.00477934325], + [-0.251526934, -0.00456167299, 0.967839535], + ] + ), + Point3(2.10584013, -0.921303194, -0.809322971), + ), + Pose3( + Rot3( + [ + [0.969854065, 0.000629052774, 0.243685716], + [0.000387180179, 0.999991428, -0.00412234326], + [-0.243686221, 0.00409242166, 0.969845508], + ] + ), + Point3(1.0753788, -0.913035975, -0.616584091), + ), + Pose3( + Rot3( + [ + [0.998189342, 0.00110235337, 0.0601400045], + [-0.00110890447, 0.999999382, 7.55559042e-05], + [-0.060139884, -0.000142108649, 0.998189948], + ] + ), + Point3(0.029993558, -0.951495122, -0.425525143), + ), + Pose3( + Rot3( + [ + [0.999999996, -2.62868666e-05, -8.67178281e-05], + [2.62791334e-05, 0.999999996, -8.91767396e-05], + [8.67201719e-05, 8.91744604e-05, 0.999999992], + ] + ), + Point3(-0.973569417, -0.936340994, -0.253464928), + ), + Pose3( + Rot3( + [ + [0.99481227, -0.00153645011, 0.101716252], + [0.000916919443, 0.999980747, 0.00613725239], + [-0.101723724, -0.00601214847, 0.994794525], + ] + ), + Point3(-2.02071256, -0.955446292, -0.240707879), + ), + Pose3( + Rot3( + [ + [0.89795602, -0.00978591184, 0.43997636], + [0.00645921401, 0.999938116, 0.00905779513], + [-0.440037771, -0.00529159974, 0.89796366], + ] + ), + Point3(-2.94096695, -0.939974858, 0.0934225593), + ), + Pose3( + Rot3( + [ + [0.726299119, -0.00916784876, 0.687318077], + [0.00892018672, 0.999952563, 0.0039118575], + [-0.687321336, 0.00328981905, 0.726346444], + ] + ), + Point3(-3.72843416, -0.897889251, 0.685129502), + ), + Pose3( + Rot3( + [ + [0.506756029, -0.000331706105, 0.862089858], + [0.00613841257, 0.999975964, -0.00322354286], + [-0.862068067, 0.00692541035, 0.506745885], + ] + ), + Point3(-4.3909926, -0.890883291, 1.43029524), + ), + Pose3( + Rot3( + [ + [0.129316352, -0.00206958814, 0.991601896], + [0.00515932597, 0.999985691, 0.00141424797], + [-0.991590634, 0.00493310721, 0.129325179], + ] + ), + Point3(-4.58510846, -0.922534227, 2.36884523), + ), + Pose3( + Rot3( + [ + [0.599853194, -0.00890004681, -0.800060263], + [0.00313716318, 0.999956608, -0.00877161373], + [0.800103615, 0.00275175707, 0.599855085], + ] + ), + Point3(5.71559638, 0.486863076, 0.279141372), + ), + Pose3( + Rot3( + [ + [0.762552447, 0.000836438681, -0.646926069], + [0.00211337894, 0.999990607, 0.00378404105], + [0.646923157, -0.00425272942, 0.762543517], + ] + ), + Point3(5.00243443, 0.513321893, -0.466921769), + ), + Pose3( + Rot3( + [ + [0.930381645, -0.00340164355, -0.36657678], + [0.00425636616, 0.999989781, 0.00152338305], + [0.366567852, -0.00297761145, 0.930386617], + ] + ), + Point3(4.05404984, 0.493385291, -0.827904571), + ), + Pose3( + Rot3( + [ + [0.999996073, -0.00278379707, -0.000323508543], + [0.00278790921, 0.999905063, 0.0134941517], + [0.000285912831, -0.0134950006, 0.999908897], + ] + ), + Point3(3.04724478, 0.491451306, -0.989571061), + ), + Pose3( + Rot3( + [ + [0.968578343, -0.002544616, 0.248695527], + [0.000806130148, 0.999974526, 0.00709200332], + [-0.248707238, -0.0066686795, 0.968555721], + ] + ), + Point3(2.05737869, 0.46840177, -0.546344594), + ), + Pose3( + Rot3( + [ + [0.968827882, 0.000182770584, 0.247734722], + [-0.000558107079, 0.9999988, 0.00144484904], + [-0.24773416, -0.00153807255, 0.968826821], + ] + ), + Point3(1.14019947, 0.469674641, -0.0491053805), + ), + Pose3( + Rot3( + [ + [0.991647805, 0.00197867892, 0.128960146], + [-0.00247518407, 0.999990129, 0.00368991165], + [-0.128951572, -0.00397829284, 0.991642914], + ] + ), + Point3(0.150270471, 0.457867448, 0.103628642), + ), + Pose3( + Rot3( + [ + [0.992244594, 0.00477781876, -0.124208847], + [-0.0037682125, 0.999957938, 0.00836195891], + [0.124243574, -0.00782906317, 0.992220862], + ] + ), + Point3(-0.937954641, 0.440532658, 0.154265069), + ), + Pose3( + Rot3( + [ + [0.999591078, 0.00215462857, -0.0285137564], + [-0.00183807224, 0.999936443, 0.0111234301], + [0.028535911, -0.0110664711, 0.999531507], + ] + ), + Point3(-1.95622231, 0.448914367, -0.0859439782), + ), + Pose3( + Rot3( + [ + [0.931835342, 0.000956922238, 0.362880212], + [0.000941640753, 0.99998678, -0.00505501434], + [-0.362880252, 0.00505214382, 0.931822122], + ] + ), + Point3(-2.85557418, 0.434739285, 0.0793777177), + ), + Pose3( + Rot3( + [ + [0.781615218, -0.0109886966, 0.623664238], + [0.00516954657, 0.999924591, 0.011139446], + [-0.623739616, -0.00548270158, 0.781613084], + ] + ), + Point3(-3.67524552, 0.444074681, 0.583718622), + ), + Pose3( + Rot3( + [ + [0.521291761, 0.00264805046, 0.853374051], + [0.00659087718, 0.999952868, -0.00712898365], + [-0.853352707, 0.00934076542, 0.521249738], + ] + ), + Point3(-4.35541796, 0.413479707, 1.31179007), + ), + Pose3( + Rot3( + [ + [0.320164205, -0.00890839482, 0.947319884], + [0.00458409304, 0.999958649, 0.007854118], + [-0.947350678, 0.00182799903, 0.320191803], + ] + ), + Point3(-4.71617526, 0.476674479, 2.16502998), + ), + Pose3( + Rot3( + [ + [0.464861609, 0.0268597443, -0.884976415], + [-0.00947397841, 0.999633409, 0.0253631906], + [0.885333239, -0.00340614699, 0.464945663], + ] + ), + Point3(6.11772094, 1.63029238, 0.491786626), + ), + Pose3( + Rot3( + [ + [0.691647251, 0.0216006293, -0.721912024], + [-0.0093228132, 0.999736395, 0.020981541], + [0.722174939, -0.00778156302, 0.691666308], + ] + ), + Point3(5.46912979, 1.68759322, -0.288499782), + ), + Pose3( + Rot3( + [ + [0.921208931, 0.00622640471, -0.389018433], + [-0.00686296262, 0.999976419, -0.000246683913], + [0.389007724, 0.00289706631, 0.92122994], + ] + ), + Point3(4.70156942, 1.72186229, -0.806181015), + ), + Pose3( + Rot3( + [ + [0.822397705, 0.00276497594, 0.568906142], + [0.00804891535, 0.999831556, -0.016494662], + [-0.568855921, 0.0181442503, 0.822236923], + ] + ), + Point3(-3.51368714, 1.59619714, 0.437437437), + ), + Pose3( + Rot3( + [ + [0.726822937, -0.00545541524, 0.686803193], + [0.00913794245, 0.999956756, -0.00172754968], + [-0.686764068, 0.00753159111, 0.726841357], + ] + ), + Point3(-4.29737821, 1.61462527, 1.11537749), + ), + Pose3( + Rot3( + [ + [0.402595481, 0.00697612855, 0.915351441], + [0.0114113638, 0.999855006, -0.0126391687], + [-0.915306892, 0.0155338804, 0.4024575], + ] + ), + Point3(-4.6516433, 1.6323107, 1.96579585), + ), + ] + # from estimated cameras + unaligned_pose_dict = { + 2: Pose3( + Rot3( + [ + [-0.681949, -0.568276, 0.460444], + [0.572389, -0.0227514, 0.819667], + [-0.455321, 0.822524, 0.34079], + ] + ), + Point3(-1.52189, 0.78906, -1.60608), + ), + 4: Pose3( + Rot3( + [ + [-0.817805393, -0.575044816, 0.022755196], + [0.0478829397, -0.0285875849, 0.998443776], + [-0.573499401, 0.81762229, 0.0509139174], + ] + ), + Point3(-1.22653168, 0.686485651, -1.39294168), + ), + 3: Pose3( + Rot3( + [ + [-0.783051568, -0.571905041, 0.244448085], + [0.314861464, -0.0255673164, 0.948793218], + [-0.536369743, 0.819921299, 0.200091385], + ] + ), + Point3(-1.37620079, 0.721408674, -1.49945316), + ), + 5: Pose3( + Rot3( + [ + [-0.818916586, -0.572896131, 0.0341415873], + [0.0550548476, -0.0192038786, 0.99829864], + [-0.571265778, 0.819402974, 0.0472670839], + ] + ), + Point3(-1.06370243, 0.663084159, -1.27672831), + ), + 6: Pose3( + Rot3( + [ + [-0.798825521, -0.571995242, 0.186277293], + [0.243311017, -0.0240196245, 0.969650869], + [-0.550161372, 0.819905178, 0.158360233], + ] + ), + Point3(-0.896250742, 0.640768239, -1.16984756), + ), + 7: Pose3( + Rot3( + [ + [-0.786416666, -0.570215296, 0.237493882], + [0.305475635, -0.0248440676, 0.951875732], + [-0.536873788, 0.821119534, 0.193724669], + ] + ), + Point3(-0.740385043, 0.613956842, -1.05908579), + ), + 8: Pose3( + Rot3( + [ + [-0.806252832, -0.57019757, 0.157578877], + [0.211046715, -0.0283979846, 0.977063375], + [-0.55264424, 0.821016617, 0.143234279], + ] + ), + Point3(-0.58333517, 0.549832698, -0.9542864), + ), + 9: Pose3( + Rot3( + [ + [-0.821191354, -0.557772774, -0.120558255], + [-0.125347331, -0.0297958331, 0.991665395], + [-0.556716092, 0.829458703, -0.0454472483], + ] + ), + Point3(-0.436483039, 0.55003923, -0.850733187), + ), + 21: Pose3( + Rot3( + [ + [-0.778607603, -0.575075476, 0.251114312], + [0.334920968, -0.0424301164, 0.941290407], + [-0.53065822, 0.816999316, 0.225641247], + ] + ), + Point3(-0.736735967, 0.571415247, -0.738663611), + ), + 17: Pose3( + Rot3( + [ + [-0.818569806, -0.573904529, 0.0240221722], + [0.0512889176, -0.0313725422, 0.998190969], + [-0.572112681, 0.818321059, 0.0551155579], + ] + ), + Point3(-1.36150982, 0.724829031, -1.16055631), + ), + 18: Pose3( + Rot3( + [ + [-0.812668105, -0.582027424, 0.0285417146], + [0.0570298244, -0.0306936169, 0.997900547], + [-0.579929436, 0.812589675, 0.0581366453], + ] + ), + Point3(-1.20484771, 0.762370343, -1.05057127), + ), + 20: Pose3( + Rot3( + [ + [-0.748446406, -0.580905382, 0.319963926], + [0.416860654, -0.0368374152, 0.908223651], + [-0.515805363, 0.813137099, 0.269727429], + ] + ), + Point3(569.449421, -907.892555, -794.585647), + ), + 22: Pose3( + Rot3( + [ + [-0.826878177, -0.559495019, -0.0569017041], + [-0.0452256802, -0.0346974602, 0.99837404], + [-0.560559647, 0.828107125, 0.00338702978], + ] + ), + Point3(-0.591431172, 0.55422253, -0.654656597), + ), + 29: Pose3( + Rot3( + [ + [-0.785759779, -0.574532433, -0.229115805], + [-0.246020939, -0.049553424, 0.967996981], + [-0.567499134, 0.81698038, -0.102409921], + ] + ), + Point3(69.4916073, 240.595227, -493.278045), + ), + 23: Pose3( + Rot3( + [ + [-0.783524382, -0.548569702, -0.291823276], + [-0.316457553, -0.051878563, 0.94718701], + [-0.534737468, 0.834493797, -0.132950906], + ] + ), + Point3(-5.93496204, 41.9304933, -3.06881633), + ), + 10: Pose3( + Rot3( + [ + [-0.766833992, -0.537641809, -0.350580824], + [-0.389506676, -0.0443270797, 0.919956336], + [-0.510147213, 0.84200736, -0.175423563], + ] + ), + Point3(234.185458, 326.007989, -691.769777), + ), + 30: Pose3( + Rot3( + [ + [-0.754844165, -0.559278755, -0.342662459], + [-0.375790683, -0.0594160018, 0.92479787], + [-0.537579435, 0.826847636, -0.165321923], + ] + ), + Point3(-5.93398168, 41.9107972, -3.07385081), + ), + } + + unaligned_pose_list = [] + for i in range(32): + wTi = unaligned_pose_dict.get(i, None) + unaligned_pose_list.append(wTi) + # GT poses are the reference/target + rSe = align_poses_sim3_ignore_missing(aTi_list=poses_gt, bTi_list=unaligned_pose_list) + assert rSe.scale() >= 0 + + +def align_poses_sim3_ignore_missing(aTi_list: List[Optional[Pose3]], bTi_list: List[Optional[Pose3]]) -> Similarity3: + """Align by similarity transformation, but allow missing estimated poses in the input. + + Note: this is a wrapper for align_poses_sim3() that allows for missing poses/dropped cameras. + This is necessary, as align_poses_sim3() requires a valid pose for every input pair. + + We force SIM(3) alignment rather than SE(3) alignment. + We assume the two trajectories are of the exact same length. + + Args: + aTi_list: reference poses in frame "a" which are the targets for alignment + bTi_list: input poses which need to be aligned to frame "a" + + Returns: + aSb: Similarity(3) object that aligns the two pose graphs. + """ + assert len(aTi_list) == len(bTi_list) + + # only choose target poses for which there is a corresponding estimated pose + corresponding_aTi_list = [] + valid_camera_idxs = [] + valid_bTi_list = [] + for i, bTi in enumerate(bTi_list): + if bTi is not None: + valid_camera_idxs.append(i) + valid_bTi_list.append(bTi) + corresponding_aTi_list.append(aTi_list[i]) + + aSb = align_poses_sim3(aTi_list=corresponding_aTi_list, bTi_list=valid_bTi_list) + return aSb + + +def align_poses_sim3(aTi_list: List[Pose3], bTi_list: List[Pose3]) -> Similarity3: + """Align two pose graphs via similarity transformation. Note: poses cannot be missing/invalid. + + We force SIM(3) alignment rather than SE(3) alignment. + We assume the two trajectories are of the exact same length. + + Args: + aTi_list: reference poses in frame "a" which are the targets for alignment + bTi_list: input poses which need to be aligned to frame "a" + + Returns: + aSb: Similarity(3) object that aligns the two pose graphs. + """ + n_to_align = len(aTi_list) + assert len(aTi_list) == len(bTi_list) + assert n_to_align >= 2, "SIM(3) alignment uses at least 2 frames" + + ab_pairs = Pose3Pairs(list(zip(aTi_list, bTi_list))) + + aSb = Similarity3.Align(ab_pairs) + + if np.isnan(aSb.scale()) or aSb.scale() == 0: + # we have run into a case where points have no translation between them (i.e. panorama). + # We will first align the rotations and then align the translation by using centroids. + # TODO: handle it in GTSAM + + # align the rotations first, so that we can find the translation between the two panoramas + aSb = Similarity3(aSb.rotation(), np.zeros((3,)), 1.0) + aTi_list_rot_aligned = [aSb.transformFrom(bTi) for bTi in bTi_list] + + # fit a single translation motion to the centroid + aTi_centroid = np.array([aTi.translation() for aTi in aTi_list]).mean(axis=0) + aTi_rot_aligned_centroid = np.array([aTi.translation() for aTi in aTi_list_rot_aligned]).mean(axis=0) + + # construct the final SIM3 transform + aSb = Similarity3(aSb.rotation(), aTi_centroid - aTi_rot_aligned_centroid, 1.0) + + return aSb + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_Triangulation.py b/python/gtsam/tests/test_Triangulation.py index 399cf019d..0a258a0af 100644 --- a/python/gtsam/tests/test_Triangulation.py +++ b/python/gtsam/tests/test_Triangulation.py @@ -6,28 +6,40 @@ All Rights Reserved See LICENSE for the license information Test Triangulation -Author: Frank Dellaert & Fan Jiang (Python) +Authors: Frank Dellaert & Fan Jiang (Python) & Sushmita Warrier & John Lambert """ import unittest +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import gtsam -from gtsam import (Cal3_S2, Cal3Bundler, CameraSetCal3_S2, - CameraSetCal3Bundler, PinholeCameraCal3_S2, - PinholeCameraCal3Bundler, Point2Vector, Point3, Pose3, - Pose3Vector, Rot3) +from gtsam import ( + Cal3_S2, + Cal3Bundler, + CameraSetCal3_S2, + CameraSetCal3Bundler, + PinholeCameraCal3_S2, + PinholeCameraCal3Bundler, + Point2, + Point2Vector, + Point3, + Pose3, + Pose3Vector, + Rot3, +) from gtsam.utils.test_case import GtsamTestCase +UPRIGHT = Rot3.Ypr(-np.pi / 2, 0.0, -np.pi / 2) -class TestVisualISAMExample(GtsamTestCase): - """ Tests for triangulation with shared and individual calibrations """ + +class TestTriangulationExample(GtsamTestCase): + """Tests for triangulation with shared and individual calibrations""" def setUp(self): - """ Set up two camera poses """ + """Set up two camera poses""" # Looking along X-axis, 1 meter above ground plane (x-y) - upright = Rot3.Ypr(-np.pi / 2, 0., -np.pi / 2) - pose1 = Pose3(upright, Point3(0, 0, 1)) + pose1 = Pose3(UPRIGHT, Point3(0, 0, 1)) # create second camera 1 meter to the right of first camera pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0))) @@ -39,15 +51,24 @@ class TestVisualISAMExample(GtsamTestCase): # landmark ~5 meters infront of camera self.landmark = Point3(5, 0.5, 1.2) - def generate_measurements(self, calibration, camera_model, cal_params, camera_set=None): + def generate_measurements( + self, + calibration: Union[Cal3Bundler, Cal3_S2], + camera_model: Union[PinholeCameraCal3Bundler, PinholeCameraCal3_S2], + cal_params: Iterable[Iterable[Union[int, float]]], + camera_set: Optional[Union[CameraSetCal3Bundler, + CameraSetCal3_S2]] = None, + ) -> Tuple[Point2Vector, Union[CameraSetCal3Bundler, CameraSetCal3_S2, + List[Cal3Bundler], List[Cal3_S2]]]: """ Generate vector of measurements for given calibration and camera model. - Args: + Args: calibration: Camera calibration e.g. Cal3_S2 camera_model: Camera model e.g. PinholeCameraCal3_S2 cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2] camera_set: Cameraset object (for individual calibrations) + Returns: list of measurements and list/CameraSet object for cameras """ @@ -66,14 +87,15 @@ class TestVisualISAMExample(GtsamTestCase): return measurements, cameras - def test_TriangulationExample(self): - """ Tests triangulation with shared Cal3_S2 calibration""" + def test_TriangulationExample(self) -> None: + """Tests triangulation with shared Cal3_S2 calibration""" # Some common constants sharedCal = (1500, 1200, 0, 640, 480) - measurements, _ = self.generate_measurements(Cal3_S2, - PinholeCameraCal3_S2, - (sharedCal, sharedCal)) + measurements, _ = self.generate_measurements( + calibration=Cal3_S2, + camera_model=PinholeCameraCal3_S2, + cal_params=(sharedCal, sharedCal)) triangulated_landmark = gtsam.triangulatePoint3(self.poses, Cal3_S2(sharedCal), @@ -95,16 +117,17 @@ class TestVisualISAMExample(GtsamTestCase): self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2) - def test_distinct_Ks(self): - """ Tests triangulation with individual Cal3_S2 calibrations """ + def test_distinct_Ks(self) -> None: + """Tests triangulation with individual Cal3_S2 calibrations""" # two camera parameters K1 = (1500, 1200, 0, 640, 480) K2 = (1600, 1300, 0, 650, 440) - measurements, cameras = self.generate_measurements(Cal3_S2, - PinholeCameraCal3_S2, - (K1, K2), - camera_set=CameraSetCal3_S2) + measurements, cameras = self.generate_measurements( + calibration=Cal3_S2, + camera_model=PinholeCameraCal3_S2, + cal_params=(K1, K2), + camera_set=CameraSetCal3_S2) triangulated_landmark = gtsam.triangulatePoint3(cameras, measurements, @@ -112,16 +135,17 @@ class TestVisualISAMExample(GtsamTestCase): optimize=True) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) - def test_distinct_Ks_Bundler(self): - """ Tests triangulation with individual Cal3Bundler calibrations""" + def test_distinct_Ks_Bundler(self) -> None: + """Tests triangulation with individual Cal3Bundler calibrations""" # two camera parameters K1 = (1500, 0, 0, 640, 480) K2 = (1600, 0, 0, 650, 440) - measurements, cameras = self.generate_measurements(Cal3Bundler, - PinholeCameraCal3Bundler, - (K1, K2), - camera_set=CameraSetCal3Bundler) + measurements, cameras = self.generate_measurements( + calibration=Cal3Bundler, + camera_model=PinholeCameraCal3Bundler, + cal_params=(K1, K2), + camera_set=CameraSetCal3Bundler) triangulated_landmark = gtsam.triangulatePoint3(cameras, measurements, @@ -129,6 +153,71 @@ class TestVisualISAMExample(GtsamTestCase): optimize=True) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) + def test_triangulation_robust_three_poses(self) -> None: + """Ensure triangulation with a robust model works.""" + sharedCal = Cal3_S2(1500, 1200, 0, 640, 480) + + # landmark ~5 meters infront of camera + landmark = Point3(5, 0.5, 1.2) + + pose1 = Pose3(UPRIGHT, Point3(0, 0, 1)) + pose2 = pose1 * Pose3(Rot3(), Point3(1, 0, 0)) + pose3 = pose1 * Pose3(Rot3.Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -0.1)) + + camera1 = PinholeCameraCal3_S2(pose1, sharedCal) + camera2 = PinholeCameraCal3_S2(pose2, sharedCal) + camera3 = PinholeCameraCal3_S2(pose3, sharedCal) + + z1: Point2 = camera1.project(landmark) + z2: Point2 = camera2.project(landmark) + z3: Point2 = camera3.project(landmark) + + poses = gtsam.Pose3Vector([pose1, pose2, pose3]) + measurements = Point2Vector([z1, z2, z3]) + + # noise free, so should give exactly the landmark + actual = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=False) + self.assertTrue(np.allclose(landmark, actual, atol=1e-2)) + + # Add outlier + measurements[0] += Point2(100, 120) # very large pixel noise! + + # now estimate does not match landmark + actual2 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=False) + # DLT is surprisingly robust, but still off (actual error is around 0.26m) + self.assertTrue(np.linalg.norm(landmark - actual2) >= 0.2) + self.assertTrue(np.linalg.norm(landmark - actual2) <= 0.5) + + # Again with nonlinear optimization + actual3 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=True) + # result from nonlinear (but non-robust optimization) is close to DLT and still off + self.assertTrue(np.allclose(actual2, actual3, atol=0.1)) + + # Again with nonlinear optimization, this time with robust loss + model = gtsam.noiseModel.Robust.Create( + gtsam.noiseModel.mEstimator.Huber.Create(1.345), + gtsam.noiseModel.Unit.Create(2)) + actual4 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=True, + model=model) + # using the Huber loss we now have a quite small error!! nice! + self.assertTrue(np.allclose(landmark, actual4, atol=0.05)) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_VisualISAMExample.py b/python/gtsam/tests/test_VisualISAMExample.py index 6eb05eeee..ebc77e2e3 100644 --- a/python/gtsam/tests/test_VisualISAMExample.py +++ b/python/gtsam/tests/test_VisualISAMExample.py @@ -10,9 +10,6 @@ Author: Frank Dellaert & Duy Nguyen Ta (Python) """ import unittest -import numpy as np - -import gtsam import gtsam.utils.visual_data_generator as generator import gtsam.utils.visual_isam as visual_isam from gtsam import symbol @@ -20,8 +17,9 @@ from gtsam.utils.test_case import GtsamTestCase class TestVisualISAMExample(GtsamTestCase): - + """Test class for ISAM2 with visual landmarks.""" def test_VisualISAMExample(self): + """Test to see if ISAM works as expected for a simple visual SLAM example.""" # Data Options options = generator.Options() options.triangle = False @@ -39,19 +37,22 @@ class TestVisualISAMExample(GtsamTestCase): data, truth = generator.generate_data(options) # Initialize iSAM with the first pose and points - isam, result, nextPose = visual_isam.initialize(data, truth, isamOptions) + isam, result, nextPose = visual_isam.initialize( + data, truth, isamOptions) # Main loop for iSAM: stepping through all poses for currentPose in range(nextPose, options.nrCameras): - isam, result = visual_isam.step(data, isam, result, truth, currentPose) + isam, result = visual_isam.step(data, isam, result, truth, + currentPose) - for i in range(len(truth.cameras)): + for i, _ in enumerate(truth.cameras): pose_i = result.atPose3(symbol('x', i)) self.gtsamAssertEquals(pose_i, truth.cameras[i].pose(), 1e-5) - for j in range(len(truth.points)): + for j, _ in enumerate(truth.points): point_j = result.atPoint3(symbol('l', j)) self.gtsamAssertEquals(point_j, truth.points[j], 1e-5) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_lago.py b/python/gtsam/tests/test_lago.py new file mode 100644 index 000000000..8ed5dd943 --- /dev/null +++ b/python/gtsam/tests/test_lago.py @@ -0,0 +1,38 @@ +""" +GTSAM Copyright 2010, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved +Authors: Frank Dellaert, et al. (see THANKS for the full author list) +See LICENSE for the license information + +Author: John Lambert (Python) +""" + +import unittest + +import numpy as np + +import gtsam +from gtsam import Point3, Pose2, PriorFactorPose2, Values + + +class TestLago(unittest.TestCase): + """Test selected LAGO methods.""" + + def test_initialize(self) -> None: + """Smokescreen to ensure LAGO can be imported and run on toy data stored in a g2o file.""" + g2oFile = gtsam.findExampleDataFile("noisyToyGraph.txt") + + graph = gtsam.NonlinearFactorGraph() + graph, initial = gtsam.readG2o(g2oFile) + + # Add prior on the pose having index (key) = 0 + priorModel = gtsam.noiseModel.Diagonal.Variances(Point3(1e-6, 1e-6, 1e-8)) + graph.add(PriorFactorPose2(0, Pose2(), priorModel)) + + estimateLago: Values = gtsam.lago.initialize(graph) + assert isinstance(estimateLago, Values) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_pickle.py b/python/gtsam/tests/test_pickle.py index 0acbf6765..a6a5745bc 100644 --- a/python/gtsam/tests/test_pickle.py +++ b/python/gtsam/tests/test_pickle.py @@ -37,8 +37,8 @@ class TestPickle(GtsamTestCase): def test_sfmTrack_roundtrip(self): obj = SfmTrack(Point3(1, 1, 0)) - obj.add_measurement(0, Point2(-1, 5)) - obj.add_measurement(1, Point2(6, 2)) + obj.addMeasurement(0, Point2(-1, 5)) + obj.addMeasurement(1, Point2(6, 2)) self.assertEqualityOnPickleRoundtrip(obj) def test_unit3_roundtrip(self): diff --git a/python/gtsam/tests/test_sam.py b/python/gtsam/tests/test_sam.py new file mode 100644 index 000000000..e01b9c1d1 --- /dev/null +++ b/python/gtsam/tests/test_sam.py @@ -0,0 +1,55 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Basis unit tests. +Author: Frank Dellaert & Varun Agrawal (Python) +""" +import unittest + +import gtsam +from gtsam.utils.test_case import GtsamTestCase + + +class TestSam(GtsamTestCase): + """ + Tests python binding for classes/functions in `sam.i`. + """ + def test_RangeFactor2D(self): + """ + Test that `measured` works as expected for RangeFactor2D. + """ + measurement = 10.0 + factor = gtsam.RangeFactor2D(1, 2, measurement, + gtsam.noiseModel.Isotropic.Sigma(1, 1)) + self.assertEqual(measurement, factor.measured()) + + def test_BearingFactor2D(self): + """ + Test that `measured` works as expected for BearingFactor2D. + """ + measurement = gtsam.Rot2(.3) + factor = gtsam.BearingFactor2D(1, 2, measurement, + gtsam.noiseModel.Isotropic.Sigma(1, 1)) + self.gtsamAssertEquals(measurement, factor.measured()) + + def test_BearingRangeFactor2D(self): + """ + Test that `measured` works as expected for BearingRangeFactor2D. + """ + range_measurement = 10.0 + bearing_measurement = gtsam.Rot2(0.3) + factor = gtsam.BearingRangeFactor2D( + 1, 2, bearing_measurement, range_measurement, + gtsam.noiseModel.Isotropic.Sigma(2, 1)) + measurement = factor.measured() + + self.assertEqual(range_measurement, measurement.range()) + self.gtsamAssertEquals(bearing_measurement, measurement.bearing()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/utils/plot.py b/python/gtsam/utils/plot.py index 7ea393077..5ff7fd7aa 100644 --- a/python/gtsam/utils/plot.py +++ b/python/gtsam/utils/plot.py @@ -10,8 +10,15 @@ from matplotlib import patches from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import import gtsam -from gtsam import Marginals, Point3, Pose2, Pose3, Values +from gtsam import Marginals, Point2, Point3, Pose2, Pose3, Values +# For future reference: following +# https://www.xarg.org/2018/04/how-to-plot-a-covariance-error-ellipse/ +# we have, in 2D: +# def kk(p): return math.sqrt(-2*math.log(1-p)) # k to get p probability mass +# def pp(k): return 1-math.exp(-float(k**2)/2.0) # p as a function of k +# Some values: +# k = 5 => p = 99.9996 % def set_axes_equal(fignum: int) -> None: """ @@ -108,6 +115,66 @@ def plot_covariance_ellipse_3d(axes, axes.plot_surface(x, y, z, alpha=alpha, cmap='hot') +def plot_point2_on_axes(axes, + point: Point2, + linespec: str, + P: Optional[np.ndarray] = None) -> None: + """ + Plot a 2D point on given axis `axes` with given `linespec`. + + Args: + axes (matplotlib.axes.Axes): Matplotlib axes. + point: The point to be plotted. + linespec: String representing formatting options for Matplotlib. + P: Marginal covariance matrix to plot the uncertainty of the estimation. + """ + axes.plot([point[0]], [point[1]], linespec, marker='.', markersize=10) + if P is not None: + w, v = np.linalg.eig(P) + + # 5 sigma corresponds to 99.9996%, see note above + k = 5.0 + + angle = np.arctan2(v[1, 0], v[0, 0]) + e1 = patches.Ellipse(point, + np.sqrt(w[0] * k), + np.sqrt(w[1] * k), + np.rad2deg(angle), + fill=False) + axes.add_patch(e1) + + +def plot_point2( + fignum: int, + point: Point2, + linespec: str, + P: np.ndarray = None, + axis_labels: Iterable[str] = ("X axis", "Y axis"), +) -> plt.Figure: + """ + Plot a 2D point on given figure with given `linespec`. + + Args: + fignum: Integer representing the figure number to use for plotting. + point: The point to be plotted. + linespec: String representing formatting options for Matplotlib. + P: Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels: List of axis labels to set. + + Returns: + fig: The matplotlib figure. + + """ + fig = plt.figure(fignum) + axes = fig.gca() + plot_point2_on_axes(axes, point, linespec, P) + + axes.set_xlabel(axis_labels[0]) + axes.set_ylabel(axis_labels[1]) + + return fig + + def plot_pose2_on_axes(axes, pose: Pose2, axis_length: float = 0.1, @@ -142,7 +209,7 @@ def plot_pose2_on_axes(axes, w, v = np.linalg.eig(gPp) - # k = 2.296 + # 5 sigma corresponds to 99.9996%, see note above k = 5.0 angle = np.arctan2(v[1, 0], v[0, 0]) diff --git a/python/gtsam/utils/visual_data_generator.py b/python/gtsam/utils/visual_data_generator.py index 51852760a..972f25477 100644 --- a/python/gtsam/utils/visual_data_generator.py +++ b/python/gtsam/utils/visual_data_generator.py @@ -1,12 +1,12 @@ from __future__ import print_function -from typing import Tuple import math -import numpy as np from math import pi +from typing import Tuple import gtsam -from gtsam import Point3, Pose3, PinholeCameraCal3_S2, Cal3_S2 +import numpy as np +from gtsam import Cal3_S2, PinholeCameraCal3_S2, Point3, Pose3 class Options: @@ -36,7 +36,7 @@ class GroundTruth: self.cameras = [Pose3()] * nrCameras self.points = [Point3(0, 0, 0)] * nrPoints - def print_(self, s="") -> None: + def print(self, s: str = "") -> None: print(s) print("K = ", self.K) print("Cameras: ", len(self.cameras)) @@ -88,7 +88,8 @@ def generate_data(options) -> Tuple[Data, GroundTruth]: r = 10 for j in range(len(truth.points)): theta = j * 2 * pi / nrPoints - truth.points[j] = Point3(r * math.cos(theta), r * math.sin(theta), 0) + truth.points[j] = Point3( + r * math.cos(theta), r * math.sin(theta), 0) else: # 3D landmarks as vertices of a cube truth.points = [ Point3(10, 10, 10), Point3(-10, 10, 10), diff --git a/python/gtsam/utils/visual_isam.py b/python/gtsam/utils/visual_isam.py index a8fed4b23..4ebd8accd 100644 --- a/python/gtsam/utils/visual_isam.py +++ b/python/gtsam/utils/visual_isam.py @@ -17,7 +17,7 @@ def initialize(data, truth, options): # Initialize iSAM params = gtsam.ISAM2Params() if options.alwaysRelinearize: - params.setRelinearizeSkip(1) + params.relinearizeSkip = 1 isam = gtsam.ISAM2(params=params) # Add constraints/priors diff --git a/python/gtsam_unstable/gtsam_unstable.tpl b/python/gtsam_unstable/gtsam_unstable.tpl index aa7ac6bdb..055fcaea7 100644 --- a/python/gtsam_unstable/gtsam_unstable.tpl +++ b/python/gtsam_unstable/gtsam_unstable.tpl @@ -40,7 +40,7 @@ PYBIND11_MODULE({module_name}, m_) {{ {wrapped_namespace} -#include "python/gtsam_unstable/specializations.h" +#include "python/gtsam_unstable/specializations/gtsam_unstable.h" }} diff --git a/python/gtsam_unstable/specializations.h b/python/gtsam_unstable/specializations/gtsam_unstable.h similarity index 100% rename from python/gtsam_unstable/specializations.h rename to python/gtsam_unstable/specializations/gtsam_unstable.h diff --git a/python/gtsam_unstable/tests/test_ProjectionFactorRollingShutter.py b/python/gtsam_unstable/tests/test_ProjectionFactorRollingShutter.py new file mode 100644 index 000000000..0e4db3faf --- /dev/null +++ b/python/gtsam_unstable/tests/test_ProjectionFactorRollingShutter.py @@ -0,0 +1,59 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +ProjectionFactorRollingShutter unit tests. +Author: Yotam Stern +""" +import unittest + +import numpy as np + +import gtsam +import gtsam_unstable +from gtsam.utils.test_case import GtsamTestCase + + +pose1 = gtsam.Pose3() +pose2 = gtsam.Pose3(np.array([[ 0.9999375 , 0.00502487, 0.00998725, 0.1 ], + [-0.00497488, 0.999975 , -0.00502487, 0.02 ], + [-0.01001225, 0.00497488, 0.9999375 , 1. ], + [ 0. , 0. , 0. , 1. ]])) +point = np.array([2, 0, 15]) +point_noise = gtsam.noiseModel.Diagonal.Sigmas(np.ones(2)) +cal = gtsam.Cal3_S2() +body_p_sensor = gtsam.Pose3() +alpha = 0.1 +measured = np.array([0.13257015, 0.0004157]) + + +class TestProjectionFactorRollingShutter(GtsamTestCase): + + def test_constructor(self): + ''' + test constructor for the ProjectionFactorRollingShutter + ''' + factor = gtsam_unstable.ProjectionFactorRollingShutter(measured, alpha, point_noise, 0, 1, 2, cal) + factor = gtsam_unstable.ProjectionFactorRollingShutter(measured, alpha, point_noise, 0, 1, 2, cal, + body_p_sensor) + factor = gtsam_unstable.ProjectionFactorRollingShutter(measured, alpha, point_noise, 0, 1, 2, cal, True, False) + factor = gtsam_unstable.ProjectionFactorRollingShutter(measured, alpha, point_noise, 0, 1, 2, cal, True, False, + body_p_sensor) + + def test_error(self): + ''' + test the factor error for a specific example + ''' + values = gtsam.Values() + values.insert(0, pose1) + values.insert(1, pose2) + values.insert(2, point) + factor = gtsam_unstable.ProjectionFactorRollingShutter(measured, alpha, point_noise, 0, 1, 2, cal) + self.gtsamAssertEquals(factor.error(values), np.array(0), tol=1e-9) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 068b39eca..5eaad45dc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,14 +3,14 @@ set (tests_exclude "") if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") # might not be best test - Richard & Jason & Frank - # clang linker segfaults on large testSerializationSLAM - list (APPEND tests_exclude "testSerializationSLAM.cpp") + # clang linker segfaults on large testSerializationSlam + list (APPEND tests_exclude "testSerializationSlam.cpp") endif() # Build tests gtsamAddTestsGlob(tests "test*.cpp" "${tests_exclude}" "gtsam") if(MSVC) - set_property(SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/testSerializationSLAM.cpp" + set_property(SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/testSerializationSlam.cpp" APPEND PROPERTY COMPILE_FLAGS "/bigobj") endif() diff --git a/tests/smallExample.h b/tests/smallExample.h index 944899e70..ca9a8580f 100644 --- a/tests/smallExample.h +++ b/tests/smallExample.h @@ -679,26 +679,25 @@ inline Ordering planarOrdering(size_t N) { } /* ************************************************************************* */ -inline std::pair splitOffPlanarTree(size_t N, - const GaussianFactorGraph& original) { - auto T = boost::make_shared(), C= boost::make_shared(); +inline std::pair splitOffPlanarTree( + size_t N, const GaussianFactorGraph& original) { + GaussianFactorGraph T, C; // Add the x11 constraint to the tree - T->push_back(original[0]); + T.push_back(original[0]); // Add all horizontal constraints to the tree size_t i = 1; for (size_t x = 1; x < N; x++) - for (size_t y = 1; y <= N; y++, i++) - T->push_back(original[i]); + for (size_t y = 1; y <= N; y++, i++) T.push_back(original[i]); // Add first vertical column of constraints to T, others to C for (size_t x = 1; x <= N; x++) for (size_t y = 1; y < N; y++, i++) if (x == 1) - T->push_back(original[i]); + T.push_back(original[i]); else - C->push_back(original[i]); + C.push_back(original[i]); return std::make_pair(T, C); } diff --git a/tests/testExpressionFactor.cpp b/tests/testExpressionFactor.cpp index 66dbed1eb..6d23144aa 100644 --- a/tests/testExpressionFactor.cpp +++ b/tests/testExpressionFactor.cpp @@ -231,7 +231,7 @@ TEST(ExpressionFactor, Shallow) { Pose3_ x_(1); Point3_ p_(2); - // Construct expression, concise evrsion + // Construct expression, concise version Point2_ expression = project(transformTo(x_, p_)); // Get and check keys and dims @@ -606,7 +606,7 @@ Vector3 f(const Point2& a, const Vector3& b, OptionalJacobian<3, 2> H1, if (H1) *H1 << b.y(), b.z(), b.x(), 0, 0, 0; if (H2) *H2 = A; return A * b; -}; +} } TEST(ExpressionFactor, MultiplyWithInverseFunction) { diff --git a/tests/testGaussianBayesTreeB.cpp b/tests/testGaussianBayesTreeB.cpp index b8b6cf284..a321aa25d 100644 --- a/tests/testGaussianBayesTreeB.cpp +++ b/tests/testGaussianBayesTreeB.cpp @@ -112,54 +112,50 @@ TEST( GaussianBayesTree, linear_smoother_shortcuts ) C4 x7 : x6 ************************************************************************* */ -TEST( GaussianBayesTree, balanced_smoother_marginals ) -{ +TEST(GaussianBayesTree, balanced_smoother_marginals) { // Create smoother with 7 nodes GaussianFactorGraph smoother = createSmoother(7); // Create the Bayes tree Ordering ordering; - ordering += X(1),X(3),X(5),X(7),X(2),X(6),X(4); + ordering += X(1), X(3), X(5), X(7), X(2), X(6), X(4); GaussianBayesTree bayesTree = *smoother.eliminateMultifrontal(ordering); VectorValues actualSolution = bayesTree.optimize(); VectorValues expectedSolution = VectorValues::Zero(actualSolution); - EXPECT(assert_equal(expectedSolution,actualSolution,tol)); + EXPECT(assert_equal(expectedSolution, actualSolution, tol)); - LONGS_EQUAL(4, (long)bayesTree.size()); + LONGS_EQUAL(4, bayesTree.size()); - double tol=1e-5; + double tol = 1e-5; // Check marginal on x1 - JacobianFactor expected1 = GaussianDensity::FromMeanAndStddev(X(1), Z_2x1, sigmax1); JacobianFactor actual1 = *bayesTree.marginalFactor(X(1)); - Matrix expectedCovarianceX1 = I_2x2 * (sigmax1 * sigmax1); - Matrix actualCovarianceX1; - GaussianFactor::shared_ptr m = bayesTree.marginalFactor(X(1), EliminateCholesky); - actualCovarianceX1 = bayesTree.marginalFactor(X(1), EliminateCholesky)->information().inverse(); - EXPECT(assert_equal(expectedCovarianceX1, actualCovarianceX1, tol)); - EXPECT(assert_equal(expected1,actual1,tol)); + Matrix expectedCovX1 = I_2x2 * (sigmax1 * sigmax1); + auto m = bayesTree.marginalFactor(X(1), EliminateCholesky); + Matrix actualCovarianceX1 = m->information().inverse(); + EXPECT(assert_equal(expectedCovX1, actualCovarianceX1, tol)); // Check marginal on x2 - double sigx2 = 0.68712938; // FIXME: this should be corrected analytically - JacobianFactor expected2 = GaussianDensity::FromMeanAndStddev(X(2), Z_2x1, sigx2); + double sigmax2 = 0.68712938; // FIXME: this should be corrected analytically JacobianFactor actual2 = *bayesTree.marginalFactor(X(2)); - EXPECT(assert_equal(expected2,actual2,tol)); + Matrix expectedCovX2 = I_2x2 * (sigmax2 * sigmax2); + EXPECT(assert_equal(expectedCovX2, actual2.information().inverse(), tol)); // Check marginal on x3 - JacobianFactor expected3 = GaussianDensity::FromMeanAndStddev(X(3), Z_2x1, sigmax3); JacobianFactor actual3 = *bayesTree.marginalFactor(X(3)); - EXPECT(assert_equal(expected3,actual3,tol)); + Matrix expectedCovX3 = I_2x2 * (sigmax3 * sigmax3); + EXPECT(assert_equal(expectedCovX3, actual3.information().inverse(), tol)); // Check marginal on x4 - JacobianFactor expected4 = GaussianDensity::FromMeanAndStddev(X(4), Z_2x1, sigmax4); JacobianFactor actual4 = *bayesTree.marginalFactor(X(4)); - EXPECT(assert_equal(expected4,actual4,tol)); + Matrix expectedCovX4 = I_2x2 * (sigmax4 * sigmax4); + EXPECT(assert_equal(expectedCovX4, actual4.information().inverse(), tol)); // Check marginal on x7 (should be equal to x1) - JacobianFactor expected7 = GaussianDensity::FromMeanAndStddev(X(7), Z_2x1, sigmax7); JacobianFactor actual7 = *bayesTree.marginalFactor(X(7)); - EXPECT(assert_equal(expected7,actual7,tol)); + Matrix expectedCovX7 = I_2x2 * (sigmax7 * sigmax7); + EXPECT(assert_equal(expectedCovX7, actual7.information().inverse(), tol)); } /* ************************************************************************* */ diff --git a/tests/testGeneralSFMFactorB.cpp b/tests/testGeneralSFMFactorB.cpp index 05b4c7f66..fa27e1370 100644 --- a/tests/testGeneralSFMFactorB.cpp +++ b/tests/testGeneralSFMFactorB.cpp @@ -15,6 +15,7 @@ * @brief test general SFM class, with nonlinear optimization and BAL files */ +#include #include #include #include @@ -42,14 +43,12 @@ using symbol_shorthand::P; /* ************************************************************************* */ TEST(PinholeCamera, BAL) { string filename = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData db; - bool success = readBAL(filename, db); - if (!success) throw runtime_error("Could not access file!"); + SfmData db = SfmData::FromBalFile(filename); SharedNoiseModel unit2 = noiseModel::Unit::Create(2); NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { for (const SfmMeasurement& m: db.tracks[j].measurements) graph.emplace_shared(m.second, unit2, m.first, P(j)); } diff --git a/tests/testGncOptimizer.cpp b/tests/testGncOptimizer.cpp index a3d1e4e9b..c3335ce20 100644 --- a/tests/testGncOptimizer.cpp +++ b/tests/testGncOptimizer.cpp @@ -98,6 +98,30 @@ TEST(GncOptimizer, gncConstructor) { CHECK(gnc.equals(gnc2)); } +/* ************************************************************************* */ +TEST(GncOptimizer, solverParameterParsing) { + // has to have Gaussian noise models ! + auto fg = example::createReallyNonlinearFactorGraph(); // just a unary factor + // on a 2D point + + Point2 p0(3, 3); + Values initial; + initial.insert(X(1), p0); + + LevenbergMarquardtParams lmParams; + lmParams.setMaxIterations(0); // forces not to perform optimization + GncParams gncParams(lmParams); + auto gnc = GncOptimizer>(fg, initial, + gncParams); + Values result = gnc.optimize(); + + // check that LM did not perform optimization and result is the same as the initial guess + DOUBLES_EQUAL(fg.error(initial), fg.error(result), tol); + + // also check the params: + DOUBLES_EQUAL(0.0, gncParams.baseOptimizerParams.maxIterations, tol); +} + /* ************************************************************************* */ TEST(GncOptimizer, gncConstructorWithRobustGraphAsInput) { auto fg = example::sharedNonRobustFactorGraphWithOutliers(); diff --git a/tests/testLie.cpp b/tests/testLie.cpp index 0ef12198b..fe1173f22 100644 --- a/tests/testLie.cpp +++ b/tests/testLie.cpp @@ -129,6 +129,46 @@ TEST( testProduct, Logmap ) { EXPECT(assert_equal(numericH, actH, tol)); } +/* ************************************************************************* */ +Product interpolate_proxy(const Product& x, const Product& y, double t) { + return interpolate(x, y, t); +} + +TEST(Lie, Interpolate) { + Product x(Point2(1, 2), Pose2(3, 4, 5)); + Product y(Point2(6, 7), Pose2(8, 9, 0)); + + double t; + Matrix actH1, numericH1, actH2, numericH2; + + t = 0.0; + interpolate(x, y, t, actH1, actH2); + numericH1 = numericalDerivative31( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH1, actH1, tol)); + numericH2 = numericalDerivative32( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH2, actH2, tol)); + + t = 0.5; + interpolate(x, y, t, actH1, actH2); + numericH1 = numericalDerivative31( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH1, actH1, tol)); + numericH2 = numericalDerivative32( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH2, actH2, tol)); + + t = 1.0; + interpolate(x, y, t, actH1, actH2); + numericH1 = numericalDerivative31( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH1, actH1, tol)); + numericH2 = numericalDerivative32( + interpolate_proxy, x, y, t); + EXPECT(assert_equal(numericH2, actH2, tol)); +} + //****************************************************************************** int main() { TestResult tr; diff --git a/tests/testNonlinearFactor.cpp b/tests/testNonlinearFactor.cpp index 84bba850b..67a23355d 100644 --- a/tests/testNonlinearFactor.cpp +++ b/tests/testNonlinearFactor.cpp @@ -101,6 +101,82 @@ TEST( NonlinearFactor, NonlinearFactor ) DOUBLES_EQUAL(expected,actual,0.00000001); } +/* ************************************************************************* */ +TEST(NonlinearFactor, Weight) { + // create a values structure for the non linear factor graph + Values values; + + // Instantiate a concrete class version of a NoiseModelFactor + PriorFactor factor1(X(1), Point2(0, 0)); + values.insert(X(1), Point2(0.1, 0.1)); + + CHECK(assert_equal(1.0, factor1.weight(values))); + + // Factor with noise model + auto noise = noiseModel::Isotropic::Sigma(2, 0.2); + PriorFactor factor2(X(2), Point2(1, 1), noise); + values.insert(X(2), Point2(1.1, 1.1)); + + CHECK(assert_equal(1.0, factor2.weight(values))); + + Point2 estimate(3, 3), prior(1, 1); + double distance = (estimate - prior).norm(); + + auto gaussian = noiseModel::Isotropic::Sigma(2, 0.2); + + PriorFactor factor; + + // vector to store all the robust models in so we can test iteratively. + vector robust_models; + + // Fair noise model + auto fair = noiseModel::Robust::Create( + noiseModel::mEstimator::Fair::Create(1.3998), gaussian); + robust_models.push_back(fair); + + // Huber noise model + auto huber = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), gaussian); + robust_models.push_back(huber); + + // Cauchy noise model + auto cauchy = noiseModel::Robust::Create( + noiseModel::mEstimator::Cauchy::Create(0.1), gaussian); + robust_models.push_back(cauchy); + + // Tukey noise model + auto tukey = noiseModel::Robust::Create( + noiseModel::mEstimator::Tukey::Create(4.6851), gaussian); + robust_models.push_back(tukey); + + // Welsch noise model + auto welsch = noiseModel::Robust::Create( + noiseModel::mEstimator::Welsch::Create(2.9846), gaussian); + robust_models.push_back(welsch); + + // Geman-McClure noise model + auto gm = noiseModel::Robust::Create( + noiseModel::mEstimator::GemanMcClure::Create(1.0), gaussian); + robust_models.push_back(gm); + + // DCS noise model + auto dcs = noiseModel::Robust::Create( + noiseModel::mEstimator::DCS::Create(1.0), gaussian); + robust_models.push_back(dcs); + + // L2WithDeadZone noise model + auto l2 = noiseModel::Robust::Create( + noiseModel::mEstimator::L2WithDeadZone::Create(1.0), gaussian); + robust_models.push_back(l2); + + for(auto&& model: robust_models) { + factor = PriorFactor(X(3), prior, model); + values.clear(); + values.insert(X(3), estimate); + CHECK(assert_equal(model->robust()->weight(distance), factor.weight(values))); + } +} + /* ************************************************************************* */ TEST( NonlinearFactor, linearize_f1 ) { diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index fdb080a63..05a6e7f45 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -15,6 +15,7 @@ * @brief testNonlinearFactorGraph * @author Carlos Nieto * @author Christian Potthast + * @author Frank Dellaert */ #include @@ -106,6 +107,24 @@ TEST( NonlinearFactorGraph, probPrime ) DOUBLES_EQUAL(expected,actual,0); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, ProbPrime2) { + NonlinearFactorGraph fg; + fg.emplace_shared>(1, 0.0, + noiseModel::Isotropic::Sigma(1, 1.0)); + + Values values; + values.insert(1, 1.0); + + // The prior factor squared error is: 0.5. + EXPECT_DOUBLES_EQUAL(0.5, fg.error(values), 1e-12); + + // The probability value is: exp^(-factor_error) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-factor_error) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, fg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ TEST( NonlinearFactorGraph, linearize ) { @@ -285,6 +304,7 @@ TEST(testNonlinearFactorGraph, addPrior) { EXPECT(0 != graph.error(values)); } +/* ************************************************************************* */ TEST(NonlinearFactorGraph, printErrors) { const NonlinearFactorGraph fg = createNonlinearFactorGraph(); @@ -309,6 +329,65 @@ TEST(NonlinearFactorGraph, printErrors) for (bool visit : visited) EXPECT(visit==true); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varl1[label=\"l1\"];\n" + " varx1[label=\"x1\"];\n" + " varx2[label=\"x2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + string actual = fg.dot(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot_extra) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varl1[label=\"l1\", pos=\"0,0!\"];\n" + " varx1[label=\"x1\", pos=\"1,0!\"];\n" + " varx2[label=\"x2\", pos=\"1,1.5!\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + const Values c = createValues(); + + stringstream ss; + fg.dot(ss, c); + EXPECT(ss.str() == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/tests/testSerializationSLAM.cpp b/tests/testSerializationSlam.cpp similarity index 85% rename from tests/testSerializationSLAM.cpp rename to tests/testSerializationSlam.cpp index 2e99aff71..ea7038635 100644 --- a/tests/testSerializationSLAM.cpp +++ b/tests/testSerializationSlam.cpp @@ -19,16 +19,16 @@ #include #include + +#include #include + #include #include -#include #include -#include +#include #include -#include -#include -#include + #include #include #include @@ -44,8 +44,16 @@ #include #include +#include +#include +#include +#include +#include #include +#include +#include + using namespace std; using namespace gtsam; using namespace gtsam::serializationTestHelpers; @@ -114,94 +122,94 @@ using symbol_shorthand::L; /* Create GUIDs for Noisemodels */ /* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Robust, "gtsam_noiseModel_Robust"); +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Constrained, "gtsam_noiseModel_Constrained") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Diagonal, "gtsam_noiseModel_Diagonal") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Gaussian, "gtsam_noiseModel_Gaussian") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Unit, "gtsam_noiseModel_Unit") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Isotropic, "gtsam_noiseModel_Isotropic") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::Robust, "gtsam_noiseModel_Robust") -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Base , "gtsam_noiseModel_mEstimator_Base"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Null , "gtsam_noiseModel_mEstimator_Null"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Fair , "gtsam_noiseModel_mEstimator_Fair"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Huber, "gtsam_noiseModel_mEstimator_Huber"); -BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Tukey, "gtsam_noiseModel_mEstimator_Tukey"); +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Base , "gtsam_noiseModel_mEstimator_Base") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Null , "gtsam_noiseModel_mEstimator_Null") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Fair , "gtsam_noiseModel_mEstimator_Fair") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Huber, "gtsam_noiseModel_mEstimator_Huber") +BOOST_CLASS_EXPORT_GUID(gtsam::noiseModel::mEstimator::Tukey, "gtsam_noiseModel_mEstimator_Tukey") -BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel"); -BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); +BOOST_CLASS_EXPORT_GUID(gtsam::SharedNoiseModel, "gtsam_SharedNoiseModel") +BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal") /* Create GUIDs for geometry */ /* ************************************************************************* */ -GTSAM_VALUE_EXPORT(gtsam::Point2); -GTSAM_VALUE_EXPORT(gtsam::StereoPoint2); -GTSAM_VALUE_EXPORT(gtsam::Point3); -GTSAM_VALUE_EXPORT(gtsam::Rot2); -GTSAM_VALUE_EXPORT(gtsam::Rot3); -GTSAM_VALUE_EXPORT(gtsam::Pose2); -GTSAM_VALUE_EXPORT(gtsam::Pose3); -GTSAM_VALUE_EXPORT(gtsam::Cal3_S2); -GTSAM_VALUE_EXPORT(gtsam::Cal3DS2); -GTSAM_VALUE_EXPORT(gtsam::Cal3_S2Stereo); -GTSAM_VALUE_EXPORT(gtsam::CalibratedCamera); -GTSAM_VALUE_EXPORT(gtsam::PinholeCameraCal3_S2); -GTSAM_VALUE_EXPORT(gtsam::StereoCamera); +GTSAM_VALUE_EXPORT(gtsam::Point2) +GTSAM_VALUE_EXPORT(gtsam::StereoPoint2) +GTSAM_VALUE_EXPORT(gtsam::Point3) +GTSAM_VALUE_EXPORT(gtsam::Rot2) +GTSAM_VALUE_EXPORT(gtsam::Rot3) +GTSAM_VALUE_EXPORT(gtsam::Pose2) +GTSAM_VALUE_EXPORT(gtsam::Pose3) +GTSAM_VALUE_EXPORT(gtsam::Cal3_S2) +GTSAM_VALUE_EXPORT(gtsam::Cal3DS2) +GTSAM_VALUE_EXPORT(gtsam::Cal3_S2Stereo) +GTSAM_VALUE_EXPORT(gtsam::CalibratedCamera) +GTSAM_VALUE_EXPORT(gtsam::PinholeCameraCal3_S2) +GTSAM_VALUE_EXPORT(gtsam::StereoCamera) /* Create GUIDs for factors */ /* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor"); -BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor"); +BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor") +BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor") -BOOST_CLASS_EXPORT_GUID(PriorFactorPoint2, "gtsam::PriorFactorPoint2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorStereoPoint2, "gtsam::PriorFactorStereoPoint2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorPoint3, "gtsam::PriorFactorPoint3"); -BOOST_CLASS_EXPORT_GUID(PriorFactorRot2, "gtsam::PriorFactorRot2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorRot3, "gtsam::PriorFactorRot3"); -BOOST_CLASS_EXPORT_GUID(PriorFactorPose2, "gtsam::PriorFactorPose2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorPose3, "gtsam::PriorFactorPose3"); -BOOST_CLASS_EXPORT_GUID(PriorFactorCal3_S2, "gtsam::PriorFactorCal3_S2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorCal3DS2, "gtsam::PriorFactorCal3DS2"); -BOOST_CLASS_EXPORT_GUID(PriorFactorCalibratedCamera, "gtsam::PriorFactorCalibratedCamera"); -BOOST_CLASS_EXPORT_GUID(PriorFactorStereoCamera, "gtsam::PriorFactorStereoCamera"); +BOOST_CLASS_EXPORT_GUID(PriorFactorPoint2, "gtsam::PriorFactorPoint2") +BOOST_CLASS_EXPORT_GUID(PriorFactorStereoPoint2, "gtsam::PriorFactorStereoPoint2") +BOOST_CLASS_EXPORT_GUID(PriorFactorPoint3, "gtsam::PriorFactorPoint3") +BOOST_CLASS_EXPORT_GUID(PriorFactorRot2, "gtsam::PriorFactorRot2") +BOOST_CLASS_EXPORT_GUID(PriorFactorRot3, "gtsam::PriorFactorRot3") +BOOST_CLASS_EXPORT_GUID(PriorFactorPose2, "gtsam::PriorFactorPose2") +BOOST_CLASS_EXPORT_GUID(PriorFactorPose3, "gtsam::PriorFactorPose3") +BOOST_CLASS_EXPORT_GUID(PriorFactorCal3_S2, "gtsam::PriorFactorCal3_S2") +BOOST_CLASS_EXPORT_GUID(PriorFactorCal3DS2, "gtsam::PriorFactorCal3DS2") +BOOST_CLASS_EXPORT_GUID(PriorFactorCalibratedCamera, "gtsam::PriorFactorCalibratedCamera") +BOOST_CLASS_EXPORT_GUID(PriorFactorStereoCamera, "gtsam::PriorFactorStereoCamera") -BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint2, "gtsam::BetweenFactorPoint2"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint3, "gtsam::BetweenFactorPoint3"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorRot2, "gtsam::BetweenFactorRot2"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorRot3, "gtsam::BetweenFactorRot3"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorPose2, "gtsam::BetweenFactorPose2"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorPose3, "gtsam::BetweenFactorPose3"); +BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint2, "gtsam::BetweenFactorPoint2") +BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint3, "gtsam::BetweenFactorPoint3") +BOOST_CLASS_EXPORT_GUID(BetweenFactorRot2, "gtsam::BetweenFactorRot2") +BOOST_CLASS_EXPORT_GUID(BetweenFactorRot3, "gtsam::BetweenFactorRot3") +BOOST_CLASS_EXPORT_GUID(BetweenFactorPose2, "gtsam::BetweenFactorPose2") +BOOST_CLASS_EXPORT_GUID(BetweenFactorPose3, "gtsam::BetweenFactorPose3") -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint2, "gtsam::NonlinearEqualityPoint2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoPoint2, "gtsam::NonlinearEqualityStereoPoint2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint3, "gtsam::NonlinearEqualityPoint3"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityRot2, "gtsam::NonlinearEqualityRot2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityRot3, "gtsam::NonlinearEqualityRot3"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPose2, "gtsam::NonlinearEqualityPose2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPose3, "gtsam::NonlinearEqualityPose3"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCal3_S2, "gtsam::NonlinearEqualityCal3_S2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCal3DS2, "gtsam::NonlinearEqualityCal3DS2"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCalibratedCamera, "gtsam::NonlinearEqualityCalibratedCamera"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoCamera, "gtsam::NonlinearEqualityStereoCamera"); +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint2, "gtsam::NonlinearEqualityPoint2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoPoint2, "gtsam::NonlinearEqualityStereoPoint2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint3, "gtsam::NonlinearEqualityPoint3") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityRot2, "gtsam::NonlinearEqualityRot2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityRot3, "gtsam::NonlinearEqualityRot3") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPose2, "gtsam::NonlinearEqualityPose2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPose3, "gtsam::NonlinearEqualityPose3") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCal3_S2, "gtsam::NonlinearEqualityCal3_S2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCal3DS2, "gtsam::NonlinearEqualityCal3DS2") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityCalibratedCamera, "gtsam::NonlinearEqualityCalibratedCamera") +BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoCamera, "gtsam::NonlinearEqualityStereoCamera") -BOOST_CLASS_EXPORT_GUID(RangeFactor2D, "gtsam::RangeFactor2D"); -BOOST_CLASS_EXPORT_GUID(RangeFactor3D, "gtsam::RangeFactor3D"); -BOOST_CLASS_EXPORT_GUID(RangeFactorPose2, "gtsam::RangeFactorPose2"); -BOOST_CLASS_EXPORT_GUID(RangeFactorPose3, "gtsam::RangeFactorPose3"); -BOOST_CLASS_EXPORT_GUID(RangeFactorCalibratedCameraPoint, "gtsam::RangeFactorCalibratedCameraPoint"); -BOOST_CLASS_EXPORT_GUID(RangeFactorPinholeCameraCal3_S2Point, "gtsam::RangeFactorPinholeCameraCal3_S2Point"); -BOOST_CLASS_EXPORT_GUID(RangeFactorCalibratedCamera, "gtsam::RangeFactorCalibratedCamera"); -BOOST_CLASS_EXPORT_GUID(RangeFactorPinholeCameraCal3_S2, "gtsam::RangeFactorPinholeCameraCal3_S2"); +BOOST_CLASS_EXPORT_GUID(RangeFactor2D, "gtsam::RangeFactor2D") +BOOST_CLASS_EXPORT_GUID(RangeFactor3D, "gtsam::RangeFactor3D") +BOOST_CLASS_EXPORT_GUID(RangeFactorPose2, "gtsam::RangeFactorPose2") +BOOST_CLASS_EXPORT_GUID(RangeFactorPose3, "gtsam::RangeFactorPose3") +BOOST_CLASS_EXPORT_GUID(RangeFactorCalibratedCameraPoint, "gtsam::RangeFactorCalibratedCameraPoint") +BOOST_CLASS_EXPORT_GUID(RangeFactorPinholeCameraCal3_S2Point, "gtsam::RangeFactorPinholeCameraCal3_S2Point") +BOOST_CLASS_EXPORT_GUID(RangeFactorCalibratedCamera, "gtsam::RangeFactorCalibratedCamera") +BOOST_CLASS_EXPORT_GUID(RangeFactorPinholeCameraCal3_S2, "gtsam::RangeFactorPinholeCameraCal3_S2") -BOOST_CLASS_EXPORT_GUID(BearingRangeFactor2D, "gtsam::BearingRangeFactor2D"); +BOOST_CLASS_EXPORT_GUID(BearingRangeFactor2D, "gtsam::BearingRangeFactor2D") -BOOST_CLASS_EXPORT_GUID(GenericProjectionFactorCal3_S2, "gtsam::GenericProjectionFactorCal3_S2"); -BOOST_CLASS_EXPORT_GUID(GenericProjectionFactorCal3DS2, "gtsam::GenericProjectionFactorCal3DS2"); +BOOST_CLASS_EXPORT_GUID(GenericProjectionFactorCal3_S2, "gtsam::GenericProjectionFactorCal3_S2") +BOOST_CLASS_EXPORT_GUID(GenericProjectionFactorCal3DS2, "gtsam::GenericProjectionFactorCal3DS2") -BOOST_CLASS_EXPORT_GUID(GeneralSFMFactorCal3_S2, "gtsam::GeneralSFMFactorCal3_S2"); -BOOST_CLASS_EXPORT_GUID(GeneralSFMFactorCal3DS2, "gtsam::GeneralSFMFactorCal3DS2"); +BOOST_CLASS_EXPORT_GUID(GeneralSFMFactorCal3_S2, "gtsam::GeneralSFMFactorCal3_S2") +BOOST_CLASS_EXPORT_GUID(GeneralSFMFactorCal3DS2, "gtsam::GeneralSFMFactorCal3DS2") -BOOST_CLASS_EXPORT_GUID(GeneralSFMFactor2Cal3_S2, "gtsam::GeneralSFMFactor2Cal3_S2"); +BOOST_CLASS_EXPORT_GUID(GeneralSFMFactor2Cal3_S2, "gtsam::GeneralSFMFactor2Cal3_S2") -BOOST_CLASS_EXPORT_GUID(GenericStereoFactor3D, "gtsam::GenericStereoFactor3D"); +BOOST_CLASS_EXPORT_GUID(GenericStereoFactor3D, "gtsam::GenericStereoFactor3D") /* ************************************************************************* */ @@ -592,6 +600,78 @@ TEST (testSerializationSLAM, factors) { EXPECT(equalsBinary(genericStereoFactor3D)); } +/* ************************************************************************* */ +// Read from XML file +namespace { +static GaussianFactorGraph read(const string& name) { + auto inputFile = findExampleDataFile(name); + ifstream is(inputFile); + if (!is.is_open()) throw runtime_error("Cannot find file " + inputFile); + boost::archive::xml_iarchive in_archive(is); + GaussianFactorGraph Ab; + in_archive >> boost::serialization::make_nvp("graph", Ab); + return Ab; +} +} // namespace + +/* ************************************************************************* */ +// Read from XML file +TEST(SubgraphSolver, Solves) { + using gtsam::example::planarGraph; + + // Create preconditioner + SubgraphPreconditioner system; + + // We test on three different graphs + const auto Ab1 = planarGraph(3).first; + const auto Ab2 = read("toy3D"); + const auto Ab3 = read("randomGrid3D"); + + // For all graphs, test solve and solveTranspose + for (const auto& Ab : {Ab1, Ab2, Ab3}) { + // Call build, a non-const method needed to make solve work :-( + KeyInfo keyInfo(Ab); + std::map lambda; + system.build(Ab, keyInfo, lambda); + + // Create a perturbed (non-zero) RHS + const auto xbar = system.Rc1().optimize(); // merely for use in zero below + auto values_y = VectorValues::Zero(xbar); + auto it = values_y.begin(); + it->second.setConstant(100); + ++it; + it->second.setConstant(-100); + + // Solve the VectorValues way + auto values_x = system.Rc1().backSubstitute(values_y); + + // Solve the matrix way, this really just checks BN::backSubstitute + // This only works with Rc1 ordering, not with keyInfo ! + // TODO(frank): why does this not work with an arbitrary ordering? + const auto ord = system.Rc1().ordering(); + const Matrix R1 = system.Rc1().matrix(ord).first; + auto ord_y = values_y.vector(ord); + auto vector_x = R1.inverse() * ord_y; + EXPECT(assert_equal(vector_x, values_x.vector(ord))); + + // Test that 'solve' does implement x = R^{-1} y + // We do this by asserting it gives same answer as backSubstitute + // Only works with keyInfo ordering: + const auto ordering = keyInfo.ordering(); + auto vector_y = values_y.vector(ordering); + const size_t N = R1.cols(); + Vector solve_x = Vector::Zero(N); + system.solve(vector_y, solve_x); + EXPECT(assert_equal(values_x.vector(ordering), solve_x)); + + // Test that transposeSolve does implement x = R^{-T} y + // We do this by asserting it gives same answer as backSubstituteTranspose + auto values_x2 = system.Rc1().backSubstituteTranspose(values_y); + Vector solveT_x = Vector::Zero(N); + system.transposeSolve(vector_y, solveT_x); + EXPECT(assert_equal(values_x2.vector(ordering), solveT_x)); + } +} /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } diff --git a/tests/testSubgraphPreconditioner.cpp b/tests/testSubgraphPreconditioner.cpp index fb9f7a5a2..c5b4e42ec 100644 --- a/tests/testSubgraphPreconditioner.cpp +++ b/tests/testSubgraphPreconditioner.cpp @@ -29,10 +29,8 @@ #include -#include #include #include -#include #include using namespace boost::assign; @@ -77,8 +75,8 @@ TEST(SubgraphPreconditioner, planarGraph) { DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue // Check that xtrue is optimal - GaussianBayesNet::shared_ptr R1 = A.eliminateSequential(); - VectorValues actual = R1->optimize(); + GaussianBayesNet R1 = *A.eliminateSequential(); + VectorValues actual = R1.optimize(); EXPECT(assert_equal(xtrue, actual)); } @@ -90,14 +88,14 @@ TEST(SubgraphPreconditioner, splitOffPlanarTree) { boost::tie(A, xtrue) = planarGraph(3); // Get the spanning tree and constraints, and check their sizes - GaussianFactorGraph::shared_ptr T, C; + GaussianFactorGraph T, C; boost::tie(T, C) = splitOffPlanarTree(3, A); - LONGS_EQUAL(9, T->size()); - LONGS_EQUAL(4, C->size()); + LONGS_EQUAL(9, T.size()); + LONGS_EQUAL(4, C.size()); // Check that the tree can be solved to give the ground xtrue - GaussianBayesNet::shared_ptr R1 = T->eliminateSequential(); - VectorValues xbar = R1->optimize(); + GaussianBayesNet R1 = *T.eliminateSequential(); + VectorValues xbar = R1.optimize(); EXPECT(assert_equal(xtrue, xbar)); } @@ -110,31 +108,29 @@ TEST(SubgraphPreconditioner, system) { boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b // Get the spanning tree and remaining graph - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); // Eliminate the spanning tree to build a prior const Ordering ord = planarOrdering(N); - auto Rc1 = Ab1->eliminateSequential(ord); // R1*x-c1 - VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1 + auto Rc1 = *Ab1.eliminateSequential(ord); // R1*x-c1 + VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1 // Create Subgraph-preconditioned system - VectorValues::shared_ptr xbarShared( - new VectorValues(xbar)); // TODO: horrible - const SubgraphPreconditioner system(Ab2, Rc1, xbarShared); + const SubgraphPreconditioner system(Ab2, Rc1, xbar); // Get corresponding matrices for tests. Add dummy factors to Ab2 to make // sure it works with the ordering. - Ordering ordering = Rc1->ordering(); // not ord in general! - Ab2->add(key(1, 1), Z_2x2, Z_2x1); - Ab2->add(key(1, 2), Z_2x2, Z_2x1); - Ab2->add(key(1, 3), Z_2x2, Z_2x1); + Ordering ordering = Rc1.ordering(); // not ord in general! + Ab2.add(key(1, 1), Z_2x2, Z_2x1); + Ab2.add(key(1, 2), Z_2x2, Z_2x1); + Ab2.add(key(1, 3), Z_2x2, Z_2x1); Matrix A, A1, A2; Vector b, b1, b2; std::tie(A, b) = Ab.jacobian(ordering); - std::tie(A1, b1) = Ab1->jacobian(ordering); - std::tie(A2, b2) = Ab2->jacobian(ordering); - Matrix R1 = Rc1->matrix(ordering).first; + std::tie(A1, b1) = Ab1.jacobian(ordering); + std::tie(A2, b2) = Ab2.jacobian(ordering); + Matrix R1 = Rc1.matrix(ordering).first; Matrix Abar(13 * 2, 9 * 2); Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2); Abar.bottomRows(8) = A2.topRows(8) * R1.inverse(); @@ -151,7 +147,7 @@ TEST(SubgraphPreconditioner, system) { y1[key(3, 3)] = Vector2(1.0, -1.0); // Check backSubstituteTranspose works with R1 - VectorValues actual = Rc1->backSubstituteTranspose(y1); + VectorValues actual = Rc1.backSubstituteTranspose(y1); Vector expected = R1.transpose().inverse() * vec(y1); EXPECT(assert_equal(expected, vec(actual))); @@ -199,75 +195,6 @@ TEST(SubgraphPreconditioner, system) { EXPECT(assert_equal(expected_g, vec(g))); } -/* ************************************************************************* */ -BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "JacobianFactor"); - -// Read from XML file -static GaussianFactorGraph read(const string& name) { - auto inputFile = findExampleDataFile(name); - ifstream is(inputFile); - if (!is.is_open()) throw runtime_error("Cannot find file " + inputFile); - boost::archive::xml_iarchive in_archive(is); - GaussianFactorGraph Ab; - in_archive >> boost::serialization::make_nvp("graph", Ab); - return Ab; -} - -TEST(SubgraphSolver, Solves) { - // Create preconditioner - SubgraphPreconditioner system; - - // We test on three different graphs - const auto Ab1 = planarGraph(3).first; - const auto Ab2 = read("toy3D"); - const auto Ab3 = read("randomGrid3D"); - - // For all graphs, test solve and solveTranspose - for (const auto& Ab : {Ab1, Ab2, Ab3}) { - // Call build, a non-const method needed to make solve work :-( - KeyInfo keyInfo(Ab); - std::map lambda; - system.build(Ab, keyInfo, lambda); - - // Create a perturbed (non-zero) RHS - const auto xbar = system.Rc1()->optimize(); // merely for use in zero below - auto values_y = VectorValues::Zero(xbar); - auto it = values_y.begin(); - it->second.setConstant(100); - ++it; - it->second.setConstant(-100); - - // Solve the VectorValues way - auto values_x = system.Rc1()->backSubstitute(values_y); - - // Solve the matrix way, this really just checks BN::backSubstitute - // This only works with Rc1 ordering, not with keyInfo ! - // TODO(frank): why does this not work with an arbitrary ordering? - const auto ord = system.Rc1()->ordering(); - const Matrix R1 = system.Rc1()->matrix(ord).first; - auto ord_y = values_y.vector(ord); - auto vector_x = R1.inverse() * ord_y; - EXPECT(assert_equal(vector_x, values_x.vector(ord))); - - // Test that 'solve' does implement x = R^{-1} y - // We do this by asserting it gives same answer as backSubstitute - // Only works with keyInfo ordering: - const auto ordering = keyInfo.ordering(); - auto vector_y = values_y.vector(ordering); - const size_t N = R1.cols(); - Vector solve_x = Vector::Zero(N); - system.solve(vector_y, solve_x); - EXPECT(assert_equal(values_x.vector(ordering), solve_x)); - - // Test that transposeSolve does implement x = R^{-T} y - // We do this by asserting it gives same answer as backSubstituteTranspose - auto values_x2 = system.Rc1()->backSubstituteTranspose(values_y); - Vector solveT_x = Vector::Zero(N); - system.transposeSolve(vector_y, solveT_x); - EXPECT(assert_equal(values_x2.vector(ordering), solveT_x)); - } -} - /* ************************************************************************* */ TEST(SubgraphPreconditioner, conjugateGradients) { // Build a planar graph @@ -277,18 +204,15 @@ TEST(SubgraphPreconditioner, conjugateGradients) { boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b // Get the spanning tree - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); // Eliminate the spanning tree to build a prior - SubgraphPreconditioner::sharedBayesNet Rc1 = - Ab1->eliminateSequential(); // R1*x-c1 - VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1 + GaussianBayesNet Rc1 = *Ab1.eliminateSequential(); // R1*x-c1 + VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1 // Create Subgraph-preconditioned system - VectorValues::shared_ptr xbarShared( - new VectorValues(xbar)); // TODO: horrible - SubgraphPreconditioner system(Ab2, Rc1, xbarShared); + SubgraphPreconditioner system(Ab2, Rc1, xbar); // Create zero config y0 and perturbed config y1 VectorValues y0 = VectorValues::Zero(xbar); diff --git a/tests/testSubgraphSolver.cpp b/tests/testSubgraphSolver.cpp index cca13c822..5d8d88775 100644 --- a/tests/testSubgraphSolver.cpp +++ b/tests/testSubgraphSolver.cpp @@ -68,10 +68,10 @@ TEST( SubgraphSolver, splitFactorGraph ) auto subgraph = builder(Ab); EXPECT_LONGS_EQUAL(9, subgraph.size()); - GaussianFactorGraph::shared_ptr Ab1, Ab2; + GaussianFactorGraph Ab1, Ab2; std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph); - EXPECT_LONGS_EQUAL(9, Ab1->size()); - EXPECT_LONGS_EQUAL(13, Ab2->size()); + EXPECT_LONGS_EQUAL(9, Ab1.size()); + EXPECT_LONGS_EQUAL(13, Ab2.size()); } /* ************************************************************************* */ @@ -99,12 +99,12 @@ TEST( SubgraphSolver, constructor2 ) std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b // Get the spanning tree - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); // The second constructor takes two factor graphs, so the caller can specify // the preconditioner (Ab1) and the constraints that are left out (Ab2) - SubgraphSolver solver(*Ab1, Ab2, kParameters, kOrdering); + SubgraphSolver solver(Ab1, Ab2, kParameters, kOrdering); VectorValues optimized = solver.optimize(); DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5); } @@ -119,11 +119,11 @@ TEST( SubgraphSolver, constructor3 ) std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b // Get the spanning tree and corresponding kOrdering - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); // The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT - auto Rc1 = Ab1->eliminateSequential(); + auto Rc1 = *Ab1.eliminateSequential(); // The third constructor allows the caller to pass an already solved preconditioner Rc1_ // as a Bayes net, in addition to the "loop closing constraints" Ab2, as before diff --git a/tests/testTranslationRecovery.cpp b/tests/testTranslationRecovery.cpp index 2915a375e..833f11355 100644 --- a/tests/testTranslationRecovery.cpp +++ b/tests/testTranslationRecovery.cpp @@ -18,6 +18,7 @@ #include #include +#include #include using namespace std; @@ -42,9 +43,7 @@ Unit3 GetDirectionFromPoses(const Values& poses, // sets up an optimization problem for the three unknown translations. TEST(TranslationRecovery, BAL) { const string filename = findExampleDataFile("dubrovnik-3-7-pre"); - SfmData db; - bool success = readBAL(filename, db); - if (!success) throw runtime_error("Could not access file!"); + SfmData db = SfmData::FromBalFile(filename); // Get camera poses, as Values size_t j = 0; @@ -116,8 +115,8 @@ TEST(TranslationRecovery, TwoPoseTest) { const auto result = algorithm.run(/*scale=*/3.0); // Check result for first two translations, determined by prior - EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); - EXPECT(assert_equal(Point3(3, 0, 0), result.at(1))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0), 1e-8)); + EXPECT(assert_equal(Point3(3, 0, 0), result.at(1), 1e-8)); } TEST(TranslationRecovery, ThreePoseTest) { @@ -153,9 +152,9 @@ TEST(TranslationRecovery, ThreePoseTest) { const auto result = algorithm.run(/*scale=*/3.0); // Check result - EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); - EXPECT(assert_equal(Point3(3, 0, 0), result.at(1))); - EXPECT(assert_equal(Point3(1.5, -1.5, 0), result.at(3))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0), 1e-8)); + EXPECT(assert_equal(Point3(3, 0, 0), result.at(1), 1e-8)); + EXPECT(assert_equal(Point3(1.5, -1.5, 0), result.at(3), 1e-8)); } TEST(TranslationRecovery, ThreePosesIncludingZeroTranslation) { @@ -190,9 +189,9 @@ TEST(TranslationRecovery, ThreePosesIncludingZeroTranslation) { const auto result = algorithm.run(/*scale=*/3.0); // Check result - EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); - EXPECT(assert_equal(Point3(3, 0, 0), result.at(1))); - EXPECT(assert_equal(Point3(3, 0, 0), result.at(2))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0), 1e-8)); + EXPECT(assert_equal(Point3(3, 0, 0), result.at(1), 1e-8)); + EXPECT(assert_equal(Point3(3, 0, 0), result.at(2), 1e-8)); } TEST(TranslationRecovery, FourPosesIncludingZeroTranslation) { @@ -231,10 +230,10 @@ TEST(TranslationRecovery, FourPosesIncludingZeroTranslation) { const auto result = algorithm.run(/*scale=*/4.0); // Check result - EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); - EXPECT(assert_equal(Point3(4, 0, 0), result.at(1))); - EXPECT(assert_equal(Point3(4, 0, 0), result.at(2))); - EXPECT(assert_equal(Point3(2, -2, 0), result.at(3))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0), 1e-8)); + EXPECT(assert_equal(Point3(4, 0, 0), result.at(1), 1e-8)); + EXPECT(assert_equal(Point3(4, 0, 0), result.at(2), 1e-8)); + EXPECT(assert_equal(Point3(2, -2, 0), result.at(3), 1e-8)); } TEST(TranslationRecovery, ThreePosesWithZeroTranslation) { @@ -261,9 +260,9 @@ TEST(TranslationRecovery, ThreePosesWithZeroTranslation) { const auto result = algorithm.run(/*scale=*/4.0); // Check result - EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); - EXPECT(assert_equal(Point3(0, 0, 0), result.at(1))); - EXPECT(assert_equal(Point3(0, 0, 0), result.at(2))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0), 1e-8)); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(1), 1e-8)); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(2), 1e-8)); } /* ************************************************************************* */ diff --git a/timing/timeBatch.cpp b/timing/timeBatch.cpp index 4ed1a4555..f59039fa7 100644 --- a/timing/timeBatch.cpp +++ b/timing/timeBatch.cpp @@ -28,7 +28,7 @@ int main(int argc, char *argv[]) { cout << "Loading data..." << endl; - string datasetFile = findExampleDataFile("w10000-odom"); + string datasetFile = findExampleDataFile("w10000"); std::pair data = load2D(datasetFile); diff --git a/timing/timeIncremental.cpp b/timing/timeIncremental.cpp index 6e0f4ccdf..5e3fc9189 100644 --- a/timing/timeIncremental.cpp +++ b/timing/timeIncremental.cpp @@ -72,7 +72,7 @@ int main(int argc, char *argv[]) { cout << "Loading data..." << endl; gttic_(Find_datafile); - //string datasetFile = findExampleDataFile("w10000-odom"); + //string datasetFile = findExampleDataFile("w10000"); string datasetFile = findExampleDataFile("victoria_park"); std::pair data = load2D(datasetFile); diff --git a/timing/timeSFMBAL.cpp b/timing/timeSFMBAL.cpp index 4a58a57a6..c1f36abd0 100644 --- a/timing/timeSFMBAL.cpp +++ b/timing/timeSFMBAL.cpp @@ -36,7 +36,7 @@ int main(int argc, char* argv[]) { // Build graph using conventional GeneralSFMFactor NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { for (const SfmMeasurement& m: db.tracks[j].measurements) { size_t i = m.first; Point2 z = m.second; diff --git a/timing/timeSFMBAL.h b/timing/timeSFMBAL.h index 548c4de70..7af798887 100644 --- a/timing/timeSFMBAL.h +++ b/timing/timeSFMBAL.h @@ -16,6 +16,9 @@ * @date July 5, 2015 */ +#pragma once + +#include #include #include #include @@ -54,9 +57,7 @@ SfmData preamble(int argc, char* argv[]) { filename = argv[argc - 1]; else filename = findExampleDataFile("dubrovnik-16-22106-pre"); - bool success = readBAL(filename, db); - if (!success) throw runtime_error("Could not access file!"); - return db; + return SfmData::FromBalFile(filename); } // Create ordering and optimize @@ -73,8 +74,8 @@ int optimize(const SfmData& db, const NonlinearFactorGraph& graph, if (gUseSchur) { // Create Schur-complement ordering Ordering ordering; - for (size_t j = 0; j < db.number_tracks(); j++) ordering.push_back(P(j)); - for (size_t i = 0; i < db.number_cameras(); i++) { + for (size_t j = 0; j < db.numberTracks(); j++) ordering.push_back(P(j)); + for (size_t i = 0; i < db.numberCameras(); i++) { ordering.push_back(C(i)); if (separateCalibration) ordering.push_back(K(i)); } diff --git a/timing/timeSFMBALautodiff.cpp b/timing/timeSFMBALautodiff.cpp index 2d0f4a1fe..1a7e35929 100644 --- a/timing/timeSFMBALautodiff.cpp +++ b/timing/timeSFMBALautodiff.cpp @@ -44,7 +44,7 @@ int main(int argc, char* argv[]) { // Build graph NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { for (const SfmMeasurement& m: db.tracks[j].measurements) { size_t i = m.first; Point2 z = m.second; @@ -59,7 +59,7 @@ int main(int argc, char* argv[]) { Values initial; size_t i = 0, j = 0; for (const SfmCamera& camera: db.cameras) { - // readBAL converts to GTSAM format, so we need to convert back ! + // SfmData::FromBalFile converts to GTSAM format, so we need to convert back ! Pose3 openGLpose = gtsam2openGL(camera.pose()); Vector9 v9; v9 << Pose3::Logmap(openGLpose), camera.calibration(); diff --git a/timing/timeSFMBALcamTnav.cpp b/timing/timeSFMBALcamTnav.cpp index 355defed9..a564a3a35 100644 --- a/timing/timeSFMBALcamTnav.cpp +++ b/timing/timeSFMBALcamTnav.cpp @@ -33,7 +33,7 @@ int main(int argc, char* argv[]) { // Build graph using conventional GeneralSFMFactor NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { for (const SfmMeasurement& m: db.tracks[j].measurements) { size_t i = m.first; Point2 z = m.second; diff --git a/timing/timeSFMBALnavTcam.cpp b/timing/timeSFMBALnavTcam.cpp index e602ef241..5299c8552 100644 --- a/timing/timeSFMBALnavTcam.cpp +++ b/timing/timeSFMBALnavTcam.cpp @@ -33,7 +33,7 @@ int main(int argc, char* argv[]) { // Build graph using conventional GeneralSFMFactor NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { Point3_ nav_point_(P(j)); for (const SfmMeasurement& m: db.tracks[j].measurements) { size_t i = m.first; diff --git a/timing/timeSFMBALsmart.cpp b/timing/timeSFMBALsmart.cpp index a69d895a5..fe2f7b925 100644 --- a/timing/timeSFMBALsmart.cpp +++ b/timing/timeSFMBALsmart.cpp @@ -35,7 +35,7 @@ int main(int argc, char* argv[]) { // Add smart factors to graph NonlinearFactorGraph graph; - for (size_t j = 0; j < db.number_tracks(); j++) { + for (size_t j = 0; j < db.numberTracks(); j++) { auto smartFactor = boost::make_shared(gNoiseModel); for (const SfmMeasurement& m : db.tracks[j].measurements) { size_t i = m.first; diff --git a/wrap/.github/workflows/macos-ci.yml b/wrap/.github/workflows/macos-ci.yml index 3910d28d8..8119a3acb 100644 --- a/wrap/.github/workflows/macos-ci.yml +++ b/wrap/.github/workflows/macos-ci.yml @@ -27,10 +27,12 @@ jobs: - name: Python Dependencies run: | + pip3 install -U pip setuptools pip3 install -r requirements.txt - name: Build and Test run: | + # Build cmake . cd tests # Use Pytest to run all the tests. diff --git a/wrap/DOCS.md b/wrap/DOCS.md index c8285baef..f08f741ff 100644 --- a/wrap/DOCS.md +++ b/wrap/DOCS.md @@ -133,9 +133,10 @@ The python wrapper supports keyword arguments for functions/methods. Hence, the template class Class2 { ... }; typedef Class2 MyInstantiatedClass; ``` - - Templates can also be defined for methods, properties and static methods. + - Templates can also be defined for constructors, methods, properties and static methods. - In the class definition, appearances of the template argument(s) will be replaced with their instantiated types, e.g. `void setValue(const T& value);`. + - Values scoped within templates are supported. E.g. one can use the form `T::Value` where T is a template, as an argument to a method. - To refer to the instantiation of the template class itself, use `This`, i.e. `static This Create();`. - To create new instantiations in other modules, you must copy-and-paste the whole class definition into the new module, but use only your new instantiation types. diff --git a/wrap/cmake/MatlabWrap.cmake b/wrap/cmake/MatlabWrap.cmake index 083b88566..3cb058102 100644 --- a/wrap/cmake/MatlabWrap.cmake +++ b/wrap/cmake/MatlabWrap.cmake @@ -62,10 +62,10 @@ macro(find_and_configure_matlab) endmacro() # Consistent and user-friendly wrap function -function(matlab_wrap interfaceHeader linkLibraries +function(matlab_wrap interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) find_and_configure_matlab() - wrap_and_install_library("${interfaceHeader}" "${linkLibraries}" + wrap_and_install_library("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${extraMexFlags}" "${ignore_classes}") endfunction() @@ -77,6 +77,7 @@ endfunction() # Arguments: # # interfaceHeader: The relative path to the wrapper interface definition file. +# moduleName: The name of the wrapped module, e.g. gtsam # linkLibraries: Any *additional* libraries to link. Your project library # (e.g. `lba`), libraries it depends on, and any necessary MATLAB libraries will # be linked automatically. So normally, leave this empty. @@ -85,15 +86,15 @@ endfunction() # extraMexFlags: Any *additional* flags to pass to the compiler when building # the wrap code. Normally, leave this empty. # ignore_classes: List of classes to ignore in the wrapping. -function(wrap_and_install_library interfaceHeader linkLibraries +function(wrap_and_install_library interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) - wrap_library_internal("${interfaceHeader}" "${linkLibraries}" + wrap_library_internal("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${mexFlags}") - install_wrapped_library_internal("${interfaceHeader}") + install_wrapped_library_internal("${moduleName}") endfunction() # Internal function that wraps a library and compiles the wrapper -function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs +function(wrap_library_internal interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags) if(UNIX AND NOT APPLE) if(CMAKE_SIZEOF_VOID_P EQUAL 8) @@ -120,7 +121,6 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # Extract module name from interface header file name get_filename_component(interfaceHeader "${interfaceHeader}" ABSOLUTE) get_filename_component(modulePath "${interfaceHeader}" PATH) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) # Paths for generated files set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") @@ -136,8 +136,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # explicit link libraries list so that the next block of code can unpack any # static libraries set(automaticDependencies "") - foreach(lib ${moduleName} ${linkLibraries}) - # message("MODULE NAME: ${moduleName}") + foreach(lib ${module} ${linkLibraries}) if(TARGET "${lib}") get_target_property(dependentLibraries ${lib} INTERFACE_LINK_LIBRARIES) # message("DEPENDENT LIBRARIES: ${dependentLibraries}") @@ -176,7 +175,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs set(otherLibraryTargets "") set(otherLibraryNontargets "") set(otherSourcesAndObjects "") - foreach(lib ${moduleName} ${linkLibraries} ${automaticDependencies}) + foreach(lib ${module} ${linkLibraries} ${automaticDependencies}) if(TARGET "${lib}") if(WRAP_MEX_BUILD_STATIC_MODULE) get_target_property(target_sources ${lib} SOURCES) @@ -250,7 +249,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" - ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src ${interfaceHeader} + ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src "${interfaceHeader}" --module_name ${moduleName} --out ${generated_files_path} --top_module_namespaces ${moduleName} --ignore ${ignore_classes} VERBATIM @@ -324,8 +323,8 @@ endfunction() # Internal function that installs a wrap toolbox function(install_wrapped_library_internal interfaceHeader) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) - set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") + get_filename_component(module "${interfaceHeader}" NAME_WE) + set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${module}") # NOTE: only installs .m and mex binary files (not .cpp) - the trailing slash # on the directory name here prevents creating the top-level module name diff --git a/wrap/cmake/PybindWrap.cmake b/wrap/cmake/PybindWrap.cmake index f341c2f98..2008bf2dd 100644 --- a/wrap/cmake/PybindWrap.cmake +++ b/wrap/cmake/PybindWrap.cmake @@ -55,15 +55,44 @@ function( set(GTWRAP_PATH_SEPARATOR ";") endif() + # Create a copy of interface_headers so we can freely manipulate it + set(interface_files ${interface_headers}) + + # Pop the main interface file so that interface_files has only submodules. + list(POP_FRONT interface_files main_interface) + # Convert .i file names to .cpp file names. - foreach(filepath ${interface_headers}) - get_filename_component(interface ${filepath} NAME) - string(REPLACE ".i" ".cpp" cpp_file ${interface}) + foreach(interface_file ${interface_files}) + # This block gets the interface file name and does the replacement + get_filename_component(interface ${interface_file} NAME_WLE) + set(cpp_file "${interface}.cpp") list(APPEND cpp_files ${cpp_file}) + + # Wrap the specific interface header + # This is done so that we can create CMake dependencies in such a way so that when changing a single .i file, + # the others don't need to be regenerated. + # NOTE: We have to use `add_custom_command` so set the dependencies correctly. + # https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes + add_custom_command( + OUTPUT ${cpp_file} + COMMAND + ${CMAKE_COMMAND} -E env + "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" + ${PYTHON_EXECUTABLE} ${PYBIND_WRAP_SCRIPT} --src "${interface_file}" + --out "${cpp_file}" --module_name ${module_name} + --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} + --template ${module_template} --is_submodule ${_WRAP_BOOST_ARG} + DEPENDS "${interface_file}" ${module_template} "${module_name}/specializations/${interface}.h" "${module_name}/preamble/${interface}.h" + VERBATIM) + endforeach() + get_filename_component(main_interface_name ${main_interface} NAME_WLE) + set(main_cpp_file "${main_interface_name}.cpp") + list(PREPEND cpp_files ${main_cpp_file}) + add_custom_command( - OUTPUT ${cpp_files} + OUTPUT ${main_cpp_file} COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" @@ -71,23 +100,10 @@ function( --out "${generated_cpp}" --module_name ${module_name} --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} --template ${module_template} ${_WRAP_BOOST_ARG} - DEPENDS "${interface_headers}" ${module_template} + DEPENDS "${main_interface}" ${module_template} "${module_name}/specializations/${main_interface_name}.h" "${module_name}/specializations/${main_interface_name}.h" VERBATIM) - add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${cpp_files}) - - # Late dependency injection, to make sure this gets called whenever the - # interface header or the wrap library are updated. - # ~~~ - # See: https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes - # ~~~ - add_custom_command( - OUTPUT ${cpp_files} - DEPENDS ${interface_headers} - # @GTWRAP_SOURCE_DIR@/gtwrap/interface_parser.py - # @GTWRAP_SOURCE_DIR@/gtwrap/pybind_wrapper.py - # @GTWRAP_SOURCE_DIR@/gtwrap/template_instantiator.py - APPEND) + add_custom_target(pybind_wrap_${module_name} DEPENDS ${cpp_files}) pybind11_add_module(${target} "${cpp_files}") @@ -192,9 +208,9 @@ function(install_python_files source_files dest_directory) endfunction() # ~~~ -# https://stackoverflow.com/questions/13959434/cmake-out-of-source-build-python-files +# Copy over the directory from source_folder to dest_foler # ~~~ -function(create_symlinks source_folder dest_folder) +function(copy_directory source_folder dest_folder) if(${source_folder} STREQUAL ${dest_folder}) return() endif() @@ -215,31 +231,13 @@ function(create_symlinks source_folder dest_folder) # Create REAL folder file(MAKE_DIRECTORY "${dest_folder}") - # Delete symlink if it exists + # Delete if it exists file(REMOVE "${dest_folder}/${path_file}") - # Get OS dependent path to use in `execute_process` - file(TO_NATIVE_PATH "${dest_folder}/${path_file}" link) + # Get OS dependent path to use in copy file(TO_NATIVE_PATH "${source_folder}/${path_file}" target) - # cmake-format: off - if(UNIX) - set(command ln -s ${target} ${link}) - else() - set(command cmd.exe /c mklink ${link} ${target}) - endif() - # cmake-format: on - - execute_process( - COMMAND ${command} - RESULT_VARIABLE result - ERROR_VARIABLE output) - - if(NOT ${result} EQUAL 0) - message( - FATAL_ERROR - "Could not create symbolic link for: ${target} --> ${output}") - endif() + file(COPY ${target} DESTINATION ${dest_folder}) endforeach(path_file) -endfunction(create_symlinks) +endfunction(copy_directory) diff --git a/wrap/gtwrap/interface_parser/classes.py b/wrap/gtwrap/interface_parser/classes.py index 841c963c2..54beb86c1 100644 --- a/wrap/gtwrap/interface_parser/classes.py +++ b/wrap/gtwrap/interface_parser/classes.py @@ -62,6 +62,10 @@ class Method: self.parent = parent + def to_cpp(self) -> str: + """Generate the C++ code for wrapping.""" + return self.name + def __repr__(self) -> str: return "Method: {} {} {}({}){}".format( self.template, @@ -84,7 +88,8 @@ class StaticMethod: ``` """ rule = ( - STATIC # + Optional(Template.rule("template")) # + + STATIC # + ReturnType.rule("return_type") # + IDENT("name") # + LPAREN # @@ -92,16 +97,18 @@ class StaticMethod: + RPAREN # + SEMI_COLON # BR ).setParseAction( - lambda t: StaticMethod(t.name, t.return_type, t.args_list)) + lambda t: StaticMethod(t.name, t.return_type, t.args_list, t.template)) def __init__(self, name: str, return_type: ReturnType, args: ArgumentList, + template: Union[Template, Any] = None, parent: Union["Class", Any] = ''): self.name = name self.return_type = return_type self.args = args + self.template = template self.parent = parent @@ -221,8 +228,8 @@ class Class: Rule for all the members within a class. """ rule = ZeroOrMore(Constructor.rule # - ^ StaticMethod.rule # ^ Method.rule # + ^ StaticMethod.rule # ^ Variable.rule # ^ Operator.rule # ^ Enum.rule # diff --git a/wrap/gtwrap/interface_parser/type.py b/wrap/gtwrap/interface_parser/type.py index 49315cc56..7aacf0b81 100644 --- a/wrap/gtwrap/interface_parser/type.py +++ b/wrap/gtwrap/interface_parser/type.py @@ -53,6 +53,10 @@ class Typename: self.name = t[-1] # the name is the last element in this list self.namespaces = t[:-1] + # If the first namespace is empty string, just get rid of it. + if self.namespaces and self.namespaces[0] == '': + self.namespaces.pop(0) + if instantiations: if isinstance(instantiations, Sequence): self.instantiations = instantiations # type: ignore @@ -92,8 +96,8 @@ class Typename: else: cpp_name = self.name return '{}{}{}'.format( - "::".join(self.namespaces[idx:]), - "::" if self.namespaces[idx:] else "", + "::".join(self.namespaces), + "::" if self.namespaces else "", cpp_name, ) @@ -158,6 +162,8 @@ class Type: """ Parsed datatype, can be either a fundamental type or a custom datatype. E.g. void, double, size_t, Matrix. + Think of this as a high-level type which encodes the typename and other + characteristics of the type. The type can optionally be a raw pointer, shared pointer or reference. Can also be optionally qualified with a `const`, e.g. `const int`. @@ -240,6 +246,9 @@ class Type: or self.typename.name in ["Matrix", "Vector"]) else "", typename=typename)) + def get_typename(self): + """Convenience method to get the typename of this type.""" + return self.typename.name class TemplatedType: """ diff --git a/wrap/gtwrap/matlab_wrapper/mixins.py b/wrap/gtwrap/matlab_wrapper/mixins.py index 217801ff3..4c2b005b7 100644 --- a/wrap/gtwrap/matlab_wrapper/mixins.py +++ b/wrap/gtwrap/matlab_wrapper/mixins.py @@ -26,25 +26,30 @@ class CheckMixin: return True return False + def can_be_pointer(self, arg_type: parser.Type): + """ + Determine if the `arg_type` can have a pointer to it. + + E.g. `Pose3` can have `Pose3*` but + `Matrix` should not have `Matrix*`. + """ + return (arg_type.typename.name not in self.not_ptr_type + and arg_type.typename.name not in self.ignore_namespace + and arg_type.typename.name != 'string') + def is_shared_ptr(self, arg_type: parser.Type): """ Determine if the `interface_parser.Type` should be treated as a shared pointer in the wrapper. """ - return arg_type.is_shared_ptr or ( - arg_type.typename.name not in self.not_ptr_type - and arg_type.typename.name not in self.ignore_namespace - and arg_type.typename.name != 'string') + return arg_type.is_shared_ptr def is_ptr(self, arg_type: parser.Type): """ Determine if the `interface_parser.Type` should be treated as a raw pointer in the wrapper. """ - return arg_type.is_ptr or ( - arg_type.typename.name not in self.not_ptr_type - and arg_type.typename.name not in self.ignore_namespace - and arg_type.typename.name != 'string') + return arg_type.is_ptr def is_ref(self, arg_type: parser.Type): """ @@ -108,11 +113,11 @@ class FormatMixin: elif is_method: formatted_type_name += self.data_type_param.get(name) or name else: - formatted_type_name += name + formatted_type_name += str(name) if separator == "::": # C++ templates = [] - for idx in range(len(type_name.instantiations)): + for idx, _ in enumerate(type_name.instantiations): template = '{}'.format( self._format_type_name(type_name.instantiations[idx], include_namespace=include_namespace, @@ -124,7 +129,7 @@ class FormatMixin: formatted_type_name += '<{}>'.format(','.join(templates)) else: - for idx in range(len(type_name.instantiations)): + for idx, _ in enumerate(type_name.instantiations): formatted_type_name += '{}'.format( self._format_type_name(type_name.instantiations[idx], separator=separator, @@ -192,10 +197,9 @@ class FormatMixin: method = '' if isinstance(static_method, parser.StaticMethod): - method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ - separator + static_method.parent.name + separator + method += static_method.parent.to_cpp() + separator - return method[2 * len(separator):] + return method def _format_global_function(self, function: Union[parser.GlobalFunction, Any], diff --git a/wrap/gtwrap/matlab_wrapper/templates.py b/wrap/gtwrap/matlab_wrapper/templates.py index 7aaf8f487..3d1306dca 100644 --- a/wrap/gtwrap/matlab_wrapper/templates.py +++ b/wrap/gtwrap/matlab_wrapper/templates.py @@ -66,7 +66,7 @@ class WrapperTemplate: mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) {{ + if(mexPutVariable("global", "gtsam_{module_name}_rttiRegistry_created", newAlreadyCreated) != 0) {{ mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); }} mxDestroyArray(newAlreadyCreated); diff --git a/wrap/gtwrap/matlab_wrapper/wrapper.py b/wrap/gtwrap/matlab_wrapper/wrapper.py index 97945f73a..e690cd213 100755 --- a/wrap/gtwrap/matlab_wrapper/wrapper.py +++ b/wrap/gtwrap/matlab_wrapper/wrapper.py @@ -5,6 +5,7 @@ that Matlab's MEX compiler can use. # pylint: disable=too-many-lines, no-self-use, too-many-arguments, too-many-branches, too-many-statements +import copy import os import os.path as osp import textwrap @@ -13,6 +14,7 @@ from typing import Dict, Iterable, List, Union import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator +from gtwrap.interface_parser.function import ArgumentList from gtwrap.matlab_wrapper.mixins import CheckMixin, FormatMixin from gtwrap.matlab_wrapper.templates import WrapperTemplate @@ -137,6 +139,40 @@ class MatlabWrapper(CheckMixin, FormatMixin): """ return x + '\n' + ('' if y == '' else ' ') + y + @staticmethod + def _expand_default_arguments(method, save_backup=True): + """Recursively expand all possibilities for optional default arguments. + We create "overload" functions with fewer arguments, but since we have to "remember" what + the default arguments are for later, we make a backup. + """ + def args_copy(args): + return ArgumentList([copy.copy(arg) for arg in args.list()]) + + def method_copy(method): + method2 = copy.copy(method) + method2.args = args_copy(method.args) + method2.args.backup = method.args.backup + return method2 + + if save_backup: + method.args.backup = args_copy(method.args) + method = method_copy(method) + for arg in reversed(method.args.list()): + if arg.default is not None: + arg.default = None + methodWithArg = method_copy(method) + method.args.list().remove(arg) + return [ + methodWithArg, + *MatlabWrapper._expand_default_arguments(method, + save_backup=False) + ] + break + assert all(arg.default is None for arg in method.args.list()), \ + 'In parsing method {:}: Arguments with default values cannot appear before ones ' \ + 'without default values.'.format(method.name) + return [method] + def _group_methods(self, methods): """Group overloaded methods together""" method_map = {} @@ -147,9 +183,12 @@ class MatlabWrapper(CheckMixin, FormatMixin): if method_index is None: method_map[method.name] = len(method_out) - method_out.append([method]) + method_out.append( + MatlabWrapper._expand_default_arguments(method)) else: - method_out[method_index].append(method) + method_out[ + method_index] += MatlabWrapper._expand_default_arguments( + method) return method_out @@ -239,18 +278,18 @@ class MatlabWrapper(CheckMixin, FormatMixin): return var_list_wrap - def _wrap_method_check_statement(self, args): + def _wrap_method_check_statement(self, args: parser.ArgumentList): """ Wrap the given arguments into either just a varargout call or a call in an if statement that checks if the parameters are accurate. + + TODO Update this method so that default arguments are supported. """ - check_statement = '' arg_id = 1 - if check_statement == '': - check_statement = \ - 'if length(varargin) == {param_count}'.format( - param_count=len(args.list())) + param_count = len(args) + check_statement = 'if length(varargin) == {param_count}'.format( + param_count=param_count) for _, arg in enumerate(args.list()): name = arg.ctype.typename.name @@ -301,56 +340,70 @@ class MatlabWrapper(CheckMixin, FormatMixin): ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");), ((a), std::shared_ptr p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");) """ - params = '' body_args = '' for arg in args.list(): + ctype_camel = self._format_type_name(arg.ctype.typename, + separator='') + ctype_sep = self._format_type_name(arg.ctype.typename) + + if self.is_ref(arg.ctype): # and not constructor: + arg_type = "{ctype}&".format(ctype=ctype_sep) + unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format( + ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id) + + elif self.is_ptr(arg.ctype) and \ + arg.ctype.typename.name not in self.ignore_namespace: + + arg_type = "{ctype_sep}*".format(ctype_sep=ctype_sep) + unwrap = 'unwrap_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");'.format( + ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id) + + elif (self.is_shared_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \ + arg.ctype.typename.name not in self.ignore_namespace: + call_type = arg.ctype.is_shared_ptr + + arg_type = "{std_boost}::shared_ptr<{ctype_sep}>".format( + std_boost='boost' if constructor else 'boost', + ctype_sep=ctype_sep) + unwrap = 'unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");'.format( + ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id) + + else: + arg_type = "{ctype}".format(ctype=arg.ctype.typename.name) + unwrap = 'unwrap< {ctype} >(in[{id}]);'.format( + ctype=arg.ctype.typename.name, id=arg_id) + + body_args += textwrap.indent(textwrap.dedent('''\ + {arg_type} {name} = {unwrap} + '''.format(arg_type=arg_type, name=arg.name, + unwrap=unwrap)), + prefix=' ') + arg_id += 1 + + params = '' + explicit_arg_names = [arg.name for arg in args.list()] + # when returning the params list, we need to re-include the default args. + for arg in args.backup.list(): if params != '': params += ',' - if self.is_ref(arg.ctype): # and not constructor: - ctype_camel = self._format_type_name(arg.ctype.typename, - separator='') - body_args += textwrap.indent(textwrap.dedent('''\ - {ctype}& {name} = *unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}"); - '''.format(ctype=self._format_type_name(arg.ctype.typename), - ctype_camel=ctype_camel, - name=arg.name, - id=arg_id)), - prefix=' ') + if (arg.default is not None) and (arg.name + not in explicit_arg_names): + params += arg.default + continue - elif (self.is_shared_ptr(arg.ctype) or self.is_ptr(arg.ctype)) and \ + if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \ + self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype))and \ arg.ctype.typename.name not in self.ignore_namespace: if arg.ctype.is_shared_ptr: call_type = arg.ctype.is_shared_ptr else: call_type = arg.ctype.is_ptr - - body_args += textwrap.indent(textwrap.dedent('''\ - {std_boost}::shared_ptr<{ctype_sep}> {name} = unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}"); - '''.format(std_boost='boost' if constructor else 'boost', - ctype_sep=self._format_type_name( - arg.ctype.typename), - ctype=self._format_type_name(arg.ctype.typename, - separator=''), - name=arg.name, - id=arg_id)), - prefix=' ') if call_type == "": params += "*" - - else: - body_args += textwrap.indent(textwrap.dedent('''\ - {ctype} {name} = unwrap< {ctype} >(in[{id}]); - '''.format(ctype=arg.ctype.typename.name, - name=arg.name, - id=arg_id)), - prefix=' ') - params += arg.name - arg_id += 1 - return params, body_args @staticmethod @@ -555,6 +608,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): if not isinstance(ctors, Iterable): ctors = [ctors] + ctors = sum((MatlabWrapper._expand_default_arguments(ctor) + for ctor in ctors), []) + methods_wrap = textwrap.indent(textwrap.dedent("""\ methods function obj = {class_name}(varargin) @@ -674,20 +730,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): def _group_class_methods(self, methods): """Group overloaded methods together""" - method_map = {} - method_out = [] - - for method in methods: - method_index = method_map.get(method.name) - - if method_index is None: - method_map[method.name] = len(method_out) - method_out.append([method]) - else: - # print("[_group_methods] Merging {} with {}".format(method_index, method.name)) - method_out[method_index].append(method) - - return method_out + return self._group_methods(methods) @classmethod def _format_varargout(cls, return_type, return_type_formatted): @@ -809,7 +852,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): for static_method in static_methods: format_name = list(static_method[0].name) - format_name[0] = format_name[0].upper() + format_name[0] = format_name[0] if static_method[0].name in self.ignore_methods: continue @@ -855,7 +898,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): end_statement=end_statement), prefix=' ') - #TODO Figure out what is static_overload doing here. + # If the arguments don't match any of the checks above, + # throw an error with the class and method name. method_text += textwrap.indent(textwrap.dedent("""\ error('Arguments do not match any overload of function {class_name}.{method_name}'); """.format(class_name=class_name, @@ -1043,7 +1087,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): pair_value = 'first' if func_id == 0 else 'second' new_line = '\n' if func_id == 0 else '' - if self.is_shared_ptr(return_type) or self.is_ptr(return_type): + if self.is_shared_ptr(return_type) or self.is_ptr(return_type) or \ + self.can_be_pointer(return_type): shared_obj = 'pairResult.' + pair_value if not (return_type.is_shared_ptr or return_type.is_ptr): @@ -1081,7 +1126,6 @@ class MatlabWrapper(CheckMixin, FormatMixin): obj_start = '' if isinstance(method, instantiator.InstantiatedMethod): - # method_name = method.original.name method_name = method.to_cpp() obj_start = 'obj->' @@ -1090,6 +1134,10 @@ class MatlabWrapper(CheckMixin, FormatMixin): # self._format_type_name(method.instantiations)) method = method.to_cpp() + elif isinstance(method, instantiator.InstantiatedStaticMethod): + method_name = self._format_static_method(method, '::') + method_name += method.original.name + elif isinstance(method, parser.GlobalFunction): method_name = self._format_global_function(method, '::') method_name += method.name @@ -1106,7 +1154,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): if return_1_name != 'void': if return_count == 1: - if self.is_shared_ptr(return_1) or self.is_ptr(return_1): + if self.is_shared_ptr(return_1) or self.is_ptr(return_1) or \ + self.can_be_pointer(return_1): sep_method_name = partial(self._format_type_name, return_1.typename, include_namespace=True) @@ -1230,9 +1279,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): Collector_{class_name}::iterator item; item = collector_{class_name}.find(self); if(item != collector_{class_name}.end()) {{ - delete self; collector_{class_name}.erase(item); }} + delete self; ''').format(class_name_sep=class_name_separated, class_name=class_name), prefix=' ') @@ -1250,7 +1299,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): method_name = '' if is_static_method: - method_name = self._format_static_method(extra) + '.' + method_name = self._format_static_method(extra, '.') method_name += extra.name @@ -1567,23 +1616,23 @@ class MatlabWrapper(CheckMixin, FormatMixin): def wrap(self, files, path): """High level function to wrap the project.""" + content = "" modules = {} for file in files: with open(file, 'r') as f: - content = f.read() + content += f.read() - # Parse the contents of the interface file - parsed_result = parser.Module.parseString(content) - # print(parsed_result) + # Parse the contents of the interface file + parsed_result = parser.Module.parseString(content) - # Instantiate the module - module = instantiator.instantiate_namespace(parsed_result) + # Instantiate the module + module = instantiator.instantiate_namespace(parsed_result) - if module.name in modules: - modules[module. - name].content[0].content += module.content[0].content - else: - modules[module.name] = module + if module.name in modules: + modules[ + module.name].content[0].content += module.content[0].content + else: + modules[module.name] = module for module in modules.values(): # Wrap the full namespace diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index 40571263a..31d8d4444 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -14,6 +14,7 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae import re from pathlib import Path +from typing import List import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator @@ -46,6 +47,11 @@ class PybindWrapper: # amount of indentation to add before each function/method declaration. self.method_indent = '\n' + (' ' * 8) + # Special methods which are leveraged by ipython/jupyter notebooks + self._ipython_special_methods = [ + "svg", "png", "jpeg", "html", "javascript", "markdown", "latex" + ] + def _py_args_names(self, args): """Set the argument names in Pybind11 format.""" names = args.names() @@ -86,45 +92,110 @@ class PybindWrapper: )) return res + def _wrap_serialization(self, cpp_class): + """Helper method to add serialize, deserialize and pickle methods to the wrapped class.""" + if not cpp_class in self._serializing_classes: + self._serializing_classes.append(cpp_class) + + serialize_method = self.method_indent + \ + ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') + + deserialize_method = self.method_indent + \ + '.def("deserialize", []({class_inst} self, string serialized)' \ + '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ + .format(class_inst=cpp_class + '*') + + # Since this class supports serialization, we also add the pickle method. + pickle_method = self.method_indent + \ + ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" + + return serialize_method + deserialize_method + \ + pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + + def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str, + args_names: List[str], args_signature_with_names: str, + py_args_names: str, prefix: str, suffix: str): + """ + Update the print method to print to the output stream and append a __repr__ method. + + Args: + ret (str): The result of the parser. + method (parser.Method): The method to be wrapped. + cpp_class (str): The C++ name of the class to which the method belongs. + args_names (List[str]): List of argument variable names passed to the method. + args_signature_with_names (str): C++ arguments containing their names and type signatures. + py_args_names (str): The pybind11 formatted version of the argument list. + prefix (str): Prefix to add to the wrapped method when writing to the cpp file. + suffix (str): Suffix to add to the wrapped method when writing to the cpp file. + + Returns: + str: The wrapped print method. + """ + # Redirect stdout - see pybind docs for why this is a good idea: + # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream + ret = ret.replace('self->print', + 'py::scoped_ostream_redirect output; self->print') + + # Make __repr__() call .print() internally + ret += '''{prefix}.def("__repr__", + [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ + gtsam::RedirectCout redirect; + self.{method_name}({method_args}); + return redirect.str(); + }}{py_args_names}){suffix}'''.format( + prefix=prefix, + cpp_class=cpp_class, + opt_comma=', ' if args_names else '', + args_signature_with_names=args_signature_with_names, + method_name=method.name, + method_args=", ".join(args_names) if args_names else '', + py_args_names=py_args_names, + suffix=suffix) + return ret + def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + """ + Wrap the `method` for the class specified by `cpp_class`. + + Args: + method: The method to wrap. + cpp_class: The C++ name of the class to which the method belongs. + prefix: Prefix to add to the wrapped method when writing to the cpp file. + suffix: Suffix to add to the wrapped method when writing to the cpp file. + method_suffix: A string to append to the wrapped method name. + """ py_method = method.name + method_suffix cpp_method = method.to_cpp() - if cpp_method in ["serialize", "serializable"]: - if not cpp_class in self._serializing_classes: - self._serializing_classes.append(cpp_class) - serialize_method = self.method_indent + \ - ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') - deserialize_method = self.method_indent + \ - '.def("deserialize", []({class_inst} self, string serialized)' \ - '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ - .format(class_inst=cpp_class + '*') - return serialize_method + deserialize_method + args_names = method.args.names() + py_args_names = self._py_args_names(method.args) + args_signature_with_names = self._method_args_signature(method.args) - if cpp_method == "pickle": - if not cpp_class in self._serializing_classes: - raise ValueError( - "Cannot pickle a class which is not serializable") - pickle_method = self.method_indent + \ - ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, - indent=self.method_indent) + # Special handling for the serialize/serializable method + if cpp_method in ["serialize", "serializable"]: + return self._wrap_serialization(cpp_class) + + # Special handling of ipython specific methods + # https://ipython.readthedocs.io/en/stable/config/integrating.html + if cpp_method in self._ipython_special_methods: + idx = self._ipython_special_methods.index(cpp_method) + py_method = f"_repr_{self._ipython_special_methods[idx]}_" # Add underscore to disambiguate if the method name matches a python keyword if py_method in self.python_keywords: py_method = py_method + "_" - is_method = isinstance(method, instantiator.InstantiatedMethod) - is_static = isinstance(method, parser.StaticMethod) + is_method = isinstance( + method, (parser.Method, instantiator.InstantiatedMethod)) + is_static = isinstance( + method, + (parser.StaticMethod, instantiator.InstantiatedStaticMethod)) return_void = method.return_type.is_void() - args_names = method.args.names() - py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature(method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{method_name}' @@ -155,27 +226,9 @@ class PybindWrapper: # Create __repr__ override # We allow all arguments to .print() and let the compiler handle type mismatches. if method.name == 'print': - # Redirect stdout - see pybind docs for why this is a good idea: - # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace( - 'self->print', - 'py::scoped_ostream_redirect output; self->print') - - # Make __repr__() call .print() internally - ret += '''{prefix}.def("__repr__", - [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ - gtsam::RedirectCout redirect; - self.{method_name}({method_args}); - return redirect.str(); - }}{py_args_names}){suffix}'''.format( - prefix=prefix, - cpp_class=cpp_class, - opt_comma=', ' if args_names else '', - args_signature_with_names=args_signature_with_names, - method_name=method.name, - method_args=", ".join(args_names) if args_names else '', - py_args_names=py_args_names, - suffix=suffix) + ret = self._wrap_print(ret, method, cpp_class, args_names, + args_signature_with_names, py_args_names, + prefix, suffix) return ret @@ -359,7 +412,7 @@ class PybindWrapper: def wrap_instantiated_declaration( self, instantiated_decl: instantiator.InstantiatedDeclaration): - """Wrap the class.""" + """Wrap the forward declaration.""" module_var = self._gen_module_var(instantiated_decl.namespaces()) cpp_class = instantiated_decl.to_cpp() if cpp_class in self.ignore_classes: @@ -367,7 +420,7 @@ class PybindWrapper: res = ( '\n py::class_<{cpp_class}, ' - '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' + '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}");' ).format(shared_ptr_type=('boost' if self.use_boost else 'std'), cpp_class=cpp_class, class_name=instantiated_decl.name, @@ -621,28 +674,47 @@ class PybindWrapper: submodules_init="\n".join(submodules_init), ) - def wrap(self, sources, main_output): + def wrap_submodule(self, source): """ - Wrap all the source interface files. + Wrap a list of submodule files, i.e. a set of interface files which are + in support of a larger wrapping project. + + E.g. This is used in GTSAM where we have a main gtsam.i, but various smaller .i files + which are the submodules. + The benefit of this scheme is that it reduces compute and memory usage during compilation. + + Args: + source: Interface file which forms the submodule. + """ + filename = Path(source).name + module_name = Path(source).stem + + # Read in the complete interface (.i) file + with open(source, "r") as f: + content = f.read() + # Wrap the read-in content + cc_content = self.wrap_file(content, module_name=module_name) + + # Generate the C++ code which Pybind11 will use. + with open(filename.replace(".i", ".cpp"), "w") as f: + f.write(cc_content) + + def wrap(self, sources, main_module_name): + """ + Wrap all the main interface file. Args: sources: List of all interface files. - main_output: The name for the main module. + The first file should be the main module. + main_module_name: The name for the main module. """ main_module = sources[0] + + # Get all the submodule names. submodules = [] for source in sources[1:]: - filename = Path(source).name module_name = Path(source).stem - # Read in the complete interface (.i) file - with open(source, "r") as f: - content = f.read() submodules.append(module_name) - cc_content = self.wrap_file(content, module_name=module_name) - - # Generate the C++ code which Pybind11 will use. - with open(filename.replace(".i", ".cpp"), "w") as f: - f.write(cc_content) with open(main_module, "r") as f: content = f.read() @@ -651,5 +723,5 @@ class PybindWrapper: submodules=submodules) # Generate the C++ code which Pybind11 will use. - with open(main_output, "w") as f: + with open(main_module_name, "w") as f: f.write(cc_content) diff --git a/wrap/gtwrap/template_instantiator.py b/wrap/gtwrap/template_instantiator.py deleted file mode 100644 index f5beb0c69..000000000 --- a/wrap/gtwrap/template_instantiator.py +++ /dev/null @@ -1,681 +0,0 @@ -"""Code to help instantiate templated classes, methods and functions.""" - -# pylint: disable=too-many-arguments, too-many-instance-attributes, no-self-use, no-else-return, too-many-arguments, unused-format-string-argument, unused-variable - -import itertools -from copy import deepcopy -from typing import Any, Iterable, List, Sequence - -import gtwrap.interface_parser as parser - - -def instantiate_type(ctype: parser.Type, - template_typenames: List[str], - instantiations: List[parser.Typename], - cpp_typename: parser.Typename, - instantiated_class=None): - """ - Instantiate template typename for @p ctype. - - Args: - instiated_class (InstantiatedClass): - - @return If ctype's name is in the @p template_typenames, return the - corresponding type to replace in @p instantiations. - If ctype name is `This`, return the new typename @p `cpp_typename`. - Otherwise, return the original ctype. - """ - # make a deep copy so that there is no overwriting of original template params - ctype = deepcopy(ctype) - - # Check if the return type has template parameters - if ctype.typename.instantiations: - for idx, instantiation in enumerate(ctype.typename.instantiations): - if instantiation.name in template_typenames: - template_idx = template_typenames.index(instantiation.name) - ctype.typename.instantiations[ - idx] = instantiations[ # type: ignore - template_idx] - - return ctype - - str_arg_typename = str(ctype.typename) - - # Instantiate templates which have enumerated instantiations in the template. - # E.g. `template`. - if str_arg_typename in template_typenames: - idx = template_typenames.index(str_arg_typename) - return parser.Type( - typename=instantiations[idx], - is_const=ctype.is_const, - is_shared_ptr=ctype.is_shared_ptr, - is_ptr=ctype.is_ptr, - is_ref=ctype.is_ref, - is_basic=ctype.is_basic, - ) - - # If a method has the keyword `This`, we replace it with the (instantiated) class. - elif str_arg_typename == 'This': - # Check if the class is template instantiated - # so we can replace it with the instantiated version. - if instantiated_class: - name = instantiated_class.original.name - namespaces_name = instantiated_class.namespaces() - namespaces_name.append(name) - cpp_typename = parser.Typename( - namespaces_name, - instantiations=instantiated_class.instantiations) - - return parser.Type( - typename=cpp_typename, - is_const=ctype.is_const, - is_shared_ptr=ctype.is_shared_ptr, - is_ptr=ctype.is_ptr, - is_ref=ctype.is_ref, - is_basic=ctype.is_basic, - ) - - # Case when 'This' is present in the type namespace, e.g `This::Subclass`. - elif 'This' in str_arg_typename: - # Simply get the index of `This` in the namespace and replace it with the instantiated name. - namespace_idx = ctype.typename.namespaces.index('This') - ctype.typename.namespaces[namespace_idx] = cpp_typename.name - return ctype - - else: - return ctype - - -def instantiate_args_list(args_list, template_typenames, instantiations, - cpp_typename): - """ - Instantiate template typenames in an argument list. - Type with name `This` will be replaced by @p `cpp_typename`. - - @param[in] args_list A list of `parser.Argument` to instantiate. - @param[in] template_typenames List of template typenames to instantiate, - e.g. ['T', 'U', 'V']. - @param[in] instantiations List of specific types to instantiate, each - associated with each template typename. Each type is a parser.Typename, - including its name and full namespaces. - @param[in] cpp_typename Full-namespace cpp class name of this instantiation - to replace for arguments of type named `This`. - @return A new list of parser.Argument which types are replaced with their - instantiations. - """ - instantiated_args = [] - for arg in args_list: - new_type = instantiate_type(arg.ctype, template_typenames, - instantiations, cpp_typename) - instantiated_args.append( - parser.Argument(name=arg.name, ctype=new_type, - default=arg.default)) - return instantiated_args - - -def instantiate_return_type(return_type, - template_typenames, - instantiations, - cpp_typename, - instantiated_class=None): - """Instantiate the return type.""" - new_type1 = instantiate_type(return_type.type1, - template_typenames, - instantiations, - cpp_typename, - instantiated_class=instantiated_class) - if return_type.type2: - new_type2 = instantiate_type(return_type.type2, - template_typenames, - instantiations, - cpp_typename, - instantiated_class=instantiated_class) - else: - new_type2 = '' - return parser.ReturnType(new_type1, new_type2) - - -def instantiate_name(original_name, instantiations): - """ - Concatenate instantiated types with an @p original name to form a new - instantiated name. - TODO(duy): To avoid conflicts, we should include the instantiation's - namespaces, but I find that too verbose. - """ - instantiated_names = [] - for inst in instantiations: - # Ensure the first character of the type is capitalized - name = inst.instantiated_name() - # Using `capitalize` on the complete name causes other caps to be lower case - instantiated_names.append(name.replace(name[0], name[0].capitalize())) - - return "{}{}".format(original_name, "".join(instantiated_names)) - - -class InstantiatedGlobalFunction(parser.GlobalFunction): - """ - Instantiate global functions. - - E.g. - template - T add(const T& x, const T& y); - """ - def __init__(self, original, instantiations=(), new_name=''): - self.original = original - self.instantiations = instantiations - self.template = '' - self.parent = original.parent - - if not original.template: - self.name = original.name - self.return_type = original.return_type - self.args = original.args - else: - self.name = instantiate_name( - original.name, instantiations) if not new_name else new_name - self.return_type = instantiate_return_type( - original.return_type, - self.original.template.typenames, - self.instantiations, - # Keyword type name `This` should already be replaced in the - # previous class template instantiation round. - cpp_typename='', - ) - instantiated_args = instantiate_args_list( - original.args.list(), - self.original.template.typenames, - self.instantiations, - # Keyword type name `This` should already be replaced in the - # previous class template instantiation round. - cpp_typename='', - ) - self.args = parser.ArgumentList(instantiated_args) - - super().__init__(self.name, - self.return_type, - self.args, - self.template, - parent=self.parent) - - def to_cpp(self): - """Generate the C++ code for wrapping.""" - if self.original.template: - instantiated_names = [ - inst.instantiated_name() for inst in self.instantiations - ] - ret = "{}<{}>".format(self.original.name, - ",".join(instantiated_names)) - else: - ret = self.original.name - return ret - - def __repr__(self): - return "Instantiated {}".format( - super(InstantiatedGlobalFunction, self).__repr__()) - - -class InstantiatedMethod(parser.Method): - """ - Instantiate method with template parameters. - - E.g. - class A { - template - void func(X x, Y y); - } - """ - def __init__(self, - original: parser.Method, - instantiations: Iterable[parser.Typename] = ()): - self.original = original - self.instantiations = instantiations - self.template: Any = '' - self.is_const = original.is_const - self.parent = original.parent - - # Check for typenames if templated. - # This way, we can gracefully handle both templated and non-templated methods. - typenames: Sequence = self.original.template.typenames if self.original.template else [] - self.name = instantiate_name(original.name, self.instantiations) - self.return_type = instantiate_return_type( - original.return_type, - typenames, - self.instantiations, - # Keyword type name `This` should already be replaced in the - # previous class template instantiation round. - cpp_typename='', - ) - - instantiated_args = instantiate_args_list( - original.args.list(), - typenames, - self.instantiations, - # Keyword type name `This` should already be replaced in the - # previous class template instantiation round. - cpp_typename='', - ) - self.args = parser.ArgumentList(instantiated_args) - - super().__init__(self.template, - self.name, - self.return_type, - self.args, - self.is_const, - parent=self.parent) - - def to_cpp(self): - """Generate the C++ code for wrapping.""" - if self.original.template: - # to_cpp will handle all the namespacing and templating - instantiation_list = [x.to_cpp() for x in self.instantiations] - # now can simply combine the instantiations, separated by commas - ret = "{}<{}>".format(self.original.name, - ",".join(instantiation_list)) - else: - ret = self.original.name - return ret - - def __repr__(self): - return "Instantiated {}".format( - super(InstantiatedMethod, self).__repr__()) - - -class InstantiatedClass(parser.Class): - """ - Instantiate the class defined in the interface file. - """ - def __init__(self, original: parser.Class, instantiations=(), new_name=''): - """ - Template - Instantiations: [T1, U1] - """ - self.original = original - self.instantiations = instantiations - - self.template = None - self.is_virtual = original.is_virtual - self.parent_class = original.parent_class - self.parent = original.parent - - # If the class is templated, check if the number of provided instantiations - # match the number of templates, else it's only a partial instantiation which is bad. - if original.template: - assert len(original.template.typenames) == len( - instantiations), "Typenames and instantiations mismatch!" - - # Get the instantiated name of the class. E.g. FuncDouble - self.name = instantiate_name( - original.name, instantiations) if not new_name else new_name - - # Check for typenames if templated. - # By passing in typenames, we can gracefully handle both templated and non-templated classes - # This will allow the `This` keyword to be used in both templated and non-templated classes. - typenames = self.original.template.typenames if self.original.template else [] - - # Instantiate the constructors, static methods, properties, respectively. - self.ctors = self.instantiate_ctors(typenames) - self.static_methods = self.instantiate_static_methods(typenames) - self.properties = self.instantiate_properties(typenames) - - # Instantiate all operator overloads - self.operators = self.instantiate_operators(typenames) - - # Set enums - self.enums = original.enums - - # Instantiate all instance methods - instantiated_methods = \ - self.instantiate_class_templates_in_methods(typenames) - - # Second instantiation round to instantiate templated methods. - # This is done in case both the class and the method are templated. - self.methods = [] - for method in instantiated_methods: - if not method.template: - self.methods.append(InstantiatedMethod(method, ())) - else: - instantiations = [] - # Get all combinations of template parameters - for instantiations in itertools.product( - *method.template.instantiations): - self.methods.append( - InstantiatedMethod(method, instantiations)) - - super().__init__( - self.template, - self.is_virtual, - self.name, - [self.parent_class], - self.ctors, - self.methods, - self.static_methods, - self.properties, - self.operators, - self.enums, - parent=self.parent, - ) - - def __repr__(self): - return "{virtual}Class {cpp_class} : {parent_class}\n"\ - "{ctors}\n{static_methods}\n{methods}\n{operators}".format( - virtual="virtual " if self.is_virtual else '', - cpp_class=self.to_cpp(), - parent_class=self.parent, - ctors="\n".join([repr(ctor) for ctor in self.ctors]), - static_methods="\n".join([repr(m) - for m in self.static_methods]), - methods="\n".join([repr(m) for m in self.methods]), - operators="\n".join([repr(op) for op in self.operators]) - ) - - def instantiate_ctors(self, typenames): - """ - Instantiate the class constructors. - - Args: - typenames: List of template types to instantiate. - - Return: List of constructors instantiated with provided template args. - """ - instantiated_ctors = [] - - def instantiate(instantiated_ctors, ctor, typenames, instantiations): - instantiated_args = instantiate_args_list( - ctor.args.list(), - typenames, - instantiations, - self.cpp_typename(), - ) - instantiated_ctors.append( - parser.Constructor( - name=self.name, - args=parser.ArgumentList(instantiated_args), - template=self.original.template, - parent=self, - )) - return instantiated_ctors - - for ctor in self.original.ctors: - # Add constructor templates to the typenames and instantiations - if isinstance(ctor.template, parser.template.Template): - typenames.extend(ctor.template.typenames) - - # Get all combinations of template args - for instantiations in itertools.product( - *ctor.template.instantiations): - instantiations = self.instantiations + list(instantiations) - - instantiated_ctors = instantiate( - instantiated_ctors, - ctor, - typenames=typenames, - instantiations=instantiations) - - else: - # If no constructor level templates, just use the class templates - instantiated_ctors = instantiate( - instantiated_ctors, - ctor, - typenames=typenames, - instantiations=self.instantiations) - return instantiated_ctors - - def instantiate_static_methods(self, typenames): - """ - Instantiate static methods in the class. - - Args: - typenames: List of template types to instantiate. - - Return: List of static methods instantiated with provided template args. - """ - instantiated_static_methods = [] - for static_method in self.original.static_methods: - instantiated_args = instantiate_args_list( - static_method.args.list(), typenames, self.instantiations, - self.cpp_typename()) - instantiated_static_methods.append( - parser.StaticMethod( - name=static_method.name, - return_type=instantiate_return_type( - static_method.return_type, - typenames, - self.instantiations, - self.cpp_typename(), - instantiated_class=self), - args=parser.ArgumentList(instantiated_args), - parent=self, - )) - return instantiated_static_methods - - def instantiate_class_templates_in_methods(self, typenames): - """ - This function only instantiates the class-level templates in the methods. - Template methods are instantiated in InstantiatedMethod in the second - round. - - E.g. - ``` - template - class Greeter{ - void sayHello(T& name); - }; - - Args: - typenames: List of template types to instantiate. - - Return: List of methods instantiated with provided template args on the class. - """ - class_instantiated_methods = [] - for method in self.original.methods: - instantiated_args = instantiate_args_list( - method.args.list(), - typenames, - self.instantiations, - self.cpp_typename(), - ) - class_instantiated_methods.append( - parser.Method( - template=method.template, - name=method.name, - return_type=instantiate_return_type( - method.return_type, - typenames, - self.instantiations, - self.cpp_typename(), - ), - args=parser.ArgumentList(instantiated_args), - is_const=method.is_const, - parent=self, - )) - return class_instantiated_methods - - def instantiate_operators(self, typenames): - """ - Instantiate the class-level template in the operator overload. - - Args: - typenames: List of template types to instantiate. - - Return: List of methods instantiated with provided template args on the class. - """ - instantiated_operators = [] - for operator in self.original.operators: - instantiated_args = instantiate_args_list( - operator.args.list(), - typenames, - self.instantiations, - self.cpp_typename(), - ) - instantiated_operators.append( - parser.Operator( - name=operator.name, - operator=operator.operator, - return_type=instantiate_return_type( - operator.return_type, - typenames, - self.instantiations, - self.cpp_typename(), - ), - args=parser.ArgumentList(instantiated_args), - is_const=operator.is_const, - parent=self, - )) - return instantiated_operators - - def instantiate_properties(self, typenames): - """ - Instantiate the class properties. - - Args: - typenames: List of template types to instantiate. - - Return: List of properties instantiated with provided template args. - """ - instantiated_properties = instantiate_args_list( - self.original.properties, - typenames, - self.instantiations, - self.cpp_typename(), - ) - return instantiated_properties - - def cpp_typename(self): - """ - Return a parser.Typename including namespaces and cpp name of this - class. - """ - if self.original.template: - name = "{}<{}>".format( - self.original.name, - ", ".join([inst.to_cpp() for inst in self.instantiations])) - else: - name = self.original.name - namespaces_name = self.namespaces() - namespaces_name.append(name) - return parser.Typename(namespaces_name) - - def to_cpp(self): - """Generate the C++ code for wrapping.""" - return self.cpp_typename().to_cpp() - - -class InstantiatedDeclaration(parser.ForwardDeclaration): - """ - Instantiate typedefs of forward declarations. - This is useful when we wish to typedef a templated class - which is not defined in the current project. - - E.g. - class FactorFromAnotherMother; - - typedef FactorFromAnotherMother FactorWeCanUse; - """ - def __init__(self, original, instantiations=(), new_name=''): - super().__init__(original.typename, - original.parent_type, - original.is_virtual, - parent=original.parent) - - self.original = original - self.instantiations = instantiations - self.parent = original.parent - - self.name = instantiate_name( - original.name, instantiations) if not new_name else new_name - - def to_cpp(self): - """Generate the C++ code for wrapping.""" - instantiated_names = [ - inst.qualified_name() for inst in self.instantiations - ] - name = "{}<{}>".format(self.original.name, - ",".join(instantiated_names)) - namespaces_name = self.namespaces() - namespaces_name.append(name) - # Leverage Typename to generate the fully qualified C++ name - return parser.Typename(namespaces_name).to_cpp() - - def __repr__(self): - return "Instantiated {}".format( - super(InstantiatedDeclaration, self).__repr__()) - - -def instantiate_namespace(namespace): - """ - Instantiate the classes and other elements in the `namespace` content and - assign it back to the namespace content attribute. - - @param[in/out] namespace The namespace whose content will be replaced with - the instantiated content. - """ - instantiated_content = [] - typedef_content = [] - - for element in namespace.content: - if isinstance(element, parser.Class): - original_class = element - if not original_class.template: - instantiated_content.append( - InstantiatedClass(original_class, [])) - else: - # This case is for when the templates have enumerated instantiations. - - # Use itertools to get all possible combinations of instantiations - # Works even if one template does not have an instantiation list - for instantiations in itertools.product( - *original_class.template.instantiations): - instantiated_content.append( - InstantiatedClass(original_class, - list(instantiations))) - - elif isinstance(element, parser.GlobalFunction): - original_func = element - if not original_func.template: - instantiated_content.append( - InstantiatedGlobalFunction(original_func, [])) - else: - # Use itertools to get all possible combinations of instantiations - # Works even if one template does not have an instantiation list - for instantiations in itertools.product( - *original_func.template.instantiations): - instantiated_content.append( - InstantiatedGlobalFunction(original_func, - list(instantiations))) - - elif isinstance(element, parser.TypedefTemplateInstantiation): - # This is for the case where `typedef` statements are used - # to specify the template parameters. - typedef_inst = element - top_level = namespace.top_level() - original_element = top_level.find_class_or_function( - typedef_inst.typename) - - # Check if element is a typedef'd class, function or - # forward declaration from another project. - if isinstance(original_element, parser.Class): - typedef_content.append( - InstantiatedClass(original_element, - typedef_inst.typename.instantiations, - typedef_inst.new_name)) - elif isinstance(original_element, parser.GlobalFunction): - typedef_content.append( - InstantiatedGlobalFunction( - original_element, typedef_inst.typename.instantiations, - typedef_inst.new_name)) - elif isinstance(original_element, parser.ForwardDeclaration): - typedef_content.append( - InstantiatedDeclaration( - original_element, typedef_inst.typename.instantiations, - typedef_inst.new_name)) - - elif isinstance(element, parser.Namespace): - element = instantiate_namespace(element) - instantiated_content.append(element) - else: - instantiated_content.append(element) - - instantiated_content.extend(typedef_content) - namespace.content = instantiated_content - - return namespace diff --git a/wrap/gtwrap/template_instantiator/__init__.py b/wrap/gtwrap/template_instantiator/__init__.py new file mode 100644 index 000000000..6a30bb3c3 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/__init__.py @@ -0,0 +1,14 @@ +"""Code to help instantiate templated classes, methods and functions.""" + +# pylint: disable=too-many-arguments, too-many-instance-attributes, no-self-use, no-else-return, too-many-arguments, unused-format-string-argument, unused-variable. unused-argument, too-many-branches + +from typing import Iterable, Sequence, Union + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.classes import * +from gtwrap.template_instantiator.constructor import * +from gtwrap.template_instantiator.declaration import * +from gtwrap.template_instantiator.function import * +from gtwrap.template_instantiator.helpers import * +from gtwrap.template_instantiator.method import * +from gtwrap.template_instantiator.namespace import * diff --git a/wrap/gtwrap/template_instantiator/classes.py b/wrap/gtwrap/template_instantiator/classes.py new file mode 100644 index 000000000..af366f80f --- /dev/null +++ b/wrap/gtwrap/template_instantiator/classes.py @@ -0,0 +1,206 @@ +"""Instantiate a class and its members.""" + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.constructor import InstantiatedConstructor +from gtwrap.template_instantiator.helpers import (InstantiationHelper, + instantiate_args_list, + instantiate_name, + instantiate_return_type) +from gtwrap.template_instantiator.method import (InstantiatedMethod, + InstantiatedStaticMethod) + + +class InstantiatedClass(parser.Class): + """ + Instantiate the class defined in the interface file. + """ + def __init__(self, original: parser.Class, instantiations=(), new_name=''): + """ + Template + Instantiations: [T1, U1] + """ + self.original = original + self.instantiations = instantiations + + self.template = None + self.is_virtual = original.is_virtual + self.parent_class = original.parent_class + self.parent = original.parent + + # If the class is templated, check if the number of provided instantiations + # match the number of templates, else it's only a partial instantiation which is bad. + if original.template: + assert len(original.template.typenames) == len( + instantiations), "Typenames and instantiations mismatch!" + + # Get the instantiated name of the class. E.g. FuncDouble + self.name = instantiate_name( + original.name, instantiations) if not new_name else new_name + + # Check for typenames if templated. + # By passing in typenames, we can gracefully handle both templated and non-templated classes + # This will allow the `This` keyword to be used in both templated and non-templated classes. + typenames = self.original.template.typenames if self.original.template else [] + + # Instantiate the constructors, static methods, properties, respectively. + self.ctors = self.instantiate_ctors(typenames) + self.static_methods = self.instantiate_static_methods(typenames) + self.properties = self.instantiate_properties(typenames) + + # Instantiate all operator overloads + self.operators = self.instantiate_operators(typenames) + + # Set enums + self.enums = original.enums + + # Instantiate all instance methods + self.methods = self.instantiate_methods(typenames) + + super().__init__( + self.template, + self.is_virtual, + self.name, + [self.parent_class], + self.ctors, + self.methods, + self.static_methods, + self.properties, + self.operators, + self.enums, + parent=self.parent, + ) + + def __repr__(self): + return "{virtual}Class {cpp_class} : {parent_class}\n"\ + "{ctors}\n{static_methods}\n{methods}\n{operators}".format( + virtual="virtual " if self.is_virtual else '', + cpp_class=self.to_cpp(), + parent_class=self.parent, + ctors="\n".join([repr(ctor) for ctor in self.ctors]), + static_methods="\n".join([repr(m) + for m in self.static_methods]), + methods="\n".join([repr(m) for m in self.methods]), + operators="\n".join([repr(op) for op in self.operators]) + ) + + def instantiate_ctors(self, typenames): + """ + Instantiate the class constructors. + + Args: + typenames: List of template types to instantiate. + + Return: List of constructors instantiated with provided template args. + """ + + helper = InstantiationHelper( + instantiation_type=InstantiatedConstructor) + + instantiated_ctors = helper.multilevel_instantiation( + self.original.ctors, typenames, self) + + return instantiated_ctors + + def instantiate_static_methods(self, typenames): + """ + Instantiate static methods in the class. + + Args: + typenames: List of template types to instantiate. + + Return: List of static methods instantiated with provided template args. + """ + helper = InstantiationHelper( + instantiation_type=InstantiatedStaticMethod) + + instantiated_static_methods = helper.multilevel_instantiation( + self.original.static_methods, typenames, self) + + return instantiated_static_methods + + def instantiate_methods(self, typenames): + """ + Instantiate regular methods in the class. + + Args: + typenames: List of template types to instantiate. + + Return: List of methods instantiated with provided template args. + """ + instantiated_methods = [] + + helper = InstantiationHelper(instantiation_type=InstantiatedMethod) + + instantiated_methods = helper.multilevel_instantiation( + self.original.methods, typenames, self) + + return instantiated_methods + + def instantiate_operators(self, typenames): + """ + Instantiate the class-level template in the operator overload. + + Args: + typenames: List of template types to instantiate. + + Return: List of methods instantiated with provided template args on the class. + """ + instantiated_operators = [] + for operator in self.original.operators: + instantiated_args = instantiate_args_list( + operator.args.list(), + typenames, + self.instantiations, + self.cpp_typename(), + ) + instantiated_operators.append( + parser.Operator( + name=operator.name, + operator=operator.operator, + return_type=instantiate_return_type( + operator.return_type, + typenames, + self.instantiations, + self.cpp_typename(), + ), + args=parser.ArgumentList(instantiated_args), + is_const=operator.is_const, + parent=self, + )) + return instantiated_operators + + def instantiate_properties(self, typenames): + """ + Instantiate the class properties. + + Args: + typenames: List of template types to instantiate. + + Return: List of properties instantiated with provided template args. + """ + instantiated_properties = instantiate_args_list( + self.original.properties, + typenames, + self.instantiations, + self.cpp_typename(), + ) + return instantiated_properties + + def cpp_typename(self): + """ + Return a parser.Typename including namespaces and cpp name of this + class. + """ + if self.original.template: + name = "{}<{}>".format( + self.original.name, + ", ".join([inst.to_cpp() for inst in self.instantiations])) + else: + name = self.original.name + namespaces_name = self.namespaces() + namespaces_name.append(name) + return parser.Typename(namespaces_name) + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + return self.cpp_typename().to_cpp() diff --git a/wrap/gtwrap/template_instantiator/constructor.py b/wrap/gtwrap/template_instantiator/constructor.py new file mode 100644 index 000000000..1ea7d22d5 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/constructor.py @@ -0,0 +1,64 @@ +"""Class constructor instantiator.""" + +# pylint: disable=unused-argument + +from typing import Iterable, List + +import gtwrap.interface_parser as parser + + +class InstantiatedConstructor(parser.Constructor): + """ + Instantiate constructor with template parameters. + + E.g. + class A { + template + A(X x, Y y); + } + """ + def __init__(self, + original: parser.Constructor, + instantiations: Iterable[parser.Typename] = ()): + self.original = original + self.instantiations = instantiations + self.name = original.name + self.args = original.args + self.template = original.template + self.parent = original.parent + + super().__init__(self.name, + self.args, + self.template, + parent=self.parent) + + @classmethod + def construct(cls, original: parser.Constructor, typenames: List[str], + class_instantiations: List[parser.Typename], + method_instantiations: List[parser.Typename], + instantiated_args: List[parser.Argument], + parent: 'InstantiatedClass'): + """Class method to construct object as required by InstantiationHelper.""" + method = parser.Constructor( + name=parent.name, + args=parser.ArgumentList(instantiated_args), + template=original.template, + parent=parent, + ) + return InstantiatedConstructor(method, + instantiations=method_instantiations) + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + if self.original.template: + # to_cpp will handle all the namespacing and templating + instantiation_list = [x.to_cpp() for x in self.instantiations] + # now can simply combine the instantiations, separated by commas + ret = "{}<{}>".format(self.original.name, + ",".join(instantiation_list)) + else: + ret = self.original.name + return ret + + def __repr__(self): + return "Instantiated {}".format(super().__repr__()) diff --git a/wrap/gtwrap/template_instantiator/declaration.py b/wrap/gtwrap/template_instantiator/declaration.py new file mode 100644 index 000000000..4fa6b75d8 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/declaration.py @@ -0,0 +1,45 @@ +"""Instantiate a forward declaration.""" + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.helpers import instantiate_name + + +class InstantiatedDeclaration(parser.ForwardDeclaration): + """ + Instantiate typedefs of forward declarations. + This is useful when we wish to typedef a templated class + which is not defined in the current project. + + E.g. + class FactorFromAnotherMother; + + typedef FactorFromAnotherMother FactorWeCanUse; + """ + def __init__(self, original, instantiations=(), new_name=''): + super().__init__(original.typename, + original.parent_type, + original.is_virtual, + parent=original.parent) + + self.original = original + self.instantiations = instantiations + self.parent = original.parent + + self.name = instantiate_name( + original.name, instantiations) if not new_name else new_name + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + instantiated_names = [ + inst.qualified_name() for inst in self.instantiations + ] + name = "{}<{}>".format(self.original.name, + ",".join(instantiated_names)) + namespaces_name = self.namespaces() + namespaces_name.append(name) + # Leverage Typename to generate the fully qualified C++ name + return parser.Typename(namespaces_name).to_cpp() + + def __repr__(self): + return "Instantiated {}".format( + super(InstantiatedDeclaration, self).__repr__()) diff --git a/wrap/gtwrap/template_instantiator/function.py b/wrap/gtwrap/template_instantiator/function.py new file mode 100644 index 000000000..3ad5da3f4 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/function.py @@ -0,0 +1,68 @@ +"""Instantiate global function.""" + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.helpers import (instantiate_args_list, + instantiate_name, + instantiate_return_type) + + +class InstantiatedGlobalFunction(parser.GlobalFunction): + """ + Instantiate global functions. + + E.g. + template + T add(const T& x, const T& y); + """ + def __init__(self, original, instantiations=(), new_name=''): + self.original = original + self.instantiations = instantiations + self.template = '' + self.parent = original.parent + + if not original.template: + self.name = original.name + self.return_type = original.return_type + self.args = original.args + else: + self.name = instantiate_name( + original.name, instantiations) if not new_name else new_name + self.return_type = instantiate_return_type( + original.return_type, + self.original.template.typenames, + self.instantiations, + # Keyword type name `This` should already be replaced in the + # previous class template instantiation round. + cpp_typename='', + ) + instantiated_args = instantiate_args_list( + original.args.list(), + self.original.template.typenames, + self.instantiations, + # Keyword type name `This` should already be replaced in the + # previous class template instantiation round. + cpp_typename='', + ) + self.args = parser.ArgumentList(instantiated_args) + + super().__init__(self.name, + self.return_type, + self.args, + self.template, + parent=self.parent) + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + if self.original.template: + instantiated_names = [ + inst.instantiated_name() for inst in self.instantiations + ] + ret = "{}<{}>".format(self.original.name, + ",".join(instantiated_names)) + else: + ret = self.original.name + return ret + + def __repr__(self): + return "Instantiated {}".format( + super(InstantiatedGlobalFunction, self).__repr__()) diff --git a/wrap/gtwrap/template_instantiator/helpers.py b/wrap/gtwrap/template_instantiator/helpers.py new file mode 100644 index 000000000..194c6f686 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/helpers.py @@ -0,0 +1,300 @@ +"""Various helpers for instantiation.""" + +import itertools +from copy import deepcopy +from typing import List, Sequence, Union + +import gtwrap.interface_parser as parser + +ClassMembers = Union[parser.Constructor, parser.Method, parser.StaticMethod, + parser.GlobalFunction, parser.Operator, parser.Variable, + parser.Enum] +InstantiatedMembers = Union['InstantiatedConstructor', 'InstantiatedMethod', + 'InstantiatedStaticMethod', + 'InstantiatedGlobalFunction'] + + +def is_scoped_template(template_typenames: Sequence[str], + str_arg_typename: str): + """ + Check if the template given by `str_arg_typename` is a scoped template e.g. T::Value, + and if so, return what template from `template_typenames` and + the corresponding index matches the scoped template correctly. + """ + for idx, template in enumerate(template_typenames): + if "::" in str_arg_typename and \ + template in str_arg_typename.split("::"): + return template, idx + return False, -1 + + +def instantiate_type( + ctype: parser.Type, + template_typenames: Sequence[str], + instantiations: Sequence[parser.Typename], + cpp_typename: parser.Typename, + instantiated_class: 'InstantiatedClass' = None) -> parser.Type: + """ + Instantiate template typename for `ctype`. + + Args: + ctype: The original argument type. + template_typenames: List of strings representing the templates. + instantiations: List of the instantiations of the templates in `template_typenames`. + cpp_typename: Full-namespace cpp class name of this instantiation + to replace for arguments of type named `This`. + instiated_class: The instantiated class which, if provided, + will be used for instantiating `This`. + + Returns: + If `ctype`'s name is in the `template_typenames`, return the + corresponding type to replace in `instantiations`. + If ctype name is `This`, return the new typename `cpp_typename`. + Otherwise, return the original ctype. + """ + # make a deep copy so that there is no overwriting of original template params + ctype = deepcopy(ctype) + + # Check if the return type has template parameters as the typename's name + if ctype.typename.instantiations: + for idx, instantiation in enumerate(ctype.typename.instantiations): + if instantiation.name in template_typenames: + template_idx = template_typenames.index(instantiation.name) + ctype.typename.instantiations[idx].name =\ + instantiations[template_idx] + + + str_arg_typename = str(ctype.typename) + + # Check if template is a scoped template e.g. T::Value where T is the template + scoped_template, scoped_idx = is_scoped_template(template_typenames, + str_arg_typename) + + # Instantiate templates which have enumerated instantiations in the template. + # E.g. `template`. + + # Instantiate scoped templates, e.g. T::Value. + if scoped_template: + # Create a copy of the instantiation so we can modify it. + instantiation = deepcopy(instantiations[scoped_idx]) + # Replace the part of the template with the instantiation + instantiation.name = str_arg_typename.replace(scoped_template, + instantiation.name) + return parser.Type( + typename=instantiation, + is_const=ctype.is_const, + is_shared_ptr=ctype.is_shared_ptr, + is_ptr=ctype.is_ptr, + is_ref=ctype.is_ref, + is_basic=ctype.is_basic, + ) + # Check for exact template match. + elif str_arg_typename in template_typenames: + idx = template_typenames.index(str_arg_typename) + return parser.Type( + typename=instantiations[idx], + is_const=ctype.is_const, + is_shared_ptr=ctype.is_shared_ptr, + is_ptr=ctype.is_ptr, + is_ref=ctype.is_ref, + is_basic=ctype.is_basic, + ) + + # If a method has the keyword `This`, we replace it with the (instantiated) class. + elif str_arg_typename == 'This': + # Check if the class is template instantiated + # so we can replace it with the instantiated version. + if instantiated_class: + name = instantiated_class.original.name + namespaces_name = instantiated_class.namespaces() + namespaces_name.append(name) + cpp_typename = parser.Typename( + namespaces_name, + instantiations=instantiated_class.instantiations) + + return parser.Type( + typename=cpp_typename, + is_const=ctype.is_const, + is_shared_ptr=ctype.is_shared_ptr, + is_ptr=ctype.is_ptr, + is_ref=ctype.is_ref, + is_basic=ctype.is_basic, + ) + + # Case when 'This' is present in the type namespace, e.g `This::Subclass`. + elif 'This' in str_arg_typename: + # Check if `This` is in the namespaces + if 'This' in ctype.typename.namespaces: + # Simply get the index of `This` in the namespace and + # replace it with the instantiated name. + namespace_idx = ctype.typename.namespaces.index('This') + ctype.typename.namespaces[namespace_idx] = cpp_typename.name + # Else check if it is in the template namespace, e.g vector + else: + for idx, instantiation in enumerate(ctype.typename.instantiations): + if 'This' in instantiation.namespaces: + ctype.typename.instantiations[idx].namespaces = \ + cpp_typename.namespaces + [cpp_typename.name] + return ctype + + else: + return ctype + + +def instantiate_args_list( + args_list: Sequence[parser.Argument], + template_typenames: Sequence[parser.template.Typename], + instantiations: Sequence, cpp_typename: parser.Typename): + """ + Instantiate template typenames in an argument list. + Type with name `This` will be replaced by @p `cpp_typename`. + + @param[in] args_list A list of `parser.Argument` to instantiate. + @param[in] template_typenames List of template typenames to instantiate, + e.g. ['T', 'U', 'V']. + @param[in] instantiations List of specific types to instantiate, each + associated with each template typename. Each type is a parser.Typename, + including its name and full namespaces. + @param[in] cpp_typename Full-namespace cpp class name of this instantiation + to replace for arguments of type named `This`. + @return A new list of parser.Argument which types are replaced with their + instantiations. + """ + instantiated_args = [] + for arg in args_list: + new_type = instantiate_type(arg.ctype, template_typenames, + instantiations, cpp_typename) + instantiated_args.append( + parser.Argument(name=arg.name, ctype=new_type, + default=arg.default)) + return instantiated_args + + +def instantiate_return_type( + return_type: parser.ReturnType, + template_typenames: Sequence[parser.template.Typename], + instantiations: Sequence[parser.Typename], + cpp_typename: parser.Typename, + instantiated_class: 'InstantiatedClass' = None): + """Instantiate the return type.""" + new_type1 = instantiate_type(return_type.type1, + template_typenames, + instantiations, + cpp_typename, + instantiated_class=instantiated_class) + if return_type.type2: + new_type2 = instantiate_type(return_type.type2, + template_typenames, + instantiations, + cpp_typename, + instantiated_class=instantiated_class) + else: + new_type2 = '' + return parser.ReturnType(new_type1, new_type2) + + +def instantiate_name(original_name: str, + instantiations: Sequence[parser.Typename]): + """ + Concatenate instantiated types with `original_name` to form a new + instantiated name. + + NOTE: To avoid conflicts, we should include the instantiation's + namespaces, but that is too verbose. + """ + instantiated_names = [] + for inst in instantiations: + # Ensure the first character of the type is capitalized + name = inst.instantiated_name() + # Using `capitalize` on the complete name causes other caps to be lower case + instantiated_names.append(name.replace(name[0], name[0].capitalize())) + + return "{}{}".format(original_name, "".join(instantiated_names)) + + +class InstantiationHelper: + """ + Helper class for instantiation templates. + Requires that `instantiation_type` defines a class method called + `construct` to generate the appropriate object type. + + Signature for `construct` should be + ``` + construct(method, + typenames, + class_instantiations, + method_instantiations, + instantiated_args, + parent=parent) + ``` + """ + def __init__(self, instantiation_type: InstantiatedMembers): + self.instantiation_type = instantiation_type + + def instantiate(self, instantiated_methods: List[InstantiatedMembers], + method: ClassMembers, typenames: Sequence[str], + class_instantiations: Sequence[parser.Typename], + method_instantiations: Sequence[parser.Typename], + parent: 'InstantiatedClass'): + """ + Instantiate both the class and method level templates. + """ + instantiations = class_instantiations + method_instantiations + + instantiated_args = instantiate_args_list(method.args.list(), + typenames, instantiations, + parent.cpp_typename()) + + instantiated_methods.append( + self.instantiation_type.construct(method, + typenames, + class_instantiations, + method_instantiations, + instantiated_args, + parent=parent)) + + return instantiated_methods + + def multilevel_instantiation(self, methods_list: Sequence[ClassMembers], + typenames: Sequence[str], + parent: 'InstantiatedClass'): + """ + Helper to instantiate methods at both the class and method level. + + Args: + methods_list: The list of methods in the class to instantiated. + typenames: List of class level template parameters, e.g. ['T']. + parent: The instantiated class to which `methods_list` belongs. + """ + instantiated_methods = [] + + for method in methods_list: + # We creare a copy since we will modify the typenames list. + method_typenames = deepcopy(typenames) + + if isinstance(method.template, parser.template.Template): + method_typenames.extend(method.template.typenames) + + # Get all combinations of template args + for instantiations in itertools.product( + *method.template.instantiations): + + instantiated_methods = self.instantiate( + instantiated_methods, + method, + typenames=method_typenames, + class_instantiations=parent.instantiations, + method_instantiations=list(instantiations), + parent=parent) + + else: + # If no constructor level templates, just use the class templates + instantiated_methods = self.instantiate( + instantiated_methods, + method, + typenames=method_typenames, + class_instantiations=parent.instantiations, + method_instantiations=[], + parent=parent) + + return instantiated_methods diff --git a/wrap/gtwrap/template_instantiator/method.py b/wrap/gtwrap/template_instantiator/method.py new file mode 100644 index 000000000..cd0a09c90 --- /dev/null +++ b/wrap/gtwrap/template_instantiator/method.py @@ -0,0 +1,124 @@ +"""Class method and static method instantiators.""" + +from typing import Iterable + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.helpers import (instantiate_name, + instantiate_return_type) + + +class InstantiatedMethod(parser.Method): + """ + Instantiate method with template parameters. + + E.g. + class A { + template + void func(X x, Y y); + } + """ + def __init__(self, + original: parser.Method, + instantiations: Iterable[parser.Typename] = ()): + self.original = original + self.instantiations = instantiations + self.template = original.template + self.is_const = original.is_const + self.parent = original.parent + + self.name = instantiate_name(original.name, self.instantiations) + self.return_type = original.return_type + self.args = original.args + + super().__init__(self.template, + self.name, + self.return_type, + self.args, + self.is_const, + parent=self.parent) + + @classmethod + def construct(cls, original, typenames, class_instantiations, + method_instantiations, instantiated_args, parent): + """Class method to construct object as required by InstantiationHelper.""" + method = parser.Method( + template=original.template, + name=original.name, + return_type=instantiate_return_type( + original.return_type, typenames, + class_instantiations + method_instantiations, + parent.cpp_typename()), + args=parser.ArgumentList(instantiated_args), + is_const=original.is_const, + parent=parent, + ) + return InstantiatedMethod(method, instantiations=method_instantiations) + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + if self.original.template: + # to_cpp will handle all the namespacing and templating + instantiation_list = [x.to_cpp() for x in self.instantiations] + # now can simply combine the instantiations, separated by commas + ret = "{}<{}>".format(self.original.name, + ",".join(instantiation_list)) + else: + ret = self.original.name + return ret + + def __repr__(self): + return "Instantiated {}".format(super().__repr__()) + + +class InstantiatedStaticMethod(parser.StaticMethod): + """ + Instantiate static method with template parameters. + """ + def __init__(self, + original: parser.StaticMethod, + instantiations: Iterable[parser.Typename] = ()): + self.original = original + self.instantiations = instantiations + + self.name = instantiate_name(original.name, self.instantiations) + self.return_type = original.return_type + self.args = original.args + self.template = original.template + self.parent = original.parent + + super().__init__(self.name, self.return_type, self.args, self.template, + self.parent) + + @classmethod + def construct(cls, original, typenames, class_instantiations, + method_instantiations, instantiated_args, parent): + """Class method to construct object as required by InstantiationHelper.""" + method = parser.StaticMethod( + name=original.name, + return_type=instantiate_return_type(original.return_type, + typenames, + class_instantiations + + method_instantiations, + parent.cpp_typename(), + instantiated_class=parent), + args=parser.ArgumentList(instantiated_args), + template=original.template, + parent=parent, + ) + return InstantiatedStaticMethod(method, + instantiations=method_instantiations) + + def to_cpp(self): + """Generate the C++ code for wrapping.""" + if self.original.template: + # to_cpp will handle all the namespacing and templating + instantiation_list = [x.to_cpp() for x in self.instantiations] + # now can simply combine the instantiations, separated by commas + ret = "{}<{}>".format(self.original.name, + ",".join(instantiation_list)) + else: + ret = self.original.name + return ret + + def __repr__(self): + return "Instantiated {}".format(super().__repr__()) diff --git a/wrap/gtwrap/template_instantiator/namespace.py b/wrap/gtwrap/template_instantiator/namespace.py new file mode 100644 index 000000000..32ba0b95d --- /dev/null +++ b/wrap/gtwrap/template_instantiator/namespace.py @@ -0,0 +1,88 @@ +"""Instantiate a namespace.""" + +import itertools + +import gtwrap.interface_parser as parser +from gtwrap.template_instantiator.classes import InstantiatedClass +from gtwrap.template_instantiator.declaration import InstantiatedDeclaration +from gtwrap.template_instantiator.function import InstantiatedGlobalFunction + + +def instantiate_namespace(namespace): + """ + Instantiate the classes and other elements in the `namespace` content and + assign it back to the namespace content attribute. + + @param[in/out] namespace The namespace whose content will be replaced with + the instantiated content. + """ + instantiated_content = [] + typedef_content = [] + + for element in namespace.content: + if isinstance(element, parser.Class): + original_class = element + if not original_class.template: + instantiated_content.append( + InstantiatedClass(original_class, [])) + else: + # This case is for when the templates have enumerated instantiations. + + # Use itertools to get all possible combinations of instantiations + # Works even if one template does not have an instantiation list + for instantiations in itertools.product( + *original_class.template.instantiations): + instantiated_content.append( + InstantiatedClass(original_class, + list(instantiations))) + + elif isinstance(element, parser.GlobalFunction): + original_func = element + if not original_func.template: + instantiated_content.append( + InstantiatedGlobalFunction(original_func, [])) + else: + # Use itertools to get all possible combinations of instantiations + # Works even if one template does not have an instantiation list + for instantiations in itertools.product( + *original_func.template.instantiations): + instantiated_content.append( + InstantiatedGlobalFunction(original_func, + list(instantiations))) + + elif isinstance(element, parser.TypedefTemplateInstantiation): + # This is for the case where `typedef` statements are used + # to specify the template parameters. + typedef_inst = element + top_level = namespace.top_level() + original_element = top_level.find_class_or_function( + typedef_inst.typename) + + # Check if element is a typedef'd class, function or + # forward declaration from another project. + if isinstance(original_element, parser.Class): + typedef_content.append( + InstantiatedClass(original_element, + typedef_inst.typename.instantiations, + typedef_inst.new_name)) + elif isinstance(original_element, parser.GlobalFunction): + typedef_content.append( + InstantiatedGlobalFunction( + original_element, typedef_inst.typename.instantiations, + typedef_inst.new_name)) + elif isinstance(original_element, parser.ForwardDeclaration): + typedef_content.append( + InstantiatedDeclaration( + original_element, typedef_inst.typename.instantiations, + typedef_inst.new_name)) + + elif isinstance(element, parser.Namespace): + element = instantiate_namespace(element) + instantiated_content.append(element) + else: + instantiated_content.append(element) + + instantiated_content.extend(typedef_content) + namespace.content = instantiated_content + + return namespace diff --git a/wrap/matlab.h b/wrap/matlab.h index bcdef3c8d..645ba8edf 100644 --- a/wrap/matlab.h +++ b/wrap/matlab.h @@ -37,15 +37,16 @@ extern "C" { #include } -#include +#include #include +#include #include -#include -#include -#include #include +#include #include +#include +#include using namespace std; using namespace boost; // not usual, but for conciseness of generated code @@ -477,6 +478,14 @@ boost::shared_ptr unwrap_shared_ptr(const mxArray* obj, const string& pro return *spp; } +template +Class* unwrap_ptr(const mxArray* obj, const string& propertyName) { + + mxArray* mxh = mxGetProperty(obj,0, propertyName.c_str()); + Class* x = reinterpret_cast (mxGetData(mxh)); + return x; +} + //// throw an error if unwrap_shared_ptr is attempted for an Eigen Vector //template <> //Vector unwrap_shared_ptr(const mxArray* obj, const string& propertyName) { diff --git a/wrap/pybind11/.appveyor.yml b/wrap/pybind11/.appveyor.yml index 149a8a3dc..85445d41a 100644 --- a/wrap/pybind11/.appveyor.yml +++ b/wrap/pybind11/.appveyor.yml @@ -19,7 +19,7 @@ install: if ($env:PLATFORM -eq "x64") { $env:PYTHON = "$env:PYTHON-x64" } $env:PATH = "C:\Python$env:PYTHON\;C:\Python$env:PYTHON\Scripts\;$env:PATH" python -W ignore -m pip install --upgrade pip wheel - python -W ignore -m pip install pytest numpy --no-warn-script-location + python -W ignore -m pip install pytest numpy --no-warn-script-location pytest-timeout - ps: | Start-FileDownload 'https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.zip' 7z x eigen-3.3.7.zip -y > $null diff --git a/wrap/pybind11/.clang-format b/wrap/pybind11/.clang-format new file mode 100644 index 000000000..8e0fd8b01 --- /dev/null +++ b/wrap/pybind11/.clang-format @@ -0,0 +1,19 @@ +--- +# See all possible options and defaults with: +# clang-format --style=llvm --dump-config +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBinaryOperators: All +BreakConstructorInitializers: BeforeColon +ColumnLimit: 99 +IndentCaseLabels: true +IndentPPDirectives: AfterHash +IndentWidth: 4 +Language: Cpp +SpaceAfterCStyleCast: true +Standard: Cpp11 +TabWidth: 4 +... diff --git a/wrap/pybind11/.clang-tidy b/wrap/pybind11/.clang-tidy index e29d92989..d853a703c 100644 --- a/wrap/pybind11/.clang-tidy +++ b/wrap/pybind11/.clang-tidy @@ -1,13 +1,66 @@ FormatStyle: file Checks: ' +*bugprone*, +cppcoreguidelines-init-variables, +cppcoreguidelines-slicing, +clang-analyzer-optin.cplusplus.VirtualCall, +google-explicit-constructor, llvm-namespace-comment, -modernize-use-override, -readability-container-size-empty, -modernize-use-using, -modernize-use-equals-default, +misc-misplaced-const, +misc-non-copyable-objects, +misc-static-assert, +misc-throw-by-value-catch-by-reference, +misc-uniqueptr-reset-release, +misc-unused-parameters, +modernize-avoid-bind, +modernize-make-shared, +modernize-redundant-void-arg, +modernize-replace-auto-ptr, +modernize-replace-disallow-copy-and-assign-macro, +modernize-replace-random-shuffle, +modernize-shrink-to-fit, modernize-use-auto, +modernize-use-bool-literals, +modernize-use-equals-default, +modernize-use-equals-delete, +modernize-use-default-member-init, +modernize-use-noexcept, modernize-use-emplace, +modernize-use-override, +modernize-use-using, +*performance*, +readability-avoid-const-params-in-decls, +readability-const-return-type, +readability-container-size-empty, +readability-delete-null-pointer, +readability-else-after-return, +readability-implicit-bool-conversion, +readability-make-member-function-const, +readability-misplaced-array-index, +readability-non-const-parameter, +readability-redundant-function-ptr-dereference, +readability-redundant-smartptr-get, +readability-redundant-string-cstr, +readability-simplify-subscript-expr, +readability-static-accessed-through-instance, +readability-static-definition-in-anonymous-namespace, +readability-string-compare, +readability-suspicious-call-argument, +readability-uniqueptr-delete-release, +-bugprone-exception-escape, +-bugprone-reserved-identifier, +-bugprone-unused-raii, ' +CheckOptions: +- key: performance-for-range-copy.WarnOnAllAutoCopies + value: true +- key: performance-unnecessary-value-param.AllowedTypes + value: 'exception_ptr$;' +- key: readability-implicit-bool-conversion.AllowPointerConditions + value: true + HeaderFilterRegex: 'pybind11/.*h' + +WarningsAsErrors: '*' diff --git a/wrap/pybind11/.github/CODEOWNERS b/wrap/pybind11/.github/CODEOWNERS new file mode 100644 index 000000000..4e2c66902 --- /dev/null +++ b/wrap/pybind11/.github/CODEOWNERS @@ -0,0 +1,9 @@ +*.cmake @henryiii +CMakeLists.txt @henryiii +*.yml @henryiii +*.yaml @henryiii +/tools/ @henryiii +/pybind11/ @henryiii +noxfile.py @henryiii +.clang-format @henryiii +.clang-tidy @henryiii diff --git a/wrap/pybind11/.github/CONTRIBUTING.md b/wrap/pybind11/.github/CONTRIBUTING.md index 4ced21baa..e8294c83c 100644 --- a/wrap/pybind11/.github/CONTRIBUTING.md +++ b/wrap/pybind11/.github/CONTRIBUTING.md @@ -53,6 +53,33 @@ derivative works thereof, in binary and source code form. ## Development of pybind11 +### Quick setup + +To setup a quick development environment, use [`nox`](https://nox.thea.codes). +This will allow you to do some common tasks with minimal setup effort, but will +take more time to run and be less flexible than a full development environment. +If you use [`pipx run nox`](https://pipx.pypa.io), you don't even need to +install `nox`. Examples: + +```bash +# List all available sessions +nox -l + +# Run linters +nox -s lint + +# Run tests on Python 3.9 +nox -s tests-3.9 + +# Build and preview docs +nox -s docs -- serve + +# Build SDists and wheels +nox -s build +``` + +### Full setup + To setup an ideal development environment, run the following commands on a system with CMake 3.14+: @@ -93,7 +120,7 @@ The valid options are: * `-DPYBIND11_NOPYTHON=ON`: Disable all Python searching (disables tests) * `-DBUILD_TESTING=ON`: Enable the tests * `-DDOWNLOAD_CATCH=ON`: Download catch to build the C++ tests -* `-DOWNLOAD_EIGEN=ON`: Download Eigen for the NumPy tests +* `-DDOWNLOAD_EIGEN=ON`: Download Eigen for the NumPy tests * `-DPYBIND11_INSTALL=ON/OFF`: Enable the install target (on by default for the master project) * `-DUSE_PYTHON_INSTALL_DIR=ON`: Try to install into the python dir @@ -126,13 +153,26 @@ cmake --build build --target check `--target` can be spelled `-t` in CMake 3.15+. You can also run individual tests with these targets: -* `pytest`: Python tests only +* `pytest`: Python tests only, using the +[pytest](https://docs.pytest.org/en/stable/) framework * `cpptest`: C++ tests only * `test_cmake_build`: Install / subdirectory tests If you want to build just a subset of tests, use -`-DPYBIND11_TEST_OVERRIDE="test_callbacks.cpp;test_pickling.cpp"`. If this is -empty, all tests will be built. +`-DPYBIND11_TEST_OVERRIDE="test_callbacks;test_pickling"`. If this is +empty, all tests will be built. Tests are specified without an extension if they need both a .py and +.cpp file. + +You may also pass flags to the `pytest` target by editing `tests/pytest.ini` or +by using the `PYTEST_ADDOPTS` environment variable +(see [`pytest` docs](https://docs.pytest.org/en/2.7.3/customize.html#adding-default-options)). As an example: + +```bash +env PYTEST_ADDOPTS="--capture=no --exitfirst" \ + cmake --build build --target pytest +# Or using abbreviated flags +env PYTEST_ADDOPTS="-s -x" cmake --build build --target pytest +``` ### Formatting @@ -164,16 +204,42 @@ name, pre-commit): pre-commit install ``` -### Clang-Tidy +### Clang-Format -To run Clang tidy, the following recipe should work. Files will be modified in -place, so you can use git to monitor the changes. +As of v2.6.2, pybind11 ships with a [`clang-format`][clang-format] +configuration file at the top level of the repo (the filename is +`.clang-format`). Currently, formatting is NOT applied automatically, but +manually using `clang-format` for newly developed files is highly encouraged. +To check if a file needs formatting: ```bash -docker run --rm -v $PWD:/pybind11 -it silkeh/clang:10 -apt-get update && apt-get install python3-dev python3-pytest -cmake -S pybind11/ -B build -DCMAKE_CXX_CLANG_TIDY="$(which clang-tidy);-fix" -cmake --build build +clang-format -style=file --dry-run some.cpp +``` + +The output will show things to be fixed, if any. To actually format the file: + +```bash +clang-format -style=file -i some.cpp +``` + +Note that the `-style-file` option searches the parent directories for the +`.clang-format` file, i.e. the commands above can be run in any subdirectory +of the pybind11 repo. + +### Clang-Tidy + +[`clang-tidy`][clang-tidy] performs deeper static code analyses and is +more complex to run, compared to `clang-format`, but support for `clang-tidy` +is built into the pybind11 CMake configuration. To run `clang-tidy`, the +following recipe should work. Run the `docker` command from the top-level +directory inside your pybind11 git clone. Files will be modified in place, +so you can use git to monitor the changes. + +```bash +docker run --rm -v $PWD:/mounted_pybind11 -it silkeh/clang:12 +apt-get update && apt-get install -y python3-dev python3-pytest +cmake -S /mounted_pybind11/ -B build -DCMAKE_CXX_CLANG_TIDY="$(which clang-tidy);-fix" -DDOWNLOAD_EIGEN=ON -DDOWNLOAD_CATCH=ON -DCMAKE_CXX_STANDARD=17 +cmake --build build -j 2 -- --keep-going ``` ### Include what you use @@ -186,7 +252,7 @@ cmake -S . -B build-iwyu -DCMAKE_CXX_INCLUDE_WHAT_YOU_USE=$(which include-what-y cmake --build build ``` -The report is sent to stderr; you can pip it into a file if you wish. +The report is sent to stderr; you can pipe it into a file if you wish. ### Build recipes @@ -313,6 +379,8 @@ if you really want to. [pre-commit]: https://pre-commit.com +[clang-format]: https://clang.llvm.org/docs/ClangFormat.html +[clang-tidy]: https://clang.llvm.org/extra/clang-tidy/ [pybind11.readthedocs.org]: http://pybind11.readthedocs.org/en/latest [issue tracker]: https://github.com/pybind/pybind11/issues [gitter]: https://gitter.im/pybind/Lobby diff --git a/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.md b/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.md deleted file mode 100644 index ae36ea650..000000000 --- a/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.md +++ /dev/null @@ -1,28 +0,0 @@ ---- -name: Bug Report -about: File an issue about a bug -title: "[BUG] " ---- - - -Make sure you've completed the following steps before submitting your issue -- thank you! - -1. Make sure you've read the [documentation][]. Your issue may be addressed there. -2. Search the [issue tracker][] to verify that this hasn't already been reported. +1 or comment there if it has. -3. Consider asking first in the [Gitter chat room][]. -4. Include a self-contained and minimal piece of code that reproduces the problem. If that's not possible, try to make the description as clear as possible. - a. If possible, make a PR with a new, failing test to give us a starting point to work on! - -[documentation]: https://pybind11.readthedocs.io -[issue tracker]: https://github.com/pybind/pybind11/issues -[Gitter chat room]: https://gitter.im/pybind/Lobby - -*After reading, remove this checklist and the template text in parentheses below.* - -## Issue description - -(Provide a short description, state the expected behavior and what actually happens.) - -## Reproducible example code - -(The code should be minimal, have no external dependencies, isolate the function(s) that cause breakage. Submit matched and complete C++ and Python snippets that can be easily compiled and run to diagnose the issue.) diff --git a/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.yml b/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 000000000..bd6a9a8e2 --- /dev/null +++ b/wrap/pybind11/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,45 @@ +name: Bug Report +description: File an issue about a bug +title: "[BUG]: " +labels: [triage] +body: + - type: markdown + attributes: + value: | + Maintainers will only make a best effort to triage PRs. Please do your best to make the issue as easy to act on as possible, and only open if clearly a problem with pybind11 (ask first if unsure). + - type: checkboxes + id: steps + attributes: + label: Required prerequisites + description: Make sure you've completed the following steps before submitting your issue -- thank you! + options: + - label: Make sure you've read the [documentation](https://pybind11.readthedocs.io). Your issue may be addressed there. + required: true + - label: Search the [issue tracker](https://github.com/pybind/pybind11/issues) and [Discussions](https:/pybind/pybind11/discussions) to verify that this hasn't already been reported. +1 or comment there if it has. + required: true + - label: Consider asking first in the [Gitter chat room](https://gitter.im/pybind/Lobby) or in a [Discussion](https:/pybind/pybind11/discussions/new). + required: false + + - type: textarea + id: description + attributes: + label: Problem description + placeholder: >- + Provide a short description, state the expected behavior and what + actually happens. Include relevant information like what version of + pybind11 you are using, what system you are on, and any useful commands + / output. + validations: + required: true + + - type: textarea + id: code + attributes: + label: Reproducible example code + placeholder: >- + The code should be minimal, have no external dependencies, isolate the + function(s) that cause breakage. Submit matched and complete C++ and + Python snippets that can be easily compiled and run to diagnose the + issue. If possible, make a PR with a new, failing test to give us a + starting point to work on! + render: text diff --git a/wrap/pybind11/.github/ISSUE_TEMPLATE/config.yml b/wrap/pybind11/.github/ISSUE_TEMPLATE/config.yml index 20e743136..27f9a8044 100644 --- a/wrap/pybind11/.github/ISSUE_TEMPLATE/config.yml +++ b/wrap/pybind11/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,8 @@ blank_issues_enabled: false contact_links: + - name: Ask a question + url: https://github.com/pybind/pybind11/discussions/new + about: Please ask and answer questions here, or propose new ideas. - name: Gitter room url: https://gitter.im/pybind/Lobby about: A room for discussing pybind11 with an active community diff --git a/wrap/pybind11/.github/ISSUE_TEMPLATE/feature-request.md b/wrap/pybind11/.github/ISSUE_TEMPLATE/feature-request.md deleted file mode 100644 index 5f6ec81ec..000000000 --- a/wrap/pybind11/.github/ISSUE_TEMPLATE/feature-request.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -name: Feature Request -about: File an issue about adding a feature -title: "[FEAT] " ---- - - -Make sure you've completed the following steps before submitting your issue -- thank you! - -1. Check if your feature has already been mentioned / rejected / planned in other issues. -2. If those resources didn't help, consider asking in the [Gitter chat room][] to see if this is interesting / useful to a larger audience and possible to implement reasonably, -4. If you have a useful feature that passes the previous items (or not suitable for chat), please fill in the details below. - -[Gitter chat room]: https://gitter.im/pybind/Lobby - -*After reading, remove this checklist.* diff --git a/wrap/pybind11/.github/ISSUE_TEMPLATE/question.md b/wrap/pybind11/.github/ISSUE_TEMPLATE/question.md deleted file mode 100644 index b199b6ee8..000000000 --- a/wrap/pybind11/.github/ISSUE_TEMPLATE/question.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -name: Question -about: File an issue about unexplained behavior -title: "[QUESTION] " ---- - -If you have a question, please check the following first: - -1. Check if your question has already been answered in the [FAQ][] section. -2. Make sure you've read the [documentation][]. Your issue may be addressed there. -3. If those resources didn't help and you only have a short question (not a bug report), consider asking in the [Gitter chat room][] -4. Search the [issue tracker][], including the closed issues, to see if your question has already been asked/answered. +1 or comment if it has been asked but has no answer. -5. If you have a more complex question which is not answered in the previous items (or not suitable for chat), please fill in the details below. -6. Include a self-contained and minimal piece of code that illustrates your question. If that's not possible, try to make the description as clear as possible. - -[FAQ]: http://pybind11.readthedocs.io/en/latest/faq.html -[documentation]: https://pybind11.readthedocs.io -[issue tracker]: https://github.com/pybind/pybind11/issues -[Gitter chat room]: https://gitter.im/pybind/Lobby - -*After reading, remove this checklist.* diff --git a/wrap/pybind11/.github/dependabot.yml b/wrap/pybind11/.github/dependabot.yml new file mode 100644 index 000000000..73273365c --- /dev/null +++ b/wrap/pybind11/.github/dependabot.yml @@ -0,0 +1,16 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + ignore: + # Official actions have moving tags like v1 + # that are used, so they don't need updates here + - dependency-name: "actions/checkout" + - dependency-name: "actions/setup-python" + - dependency-name: "actions/cache" + - dependency-name: "actions/upload-artifact" + - dependency-name: "actions/download-artifact" + - dependency-name: "actions/labeler" diff --git a/wrap/pybind11/.github/labeler.yml b/wrap/pybind11/.github/labeler.yml new file mode 100644 index 000000000..abb0d05aa --- /dev/null +++ b/wrap/pybind11/.github/labeler.yml @@ -0,0 +1,8 @@ +docs: +- any: + - 'docs/**/*.rst' + - '!docs/changelog.rst' + - '!docs/upgrade.rst' + +ci: +- '.github/workflows/*.yml' diff --git a/wrap/pybind11/.github/labeler_merged.yml b/wrap/pybind11/.github/labeler_merged.yml new file mode 100644 index 000000000..2374ad42e --- /dev/null +++ b/wrap/pybind11/.github/labeler_merged.yml @@ -0,0 +1,3 @@ +needs changelog: +- all: + - '!docs/changelog.rst' diff --git a/wrap/pybind11/.github/pull_request_template.md b/wrap/pybind11/.github/pull_request_template.md new file mode 100644 index 000000000..54b7f5100 --- /dev/null +++ b/wrap/pybind11/.github/pull_request_template.md @@ -0,0 +1,19 @@ + +## Description + + + + +## Suggested changelog entry: + + + +```rst + +``` + + diff --git a/wrap/pybind11/.github/workflows/ci.yml b/wrap/pybind11/.github/workflows/ci.yml index 1749d07f0..050c525ce 100644 --- a/wrap/pybind11/.github/workflows/ci.yml +++ b/wrap/pybind11/.github/workflows/ci.yml @@ -9,6 +9,13 @@ on: - stable - v* +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true + +env: + PIP_ONLY_BINARY: numpy + jobs: # This is the "main" test suite, which tests a large number of different # versions of default compilers and Python versions in GitHub Actions. @@ -16,71 +23,42 @@ jobs: strategy: fail-fast: false matrix: - runs-on: [ubuntu-latest, windows-latest, macos-latest] - arch: [x64] + runs-on: [ubuntu-latest, windows-2022, macos-latest] python: - - 2.7 - - 3.5 - - 3.8 - - pypy2 - - pypy3 + - '2.7' + - '3.5' + - '3.6' + - '3.9' + - '3.10' + - 'pypy-3.7-v7.3.7' + - 'pypy-3.8-v7.3.7' # Items in here will either be added to the build matrix (if not # present), or add new keys to an existing matrix element if all the # existing keys match. # - # We support three optional keys: args (both build), args1 (first - # build), and args2 (second build). + # We support an optional key: args, for cmake args include: + # Just add a key - runs-on: ubuntu-latest - python: 3.6 - arch: x64 + python: '3.6' args: > -DPYBIND11_FINDPYTHON=ON - - runs-on: windows-2016 - python: 3.7 - arch: x86 - args2: > - -DCMAKE_CXX_FLAGS="/permissive- /EHsc /GR" + -DCMAKE_CXX_FLAGS="-D_=1" - runs-on: windows-latest - python: 3.6 - arch: x64 + python: '3.6' args: > -DPYBIND11_FINDPYTHON=ON - - runs-on: windows-latest - python: 3.7 - arch: x64 - - - runs-on: ubuntu-latest - python: 3.9-dev - arch: x64 - runs-on: macos-latest - python: 3.9-dev - arch: x64 - args: > - -DPYBIND11_FINDPYTHON=ON + python: 'pypy-2.7' + # Inject a couple Windows 2019 runs + - runs-on: windows-2019 + python: '3.9' + - runs-on: windows-2019 + python: '2.7' - # These items will be removed from the build matrix, keys must match. - exclude: - # Currently 32bit only, and we build 64bit - - runs-on: windows-latest - python: pypy2 - arch: x64 - - runs-on: windows-latest - python: pypy3 - arch: x64 - - # Currently broken on embed_test - - runs-on: windows-latest - python: 3.8 - arch: x64 - - runs-on: windows-latest - python: 3.9-dev - arch: x64 - - name: "🐍 ${{ matrix.python }} • ${{ matrix.runs-on }} • ${{ matrix.arch }} ${{ matrix.args }}" + name: "🐍 ${{ matrix.python }} • ${{ matrix.runs-on }} • x64 ${{ matrix.args }}" runs-on: ${{ matrix.runs-on }} - continue-on-error: ${{ endsWith(matrix.python, 'dev') }} steps: - uses: actions/checkout@v2 @@ -89,13 +67,18 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python }} - architecture: ${{ matrix.arch }} - - name: Setup Boost (Windows / Linux latest) - run: echo "::set-env name=BOOST_ROOT::$BOOST_ROOT_1_72_0" + - name: Setup Boost (Linux) + # Can't use boost + define _ + if: runner.os == 'Linux' && matrix.python != '3.6' + run: sudo apt-get install libboost-dev + + - name: Setup Boost (macOS) + if: runner.os == 'macOS' + run: brew install boost - name: Update CMake - uses: jwlawson/actions-setup-cmake@v1.3 + uses: jwlawson/actions-setup-cmake@v1.12 - name: Cache wheels if: runner.os == 'macOS' @@ -106,10 +89,11 @@ jobs: # for ways to do this more generally path: ~/Library/Caches/pip # Look to see if there is a cache hit for the corresponding requirements file - key: ${{ runner.os }}-pip-${{ matrix.python }}-${{ matrix.arch }}-${{ hashFiles('tests/requirements.txt') }} + key: ${{ runner.os }}-pip-${{ matrix.python }}-x64-${{ hashFiles('tests/requirements.txt') }} - name: Prepare env - run: python -m pip install -r tests/requirements.txt --prefer-binary + run: | + python -m pip install -r tests/requirements.txt - name: Setup annotations on Linux if: runner.os == 'Linux' @@ -132,6 +116,8 @@ jobs: run: cmake --build . --target pytest -j 2 - name: C++11 tests + # TODO: Figure out how to load the DLL on Python 3.8+ + if: "!(runner.os == 'Windows' && (matrix.python == 3.8 || matrix.python == 3.9 || matrix.python == '3.10' || matrix.python == '3.11-dev' || matrix.python == 'pypy-3.8'))" run: cmake --build . --target cpptest -j 2 - name: Interface test C++11 @@ -141,7 +127,7 @@ jobs: run: git clean -fdx # Second build - C++17 mode and in a build directory - - name: Configure ${{ matrix.args2 }} + - name: Configure C++17 run: > cmake -S . -B build2 -DPYBIND11_WERROR=ON @@ -149,7 +135,6 @@ jobs: -DDOWNLOAD_EIGEN=ON -DCMAKE_CXX_STANDARD=17 ${{ matrix.args }} - ${{ matrix.args2 }} - name: Build run: cmake --build build2 -j 2 @@ -158,8 +143,28 @@ jobs: run: cmake --build build2 --target pytest - name: C++ tests + # TODO: Figure out how to load the DLL on Python 3.8+ + if: "!(runner.os == 'Windows' && (matrix.python == 3.8 || matrix.python == 3.9 || matrix.python == '3.10' || matrix.python == '3.11-dev' || matrix.python == 'pypy-3.8'))" run: cmake --build build2 --target cpptest + # Third build - C++17 mode with unstable ABI + - name: Configure (unstable ABI) + run: > + cmake -S . -B build3 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=17 + -DPYBIND11_INTERNALS_VERSION=10000000 + "-DPYBIND11_TEST_OVERRIDE=test_call_policies.cpp;test_gil_scoped.cpp;test_thread.cpp" + ${{ matrix.args }} + + - name: Build (unstable ABI) + run: cmake --build build3 -j 2 + + - name: Python tests (unstable ABI) + run: cmake --build build3 --target pytest + - name: Interface test run: cmake --build build2 --target test_cmake_build @@ -167,21 +172,105 @@ jobs: # MSVC, but for now, this action works: - name: Prepare compiler environment for Windows 🐍 2.7 if: matrix.python == 2.7 && runner.os == 'Windows' - uses: ilammy/msvc-dev-cmd@v1 + uses: ilammy/msvc-dev-cmd@v1.10.0 with: arch: x64 # This makes two environment variables available in the following step(s) - name: Set Windows 🐍 2.7 environment variables if: matrix.python == 2.7 && runner.os == 'Windows' + shell: bash run: | - echo "::set-env name=DISTUTILS_USE_SDK::1" - echo "::set-env name=MSSdk::1" + echo "DISTUTILS_USE_SDK=1" >> $GITHUB_ENV + echo "MSSdk=1" >> $GITHUB_ENV # This makes sure the setup_helpers module can build packages using # setuptools - name: Setuptools helpers test run: pytest tests/extra_setuptools + if: "!(matrix.python == '3.5' && matrix.runs-on == 'windows-2022')" + + + deadsnakes: + strategy: + fail-fast: false + matrix: + include: + # TODO: Fails on 3.10, investigate + - python-version: "3.9" + python-debug: true + valgrind: true + # - python-version: "3.11-dev" + # python-debug: false + + name: "🐍 ${{ matrix.python-version }}${{ matrix.python-debug && '-dbg' || '' }} (deadsnakes)${{ matrix.valgrind && ' • Valgrind' || '' }} • x64" + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Setup Python ${{ matrix.python-version }} (deadsnakes) + uses: deadsnakes/action@v2.1.1 + with: + python-version: ${{ matrix.python-version }} + debug: ${{ matrix.python-debug }} + + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 + + - name: Valgrind cache + if: matrix.valgrind + uses: actions/cache@v2 + id: cache-valgrind + with: + path: valgrind + key: 3.16.1 # Valgrind version + + - name: Compile Valgrind + if: matrix.valgrind && steps.cache-valgrind.outputs.cache-hit != 'true' + run: | + VALGRIND_VERSION=3.16.1 + curl https://sourceware.org/pub/valgrind/valgrind-$VALGRIND_VERSION.tar.bz2 -o - | tar xj + mv valgrind-$VALGRIND_VERSION valgrind + cd valgrind + ./configure + make -j 2 > /dev/null + + - name: Install Valgrind + if: matrix.valgrind + working-directory: valgrind + run: | + sudo make install + sudo apt-get update + sudo apt-get install libc6-dbg # Needed by Valgrind + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + - name: Configure + env: + SETUPTOOLS_USE_DISTUTILS: stdlib + run: > + cmake -S . -B build + -DCMAKE_BUILD_TYPE=Debug + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=17 + + - name: Build + run: cmake --build build -j 2 + + - name: Python tests + run: cmake --build build --target pytest + + - name: C++ tests + run: cmake --build build --target cpptest + + - name: Run Valgrind on Python tests + if: matrix.valgrind + run: cmake --build build --target memcheck # Testing on clang using the excellent silkeh clang docker images @@ -194,12 +283,20 @@ jobs: - 3.6 - 3.7 - 3.9 - - 5 - 7 - 9 - dev + std: + - 11 + include: + - clang: 5 + std: 14 + - clang: 10 + std: 20 + - clang: 10 + std: 17 - name: "🐍 3 • Clang ${{ matrix.clang }} • x64" + name: "🐍 3 • Clang ${{ matrix.clang }} • C++${{ matrix.std }} • x64" container: "silkeh/clang:${{ matrix.clang }}" steps: @@ -214,6 +311,7 @@ jobs: cmake -S . -B build -DPYBIND11_WERROR=ON -DDOWNLOAD_CATCH=ON + -DCMAKE_CXX_STANDARD=${{ matrix.std }} -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") - name: Build @@ -252,50 +350,54 @@ jobs: run: cmake --build build --target pytest - # Testing CentOS 8 + PGI compilers - centos-nvhpc8: - runs-on: ubuntu-latest - name: "🐍 3 • CentOS8 / PGI 20.7 • x64" - container: centos:8 - - steps: - - uses: actions/checkout@v2 - - - name: Add Python 3 and a few requirements - run: yum update -y && yum install -y git python3-devel python3-numpy python3-pytest make environment-modules - - - name: Install CMake with pip - run: | - python3 -m pip install --upgrade pip - python3 -m pip install cmake --prefer-binary - - - name: Install NVidia HPC SDK - run: yum -y install https://developer.download.nvidia.com/hpc-sdk/nvhpc-20-7-20.7-1.x86_64.rpm https://developer.download.nvidia.com/hpc-sdk/nvhpc-2020-20.7-1.x86_64.rpm - - - name: Configure - shell: bash - run: | - source /etc/profile.d/modules.sh - module load /opt/nvidia/hpc_sdk/modulefiles/nvhpc/20.7 - cmake -S . -B build -DDOWNLOAD_CATCH=ON -DCMAKE_CXX_STANDARD=14 -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") - - - name: Build - run: cmake --build build -j 2 --verbose - - - name: Python tests - run: cmake --build build --target pytest - - - name: C++ tests - run: cmake --build build --target cpptest - - - name: Interface test - run: cmake --build build --target test_cmake_build +# TODO: Internal compiler error - report to NVidia +# # Testing CentOS 8 + PGI compilers +# centos-nvhpc8: +# runs-on: ubuntu-latest +# name: "🐍 3 • CentOS8 / PGI 20.11 • x64" +# container: centos:8 +# +# steps: +# - uses: actions/checkout@v2 +# +# - name: Add Python 3 and a few requirements +# run: yum update -y && yum install -y git python3-devel python3-numpy python3-pytest make environment-modules +# +# - name: Install CMake with pip +# run: | +# python3 -m pip install --upgrade pip +# python3 -m pip install cmake --prefer-binary +# +# - name: Install NVidia HPC SDK +# run: > +# yum -y install +# https://developer.download.nvidia.com/hpc-sdk/20.11/nvhpc-20-11-20.11-1.x86_64.rpm +# https://developer.download.nvidia.com/hpc-sdk/20.11/nvhpc-2020-20.11-1.x86_64.rpm +# +# - name: Configure +# shell: bash +# run: | +# source /etc/profile.d/modules.sh +# module load /opt/nvidia/hpc_sdk/modulefiles/nvhpc/20.11 +# cmake -S . -B build -DDOWNLOAD_CATCH=ON -DCMAKE_CXX_STANDARD=14 -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") +# +# - name: Build +# run: cmake --build build -j 2 --verbose +# +# - name: Python tests +# run: cmake --build build --target pytest +# +# - name: C++ tests +# run: cmake --build build --target cpptest +# +# - name: Interface test +# run: cmake --build build --target test_cmake_build # Testing on CentOS 7 + PGI compilers, which seems to require more workarounds centos-nvhpc7: runs-on: ubuntu-latest - name: "🐍 3 • CentOS7 / PGI 20.7 • x64" + name: "🐍 3 • CentOS7 / PGI 20.9 • x64" container: centos:7 steps: @@ -305,17 +407,17 @@ jobs: run: yum update -y && yum install -y epel-release && yum install -y git python3-devel make environment-modules cmake3 - name: Install NVidia HPC SDK - run: yum -y install https://developer.download.nvidia.com/hpc-sdk/nvhpc-20-7-20.7-1.x86_64.rpm https://developer.download.nvidia.com/hpc-sdk/nvhpc-2020-20.7-1.x86_64.rpm + run: yum -y install https://developer.download.nvidia.com/hpc-sdk/20.9/nvhpc-20-9-20.9-1.x86_64.rpm https://developer.download.nvidia.com/hpc-sdk/20.9/nvhpc-2020-20.9-1.x86_64.rpm # On CentOS 7, we have to filter a few tests (compiler internal error) - # and allow deeper templete recursion (not needed on CentOS 8 with a newer + # and allow deeper template recursion (not needed on CentOS 8 with a newer # standard library). On some systems, you many need further workarounds: # https://github.com/pybind/pybind11/pull/2475 - name: Configure shell: bash run: | source /etc/profile.d/modules.sh - module load /opt/nvidia/hpc_sdk/modulefiles/nvhpc/20.7 + module load /opt/nvidia/hpc_sdk/modulefiles/nvhpc/20.9 cmake3 -S . -B build -DDOWNLOAD_CATCH=ON \ -DCMAKE_CXX_STANDARD=11 \ -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") \ @@ -340,6 +442,7 @@ jobs: - name: Interface test run: cmake3 --build build --target test_cmake_build + # Testing on GCC using the GCC docker images (only recent images supported) gcc: runs-on: ubuntu-latest @@ -349,8 +452,13 @@ jobs: gcc: - 7 - latest + std: + - 11 + include: + - gcc: 10 + std: 20 - name: "🐍 3 • GCC ${{ matrix.gcc }} • x64" + name: "🐍 3 • GCC ${{ matrix.gcc }} • C++${{ matrix.std }}• x64" container: "gcc:${{ matrix.gcc }}" steps: @@ -362,10 +470,8 @@ jobs: - name: Update pip run: python3 -m pip install --upgrade pip - - name: Setup CMake 3.18 - uses: jwlawson/actions-setup-cmake@v1.3 - with: - cmake-version: 3.18 + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 - name: Configure shell: bash @@ -373,7 +479,7 @@ jobs: cmake -S . -B build -DPYBIND11_WERROR=ON -DDOWNLOAD_CATCH=ON - -DCMAKE_CXX_STANDARD=11 + -DCMAKE_CXX_STANDARD=${{ matrix.std }} -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") - name: Build @@ -389,6 +495,103 @@ jobs: run: cmake --build build --target test_cmake_build + # Testing on ICC using the oneAPI apt repo + icc: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + + name: "🐍 3 • ICC latest • x64" + + steps: + - uses: actions/checkout@v2 + + - name: Add apt repo + run: | + sudo apt-get update + sudo apt-get install -y wget build-essential pkg-config cmake ca-certificates gnupg + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2023.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS-2023.PUB + echo "deb https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list + + - name: Add ICC & Python 3 + run: sudo apt-get update; sudo apt-get install -y intel-oneapi-compiler-dpcpp-cpp-and-cpp-classic cmake python3-dev python3-numpy python3-pytest python3-pip + + - name: Update pip + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + python3 -m pip install --upgrade pip + + - name: Install dependencies + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + python3 -m pip install -r tests/requirements.txt + + - name: Configure C++11 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake -S . -B build-11 \ + -DPYBIND11_WERROR=ON \ + -DDOWNLOAD_CATCH=ON \ + -DDOWNLOAD_EIGEN=OFF \ + -DCMAKE_CXX_STANDARD=11 \ + -DCMAKE_CXX_COMPILER=$(which icpc) \ + -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") + + - name: Build C++11 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-11 -j 2 -v + + - name: Python tests C++11 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + sudo service apport stop + cmake --build build-11 --target check + + - name: C++ tests C++11 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-11 --target cpptest + + - name: Interface test C++11 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-11 --target test_cmake_build + + - name: Configure C++17 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake -S . -B build-17 \ + -DPYBIND11_WERROR=ON \ + -DDOWNLOAD_CATCH=ON \ + -DDOWNLOAD_EIGEN=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DCMAKE_CXX_COMPILER=$(which icpc) \ + -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") + + - name: Build C++17 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-17 -j 2 -v + + - name: Python tests C++17 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + sudo service apport stop + cmake --build build-17 --target check + + - name: C++ tests C++17 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-17 --target cpptest + + - name: Interface test C++17 + run: | + set +e; source /opt/intel/oneapi/setvars.sh; set -e + cmake --build build-17 --target test_cmake_build + + # Testing on CentOS (manylinux uses a centos base, and this is an easy way # to get GCC 4.8, which is the manylinux1 compiler). centos: @@ -397,11 +600,11 @@ jobs: fail-fast: false matrix: centos: - - 7 # GCC 4.8 - - 8 + - centos7 # GCC 4.8 + - stream8 name: "🐍 3 • CentOS ${{ matrix.centos }} • x64" - container: "centos:${{ matrix.centos }}" + container: "quay.io/centos/centos:${{ matrix.centos }}" steps: - uses: actions/checkout@v2 @@ -413,12 +616,14 @@ jobs: run: python3 -m pip install --upgrade pip - name: Install dependencies - run: python3 -m pip install cmake -r tests/requirements.txt --prefer-binary + run: | + python3 -m pip install cmake -r tests/requirements.txt - name: Configure shell: bash run: > cmake -S . -B build + -DCMAKE_BUILD_TYPE=MinSizeRel -DPYBIND11_WERROR=ON -DDOWNLOAD_CATCH=ON -DDOWNLOAD_EIGEN=ON @@ -476,7 +681,7 @@ jobs: -DPYTHON_EXECUTABLE=$(python3 -c "import sys; print(sys.executable)") working-directory: /build-tests - - name: Run tests + - name: Python tests run: make pytest -j 2 working-directory: /build-tests @@ -493,16 +698,13 @@ jobs: - uses: actions/setup-python@v2 - name: Install Doxygen - run: sudo apt install -y doxygen - - - name: Install docs & setup requirements - run: python3 -m pip install -r docs/requirements.txt + run: sudo apt-get install -y doxygen librsvg2-bin # Changed to rsvg-convert in 20.04 - name: Build docs - run: python3 -m sphinx -W -b html docs docs/.build + run: pipx run nox -s docs - name: Make SDist - run: python3 setup.py sdist + run: pipx run nox -s build -- --sdist - run: git status --ignored @@ -514,6 +716,250 @@ jobs: - name: Compare Dists (headers only) working-directory: include run: | - python3 -m pip install --user -U ../dist/* + python3 -m pip install --user -U ../dist/*.tar.gz installed=$(python3 -c "import pybind11; print(pybind11.get_include() + '/pybind11')") diff -rq $installed ./pybind11 + + win32: + strategy: + fail-fast: false + matrix: + python: + - 3.5 + - 3.6 + - 3.7 + - 3.8 + - 3.9 + - pypy-3.6 + + include: + - python: 3.9 + args: -DCMAKE_CXX_STANDARD=20 -DDOWNLOAD_EIGEN=OFF + - python: 3.8 + args: -DCMAKE_CXX_STANDARD=17 + + name: "🐍 ${{ matrix.python }} • MSVC 2019 • x86 ${{ matrix.args }}" + runs-on: windows-latest + + steps: + - uses: actions/checkout@v2 + + - name: Setup Python ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x86 + + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 + + - name: Prepare MSVC + uses: ilammy/msvc-dev-cmd@v1.10.0 + with: + arch: x86 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + # First build - C++11 mode and inplace + - name: Configure ${{ matrix.args }} + run: > + cmake -S . -B build + -G "Visual Studio 16 2019" -A Win32 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + ${{ matrix.args }} + - name: Build C++11 + run: cmake --build build -j 2 + + - name: Python tests + run: cmake --build build -t pytest + + win32-msvc2015: + name: "🐍 ${{ matrix.python }} • MSVC 2015 • x64" + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + python: + - 2.7 + - 3.6 + - 3.7 + # todo: check/cpptest does not support 3.8+ yet + + steps: + - uses: actions/checkout@v2 + + - name: Setup 🐍 ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 + + - name: Prepare MSVC + uses: ilammy/msvc-dev-cmd@v1.10.0 + with: + toolset: 14.0 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + # First build - C++11 mode and inplace + - name: Configure + run: > + cmake -S . -B build + -G "Visual Studio 14 2015" -A x64 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + + - name: Build C++14 + run: cmake --build build -j 2 + + - name: Run all checks + run: cmake --build build -t check + + + win32-msvc2017: + name: "🐍 ${{ matrix.python }} • MSVC 2017 • x64" + runs-on: windows-2016 + strategy: + fail-fast: false + matrix: + python: + - 2.7 + - 3.5 + - 3.7 + std: + - 14 + + include: + - python: 2.7 + std: 17 + args: > + -DCMAKE_CXX_FLAGS="/permissive- /EHsc /GR" + - python: 3.7 + std: 17 + args: > + -DCMAKE_CXX_FLAGS="/permissive- /EHsc /GR" + + steps: + - uses: actions/checkout@v2 + + - name: Setup 🐍 ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + # First build - C++11 mode and inplace + - name: Configure + run: > + cmake -S . -B build + -G "Visual Studio 15 2017" -A x64 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=${{ matrix.std }} + ${{ matrix.args }} + + - name: Build ${{ matrix.std }} + run: cmake --build build -j 2 + + - name: Run all checks + run: cmake --build build -t check + + mingw: + name: "🐍 3 • windows-latest • ${{ matrix.sys }}" + runs-on: windows-latest + defaults: + run: + shell: msys2 {0} + strategy: + fail-fast: false + matrix: + include: + - { sys: mingw64, env: x86_64 } + - { sys: mingw32, env: i686 } + steps: + - uses: msys2/setup-msys2@v2 + with: + msystem: ${{matrix.sys}} + install: >- + git + mingw-w64-${{matrix.env}}-gcc + mingw-w64-${{matrix.env}}-python-pip + mingw-w64-${{matrix.env}}-python-numpy + mingw-w64-${{matrix.env}}-python-scipy + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-make + mingw-w64-${{matrix.env}}-python-pytest + mingw-w64-${{matrix.env}}-eigen3 + mingw-w64-${{matrix.env}}-boost + mingw-w64-${{matrix.env}}-catch + + - uses: actions/checkout@v2 + + - name: Configure C++11 + # LTO leads to many undefined reference like + # `pybind11::detail::function_call::function_call(pybind11::detail::function_call&&) + run: cmake -G "MinGW Makefiles" -DCMAKE_CXX_STANDARD=11 -S . -B build + + - name: Build C++11 + run: cmake --build build -j 2 + + - name: Python tests C++11 + run: cmake --build build --target pytest -j 2 + + - name: C++11 tests + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build --target cpptest -j 2 + + - name: Interface test C++11 + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build --target test_cmake_build + + - name: Clean directory + run: git clean -fdx + + - name: Configure C++14 + run: cmake -G "MinGW Makefiles" -DCMAKE_CXX_STANDARD=14 -S . -B build2 + + - name: Build C++14 + run: cmake --build build2 -j 2 + + - name: Python tests C++14 + run: cmake --build build2 --target pytest -j 2 + + - name: C++14 tests + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build2 --target cpptest -j 2 + + - name: Interface test C++14 + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build2 --target test_cmake_build + + - name: Clean directory + run: git clean -fdx + + - name: Configure C++17 + run: cmake -G "MinGW Makefiles" -DCMAKE_CXX_STANDARD=17 -S . -B build3 + + - name: Build C++17 + run: cmake --build build3 -j 2 + + - name: Python tests C++17 + run: cmake --build build3 --target pytest -j 2 + + - name: C++17 tests + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build3 --target cpptest -j 2 + + - name: Interface test C++17 + run: PYTHONHOME=/${{matrix.sys}} PYTHONPATH=/${{matrix.sys}} cmake --build build3 --target test_cmake_build diff --git a/wrap/pybind11/.github/workflows/configure.yml b/wrap/pybind11/.github/workflows/configure.yml index 3dd248e04..66ab0e3d7 100644 --- a/wrap/pybind11/.github/workflows/configure.yml +++ b/wrap/pybind11/.github/workflows/configure.yml @@ -18,7 +18,7 @@ jobs: matrix: runs-on: [ubuntu-latest, macos-latest, windows-latest] arch: [x64] - cmake: [3.18] + cmake: ["3.21"] include: - runs-on: ubuntu-latest @@ -55,7 +55,7 @@ jobs: # An action for adding a specific version of CMake: # https://github.com/jwlawson/actions-setup-cmake - name: Setup CMake ${{ matrix.cmake }} - uses: jwlawson/actions-setup-cmake@v1.3 + uses: jwlawson/actions-setup-cmake@v1.12 with: cmake-version: ${{ matrix.cmake }} @@ -82,57 +82,3 @@ jobs: working-directory: build dir if: github.event_name == 'workflow_dispatch' run: cmake --build . --config Release --target check - - # This builds the sdists and wheels and makes sure the files are exactly as - # expected. Using Windows and Python 2.7, since that is often the most - # challenging matrix element. - test-packaging: - name: 🐍 2.7 • 📦 tests • windows-latest - runs-on: windows-latest - - steps: - - uses: actions/checkout@v2 - - - name: Setup 🐍 2.7 - uses: actions/setup-python@v2 - with: - python-version: 2.7 - - - name: Prepare env - run: python -m pip install -r tests/requirements.txt --prefer-binary - - - name: Python Packaging tests - run: pytest tests/extra_python_package/ - - - # This runs the packaging tests and also builds and saves the packages as - # artifacts. - packaging: - name: 🐍 3.8 • 📦 & 📦 tests • ubuntu-latest - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - - name: Setup 🐍 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Prepare env - run: python -m pip install -r tests/requirements.txt build twine --prefer-binary - - - name: Python Packaging tests - run: pytest tests/extra_python_package/ - - - name: Build SDist and wheels - run: | - python -m build -s -w . - PYBIND11_GLOBAL_SDIST=1 python -m build -s -w . - - - name: Check metadata - run: twine check dist/* - - - uses: actions/upload-artifact@v2 - with: - path: dist/* diff --git a/wrap/pybind11/.github/workflows/format.yml b/wrap/pybind11/.github/workflows/format.yml index 28cfeb9b7..ab7b40503 100644 --- a/wrap/pybind11/.github/workflows/format.yml +++ b/wrap/pybind11/.github/workflows/format.yml @@ -19,15 +19,17 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.0 + - uses: pre-commit/action@v2.0.3 with: # Slow hooks are marked with manual - slow is okay here, run them too - extra_args: --hook-stage manual + extra_args: --hook-stage manual --all-files clang-tidy: + # When making changes here, please also review the "Clang-Tidy" section + # in .github/CONTRIBUTING.md and update as needed. name: Clang-Tidy runs-on: ubuntu-latest - container: silkeh/clang:10 + container: silkeh/clang:12 steps: - uses: actions/checkout@v2 @@ -35,7 +37,12 @@ jobs: run: apt-get update && apt-get install -y python3-dev python3-pytest - name: Configure - run: cmake -S . -B build -DCMAKE_CXX_CLANG_TIDY="$(which clang-tidy);--warnings-as-errors=*" + run: > + cmake -S . -B build + -DCMAKE_CXX_CLANG_TIDY="$(which clang-tidy)" + -DDOWNLOAD_EIGEN=ON + -DDOWNLOAD_CATCH=ON + -DCMAKE_CXX_STANDARD=17 - name: Build - run: cmake --build build -j 2 + run: cmake --build build -j 2 -- --keep-going diff --git a/wrap/pybind11/.github/workflows/labeler.yml b/wrap/pybind11/.github/workflows/labeler.yml new file mode 100644 index 000000000..d2b597968 --- /dev/null +++ b/wrap/pybind11/.github/workflows/labeler.yml @@ -0,0 +1,16 @@ +name: Labeler +on: + pull_request_target: + types: [closed] + +jobs: + label: + name: Labeler + runs-on: ubuntu-latest + steps: + + - uses: actions/labeler@main + if: github.event.pull_request.merged == true + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + configuration-path: .github/labeler_merged.yml diff --git a/wrap/pybind11/.github/workflows/pip.yml b/wrap/pybind11/.github/workflows/pip.yml new file mode 100644 index 000000000..f74b79f0c --- /dev/null +++ b/wrap/pybind11/.github/workflows/pip.yml @@ -0,0 +1,108 @@ +name: Pip + +on: + workflow_dispatch: + pull_request: + push: + branches: + - master + - stable + - v* + release: + types: + - published + +env: + PIP_ONLY_BINARY: numpy + +jobs: + # This builds the sdists and wheels and makes sure the files are exactly as + # expected. Using Windows and Python 2.7, since that is often the most + # challenging matrix element. + test-packaging: + name: 🐍 2.7 • 📦 tests • windows-latest + runs-on: windows-latest + + steps: + - uses: actions/checkout@v2 + + - name: Setup 🐍 2.7 + uses: actions/setup-python@v2 + with: + python-version: 2.7 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + - name: Python Packaging tests + run: pytest tests/extra_python_package/ + + + # This runs the packaging tests and also builds and saves the packages as + # artifacts. + packaging: + name: 🐍 3.8 • 📦 & 📦 tests • ubuntu-latest + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Setup 🐍 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt build twine + + - name: Python Packaging tests + run: pytest tests/extra_python_package/ + + - name: Build SDist and wheels + run: | + python -m build + PYBIND11_GLOBAL_SDIST=1 python -m build + + - name: Check metadata + run: twine check dist/* + + - name: Save standard package + uses: actions/upload-artifact@v2 + with: + name: standard + path: dist/pybind11-* + + - name: Save global package + uses: actions/upload-artifact@v2 + with: + name: global + path: dist/pybind11_global-* + + + + # When a GitHub release is made, upload the artifacts to PyPI + upload: + name: Upload to PyPI + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + needs: [packaging] + + steps: + - uses: actions/setup-python@v2 + + # Downloads all to directories matching the artifact names + - uses: actions/download-artifact@v2 + + - name: Publish standard package + uses: pypa/gh-action-pypi-publish@v1.5.0 + with: + password: ${{ secrets.pypi_password }} + packages_dir: standard/ + + - name: Publish global package + uses: pypa/gh-action-pypi-publish@v1.5.0 + with: + password: ${{ secrets.pypi_password_global }} + packages_dir: global/ diff --git a/wrap/pybind11/.github/workflows/upstream.yml b/wrap/pybind11/.github/workflows/upstream.yml new file mode 100644 index 000000000..138c9ad29 --- /dev/null +++ b/wrap/pybind11/.github/workflows/upstream.yml @@ -0,0 +1,112 @@ + +name: Upstream + +on: + workflow_dispatch: + pull_request: + +concurrency: + group: upstream-${{ github.ref }} + cancel-in-progress: true + +env: + PIP_ONLY_BINARY: numpy + +jobs: + standard: + name: "🐍 3.11 dev • ubuntu-latest • x64" + runs-on: ubuntu-latest + if: "contains(github.event.pull_request.labels.*.name, 'python dev')" + + steps: + - uses: actions/checkout@v2 + + - name: Setup Python 3.11 + uses: actions/setup-python@v2 + with: + python-version: "3.11-dev" + + - name: Setup Boost (Linux) + if: runner.os == 'Linux' + run: sudo apt-get install libboost-dev + + - name: Update CMake + uses: jwlawson/actions-setup-cmake@v1.12 + + - name: Prepare env + run: | + python -m pip install -r tests/requirements.txt + + - name: Setup annotations on Linux + if: runner.os == 'Linux' + run: python -m pip install pytest-github-actions-annotate-failures + + # First build - C++11 mode and inplace + - name: Configure C++11 + run: > + cmake -S . -B . + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=11 + + - name: Build C++11 + run: cmake --build . -j 2 + + - name: Python tests C++11 + run: cmake --build . --target pytest -j 2 + + - name: C++11 tests + run: cmake --build . --target cpptest -j 2 + + - name: Interface test C++11 + run: cmake --build . --target test_cmake_build + + - name: Clean directory + run: git clean -fdx + + # Second build - C++17 mode and in a build directory + - name: Configure C++17 + run: > + cmake -S . -B build2 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=17 + ${{ matrix.args }} + ${{ matrix.args2 }} + + - name: Build + run: cmake --build build2 -j 2 + + - name: Python tests + run: cmake --build build2 --target pytest + + - name: C++ tests + run: cmake --build build2 --target cpptest + + # Third build - C++17 mode with unstable ABI + - name: Configure (unstable ABI) + run: > + cmake -S . -B build3 + -DPYBIND11_WERROR=ON + -DDOWNLOAD_CATCH=ON + -DDOWNLOAD_EIGEN=ON + -DCMAKE_CXX_STANDARD=17 + -DPYBIND11_INTERNALS_VERSION=10000000 + "-DPYBIND11_TEST_OVERRIDE=test_call_policies.cpp;test_gil_scoped.cpp;test_thread.cpp" + ${{ matrix.args }} + + - name: Build (unstable ABI) + run: cmake --build build3 -j 2 + + - name: Python tests (unstable ABI) + run: cmake --build build3 --target pytest + + - name: Interface test + run: cmake --build build2 --target test_cmake_build + + # This makes sure the setup_helpers module can build packages using + # setuptools + - name: Setuptools helpers test + run: pytest tests/extra_setuptools diff --git a/wrap/pybind11/.gitignore b/wrap/pybind11/.gitignore index 3f36b89e0..3cf4fbbda 100644 --- a/wrap/pybind11/.gitignore +++ b/wrap/pybind11/.gitignore @@ -41,3 +41,5 @@ pybind11Targets.cmake /.vscode /pybind11/include/* /pybind11/share/* +/docs/_build/* +.ipynb_checkpoints/ diff --git a/wrap/pybind11/.pre-commit-config.yaml b/wrap/pybind11/.pre-commit-config.yaml index 71513c991..2014cb2b4 100644 --- a/wrap/pybind11/.pre-commit-config.yaml +++ b/wrap/pybind11/.pre-commit-config.yaml @@ -15,12 +15,14 @@ repos: # Standard hooks - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.1.0 hooks: - id: check-added-large-files - id: check-case-conflict + - id: check-docstring-first - id: check-merge-conflict - id: check-symlinks + - id: check-toml - id: check-yaml - id: debug-statements - id: end-of-file-fixer @@ -28,54 +30,115 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - id: fix-encoding-pragma + exclude: ^noxfile.py$ + +- repo: https://github.com/asottile/pyupgrade + rev: v2.31.0 + hooks: + - id: pyupgrade + +- repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort # Black, the code formatter, natively supports pre-commit - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 21.12b0 # Keep in sync with blacken-docs hooks: - id: black - # Not all Python files are Blacked, yet - files: ^(setup.py|pybind11|tests/extra) + +- repo: https://github.com/asottile/blacken-docs + rev: v1.12.0 + hooks: + - id: blacken-docs + additional_dependencies: + - black==21.12b0 # keep in sync with black hook # Changes tabs to spaces - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.1.9 + rev: v1.1.10 hooks: - id: remove-tabs +# Autoremoves unused imports +- repo: https://github.com/hadialqattan/pycln + rev: v1.1.0 + hooks: + - id: pycln + +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-check-blanket-noqa + - id: python-check-blanket-type-ignore + - id: python-no-log-warn + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + # Flake8 also supports pre-commit natively (same author) -- repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.3 +- repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 hooks: - id: flake8 - additional_dependencies: [flake8-bugbear, pep8-naming] + additional_dependencies: &flake8_dependencies + - flake8-bugbear + - pep8-naming exclude: ^(docs/.*|tools/.*)$ +- repo: https://github.com/asottile/yesqa + rev: v1.3.0 + hooks: + - id: yesqa + additional_dependencies: *flake8_dependencies + # CMake formatting - repo: https://github.com/cheshirekow/cmake-format-precommit - rev: v0.6.11 + rev: v0.6.13 hooks: - id: cmake-format additional_dependencies: [pyyaml] types: [file] files: (\.cmake|CMakeLists.txt)(.in)?$ +# Check static types with mypy +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.931 + hooks: + - id: mypy + # Running per-file misbehaves a bit, so just run on all files, it's fast + pass_filenames: false + additional_dependencies: [typed_ast] + # Checks the manifest for missing files (native support) - repo: https://github.com/mgedmin/check-manifest - rev: "0.42" + rev: "0.47" hooks: - id: check-manifest # This is a slow hook, so only run this if --hook-stage manual is passed stages: [manual] additional_dependencies: [cmake, ninja] +- repo: https://github.com/codespell-project/codespell + rev: v2.1.0 + hooks: + - id: codespell + exclude: ".supp$" + args: ["-L", "nd,ot,thist"] + +- repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.8.0.3 + hooks: + - id: shellcheck + # The original pybind11 checks for a few C++ style items - repo: local hooks: - id: disallow-caps name: Disallow improper capitalization language: pygrep - entry: PyBind|Numpy|Cmake + entry: PyBind|Numpy|Cmake|CCache|PyTest exclude: .pre-commit-config.yaml - repo: local diff --git a/wrap/pybind11/CMakeLists.txt b/wrap/pybind11/CMakeLists.txt index 123abf77d..3787982cb 100644 --- a/wrap/pybind11/CMakeLists.txt +++ b/wrap/pybind11/CMakeLists.txt @@ -7,13 +7,18 @@ cmake_minimum_required(VERSION 3.4) -# The `cmake_minimum_required(VERSION 3.4...3.18)` syntax does not work with +# The `cmake_minimum_required(VERSION 3.4...3.22)` syntax does not work with # some versions of VS that have a patched CMake 3.11. This forces us to emulate # the behavior using the following workaround: -if(${CMAKE_VERSION} VERSION_LESS 3.18) +if(${CMAKE_VERSION} VERSION_LESS 3.22) cmake_policy(VERSION ${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}) else() - cmake_policy(VERSION 3.18) + cmake_policy(VERSION 3.22) +endif() + +# Avoid infinite recursion if tests include this as a subdirectory +if(DEFINED PYBIND11_MASTER_PROJECT) + return() endif() # Extract project version from source @@ -73,6 +78,10 @@ if(CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_STANDARD_REQUIRED ON) endif() + + set(pybind11_system "") + + set_property(GLOBAL PROPERTY USE_FOLDERS ON) else() set(PYBIND11_MASTER_PROJECT OFF) set(pybind11_system SYSTEM) @@ -82,6 +91,9 @@ endif() option(PYBIND11_INSTALL "Install pybind11 header files?" ${PYBIND11_MASTER_PROJECT}) option(PYBIND11_TEST "Build pybind11 test suite?" ${PYBIND11_MASTER_PROJECT}) option(PYBIND11_NOPYTHON "Disable search for Python" OFF) +set(PYBIND11_INTERNALS_VERSION + "" + CACHE STRING "Override the ABI version, may be used to enable the unstable ABI.") cmake_dependent_option( USE_PYTHON_INCLUDE_DIR @@ -98,6 +110,7 @@ set(PYBIND11_HEADERS include/pybind11/detail/descr.h include/pybind11/detail/init.h include/pybind11/detail/internals.h + include/pybind11/detail/type_caster_base.h include/pybind11/detail/typeid.h include/pybind11/attr.h include/pybind11/buffer_info.h @@ -109,6 +122,7 @@ set(PYBIND11_HEADERS include/pybind11/eigen.h include/pybind11/embed.h include/pybind11/eval.h + include/pybind11/gil.h include/pybind11/iostream.h include/pybind11/functional.h include/pybind11/numpy.h @@ -116,7 +130,8 @@ set(PYBIND11_HEADERS include/pybind11/pybind11.h include/pybind11/pytypes.h include/pybind11/stl.h - include/pybind11/stl_bind.h) + include/pybind11/stl_bind.h + include/pybind11/stl/filesystem.h) # Compare with grep and warn if mismatched if(PYBIND11_MASTER_PROJECT AND NOT CMAKE_VERSION VERSION_LESS 3.12) @@ -142,22 +157,45 @@ endif() string(REPLACE "include/" "${CMAKE_CURRENT_SOURCE_DIR}/include/" PYBIND11_HEADERS "${PYBIND11_HEADERS}") -# Cache variables so pybind11_add_module can be used in parent projects -set(PYBIND11_INCLUDE_DIR +# Cache variable so this can be used in parent projects +set(pybind11_INCLUDE_DIR "${CMAKE_CURRENT_LIST_DIR}/include" - CACHE INTERNAL "") + CACHE INTERNAL "Directory where pybind11 headers are located") + +# Backward compatible variable for add_subdirectory mode +if(NOT PYBIND11_MASTER_PROJECT) + set(PYBIND11_INCLUDE_DIR + "${pybind11_INCLUDE_DIR}" + CACHE INTERNAL "") +endif() # Note: when creating targets, you cannot use if statements at configure time - # you need generator expressions, because those will be placed in the target file. # You can also place ifs *in* the Config.in, but not here. # This section builds targets, but does *not* touch Python +# Non-IMPORT targets cannot be defined twice +if(NOT TARGET pybind11_headers) + # Build the headers-only target (no Python included): + # (long name used here to keep this from clashing in subdirectory mode) + add_library(pybind11_headers INTERFACE) + add_library(pybind11::pybind11_headers ALIAS pybind11_headers) # to match exported target + add_library(pybind11::headers ALIAS pybind11_headers) # easier to use/remember -# Build the headers-only target (no Python included): -# (long name used here to keep this from clashing in subdirectory mode) -add_library(pybind11_headers INTERFACE) -add_library(pybind11::pybind11_headers ALIAS pybind11_headers) # to match exported target -add_library(pybind11::headers ALIAS pybind11_headers) # easier to use/remember + target_include_directories( + pybind11_headers ${pybind11_system} INTERFACE $ + $) + + target_compile_features(pybind11_headers INTERFACE cxx_inheriting_constructors cxx_user_literals + cxx_right_angle_brackets) + if(NOT "${PYBIND11_INTERNALS_VERSION}" STREQUAL "") + target_compile_definitions( + pybind11_headers INTERFACE "PYBIND11_INTERNALS_VERSION=${PYBIND11_INTERNALS_VERSION}") + endif() +else() + # It is invalid to install a target twice, too. + set(PYBIND11_INSTALL OFF) +endif() include("${CMAKE_CURRENT_SOURCE_DIR}/tools/pybind11Common.cmake") @@ -168,21 +206,18 @@ elseif(USE_PYTHON_INCLUDE_DIR AND DEFINED PYTHON_INCLUDE_DIR) file(RELATIVE_PATH CMAKE_INSTALL_INCLUDEDIR ${CMAKE_INSTALL_PREFIX} ${PYTHON_INCLUDE_DIRS}) endif() -# Fill in headers target -target_include_directories( - pybind11_headers ${pybind11_system} INTERFACE $ - $) - -target_compile_features(pybind11_headers INTERFACE cxx_inheriting_constructors cxx_user_literals - cxx_right_angle_brackets) - if(PYBIND11_INSTALL) - install(DIRECTORY ${PYBIND11_INCLUDE_DIR}/pybind11 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) - # GNUInstallDirs "DATADIR" wrong here; CMake search path wants "share". + install(DIRECTORY ${pybind11_INCLUDE_DIR}/pybind11 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) set(PYBIND11_CMAKECONFIG_INSTALL_DIR - "share/cmake/${PROJECT_NAME}" + "${CMAKE_INSTALL_DATAROOTDIR}/cmake/${PROJECT_NAME}" CACHE STRING "install path for pybind11Config.cmake") + if(IS_ABSOLUTE "${CMAKE_INSTALL_INCLUDEDIR}") + set(pybind11_INCLUDEDIR "${CMAKE_INSTALL_FULL_INCLUDEDIR}") + else() + set(pybind11_INCLUDEDIR "\$\{PACKAGE_PREFIX_DIR\}/${CMAKE_INSTALL_INCLUDEDIR}") + endif() + configure_package_config_file( tools/${PROJECT_NAME}Config.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake" INSTALL_DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) @@ -260,8 +295,5 @@ endif() if(NOT PYBIND11_MASTER_PROJECT) set(pybind11_FOUND TRUE - CACHE INTERNAL "true if pybind11 and all required components found on the system") - set(pybind11_INCLUDE_DIR - "${PYBIND11_INCLUDE_DIR}" - CACHE INTERNAL "Directory where pybind11 headers are located") + CACHE INTERNAL "True if pybind11 and all required components found on the system") endif() diff --git a/wrap/pybind11/MANIFEST.in b/wrap/pybind11/MANIFEST.in index 9336b6030..aed183e87 100644 --- a/wrap/pybind11/MANIFEST.in +++ b/wrap/pybind11/MANIFEST.in @@ -1,4 +1,6 @@ recursive-include pybind11/include/pybind11 *.h recursive-include pybind11 *.py +recursive-include pybind11 py.typed +recursive-include pybind11 *.pyi include pybind11/share/cmake/pybind11/*.cmake -include LICENSE README.md pyproject.toml setup.py setup.cfg +include LICENSE README.rst pyproject.toml setup.py setup.cfg diff --git a/wrap/pybind11/README.md b/wrap/pybind11/README.md deleted file mode 100644 index 69a0fc90b..000000000 --- a/wrap/pybind11/README.md +++ /dev/null @@ -1,145 +0,0 @@ -![pybind11 logo](https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png) - -# pybind11 — Seamless operability between C++11 and Python - -[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=master)](http://pybind11.readthedocs.org/en/master/?badge=master) -[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=stable)](http://pybind11.readthedocs.org/en/stable/?badge=stable) -[![Gitter chat](https://img.shields.io/gitter/room/gitterHQ/gitter.svg)](https://gitter.im/pybind/Lobby) -[![CI](https://github.com/pybind/pybind11/workflows/CI/badge.svg)](https://github.com/pybind/pybind11/actions) -[![Build status](https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true)](https://ci.appveyor.com/project/wjakob/pybind11) - -**pybind11** is a lightweight header-only library that exposes C++ types in -Python and vice versa, mainly to create Python bindings of existing C++ code. -Its goals and syntax are similar to the excellent [Boost.Python][] library by -David Abrahams: to minimize boilerplate code in traditional extension modules -by inferring type information using compile-time introspection. - -The main issue with Boost.Python—and the reason for creating such a similar -project—is Boost. Boost is an enormously large and complex suite of utility -libraries that works with almost every C++ compiler in existence. This -compatibility has its cost: arcane template tricks and workarounds are -necessary to support the oldest and buggiest of compiler specimens. Now that -C++11-compatible compilers are widely available, this heavy machinery has -become an excessively large and unnecessary dependency. - -Think of this library as a tiny self-contained version of Boost.Python with -everything stripped away that isn't relevant for binding generation. Without -comments, the core header files only require ~4K lines of code and depend on -Python (2.7 or 3.5+, or PyPy) and the C++ standard library. This compact -implementation was possible thanks to some of the new C++11 language features -(specifically: tuples, lambda functions and variadic templates). Since its -creation, this library has grown beyond Boost.Python in many ways, leading to -dramatically simpler binding code in many common situations. - -Tutorial and reference documentation is provided at -[pybind11.readthedocs.org][]. A PDF version of the manual is available -[here][docs-pdf]. - -## Core features -pybind11 can map the following core C++ features to Python: - -- Functions accepting and returning custom data structures per value, reference, or pointer -- Instance methods and static methods -- Overloaded functions -- Instance attributes and static attributes -- Arbitrary exception types -- Enumerations -- Callbacks -- Iterators and ranges -- Custom operators -- Single and multiple inheritance -- STL data structures -- Smart pointers with reference counting like `std::shared_ptr` -- Internal references with correct reference counting -- C++ classes with virtual (and pure virtual) methods can be extended in Python - -## Goodies -In addition to the core functionality, pybind11 provides some extra goodies: - -- Python 2.7, 3.5+, and PyPy (tested on 7.3) are supported with an implementation-agnostic - interface. - -- It is possible to bind C++11 lambda functions with captured variables. The - lambda capture data is stored inside the resulting Python function object. - -- pybind11 uses C++11 move constructors and move assignment operators whenever - possible to efficiently transfer custom data types. - -- It's easy to expose the internal storage of custom data types through - Pythons' buffer protocols. This is handy e.g. for fast conversion between - C++ matrix classes like Eigen and NumPy without expensive copy operations. - -- pybind11 can automatically vectorize functions so that they are transparently - applied to all entries of one or more NumPy array arguments. - -- Python's slice-based access and assignment operations can be supported with - just a few lines of code. - -- Everything is contained in just a few header files; there is no need to link - against any additional libraries. - -- Binaries are generally smaller by a factor of at least 2 compared to - equivalent bindings generated by Boost.Python. A recent pybind11 conversion - of PyRosetta, an enormous Boost.Python binding project, - [reported][pyrosetta-report] a binary size reduction of **5.4x** and compile - time reduction by **5.8x**. - -- Function signatures are precomputed at compile time (using `constexpr`), - leading to smaller binaries. - -- With little extra effort, C++ types can be pickled and unpickled similar to - regular Python objects. - -## Supported compilers - -1. Clang/LLVM 3.3 or newer (for Apple Xcode's clang, this is 5.0.0 or newer) -2. GCC 4.8 or newer -3. Microsoft Visual Studio 2015 Update 3 or newer -4. Intel C++ compiler 17 or newer (16 with pybind11 v2.0 and 15 with pybind11 - v2.0 and a [workaround][intel-15-workaround]) -5. Cygwin/GCC (tested on 2.5.1) -6. NVCC (CUDA 11 tested) -7. NVIDIA PGI (20.7 tested) - -## About - -This project was created by [Wenzel Jakob](http://rgl.epfl.ch/people/wjakob). -Significant features and/or improvements to the code were contributed by -Jonas Adler, -Lori A. Burns, -Sylvain Corlay, -Trent Houliston, -Axel Huebl, -@hulucc, -Sergey Lyskov -Johan Mabille, -Tomasz Miąsko, -Dean Moldovan, -Ben Pritchard, -Jason Rhinelander, -Boris Schäling, -Pim Schellart, -Henry Schreiner, -Ivan Smirnov, and -Patrick Stewart. - -### Contributing - -See the [contributing guide][] for information on building and contributing to -pybind11. - - -### License - -pybind11 is provided under a BSD-style license that can be found in the -[`LICENSE`][] file. By using, distributing, or contributing to this project, -you agree to the terms and conditions of this license. - - -[pybind11.readthedocs.org]: http://pybind11.readthedocs.org/en/master -[docs-pdf]: https://media.readthedocs.org/pdf/pybind11/master/pybind11.pdf -[Boost.Python]: http://www.boost.org/doc/libs/1_58_0/libs/python/doc/ -[pyrosetta-report]: http://graylab.jhu.edu/RosettaCon2016/PyRosetta-4.pdf -[contributing guide]: https://github.com/pybind/pybind11/blob/master/.github/CONTRIBUTING.md -[`LICENSE`]: https://github.com/pybind/pybind11/blob/master/LICENSE -[intel-15-workaround]: https://github.com/pybind/pybind11/issues/276 diff --git a/wrap/pybind11/README.rst b/wrap/pybind11/README.rst new file mode 100644 index 000000000..45c4af5a6 --- /dev/null +++ b/wrap/pybind11/README.rst @@ -0,0 +1,180 @@ +.. figure:: https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png + :alt: pybind11 logo + +**pybind11 — Seamless operability between C++11 and Python** + +|Latest Documentation Status| |Stable Documentation Status| |Gitter chat| |GitHub Discussions| |CI| |Build status| + +|Repology| |PyPI package| |Conda-forge| |Python Versions| + +`Setuptools example `_ +• `Scikit-build example `_ +• `CMake example `_ + +.. start + + +**pybind11** is a lightweight header-only library that exposes C++ types +in Python and vice versa, mainly to create Python bindings of existing +C++ code. Its goals and syntax are similar to the excellent +`Boost.Python `_ +library by David Abrahams: to minimize boilerplate code in traditional +extension modules by inferring type information using compile-time +introspection. + +The main issue with Boost.Python—and the reason for creating such a +similar project—is Boost. Boost is an enormously large and complex suite +of utility libraries that works with almost every C++ compiler in +existence. This compatibility has its cost: arcane template tricks and +workarounds are necessary to support the oldest and buggiest of compiler +specimens. Now that C++11-compatible compilers are widely available, +this heavy machinery has become an excessively large and unnecessary +dependency. + +Think of this library as a tiny self-contained version of Boost.Python +with everything stripped away that isn’t relevant for binding +generation. Without comments, the core header files only require ~4K +lines of code and depend on Python (2.7 or 3.5+, or PyPy) and the C++ +standard library. This compact implementation was possible thanks to +some of the new C++11 language features (specifically: tuples, lambda +functions and variadic templates). Since its creation, this library has +grown beyond Boost.Python in many ways, leading to dramatically simpler +binding code in many common situations. + +Tutorial and reference documentation is provided at +`pybind11.readthedocs.io `_. +A PDF version of the manual is available +`here `_. +And the source code is always available at +`github.com/pybind/pybind11 `_. + + +Core features +------------- + + +pybind11 can map the following core C++ features to Python: + +- Functions accepting and returning custom data structures per value, + reference, or pointer +- Instance methods and static methods +- Overloaded functions +- Instance attributes and static attributes +- Arbitrary exception types +- Enumerations +- Callbacks +- Iterators and ranges +- Custom operators +- Single and multiple inheritance +- STL data structures +- Smart pointers with reference counting like ``std::shared_ptr`` +- Internal references with correct reference counting +- C++ classes with virtual (and pure virtual) methods can be extended + in Python + +Goodies +------- + +In addition to the core functionality, pybind11 provides some extra +goodies: + +- Python 2.7, 3.5+, and PyPy/PyPy3 7.3 are supported with an + implementation-agnostic interface. + +- It is possible to bind C++11 lambda functions with captured + variables. The lambda capture data is stored inside the resulting + Python function object. + +- pybind11 uses C++11 move constructors and move assignment operators + whenever possible to efficiently transfer custom data types. + +- It’s easy to expose the internal storage of custom data types through + Pythons’ buffer protocols. This is handy e.g. for fast conversion + between C++ matrix classes like Eigen and NumPy without expensive + copy operations. + +- pybind11 can automatically vectorize functions so that they are + transparently applied to all entries of one or more NumPy array + arguments. + +- Python's slice-based access and assignment operations can be + supported with just a few lines of code. + +- Everything is contained in just a few header files; there is no need + to link against any additional libraries. + +- Binaries are generally smaller by a factor of at least 2 compared to + equivalent bindings generated by Boost.Python. A recent pybind11 + conversion of PyRosetta, an enormous Boost.Python binding project, + `reported `_ + a binary size reduction of **5.4x** and compile time reduction by + **5.8x**. + +- Function signatures are precomputed at compile time (using + ``constexpr``), leading to smaller binaries. + +- With little extra effort, C++ types can be pickled and unpickled + similar to regular Python objects. + +Supported compilers +------------------- + +1. Clang/LLVM 3.3 or newer (for Apple Xcode’s clang, this is 5.0.0 or + newer) +2. GCC 4.8 or newer +3. Microsoft Visual Studio 2015 Update 3 or newer +4. Intel classic C++ compiler 18 or newer (ICC 20.2 tested in CI) +5. Cygwin/GCC (previously tested on 2.5.1) +6. NVCC (CUDA 11.0 tested in CI) +7. NVIDIA PGI (20.9 tested in CI) + +About +----- + +This project was created by `Wenzel +Jakob `_. Significant features and/or +improvements to the code were contributed by Jonas Adler, Lori A. Burns, +Sylvain Corlay, Eric Cousineau, Aaron Gokaslan, Ralf Grosse-Kunstleve, Trent Houliston, Axel +Huebl, @hulucc, Yannick Jadoul, Sergey Lyskov Johan Mabille, Tomasz Miąsko, +Dean Moldovan, Ben Pritchard, Jason Rhinelander, Boris Schäling, Pim +Schellart, Henry Schreiner, Ivan Smirnov, Boris Staletic, and Patrick Stewart. + +We thank Google for a generous financial contribution to the continuous +integration infrastructure used by this project. + + +Contributing +~~~~~~~~~~~~ + +See the `contributing +guide `_ +for information on building and contributing to pybind11. + +License +~~~~~~~ + +pybind11 is provided under a BSD-style license that can be found in the +`LICENSE `_ +file. By using, distributing, or contributing to this project, you agree +to the terms and conditions of this license. + +.. |Latest Documentation Status| image:: https://readthedocs.org/projects/pybind11/badge?version=latest + :target: http://pybind11.readthedocs.org/en/latest +.. |Stable Documentation Status| image:: https://img.shields.io/badge/docs-stable-blue.svg + :target: http://pybind11.readthedocs.org/en/stable +.. |Gitter chat| image:: https://img.shields.io/gitter/room/gitterHQ/gitter.svg + :target: https://gitter.im/pybind/Lobby +.. |CI| image:: https://github.com/pybind/pybind11/workflows/CI/badge.svg + :target: https://github.com/pybind/pybind11/actions +.. |Build status| image:: https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true + :target: https://ci.appveyor.com/project/wjakob/pybind11 +.. |PyPI package| image:: https://img.shields.io/pypi/v/pybind11.svg + :target: https://pypi.org/project/pybind11/ +.. |Conda-forge| image:: https://img.shields.io/conda/vn/conda-forge/pybind11.svg + :target: https://github.com/conda-forge/pybind11-feedstock +.. |Repology| image:: https://repology.org/badge/latest-versions/python:pybind11.svg + :target: https://repology.org/project/python:pybind11/versions +.. |Python Versions| image:: https://img.shields.io/pypi/pyversions/pybind11.svg + :target: https://pypi.org/project/pybind11/ +.. |GitHub Discussions| image:: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github + :target: https://github.com/pybind/pybind11/discussions diff --git a/wrap/pybind11/docs/Doxyfile b/wrap/pybind11/docs/Doxyfile index 24ece0d8d..62c267556 100644 --- a/wrap/pybind11/docs/Doxyfile +++ b/wrap/pybind11/docs/Doxyfile @@ -18,5 +18,5 @@ ALIASES += "endrst=\endverbatim" QUIET = YES WARNINGS = YES WARN_IF_UNDOCUMENTED = NO -PREDEFINED = DOXYGEN_SHOULD_SKIP_THIS \ - PY_MAJOR_VERSION=3 +PREDEFINED = PY_MAJOR_VERSION=3 \ + PYBIND11_NOINLINE diff --git a/wrap/pybind11/docs/advanced/cast/custom.rst b/wrap/pybind11/docs/advanced/cast/custom.rst index a779444c2..1df4d3e14 100644 --- a/wrap/pybind11/docs/advanced/cast/custom.rst +++ b/wrap/pybind11/docs/advanced/cast/custom.rst @@ -26,7 +26,9 @@ The following Python snippet demonstrates the intended usage from the Python sid def __int__(self): return 123 + from example import print + print(A()) To register the necessary conversion routines, it is necessary to add an @@ -44,7 +46,7 @@ type is explicitly allowed. * function signatures and declares a local variable * 'value' of type inty */ - PYBIND11_TYPE_CASTER(inty, _("inty")); + PYBIND11_TYPE_CASTER(inty, const_name("inty")); /** * Conversion part 1 (Python->C++): convert a PyObject into a inty diff --git a/wrap/pybind11/docs/advanced/cast/eigen.rst b/wrap/pybind11/docs/advanced/cast/eigen.rst index e01472d5a..a5c11a3f1 100644 --- a/wrap/pybind11/docs/advanced/cast/eigen.rst +++ b/wrap/pybind11/docs/advanced/cast/eigen.rst @@ -52,7 +52,7 @@ can be mapped *and* if the numpy array is writeable (that is the passed variable will be transparently carried out directly on the ``numpy.ndarray``. -This means you can can write code such as the following and have it work as +This means you can write code such as the following and have it work as expected: .. code-block:: cpp @@ -112,7 +112,7 @@ example: .. code-block:: python a = MyClass() - m = a.get_matrix() # flags.writeable = True, flags.owndata = False + m = a.get_matrix() # flags.writeable = True, flags.owndata = False v = a.view_matrix() # flags.writeable = False, flags.owndata = False c = a.copy_matrix() # flags.writeable = True, flags.owndata = True # m[5,6] and v[5,6] refer to the same element, c[5,6] does not. @@ -203,7 +203,7 @@ adding the ``order='F'`` option when creating an array: .. code-block:: python - myarray = np.array(source, order='F') + myarray = np.array(source, order="F") Such an object will be passable to a bound function accepting an ``Eigen::Ref`` (or similar column-major Eigen type). diff --git a/wrap/pybind11/docs/advanced/cast/overview.rst b/wrap/pybind11/docs/advanced/cast/overview.rst index b0e32a52f..6a834a3e5 100644 --- a/wrap/pybind11/docs/advanced/cast/overview.rst +++ b/wrap/pybind11/docs/advanced/cast/overview.rst @@ -75,91 +75,97 @@ The following basic data types are supported out of the box (some may require an additional extension header to be included). To pass other data structures as arguments and return values, refer to the section on binding :ref:`classes`. -+------------------------------------+---------------------------+-------------------------------+ -| Data type | Description | Header file | -+====================================+===========================+===============================+ -| ``int8_t``, ``uint8_t`` | 8-bit integers | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``int16_t``, ``uint16_t`` | 16-bit integers | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``int32_t``, ``uint32_t`` | 32-bit integers | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``int64_t``, ``uint64_t`` | 64-bit integers | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``ssize_t``, ``size_t`` | Platform-dependent size | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``float``, ``double`` | Floating point types | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``bool`` | Two-state Boolean type | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``char`` | Character literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``char16_t`` | UTF-16 character literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``char32_t`` | UTF-32 character literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``wchar_t`` | Wide character literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``const char *`` | UTF-8 string literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``const char16_t *`` | UTF-16 string literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``const char32_t *`` | UTF-32 string literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``const wchar_t *`` | Wide string literal | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::string`` | STL dynamic UTF-8 string | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::u16string`` | STL dynamic UTF-16 string | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::u32string`` | STL dynamic UTF-32 string | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::wstring`` | STL dynamic wide string | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::string_view``, | STL C++17 string views | :file:`pybind11/pybind11.h` | -| ``std::u16string_view``, etc. | | | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::pair`` | Pair of two custom types | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::tuple<...>`` | Arbitrary tuple of types | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::reference_wrapper<...>`` | Reference type wrapper | :file:`pybind11/pybind11.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::complex`` | Complex numbers | :file:`pybind11/complex.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::array`` | STL static array | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::vector`` | STL dynamic array | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::deque`` | STL double-ended queue | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::valarray`` | STL value array | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::list`` | STL linked list | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::map`` | STL ordered map | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::unordered_map`` | STL unordered map | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::set`` | STL ordered set | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::unordered_set`` | STL unordered set | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::optional`` | STL optional type (C++17) | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::experimental::optional`` | STL optional type (exp.) | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::variant<...>`` | Type-safe union (C++17) | :file:`pybind11/stl.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::function<...>`` | STL polymorphic function | :file:`pybind11/functional.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::chrono::duration<...>`` | STL time duration | :file:`pybind11/chrono.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``std::chrono::time_point<...>`` | STL date/time | :file:`pybind11/chrono.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``Eigen::Matrix<...>`` | Eigen: dense matrix | :file:`pybind11/eigen.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``Eigen::Map<...>`` | Eigen: mapped memory | :file:`pybind11/eigen.h` | -+------------------------------------+---------------------------+-------------------------------+ -| ``Eigen::SparseMatrix<...>`` | Eigen: sparse matrix | :file:`pybind11/eigen.h` | -+------------------------------------+---------------------------+-------------------------------+ ++------------------------------------+---------------------------+-----------------------------------+ +| Data type | Description | Header file | ++====================================+===========================+===================================+ +| ``int8_t``, ``uint8_t`` | 8-bit integers | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``int16_t``, ``uint16_t`` | 16-bit integers | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``int32_t``, ``uint32_t`` | 32-bit integers | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``int64_t``, ``uint64_t`` | 64-bit integers | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``ssize_t``, ``size_t`` | Platform-dependent size | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``float``, ``double`` | Floating point types | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``bool`` | Two-state Boolean type | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``char`` | Character literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``char16_t`` | UTF-16 character literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``char32_t`` | UTF-32 character literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``wchar_t`` | Wide character literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``const char *`` | UTF-8 string literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``const char16_t *`` | UTF-16 string literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``const char32_t *`` | UTF-32 string literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``const wchar_t *`` | Wide string literal | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::string`` | STL dynamic UTF-8 string | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::u16string`` | STL dynamic UTF-16 string | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::u32string`` | STL dynamic UTF-32 string | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::wstring`` | STL dynamic wide string | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::string_view``, | STL C++17 string views | :file:`pybind11/pybind11.h` | +| ``std::u16string_view``, etc. | | | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::pair`` | Pair of two custom types | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::tuple<...>`` | Arbitrary tuple of types | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::reference_wrapper<...>`` | Reference type wrapper | :file:`pybind11/pybind11.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::complex`` | Complex numbers | :file:`pybind11/complex.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::array`` | STL static array | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::vector`` | STL dynamic array | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::deque`` | STL double-ended queue | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::valarray`` | STL value array | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::list`` | STL linked list | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::map`` | STL ordered map | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::unordered_map`` | STL unordered map | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::set`` | STL ordered set | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::unordered_set`` | STL unordered set | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::optional`` | STL optional type (C++17) | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::experimental::optional`` | STL optional type (exp.) | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::variant<...>`` | Type-safe union (C++17) | :file:`pybind11/stl.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::filesystem::path`` | STL path (C++17) [#]_ | :file:`pybind11/stl/filesystem.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::function<...>`` | STL polymorphic function | :file:`pybind11/functional.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::chrono::duration<...>`` | STL time duration | :file:`pybind11/chrono.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``std::chrono::time_point<...>`` | STL date/time | :file:`pybind11/chrono.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``Eigen::Matrix<...>`` | Eigen: dense matrix | :file:`pybind11/eigen.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``Eigen::Map<...>`` | Eigen: mapped memory | :file:`pybind11/eigen.h` | ++------------------------------------+---------------------------+-----------------------------------+ +| ``Eigen::SparseMatrix<...>`` | Eigen: sparse matrix | :file:`pybind11/eigen.h` | ++------------------------------------+---------------------------+-----------------------------------+ + +.. [#] ``std::filesystem::path`` is converted to ``pathlib.Path`` and + ``os.PathLike`` is converted to ``std::filesystem::path``, but this requires + Python 3.6 (for ``__fspath__`` support). diff --git a/wrap/pybind11/docs/advanced/cast/stl.rst b/wrap/pybind11/docs/advanced/cast/stl.rst index 7f708b81e..b8622ee09 100644 --- a/wrap/pybind11/docs/advanced/cast/stl.rst +++ b/wrap/pybind11/docs/advanced/cast/stl.rst @@ -5,7 +5,7 @@ Automatic conversion ==================== When including the additional header file :file:`pybind11/stl.h`, conversions -between ``std::vector<>``/``std::deque<>``/``std::list<>``/``std::array<>``, +between ``std::vector<>``/``std::deque<>``/``std::list<>``/``std::array<>``/``std::valarray<>``, ``std::set<>``/``std::unordered_set<>``, and ``std::map<>``/``std::unordered_map<>`` and the Python ``list``, ``set`` and ``dict`` data structures are automatically enabled. The types ``std::pair<>`` @@ -72,6 +72,17 @@ The ``visit_helper`` specialization is not required if your ``name::variant`` pr a ``name::visit()`` function. For any other function name, the specialization must be included to tell pybind11 how to visit the variant. +.. warning:: + + When converting a ``variant`` type, pybind11 follows the same rules as when + determining which function overload to call (:ref:`overload_resolution`), and + so the same caveats hold. In particular, the order in which the ``variant``'s + alternatives are listed is important, since pybind11 will try conversions in + this order. This means that, for example, when converting ``variant``, + the ``bool`` variant will never be selected, as any Python ``bool`` is already + an ``int`` and is convertible to a C++ ``int``. Changing the order of alternatives + (and using ``variant``, in this example) provides a solution. + .. note:: pybind11 only supports the modern implementation of ``boost::variant`` diff --git a/wrap/pybind11/docs/advanced/cast/strings.rst b/wrap/pybind11/docs/advanced/cast/strings.rst index e25701eca..cfd7e7b7a 100644 --- a/wrap/pybind11/docs/advanced/cast/strings.rst +++ b/wrap/pybind11/docs/advanced/cast/strings.rst @@ -36,13 +36,13 @@ everywhere `_. } ); -.. code-block:: python +.. code-block:: pycon - >>> utf8_test('🎂') + >>> utf8_test("🎂") utf-8 is icing on the cake. 🎂 - >>> utf8_charptr('🍕') + >>> utf8_charptr("🍕") My favorite food is 🍕 @@ -80,7 +80,7 @@ raise a ``UnicodeDecodeError``. } ); -.. code-block:: python +.. code-block:: pycon >>> isinstance(example.std_string_return(), str) True @@ -114,7 +114,7 @@ conversion has the same overhead as implicit conversion. } ); -.. code-block:: python +.. code-block:: pycon >>> str_output() 'Send your résumé to Alice in HR' @@ -143,7 +143,7 @@ returned to Python as ``bytes``, then one can return the data as a } ); -.. code-block:: python +.. code-block:: pycon >>> example.return_bytes() b'\xba\xd0\xba\xd0' @@ -160,7 +160,7 @@ encoding, but cannot convert ``std::string`` back to ``bytes`` implicitly. } ); -.. code-block:: python +.. code-block:: pycon >>> isinstance(example.asymmetry(b"have some bytes"), str) True @@ -229,16 +229,16 @@ character. m.def("pass_char", [](char c) { return c; }); m.def("pass_wchar", [](wchar_t w) { return w; }); -.. code-block:: python +.. code-block:: pycon - >>> example.pass_char('A') + >>> example.pass_char("A") 'A' While C++ will cast integers to character types (``char c = 0x65;``), pybind11 does not convert Python integers to characters implicitly. The Python function ``chr()`` can be used to convert integers to characters. -.. code-block:: python +.. code-block:: pycon >>> example.pass_char(0x65) TypeError @@ -259,17 +259,17 @@ a combining acute accent). The combining character will be lost if the two-character sequence is passed as an argument, even though it renders as a single grapheme. -.. code-block:: python +.. code-block:: pycon - >>> example.pass_wchar('é') + >>> example.pass_wchar("é") 'é' - >>> combining_e_acute = 'e' + '\u0301' + >>> combining_e_acute = "e" + "\u0301" >>> combining_e_acute 'é' - >>> combining_e_acute == 'é' + >>> combining_e_acute == "é" False >>> example.pass_wchar(combining_e_acute) @@ -278,9 +278,9 @@ single grapheme. Normalizing combining characters before passing the character literal to C++ may resolve *some* of these issues: -.. code-block:: python +.. code-block:: pycon - >>> example.pass_wchar(unicodedata.normalize('NFC', combining_e_acute)) + >>> example.pass_wchar(unicodedata.normalize("NFC", combining_e_acute)) 'é' In some languages (Thai for example), there are `graphemes that cannot be diff --git a/wrap/pybind11/docs/advanced/classes.rst b/wrap/pybind11/docs/advanced/classes.rst index 492790206..f3339336d 100644 --- a/wrap/pybind11/docs/advanced/classes.rst +++ b/wrap/pybind11/docs/advanced/classes.rst @@ -9,7 +9,7 @@ that you are already familiar with the basics from :doc:`/classes`. Overriding virtual functions in Python ====================================== -Suppose that a C++ class or interface has a virtual function that we'd like to +Suppose that a C++ class or interface has a virtual function that we'd like to override from within Python (we'll focus on the class ``Animal``; ``Dog`` is given as a specific example of how one would do this with traditional C++ code). @@ -136,7 +136,7 @@ a virtual method call. u'woof! woof! woof! ' >>> class Cat(Animal): ... def go(self, n_times): - ... return "meow! " * n_times + ... return "meow! " * n_times ... >>> c = Cat() >>> call_go(c) @@ -159,8 +159,9 @@ Here is an example: class Dachshund(Dog): def __init__(self, name): - Dog.__init__(self) # Without this, a TypeError is raised. + Dog.__init__(self) # Without this, a TypeError is raised. self.name = name + def bark(self): return "yap!" @@ -259,7 +260,7 @@ override the ``name()`` method): .. note:: - Note the trailing commas in the ``PYBIND11_OVERIDE`` calls to ``name()`` + Note the trailing commas in the ``PYBIND11_OVERRIDE`` calls to ``name()`` and ``bark()``. These are needed to portably implement a trampoline for a function that does not take any arguments. For functions that take a nonzero number of arguments, the trailing comma must be omitted. @@ -804,7 +805,7 @@ to bind these two functions: } )); -The ``__setstate__`` part of the ``py::picke()`` definition follows the same +The ``__setstate__`` part of the ``py::pickle()`` definition follows the same rules as the single-argument version of ``py::init()``. The return type can be a value, pointer or holder type. See :ref:`custom_constructors` for details. @@ -1153,12 +1154,65 @@ error: >>> class PyFinalChild(IsFinal): ... pass + ... TypeError: type 'IsFinal' is not an acceptable base type .. note:: This attribute is currently ignored on PyPy .. versionadded:: 2.6 +Binding classes with template parameters +======================================== + +pybind11 can also wrap classes that have template parameters. Consider these classes: + +.. code-block:: cpp + + struct Cat {}; + struct Dog {}; + + template + struct Cage { + Cage(PetType& pet); + PetType& get(); + }; + +C++ templates may only be instantiated at compile time, so pybind11 can only +wrap instantiated templated classes. You cannot wrap a non-instantiated template: + +.. code-block:: cpp + + // BROKEN (this will not compile) + py::class_(m, "Cage"); + .def("get", &Cage::get); + +You must explicitly specify each template/type combination that you want to +wrap separately. + +.. code-block:: cpp + + // ok + py::class_>(m, "CatCage") + .def("get", &Cage::get); + + // ok + py::class_>(m, "DogCage") + .def("get", &Cage::get); + +If your class methods have template parameters you can wrap those as well, +but once again each instantiation must be explicitly specified: + +.. code-block:: cpp + + typename + struct MyClass { + template + T fn(V v); + }; + + py::class>(m, "MyClassT") + .def("fn", &MyClass::fn); + Custom automatic downcasters ============================ @@ -1247,7 +1301,7 @@ Accessing the type object You can get the type object from a C++ class that has already been registered using: -.. code-block:: python +.. code-block:: cpp py::type T_py = py::type::of(); @@ -1259,3 +1313,37 @@ object, just like ``type(ob)`` in Python. Other types, like ``py::type::of()``, do not work, see :ref:`type-conversions`. .. versionadded:: 2.6 + +Custom type setup +================= + +For advanced use cases, such as enabling garbage collection support, you may +wish to directly manipulate the ``PyHeapTypeObject`` corresponding to a +``py::class_`` definition. + +You can do that using ``py::custom_type_setup``: + +.. code-block:: cpp + + struct OwnsPythonObjects { + py::object value = py::none(); + }; + py::class_ cls( + m, "OwnsPythonObjects", py::custom_type_setup([](PyHeapTypeObject *heap_type) { + auto *type = &heap_type->ht_type; + type->tp_flags |= Py_TPFLAGS_HAVE_GC; + type->tp_traverse = [](PyObject *self_base, visitproc visit, void *arg) { + auto &self = py::cast(py::handle(self_base)); + Py_VISIT(self.value.ptr()); + return 0; + }; + type->tp_clear = [](PyObject *self_base) { + auto &self = py::cast(py::handle(self_base)); + self.value = py::none(); + return 0; + }; + })); + cls.def(py::init<>()); + cls.def_readwrite("value", &OwnsPythonObjects::value); + +.. versionadded:: 2.8 diff --git a/wrap/pybind11/docs/advanced/embedding.rst b/wrap/pybind11/docs/advanced/embedding.rst index 98a5c5219..dd980d483 100644 --- a/wrap/pybind11/docs/advanced/embedding.rst +++ b/wrap/pybind11/docs/advanced/embedding.rst @@ -40,15 +40,15 @@ The essential structure of the ``main.cpp`` file looks like this: } The interpreter must be initialized before using any Python API, which includes -all the functions and classes in pybind11. The RAII guard class `scoped_interpreter` +all the functions and classes in pybind11. The RAII guard class ``scoped_interpreter`` takes care of the interpreter lifetime. After the guard is destroyed, the interpreter shuts down and clears its memory. No Python functions can be called after this. Executing Python code ===================== -There are a few different ways to run Python code. One option is to use `eval`, -`exec` or `eval_file`, as explained in :ref:`eval`. Here is a quick example in +There are a few different ways to run Python code. One option is to use ``eval``, +``exec`` or ``eval_file``, as explained in :ref:`eval`. Here is a quick example in the context of an executable with an embedded interpreter: .. code-block:: cpp @@ -108,11 +108,11 @@ The two approaches can also be combined: Importing modules ================= -Python modules can be imported using `module::import()`: +Python modules can be imported using ``module_::import()``: .. code-block:: cpp - py::module sys = py::module::import("sys"); + py::module_ sys = py::module_::import("sys"); py::print(sys.attr("path")); For convenience, the current working directory is included in ``sys.path`` when @@ -122,18 +122,19 @@ embedding the interpreter. This makes it easy to import local Python files: """calc.py located in the working directory""" + def add(i, j): return i + j .. code-block:: cpp - py::module calc = py::module::import("calc"); + py::module_ calc = py::module_::import("calc"); py::object result = calc.attr("add")(1, 2); int n = result.cast(); assert(n == 3); -Modules can be reloaded using `module::reload()` if the source is modified e.g. +Modules can be reloaded using ``module_::reload()`` if the source is modified e.g. by an external process. This can be useful in scenarios where the application imports a user defined data processing script which needs to be updated after changes by the user. Note that this function does not reload modules recursively. @@ -143,7 +144,7 @@ changes by the user. Note that this function does not reload modules recursively Adding embedded modules ======================= -Embedded binary modules can be added using the `PYBIND11_EMBEDDED_MODULE` macro. +Embedded binary modules can be added using the ``PYBIND11_EMBEDDED_MODULE`` macro. Note that the definition must be placed at global scope. They can be imported like any other module. @@ -153,7 +154,7 @@ like any other module. namespace py = pybind11; PYBIND11_EMBEDDED_MODULE(fast_calc, m) { - // `m` is a `py::module` which is used to bind functions and classes + // `m` is a `py::module_` which is used to bind functions and classes m.def("add", [](int i, int j) { return i + j; }); @@ -162,14 +163,14 @@ like any other module. int main() { py::scoped_interpreter guard{}; - auto fast_calc = py::module::import("fast_calc"); + auto fast_calc = py::module_::import("fast_calc"); auto result = fast_calc.attr("add")(1, 2).cast(); assert(result == 3); } Unlike extension modules where only a single binary module can be created, on the embedded side an unlimited number of modules can be added using multiple -`PYBIND11_EMBEDDED_MODULE` definitions (as long as they have unique names). +``PYBIND11_EMBEDDED_MODULE`` definitions (as long as they have unique names). These modules are added to Python's list of builtins, so they can also be imported in pure Python files loaded by the interpreter. Everything interacts @@ -196,7 +197,7 @@ naturally: int main() { py::scoped_interpreter guard{}; - auto py_module = py::module::import("py_module"); + auto py_module = py::module_::import("py_module"); auto locals = py::dict("fmt"_a="{} + {} = {}", **py_module.attr("__dict__")); assert(locals["a"].cast() == 1); @@ -215,9 +216,9 @@ naturally: Interpreter lifetime ==================== -The Python interpreter shuts down when `scoped_interpreter` is destroyed. After +The Python interpreter shuts down when ``scoped_interpreter`` is destroyed. After this, creating a new instance will restart the interpreter. Alternatively, the -`initialize_interpreter` / `finalize_interpreter` pair of functions can be used +``initialize_interpreter`` / ``finalize_interpreter`` pair of functions can be used to directly set the state at any time. Modules created with pybind11 can be safely re-initialized after the interpreter @@ -229,8 +230,8 @@ global data. All the details can be found in the CPython documentation. .. warning:: - Creating two concurrent `scoped_interpreter` guards is a fatal error. So is - calling `initialize_interpreter` for a second time after the interpreter + Creating two concurrent ``scoped_interpreter`` guards is a fatal error. So is + calling ``initialize_interpreter`` for a second time after the interpreter has already been initialized. Do not use the raw CPython API functions ``Py_Initialize`` and @@ -241,7 +242,7 @@ global data. All the details can be found in the CPython documentation. Sub-interpreter support ======================= -Creating multiple copies of `scoped_interpreter` is not possible because it +Creating multiple copies of ``scoped_interpreter`` is not possible because it represents the main Python interpreter. Sub-interpreters are something different and they do permit the existence of multiple interpreters. This is an advanced feature of the CPython API and should be handled with care. pybind11 does not @@ -257,5 +258,5 @@ We'll just mention a couple of caveats the sub-interpreters support in pybind11: 2. Managing multiple threads, multiple interpreters and the GIL can be challenging and there are several caveats here, even within the pure CPython API (please refer to the Python docs for details). As for - pybind11, keep in mind that `gil_scoped_release` and `gil_scoped_acquire` + pybind11, keep in mind that ``gil_scoped_release`` and ``gil_scoped_acquire`` do not take sub-interpreters into account. diff --git a/wrap/pybind11/docs/advanced/exceptions.rst b/wrap/pybind11/docs/advanced/exceptions.rst index a96f8e8f4..7cd8447b9 100644 --- a/wrap/pybind11/docs/advanced/exceptions.rst +++ b/wrap/pybind11/docs/advanced/exceptions.rst @@ -43,18 +43,28 @@ at its exception handler. | | of bounds access in ``__getitem__``, | | | ``__setitem__``, etc.) | +--------------------------------------+--------------------------------------+ -| :class:`pybind11::value_error` | ``ValueError`` (used to indicate | -| | wrong value passed in | -| | ``container.remove(...)``) | -+--------------------------------------+--------------------------------------+ | :class:`pybind11::key_error` | ``KeyError`` (used to indicate out | | | of bounds access in ``__getitem__``, | | | ``__setitem__`` in dict-like | | | objects, etc.) | +--------------------------------------+--------------------------------------+ +| :class:`pybind11::value_error` | ``ValueError`` (used to indicate | +| | wrong value passed in | +| | ``container.remove(...)``) | ++--------------------------------------+--------------------------------------+ +| :class:`pybind11::type_error` | ``TypeError`` | ++--------------------------------------+--------------------------------------+ +| :class:`pybind11::buffer_error` | ``BufferError`` | ++--------------------------------------+--------------------------------------+ +| :class:`pybind11::import_error` | ``ImportError`` | ++--------------------------------------+--------------------------------------+ +| :class:`pybind11::attribute_error` | ``AttributeError`` | ++--------------------------------------+--------------------------------------+ +| Any other exception | ``RuntimeError`` | ++--------------------------------------+--------------------------------------+ Exception translation is not bidirectional. That is, *catching* the C++ -exceptions defined above above will not trap exceptions that originate from +exceptions defined above will not trap exceptions that originate from Python. For that, catch :class:`pybind11::error_already_set`. See :ref:`below ` for further details. @@ -67,9 +77,10 @@ Registering custom translators If the default exception conversion policy described above is insufficient, pybind11 also provides support for registering custom exception translators. -To register a simple exception conversion that translates a C++ exception into -a new Python exception using the C++ exception's ``what()`` method, a helper -function is available: +Similar to pybind11 classes, exception translators can be local to the module +they are defined in or global to the entire python session. To register a simple +exception conversion that translates a C++ exception into a new Python exception +using the C++ exception's ``what()`` method, a helper function is available: .. code-block:: cpp @@ -79,29 +90,39 @@ This call creates a Python exception class with the name ``PyExp`` in the given module and automatically converts any encountered exceptions of type ``CppExp`` into Python exceptions of type ``PyExp``. +A matching function is available for registering a local exception translator: + +.. code-block:: cpp + + py::register_local_exception(module, "PyExp"); + + It is possible to specify base class for the exception using the third -parameter, a `handle`: +parameter, a ``handle``: .. code-block:: cpp py::register_exception(module, "PyExp", PyExc_RuntimeError); + py::register_local_exception(module, "PyExp", PyExc_RuntimeError); -Then `PyExp` can be caught both as `PyExp` and `RuntimeError`. +Then ``PyExp`` can be caught both as ``PyExp`` and ``RuntimeError``. The class objects of the built-in Python exceptions are listed in the Python documentation on `Standard Exceptions `_. -The default base class is `PyExc_Exception`. +The default base class is ``PyExc_Exception``. -When more advanced exception translation is needed, the function -``py::register_exception_translator(translator)`` can be used to register +When more advanced exception translation is needed, the functions +``py::register_exception_translator(translator)`` and +``py::register_local_exception_translator(translator)`` can be used to register functions that can translate arbitrary exception types (and which may include -additional logic to do so). The function takes a stateless callable (e.g. a +additional logic to do so). The functions takes a stateless callable (e.g. a function pointer or a lambda function without captured variables) with the call signature ``void(std::exception_ptr)``. When a C++ exception is thrown, the registered exception translators are tried in reverse order of registration (i.e. the last registered translator gets the -first shot at handling the exception). +first shot at handling the exception). All local translators will be tried +before a global translator is tried. Inside the translator, ``std::rethrow_exception`` should be used within a try block to re-throw the exception. One or more catch clauses to catch @@ -156,6 +177,57 @@ section. may be explicitly (re-)thrown to delegate it to the other, previously-declared existing exception translators. + Note that ``libc++`` and ``libstdc++`` `behave differently `_ + with ``-fvisibility=hidden``. Therefore exceptions that are used across ABI boundaries need to be explicitly exported, as exercised in ``tests/test_exceptions.h``. + See also: "Problems with C++ exceptions" under `GCC Wiki `_. + + +Local vs Global Exception Translators +===================================== + +When a global exception translator is registered, it will be applied across all +modules in the reverse order of registration. This can create behavior where the +order of module import influences how exceptions are translated. + +If module1 has the following translator: + +.. code-block:: cpp + + py::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const std::invalid_argument &e) { + PyErr_SetString("module1 handled this") + } + } + +and module2 has the following similar translator: + +.. code-block:: cpp + + py::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const std::invalid_argument &e) { + PyErr_SetString("module2 handled this") + } + } + +then which translator handles the invalid_argument will be determined by the +order that module1 and module2 are imported. Since exception translators are +applied in the reverse order of registration, which ever module was imported +last will "win" and that translator will be applied. + +If there are multiple pybind11 modules that share exception types (either +standard built-in or custom) loaded into a single python instance and +consistent error handling behavior is needed, then local translators should be +used. + +Changing the previous example to use ``register_local_exception_translator`` +would mean that when invalid_argument is thrown in the module2 code, the +module2 translator will always handle it, while in module1, the module1 +translator will do the same. + .. _handling_python_exceptions_cpp: Handling exceptions from Python in C++ @@ -182,13 +254,13 @@ For example: try { // open("missing.txt", "r") - auto file = py::module::import("io").attr("open")("missing.txt", "r"); + auto file = py::module_::import("io").attr("open")("missing.txt", "r"); auto text = file.attr("read")(); file.attr("close")(); } catch (py::error_already_set &e) { if (e.matches(PyExc_FileNotFoundError)) { py::print("missing.txt not found"); - } else if (e.match(PyExc_PermissionError)) { + } else if (e.matches(PyExc_PermissionError)) { py::print("missing.txt found but not accessible"); } else { throw; @@ -253,6 +325,34 @@ Alternately, to ignore the error, call `PyErr_Clear Any Python error must be thrown or cleared, or Python/pybind11 will be left in an invalid state. +Chaining exceptions ('raise from') +================================== + +In Python 3.3 a mechanism for indicating that exceptions were caused by other +exceptions was introduced: + +.. code-block:: py + + try: + print(1 / 0) + except Exception as exc: + raise RuntimeError("could not divide by zero") from exc + +To do a similar thing in pybind11, you can use the ``py::raise_from`` function. It +sets the current python error indicator, so to continue propagating the exception +you should ``throw py::error_already_set()`` (Python 3 only). + +.. code-block:: cpp + + try { + py::eval("print(1 / 0")); + } catch (py::error_already_set &e) { + py::raise_from(e, PyExc_RuntimeError, "could not divide by zero"); + throw py::error_already_set(); + } + +.. versionadded:: 2.8 + .. _unraisable_exceptions: Handling unraisable exceptions diff --git a/wrap/pybind11/docs/advanced/functions.rst b/wrap/pybind11/docs/advanced/functions.rst index c895517c5..bf5b5fa00 100644 --- a/wrap/pybind11/docs/advanced/functions.rst +++ b/wrap/pybind11/docs/advanced/functions.rst @@ -17,7 +17,7 @@ bindings for functions that return a non-trivial type. Just by looking at the type information, it is not clear whether Python should take charge of the returned value and eventually free its resources, or if this is handled on the C++ side. For this reason, pybind11 provides a several *return value policy* -annotations that can be passed to the :func:`module::def` and +annotations that can be passed to the :func:`module_::def` and :func:`class_::def` functions. The default policy is :enum:`return_value_policy::automatic`. @@ -50,7 +50,7 @@ implied transfer of ownership, i.e.: .. code-block:: cpp - m.def("get_data", &get_data, return_value_policy::reference); + m.def("get_data", &get_data, py::return_value_policy::reference); On the other hand, this is not the right policy for many other situations, where ignoring ownership could lead to resource leaks. @@ -90,17 +90,18 @@ The following table provides an overview of available policies: | | return value is referenced by Python. This is the default policy for | | | property getters created via ``def_property``, ``def_readwrite``, etc. | +--------------------------------------------------+----------------------------------------------------------------------------+ -| :enum:`return_value_policy::automatic` | **Default policy.** This policy falls back to the policy | +| :enum:`return_value_policy::automatic` | This policy falls back to the policy | | | :enum:`return_value_policy::take_ownership` when the return value is a | | | pointer. Otherwise, it uses :enum:`return_value_policy::move` or | | | :enum:`return_value_policy::copy` for rvalue and lvalue references, | | | respectively. See above for a description of what all of these different | -| | policies do. | +| | policies do. This is the default policy for ``py::class_``-wrapped types. | +--------------------------------------------------+----------------------------------------------------------------------------+ | :enum:`return_value_policy::automatic_reference` | As above, but use policy :enum:`return_value_policy::reference` when the | | | return value is a pointer. This is the default conversion policy for | | | function arguments when calling Python functions manually from C++ code | -| | (i.e. via handle::operator()). You probably won't need to use this. | +| | (i.e. via ``handle::operator()``) and the casters in ``pybind11/stl.h``. | +| | You probably won't need to use this explicitly. | +--------------------------------------------------+----------------------------------------------------------------------------+ Return value policies can also be applied to properties: @@ -119,7 +120,7 @@ targeted arguments can be passed through the :class:`cpp_function` constructor: .. code-block:: cpp class_(m, "MyClass") - .def_property("data" + .def_property("data", py::cpp_function(&MyClass::getData, py::return_value_policy::copy), py::cpp_function(&MyClass::setData) ); @@ -182,6 +183,9 @@ relies on the ability to create a *weak reference* to the nurse object. When the nurse object is not a pybind11-registered type and does not support weak references, an exception will be thrown. +If you use an incorrect argument index, you will get a ``RuntimeError`` saying +``Could not activate keep_alive!``. You should review the indices you're using. + Consider the following example: here, the binding code for a list append operation ties the lifetime of the newly added element to the underlying container: @@ -228,7 +232,7 @@ is equivalent to the following pseudocode: }); The only requirement is that ``T`` is default-constructible, but otherwise any -scope guard will work. This is very useful in combination with `gil_scoped_release`. +scope guard will work. This is very useful in combination with ``gil_scoped_release``. See :ref:`gil`. Multiple guards can also be specified as ``py::call_guard``. The @@ -251,7 +255,7 @@ For instance, the following statement iterates over a Python ``dict``: .. code-block:: cpp - void print_dict(py::dict dict) { + void print_dict(const py::dict& dict) { /* Easily interact with Python types */ for (auto item : dict) std::cout << "key=" << std::string(py::str(item.first)) << ", " @@ -268,7 +272,7 @@ And used in Python as usual: .. code-block:: pycon - >>> print_dict({'foo': 123, 'bar': 'hello'}) + >>> print_dict({"foo": 123, "bar": "hello"}) key=foo, value=123 key=bar, value=hello @@ -289,7 +293,7 @@ Such functions can also be created using pybind11: .. code-block:: cpp - void generic(py::args args, py::kwargs kwargs) { + void generic(py::args args, const py::kwargs& kwargs) { /// .. do something with args if (kwargs) /// .. do something with kwargs @@ -302,8 +306,9 @@ The class ``py::args`` derives from ``py::tuple`` and ``py::kwargs`` derives from ``py::dict``. You may also use just one or the other, and may combine these with other -arguments as long as the ``py::args`` and ``py::kwargs`` arguments are the last -arguments accepted by the function. +arguments. Note, however, that ``py::kwargs`` must always be the last argument +of the function, and ``py::args`` implies that any further arguments are +keyword-only (see :ref:`keyword_only_arguments`). Please refer to the other examples for details on how to iterate over these, and on how to cast their entries into C++ objects. A demonstration is also @@ -362,6 +367,8 @@ like so: py::class_("MyClass") .def("myFunction", py::arg("arg") = static_cast(nullptr)); +.. _keyword_only_arguments: + Keyword-only arguments ====================== @@ -373,10 +380,11 @@ argument in a function definition: def f(a, *, b): # a can be positional or via keyword; b must be via keyword pass + f(a=1, b=2) # good f(b=2, a=1) # good - f(1, b=2) # good - f(1, 2) # TypeError: f() takes 1 positional argument but 2 were given + f(1, b=2) # good + f(1, 2) # TypeError: f() takes 1 positional argument but 2 were given Pybind11 provides a ``py::kw_only`` object that allows you to implement the same behaviour by specifying the object between positional and keyword-only @@ -392,6 +400,15 @@ feature does *not* require Python 3 to work. .. versionadded:: 2.6 +As of pybind11 2.9, a ``py::args`` argument implies that any following arguments +are keyword-only, as if ``py::kw_only()`` had been specified in the same +relative location of the argument list as the ``py::args`` argument. The +``py::kw_only()`` may be included to be explicit about this, but is not +required. (Prior to 2.9 ``py::args`` may only occur at the end of the argument +list, or immediately before a ``py::kwargs`` argument at the end). + +.. versionadded:: 2.9 + Positional-only arguments ========================= @@ -524,6 +541,8 @@ The default behaviour when the tag is unspecified is to allow ``None``. not allow ``None`` as argument. To pass optional argument of these copied types consider using ``std::optional`` +.. _overload_resolution: + Overload resolution order ========================= @@ -540,11 +559,13 @@ an explicit ``py::arg().noconvert()`` attribute in the function definition). If the second pass also fails a ``TypeError`` is raised. Within each pass, overloads are tried in the order they were registered with -pybind11. +pybind11. If the ``py::prepend()`` tag is added to the definition, a function +can be placed at the beginning of the overload sequence instead, allowing user +overloads to proceed built in functions. What this means in practice is that pybind11 will prefer any overload that does -not require conversion of arguments to an overload that does, but otherwise prefers -earlier-defined overloads to later-defined ones. +not require conversion of arguments to an overload that does, but otherwise +prefers earlier-defined overloads to later-defined ones. .. note:: @@ -553,3 +574,42 @@ earlier-defined overloads to later-defined ones. requiring one conversion over one requiring three, but only prioritizes overloads requiring no conversion at all to overloads that require conversion of at least one argument. + +.. versionadded:: 2.6 + + The ``py::prepend()`` tag. + +Binding functions with template parameters +========================================== + +You can bind functions that have template parameters. Here's a function: + +.. code-block:: cpp + + template + void set(T t); + +C++ templates cannot be instantiated at runtime, so you cannot bind the +non-instantiated function: + +.. code-block:: cpp + + // BROKEN (this will not compile) + m.def("set", &set); + +You must bind each instantiated function template separately. You may bind +each instantiation with the same name, which will be treated the same as +an overloaded function: + +.. code-block:: cpp + + m.def("set", &set); + m.def("set", &set); + +Sometimes it's more clear to bind them with separate names, which is also +an option: + +.. code-block:: cpp + + m.def("setInt", &set); + m.def("setString", &set); diff --git a/wrap/pybind11/docs/advanced/misc.rst b/wrap/pybind11/docs/advanced/misc.rst index a5899c67a..edab15fcb 100644 --- a/wrap/pybind11/docs/advanced/misc.rst +++ b/wrap/pybind11/docs/advanced/misc.rst @@ -84,7 +84,7 @@ could be realized as follows (important changes highlighted): }); } -The ``call_go`` wrapper can also be simplified using the `call_guard` policy +The ``call_go`` wrapper can also be simplified using the ``call_guard`` policy (see :ref:`call_policies`) which yields the same result: .. code-block:: cpp @@ -132,7 +132,7 @@ However, it can be acquired as follows: .. code-block:: cpp - py::object pet = (py::object) py::module::import("basic").attr("Pet"); + py::object pet = (py::object) py::module_::import("basic").attr("Pet"); py::class_(m, "Dog", pet) .def(py::init()) @@ -146,7 +146,7 @@ has been executed: .. code-block:: cpp - py::module::import("basic"); + py::module_::import("basic"); py::class_(m, "Dog") .def(py::init()) @@ -223,7 +223,7 @@ avoids this issue involves weak reference with a cleanup callback: .. code-block:: cpp - // Register a callback function that is invoked when the BaseClass object is colelcted + // Register a callback function that is invoked when the BaseClass object is collected py::cpp_function cleanup_callback( [](py::handle weakref) { // perform cleanup here -- this function is called with the GIL held @@ -237,13 +237,13 @@ avoids this issue involves weak reference with a cleanup callback: .. note:: - PyPy (at least version 5.9) does not garbage collect objects when the - interpreter exits. An alternative approach (which also works on CPython) is to use - the :py:mod:`atexit` module [#f7]_, for example: + PyPy does not garbage collect objects when the interpreter exits. An alternative + approach (which also works on CPython) is to use the :py:mod:`atexit` module [#f7]_, + for example: .. code-block:: cpp - auto atexit = py::module::import("atexit"); + auto atexit = py::module_::import("atexit"); atexit.attr("register")(py::cpp_function([]() { // perform cleanup here -- this function is called with the GIL held })); @@ -284,7 +284,7 @@ work, it is important that all lines are indented consistently, i.e.: )mydelimiter"); By default, pybind11 automatically generates and prepends a signature to the docstring of a function -registered with ``module::def()`` and ``class_::def()``. Sometimes this +registered with ``module_::def()`` and ``class_::def()``. Sometimes this behavior is not desirable, because you want to provide your own signature or remove the docstring completely to exclude the function from the Sphinx documentation. The class ``options`` allows you to selectively suppress auto-generated signatures: diff --git a/wrap/pybind11/docs/advanced/pycpp/numpy.rst b/wrap/pybind11/docs/advanced/pycpp/numpy.rst index e50d24a99..30daeefff 100644 --- a/wrap/pybind11/docs/advanced/pycpp/numpy.rst +++ b/wrap/pybind11/docs/advanced/pycpp/numpy.rst @@ -57,11 +57,11 @@ specification. struct buffer_info { void *ptr; - ssize_t itemsize; + py::ssize_t itemsize; std::string format; - ssize_t ndim; - std::vector shape; - std::vector strides; + py::ssize_t ndim; + std::vector shape; + std::vector strides; }; To create a C++ function that can take a Python buffer object as an argument, @@ -150,8 +150,10 @@ NumPy array containing double precision values. When it is invoked with a different type (e.g. an integer or a list of integers), the binding code will attempt to cast the input into a NumPy array -of the requested type. Note that this feature requires the -:file:`pybind11/numpy.h` header to be included. +of the requested type. This feature requires the :file:`pybind11/numpy.h` +header to be included. Note that :file:`pybind11/numpy.h` does not depend on +the NumPy headers, and thus can be used without declaring a build-time +dependency on NumPy; NumPy>=1.7.0 is a runtime dependency. Data in NumPy arrays is not guaranteed to packed in a dense manner; furthermore, entries can be separated by arbitrary column and row strides. @@ -169,6 +171,31 @@ template parameter, and it ensures that non-conforming arguments are converted into an array satisfying the specified requirements instead of trying the next function overload. +There are several methods on arrays; the methods listed below under references +work, as well as the following functions based on the NumPy API: + +- ``.dtype()`` returns the type of the contained values. + +- ``.strides()`` returns a pointer to the strides of the array (optionally pass + an integer axis to get a number). + +- ``.flags()`` returns the flag settings. ``.writable()`` and ``.owndata()`` + are directly available. + +- ``.offset_at()`` returns the offset (optionally pass indices). + +- ``.squeeze()`` returns a view with length-1 axes removed. + +- ``.view(dtype)`` returns a view of the array with a different dtype. + +- ``.reshape({i, j, ...})`` returns a view of the array with a different shape. + ``.resize({...})`` is also available. + +- ``.index_at(i, j, ...)`` gets the count from the beginning to a given index. + + +There are also several methods for getting references (described below). + Structured types ================ @@ -231,8 +258,8 @@ by the compiler. The result is returned as a NumPy array of type .. code-block:: pycon - >>> x = np.array([[1, 3],[5, 7]]) - >>> y = np.array([[2, 4],[6, 8]]) + >>> x = np.array([[1, 3], [5, 7]]) + >>> y = np.array([[2, 4], [6, 8]]) >>> z = 3 >>> result = vectorized_func(x, y, z) @@ -309,17 +336,17 @@ where ``N`` gives the required dimensionality of the array: m.def("sum_3d", [](py::array_t x) { auto r = x.unchecked<3>(); // x must have ndim = 3; can be non-writeable double sum = 0; - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t k = 0; k < r.shape(2); k++) sum += r(i, j, k); return sum; }); m.def("increment_3d", [](py::array_t x) { auto r = x.mutable_unchecked<3>(); // Will throw if ndim != 3 or flags.writeable is false - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t k = 0; k < r.shape(2); k++) r(i, j, k) += 1.0; }, py::arg().noconvert()); @@ -343,21 +370,21 @@ The returned proxy object supports some of the same methods as ``py::array`` so that it can be used as a drop-in replacement for some existing, index-checked uses of ``py::array``: -- ``r.ndim()`` returns the number of dimensions +- ``.ndim()`` returns the number of dimensions -- ``r.data(1, 2, ...)`` and ``r.mutable_data(1, 2, ...)``` returns a pointer to +- ``.data(1, 2, ...)`` and ``r.mutable_data(1, 2, ...)``` returns a pointer to the ``const T`` or ``T`` data, respectively, at the given indices. The latter is only available to proxies obtained via ``a.mutable_unchecked()``. -- ``itemsize()`` returns the size of an item in bytes, i.e. ``sizeof(T)``. +- ``.itemsize()`` returns the size of an item in bytes, i.e. ``sizeof(T)``. -- ``ndim()`` returns the number of dimensions. +- ``.ndim()`` returns the number of dimensions. -- ``shape(n)`` returns the size of dimension ``n`` +- ``.shape(n)`` returns the size of dimension ``n`` -- ``size()`` returns the total number of elements (i.e. the product of the shapes). +- ``.size()`` returns the total number of elements (i.e. the product of the shapes). -- ``nbytes()`` returns the number of bytes used by the referenced elements +- ``.nbytes()`` returns the number of bytes used by the referenced elements (i.e. ``itemsize()`` times ``size()``). .. seealso:: @@ -376,7 +403,7 @@ In Python 2, the syntactic sugar ``...`` is not available, but the singleton .. code-block:: python - a = # a NumPy array + a = ... # a NumPy array b = a[0, ..., 0] The function ``py::ellipsis()`` function can be used to perform the same @@ -388,7 +415,7 @@ operation on the C++ side: py::array b = a[py::make_tuple(0, py::ellipsis(), 0)]; .. versionchanged:: 2.6 - ``py::ellipsis()`` is now also avaliable in Python 2. + ``py::ellipsis()`` is now also available in Python 2. Memory view =========== diff --git a/wrap/pybind11/docs/advanced/pycpp/object.rst b/wrap/pybind11/docs/advanced/pycpp/object.rst index 70e493acd..93e1a94d8 100644 --- a/wrap/pybind11/docs/advanced/pycpp/object.rst +++ b/wrap/pybind11/docs/advanced/pycpp/object.rst @@ -20,6 +20,40 @@ Available types include :class:`handle`, :class:`object`, :class:`bool_`, Be sure to review the :ref:`pytypes_gotchas` before using this heavily in your C++ API. +.. _instantiating_compound_types: + +Instantiating compound Python types from C++ +============================================ + +Dictionaries can be initialized in the :class:`dict` constructor: + +.. code-block:: cpp + + using namespace pybind11::literals; // to bring in the `_a` literal + py::dict d("spam"_a=py::none(), "eggs"_a=42); + +A tuple of python objects can be instantiated using :func:`py::make_tuple`: + +.. code-block:: cpp + + py::tuple tup = py::make_tuple(42, py::none(), "spam"); + +Each element is converted to a supported Python type. + +A `simple namespace`_ can be instantiated using + +.. code-block:: cpp + + using namespace pybind11::literals; // to bring in the `_a` literal + py::object SimpleNamespace = py::module_::import("types").attr("SimpleNamespace"); + py::object ns = SimpleNamespace("spam"_a=py::none(), "eggs"_a=42); + +Attributes on a namespace can be modified with the :func:`py::delattr`, +:func:`py::getattr`, and :func:`py::setattr` functions. Simple namespaces can +be useful as lightweight stand-ins for class instances. + +.. _simple namespace: https://docs.python.org/3/library/types.html#types.SimpleNamespace + .. _casting_back_and_forth: Casting back and forth @@ -30,7 +64,7 @@ types to Python, which can be done using :func:`py::cast`: .. code-block:: cpp - MyClass *cls = ..; + MyClass *cls = ...; py::object obj = py::cast(cls); The reverse direction uses the following syntax: @@ -56,12 +90,12 @@ This example obtains a reference to the Python ``Decimal`` class. .. code-block:: cpp // Equivalent to "from decimal import Decimal" - py::object Decimal = py::module::import("decimal").attr("Decimal"); + py::object Decimal = py::module_::import("decimal").attr("Decimal"); .. code-block:: cpp // Try to import scipy - py::object scipy = py::module::import("scipy"); + py::object scipy = py::module_::import("scipy"); return scipy.attr("__version__"); @@ -81,7 +115,7 @@ via ``operator()``. .. code-block:: cpp // Use Python to make our directories - py::object os = py::module::import("os"); + py::object os = py::module_::import("os"); py::object makedirs = os.attr("makedirs"); makedirs("/tmp/path/to/somewhere"); @@ -132,6 +166,7 @@ Keyword arguments are also supported. In Python, there is the usual call syntax: def f(number, say, to): ... # function code + f(1234, say="hello", to=some_instance) # keyword call in Python In C++, the same call can be made using: @@ -196,9 +231,9 @@ C++ functions that require a specific subtype rather than a generic :class:`obje #include using namespace pybind11::literals; - py::module os = py::module::import("os"); - py::module path = py::module::import("os.path"); // like 'import os.path as path' - py::module np = py::module::import("numpy"); // like 'import numpy as np' + py::module_ os = py::module_::import("os"); + py::module_ path = py::module_::import("os.path"); // like 'import os.path as path' + py::module_ np = py::module_::import("numpy"); // like 'import numpy as np' py::str curdir_abs = path.attr("abspath")(path.attr("curdir")); py::print(py::str("Current directory: ") + curdir_abs); diff --git a/wrap/pybind11/docs/advanced/pycpp/utilities.rst b/wrap/pybind11/docs/advanced/pycpp/utilities.rst index 369e7c94d..af0f9cb2b 100644 --- a/wrap/pybind11/docs/advanced/pycpp/utilities.rst +++ b/wrap/pybind11/docs/advanced/pycpp/utilities.rst @@ -28,7 +28,7 @@ Capturing standard output from ostream Often, a library will use the streams ``std::cout`` and ``std::cerr`` to print, but this does not play well with Python's standard ``sys.stdout`` and ``sys.stderr`` -redirection. Replacing a library's printing with `py::print ` may not +redirection. Replacing a library's printing with ``py::print `` may not be feasible. This can be fixed using a guard around the library function that redirects output to the corresponding Python streams: @@ -42,20 +42,31 @@ redirects output to the corresponding Python streams: m.def("noisy_func", []() { py::scoped_ostream_redirect stream( std::cout, // std::ostream& - py::module::import("sys").attr("stdout") // Python output + py::module_::import("sys").attr("stdout") // Python output ); call_noisy_func(); }); +.. warning:: + + The implementation in ``pybind11/iostream.h`` is NOT thread safe. Multiple + threads writing to a redirected ostream concurrently cause data races + and potentially buffer overflows. Therefore it is currently a requirement + that all (possibly) concurrent redirected ostream writes are protected by + a mutex. #HelpAppreciated: Work on iostream.h thread safety. For more + background see the discussions under + `PR #2982 `_ and + `PR #2995 `_. + This method respects flushes on the output streams and will flush if needed when the scoped guard is destroyed. This allows the output to be redirected in real time, such as to a Jupyter notebook. The two arguments, the C++ stream and the Python output, are optional, and default to standard output if not given. An -extra type, `py::scoped_estream_redirect `, is identical +extra type, ``py::scoped_estream_redirect ``, is identical except for defaulting to ``std::cerr`` and ``sys.stderr``; this can be useful with -`py::call_guard`, which allows multiple items, but uses the default constructor: +``py::call_guard``, which allows multiple items, but uses the default constructor: -.. code-block:: py +.. code-block:: cpp // Alternative: Call single function using call guard m.def("noisy_func", &call_noisy_function, @@ -63,7 +74,7 @@ except for defaulting to ``std::cerr`` and ``sys.stderr``; this can be useful wi py::scoped_estream_redirect>()); The redirection can also be done in Python with the addition of a context -manager, using the `py::add_ostream_redirect() ` function: +manager, using the ``py::add_ostream_redirect() `` function: .. code-block:: cpp @@ -92,7 +103,7 @@ arguments to disable one of the streams if needed. Evaluating Python expressions from strings and files ==================================================== -pybind11 provides the `eval`, `exec` and `eval_file` functions to evaluate +pybind11 provides the ``eval``, ``exec`` and ``eval_file`` functions to evaluate Python expressions and statements. The following example illustrates how they can be used. @@ -104,7 +115,7 @@ can be used. ... // Evaluate in scope of main module - py::object scope = py::module::import("__main__").attr("__dict__"); + py::object scope = py::module_::import("__main__").attr("__dict__"); // Evaluate an isolated expression int result = py::eval("my_variable + 10", scope).cast(); diff --git a/wrap/pybind11/docs/advanced/smart_ptrs.rst b/wrap/pybind11/docs/advanced/smart_ptrs.rst index da57748ca..5a2220109 100644 --- a/wrap/pybind11/docs/advanced/smart_ptrs.rst +++ b/wrap/pybind11/docs/advanced/smart_ptrs.rst @@ -77,6 +77,7 @@ segmentation fault). .. code-block:: python from example import Parent + print(Parent().get_child()) The problem is that ``Parent::get_child()`` returns a pointer to an instance of diff --git a/wrap/pybind11/docs/basics.rst b/wrap/pybind11/docs/basics.rst index 71440c9c6..e0479b298 100644 --- a/wrap/pybind11/docs/basics.rst +++ b/wrap/pybind11/docs/basics.rst @@ -39,7 +39,7 @@ on various C++11 language features that break older versions of Visual Studio. To use the C++17 in Visual Studio 2017 (MSVC 14.1), pybind11 requires the flag ``/permissive-`` to be passed to the compiler `to enforce standard conformance`_. When - building with Visual Studio 2019, this is not strictly necessary, but still adviced. + building with Visual Studio 2019, this is not strictly necessary, but still advised. .. _`to enforce standard conformance`: https://docs.microsoft.com/en-us/cpp/build/reference/permissive-standards-conformance?view=vs-2017 @@ -109,7 +109,7 @@ a file named :file:`example.cpp` with the following contents: PYBIND11_MODULE(example, m) { m.doc() = "pybind11 example plugin"; // optional module docstring - m.def("add", &add, "A function which adds two numbers"); + m.def("add", &add, "A function that adds two numbers"); } .. [#f1] In practice, implementation and binding code will generally be located @@ -118,8 +118,8 @@ a file named :file:`example.cpp` with the following contents: The :func:`PYBIND11_MODULE` macro creates a function that will be called when an ``import`` statement is issued from within Python. The module name (``example``) is given as the first macro argument (it should not be in quotes). The second -argument (``m``) defines a variable of type :class:`py::module ` which -is the main interface for creating bindings. The method :func:`module::def` +argument (``m``) defines a variable of type :class:`py::module_ ` which +is the main interface for creating bindings. The method :func:`module_::def` generates binding code that exposes the ``add()`` function to Python. .. note:: @@ -136,7 +136,14 @@ On Linux, the above example can be compiled using the following command: .. code-block:: bash - $ c++ -O3 -Wall -shared -std=c++11 -fPIC `python3 -m pybind11 --includes` example.cpp -o example`python3-config --extension-suffix` + $ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) example.cpp -o example$(python3-config --extension-suffix) + +.. note:: + + If you used :ref:`include_as_a_submodule` to get the pybind11 source, then + use ``$(python3-config --includes) -Iextern/pybind11/include`` instead of + ``$(python3 -m pybind11 --includes)`` in the above compilation, as + explained in :ref:`building_manually`. For more details on the required compiler flags on Linux and macOS, see :ref:`building_manually`. For complete cross-platform compilation instructions, @@ -181,7 +188,7 @@ names of the arguments ("i" and "j" in this case). py::arg("i"), py::arg("j")); :class:`arg` is one of several special tag classes which can be used to pass -metadata into :func:`module::def`. With this modified binding code, we can now +metadata into :func:`module_::def`. With this modified binding code, we can now call the function using keyword arguments, which is a more readable alternative particularly for functions taking many parameters: diff --git a/wrap/pybind11/docs/benchmark.py b/wrap/pybind11/docs/benchmark.py index 023477212..f19079367 100644 --- a/wrap/pybind11/docs/benchmark.py +++ b/wrap/pybind11/docs/benchmark.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -import random -import os -import time import datetime as dt +import os +import random nfns = 4 # Functions per class nargs = 4 # Arguments per function @@ -14,7 +13,7 @@ def generate_dummy_code_pybind11(nclasses=10): for cl in range(nclasses): decl += "class cl%03i;\n" % cl - decl += '\n' + decl += "\n" for cl in range(nclasses): decl += "class cl%03i {\n" % cl @@ -22,18 +21,17 @@ def generate_dummy_code_pybind11(nclasses=10): bindings += ' py::class_(m, "cl%03i")\n' % (cl, cl) for fn in range(nfns): ret = random.randint(0, nclasses - 1) - params = [random.randint(0, nclasses - 1) for i in range(nargs)] + params = [random.randint(0, nclasses - 1) for i in range(nargs)] decl += " cl%03i *fn_%03i(" % (ret, fn) decl += ", ".join("cl%03i *" % p for p in params) decl += ");\n" - bindings += ' .def("fn_%03i", &cl%03i::fn_%03i)\n' % \ - (fn, cl, fn) + bindings += ' .def("fn_%03i", &cl%03i::fn_%03i)\n' % (fn, cl, fn) decl += "};\n\n" - bindings += ' ;\n' + bindings += " ;\n" result = "#include \n\n" result += "namespace py = pybind11;\n\n" - result += decl + '\n' + result += decl + "\n" result += "PYBIND11_MODULE(example, m) {\n" result += bindings result += "}" @@ -46,7 +44,7 @@ def generate_dummy_code_boost(nclasses=10): for cl in range(nclasses): decl += "class cl%03i;\n" % cl - decl += '\n' + decl += "\n" for cl in range(nclasses): decl += "class cl%03i {\n" % cl @@ -54,18 +52,20 @@ def generate_dummy_code_boost(nclasses=10): bindings += ' py::class_("cl%03i")\n' % (cl, cl) for fn in range(nfns): ret = random.randint(0, nclasses - 1) - params = [random.randint(0, nclasses - 1) for i in range(nargs)] + params = [random.randint(0, nclasses - 1) for i in range(nargs)] decl += " cl%03i *fn_%03i(" % (ret, fn) decl += ", ".join("cl%03i *" % p for p in params) decl += ");\n" - bindings += ' .def("fn_%03i", &cl%03i::fn_%03i, py::return_value_policy())\n' % \ - (fn, cl, fn) + bindings += ( + ' .def("fn_%03i", &cl%03i::fn_%03i, py::return_value_policy())\n' + % (fn, cl, fn) + ) decl += "};\n\n" - bindings += ' ;\n' + bindings += " ;\n" result = "#include \n\n" result += "namespace py = boost::python;\n\n" - result += decl + '\n' + result += decl + "\n" result += "BOOST_PYTHON_MODULE(example) {\n" result += bindings result += "}" @@ -73,17 +73,19 @@ def generate_dummy_code_boost(nclasses=10): for codegen in [generate_dummy_code_pybind11, generate_dummy_code_boost]: - print ("{") + print("{") for i in range(0, 10): nclasses = 2 ** i with open("test.cpp", "w") as f: f.write(codegen(nclasses)) n1 = dt.datetime.now() - os.system("g++ -Os -shared -rdynamic -undefined dynamic_lookup " + os.system( + "g++ -Os -shared -rdynamic -undefined dynamic_lookup " "-fvisibility=hidden -std=c++14 test.cpp -I include " - "-I /System/Library/Frameworks/Python.framework/Headers -o test.so") + "-I /System/Library/Frameworks/Python.framework/Headers -o test.so" + ) n2 = dt.datetime.now() elapsed = (n2 - n1).total_seconds() - size = os.stat('test.so').st_size + size = os.stat("test.so").st_size print(" {%i, %f, %i}," % (nclasses * nfns, elapsed, size)) - print ("}") + print("}") diff --git a/wrap/pybind11/docs/changelog.rst b/wrap/pybind11/docs/changelog.rst index 8f95c1274..16bf3aa3f 100644 --- a/wrap/pybind11/docs/changelog.rst +++ b/wrap/pybind11/docs/changelog.rst @@ -6,21 +6,697 @@ Changelog Starting with version 1.8.0, pybind11 releases use a `semantic versioning `_ policy. -v2.6.0 (IN PROGRESS) +Version 2.9.1 (Feb 2, 2022) +--------------------------- + +Changes: + +* If possible, attach Python exception with ``py::raise_from`` to ``TypeError`` + when casting from C++ to Python. This will give additional info if Python + exceptions occur in the caster. Adds a test case of trying to convert a set + from C++ to Python when the hash function is not defined in Python. + `#3605 `_ + +* Add a mapping of C++11 nested exceptions to their Python exception + equivalent using ``py::raise_from``. This attaches the nested exceptions in + Python using the ``__cause__`` field. + `#3608 `_ + +* Propagate Python exception traceback using ``raise_from`` if a pybind11 + function runs out of overloads. + `#3671 `_ + +* ``py::multiple_inheritance`` is now only needed when C++ bases are hidden + from pybind11. + `#3650 `_ and + `#3659 `_ + + +Bug fixes: + +* Remove a boolean cast in ``numpy.h`` that causes MSVC C4800 warnings when + compiling against Python 3.10 or newer. + `#3669 `_ + +* Render ``py::bool_`` and ``py::float_`` as ``bool`` and ``float`` + respectively. + `#3622 `_ + +Build system improvements: + +* Fix CMake extension suffix computation on Python 3.10+. + `#3663 `_ + +* Allow ``CMAKE_ARGS`` to override CMake args in pybind11's own ``setup.py``. + `#3577 `_ + +* Remove a few deprecated c-headers. + `#3610 `_ + +* More uniform handling of test targets. + `#3590 `_ + +* Add clang-tidy readability check to catch potentially swapped function args. + `#3611 `_ + + +Version 2.9.0 (Dec 28, 2021) +---------------------------- + +This is the last version to support Python 2.7 and 3.5. + +New Features: + +* Allow ``py::args`` to be followed by other arguments; the remaining arguments + are implicitly keyword-only, as if a ``py::kw_only{}`` annotation had been + used. + `#3402 `_ + +Changes: + +* Make str/bytes/memoryview more interoperable with ``std::string_view``. + `#3521 `_ + +* Replace ``_`` with ``const_name`` in internals, avoid defining ``pybind::_`` + if ``_`` defined as macro (common gettext usage) + `#3423 `_ + + +Bug fixes: + +* Fix a rare warning about extra copy in an Eigen constructor. + `#3486 `_ + +* Fix caching of the C++ overrides. + `#3465 `_ + +* Add missing ``std::forward`` calls to some ``cpp_function`` overloads. + `#3443 `_ + +* Support PyPy 7.3.7 and the PyPy3.8 beta. Test python-3.11 on PRs with the + ``python dev`` label. + `#3419 `_ + +* Replace usage of deprecated ``Eigen::MappedSparseMatrix`` with + ``Eigen::Map>`` for Eigen 3.3+. + `#3499 `_ + +* Tweaks to support Microsoft Visual Studio 2022. + `#3497 `_ + +Build system improvements: + +* Nicer CMake printout and IDE organisation for pybind11's own tests. + `#3479 `_ + +* CMake: report version type as part of the version string to avoid a spurious + space in the package status message. + `#3472 `_ + +* Flags starting with ``-g`` in ``$CFLAGS`` and ``$CPPFLAGS`` are no longer + overridden by ``.Pybind11Extension``. + `#3436 `_ + +* Ensure ThreadPool is closed in ``setup_helpers``. + `#3548 `_ + +* Avoid LTS on ``mips64`` and ``ppc64le`` (reported broken). + `#3557 `_ + + +v2.8.1 (Oct 27, 2021) +--------------------- + +Changes and additions: + +* The simple namespace creation shortcut added in 2.8.0 was deprecated due to + usage of CPython internal API, and will be removed soon. Use + ``py::module_::import("types").attr("SimpleNamespace")``. + `#3374 `_ + +* Add C++ Exception type to throw and catch ``AttributeError``. Useful for + defining custom ``__setattr__`` and ``__getattr__`` methods. + `#3387 `_ + +Fixes: + +* Fixed the potential for dangling references when using properties with + ``std::optional`` types. + `#3376 `_ + +* Modernize usage of ``PyCodeObject`` on Python 3.9+ (moving toward support for + Python 3.11a1) + `#3368 `_ + +* A long-standing bug in ``eigen.h`` was fixed (originally PR #3343). The bug + was unmasked by newly added ``static_assert``'s in the Eigen 3.4.0 release. + `#3352 `_ + +* Support multiple raw inclusion of CMake helper files (Conan.io does this for + multi-config generators). + `#3420 `_ + +* Fix harmless warning on upcoming CMake 3.22. + `#3368 `_ + +* Fix 2.8.0 regression with MSVC 2017 + C++17 mode + Python 3. + `#3407 `_ + +* Fix 2.8.0 regression that caused undefined behavior (typically + segfaults) in ``make_key_iterator``/``make_value_iterator`` if dereferencing + the iterator returned a temporary value instead of a reference. + `#3348 `_ + + +v2.8.0 (Oct 4, 2021) -------------------- +New features: + +* Added ``py::raise_from`` to enable chaining exceptions. + `#3215 `_ + +* Allow exception translators to be optionally registered local to a module + instead of applying globally across all pybind11 modules. Use + ``register_local_exception_translator(ExceptionTranslator&& translator)`` + instead of ``register_exception_translator(ExceptionTranslator&& + translator)`` to keep your exception remapping code local to the module. + `#2650 `_ + +* Add ``make_simple_namespace`` function for instantiating Python + ``SimpleNamespace`` objects. **Deprecated in 2.8.1.** + `#2840 `_ + +* ``pybind11::scoped_interpreter`` and ``initialize_interpreter`` have new + arguments to allow ``sys.argv`` initialization. + `#2341 `_ + +* Allow Python builtins to be used as callbacks in CPython. + `#1413 `_ + +* Added ``view`` to view arrays with a different datatype. + `#987 `_ + +* Implemented ``reshape`` on arrays. + `#984 `_ + +* Enable defining custom ``__new__`` methods on classes by fixing bug + preventing overriding methods if they have non-pybind11 siblings. + `#3265 `_ + +* Add ``make_value_iterator()``, and fix ``make_key_iterator()`` to return + references instead of copies. + `#3293 `_ + +* Improve the classes generated by ``bind_map``: `#3310 `_ + + * Change ``.items`` from an iterator to a dictionary view. + * Add ``.keys`` and ``.values`` (both dictionary views). + * Allow ``__contains__`` to take any object. + +* ``pybind11::custom_type_setup`` was added, for customizing the + ``PyHeapTypeObject`` corresponding to a class, which may be useful for + enabling garbage collection support, among other things. + `#3287 `_ + + +Changes: + +* Set ``__file__`` constant when running ``eval_file`` in an embedded interpreter. + `#3233 `_ + +* Python objects and (C++17) ``std::optional`` now accepted in ``py::slice`` + constructor. + `#1101 `_ + +* The pybind11 proxy types ``str``, ``bytes``, ``bytearray``, ``tuple``, + ``list`` now consistently support passing ``ssize_t`` values for sizes and + indexes. Previously, only ``size_t`` was accepted in several interfaces. + `#3219 `_ + +* Avoid evaluating ``PYBIND11_TLS_REPLACE_VALUE`` arguments more than once. + `#3290 `_ + +Fixes: + +* Bug fix: enum value's ``__int__`` returning non-int when underlying type is + bool or of char type. + `#1334 `_ + +* Fixes bug in setting error state in Capsule's pointer methods. + `#3261 `_ + +* A long-standing memory leak in ``py::cpp_function::initialize`` was fixed. + `#3229 `_ + +* Fixes thread safety for some ``pybind11::type_caster`` which require lifetime + extension, such as for ``std::string_view``. + `#3237 `_ + +* Restore compatibility with gcc 4.8.4 as distributed by ubuntu-trusty, linuxmint-17. + `#3270 `_ + + +Build system improvements: + +* Fix regression in CMake Python package config: improper use of absolute path. + `#3144 `_ + +* Cached Python version information could become stale when CMake was re-run + with a different Python version. The build system now detects this and + updates this information. + `#3299 `_ + +* Specified UTF8-encoding in setup.py calls of open(). + `#3137 `_ + +* Fix a harmless warning from CMake 3.21 with the classic Python discovery. + `#3220 `_ + +* Eigen repo and version can now be specified as cmake options. + `#3324 `_ + + +Backend and tidying up: + +* Reduced thread-local storage required for keeping alive temporary data for + type conversion to one key per ABI version, rather than one key per extension + module. This makes the total thread-local storage required by pybind11 2 + keys per ABI version. + `#3275 `_ + +* Optimize NumPy array construction with additional moves. + `#3183 `_ + +* Conversion to ``std::string`` and ``std::string_view`` now avoids making an + extra copy of the data on Python >= 3.3. + `#3257 `_ + +* Remove const modifier from certain C++ methods on Python collections + (``list``, ``set``, ``dict``) such as (``clear()``, ``append()``, + ``insert()``, etc...) and annotated them with ``py-non-const``. + +* Enable readability ``clang-tidy-const-return`` and remove useless consts. + `#3254 `_ + `#3194 `_ + +* The clang-tidy ``google-explicit-constructor`` option was enabled. + `#3250 `_ + +* Mark a pytype move constructor as noexcept (perf). + `#3236 `_ + +* Enable clang-tidy check to guard against inheritance slicing. + `#3210 `_ + +* Legacy warning suppression pragma were removed from eigen.h. On Unix + platforms, please use -isystem for Eigen include directories, to suppress + compiler warnings originating from Eigen headers. Note that CMake does this + by default. No adjustments are needed for Windows. + `#3198 `_ + +* Format pybind11 with isort consistent ordering of imports + `#3195 `_ + +* The warnings-suppression "pragma clamp" at the top/bottom of pybind11 was + removed, clearing the path to refactoring and IWYU cleanup. + `#3186 `_ + +* Enable most bugprone checks in clang-tidy and fix the found potential bugs + and poor coding styles. + `#3166 `_ + +* Add ``clang-tidy-readability`` rules to make boolean casts explicit improving + code readability. Also enabled other misc and readability clang-tidy checks. + `#3148 `_ + +* Move object in ``.pop()`` for list. + `#3116 `_ + + + + +v2.7.1 (Aug 3, 2021) +--------------------- + +Minor missing functionality added: + +* Allow Python builtins to be used as callbacks in CPython. + `#1413 `_ + +Bug fixes: + +* Fix regression in CMake Python package config: improper use of absolute path. + `#3144 `_ + +* Fix Mingw64 and add to the CI testing matrix. + `#3132 `_ + +* Specified UTF8-encoding in setup.py calls of open(). + `#3137 `_ + +* Add clang-tidy-readability rules to make boolean casts explicit improving + code readability. Also enabled other misc and readability clang-tidy checks. + `#3148 `_ + +* Move object in ``.pop()`` for list. + `#3116 `_ + +Backend and tidying up: + +* Removed and fixed warning suppressions. + `#3127 `_ + `#3129 `_ + `#3135 `_ + `#3141 `_ + `#3142 `_ + `#3150 `_ + `#3152 `_ + `#3160 `_ + `#3161 `_ + + +v2.7.0 (Jul 16, 2021) +--------------------- + +New features: + +* Enable ``py::implicitly_convertible`` for + ``py::class_``-wrapped types. + `#3059 `_ + +* Allow function pointer extraction from overloaded functions. + `#2944 `_ + +* NumPy: added ``.char_()`` to type which gives the NumPy public ``char`` + result, which also distinguishes types by bit length (unlike ``.kind()``). + `#2864 `_ + +* Add ``pybind11::bytearray`` to manipulate ``bytearray`` similar to ``bytes``. + `#2799 `_ + +* ``pybind11/stl/filesystem.h`` registers a type caster that, on C++17/Python + 3.6+, converts ``std::filesystem::path`` to ``pathlib.Path`` and any + ``os.PathLike`` to ``std::filesystem::path``. + `#2730 `_ + +* A ``PYBIND11_VERSION_HEX`` define was added, similar to ``PY_VERSION_HEX``. + `#3120 `_ + + + +Changes: + +* ``py::str`` changed to exclusively hold ``PyUnicodeObject``. Previously + ``py::str`` could also hold ``bytes``, which is probably surprising, was + never documented, and can mask bugs (e.g. accidental use of ``py::str`` + instead of ``py::bytes``). + `#2409 `_ + +* Add a safety guard to ensure that the Python GIL is held when C++ calls back + into Python via ``object_api<>::operator()`` (e.g. ``py::function`` + ``__call__``). (This feature is available for Python 3.6+ only.) + `#2919 `_ + +* Catch a missing ``self`` argument in calls to ``__init__()``. + `#2914 `_ + +* Use ``std::string_view`` if available to avoid a copy when passing an object + to a ``std::ostream``. + `#3042 `_ + +* An important warning about thread safety was added to the ``iostream.h`` + documentation; attempts to make ``py::scoped_ostream_redirect`` thread safe + have been removed, as it was only partially effective. + `#2995 `_ + + +Fixes: + +* Performance: avoid unnecessary strlen calls. + `#3058 `_ + +* Fix auto-generated documentation string when using ``const T`` in + ``pyarray_t``. + `#3020 `_ + +* Unify error messages thrown by ``simple_collector``/``unpacking_collector``. + `#3013 `_ + +* ``pybind11::builtin_exception`` is now explicitly exported, which means the + types included/defined in different modules are identical, and exceptions + raised in different modules can be caught correctly. The documentation was + updated to explain that custom exceptions that are used across module + boundaries need to be explicitly exported as well. + `#2999 `_ + +* Fixed exception when printing UTF-8 to a ``scoped_ostream_redirect``. + `#2982 `_ + +* Pickle support enhancement: ``setstate`` implementation will attempt to + ``setattr`` ``__dict__`` only if the unpickled ``dict`` object is not empty, + to not force use of ``py::dynamic_attr()`` unnecessarily. + `#2972 `_ + +* Allow negative timedelta values to roundtrip. + `#2870 `_ + +* Fix unchecked errors could potentially swallow signals/other exceptions. + `#2863 `_ + +* Add null pointer check with ``std::localtime``. + `#2846 `_ + +* Fix the ``weakref`` constructor from ``py::object`` to create a new + ``weakref`` on conversion. + `#2832 `_ + +* Avoid relying on exceptions in C++17 when getting a ``shared_ptr`` holder + from a ``shared_from_this`` class. + `#2819 `_ + +* Allow the codec's exception to be raised instead of :code:`RuntimeError` when + casting from :code:`py::str` to :code:`std::string`. + `#2903 `_ + + +Build system improvements: + +* In ``setup_helpers.py``, test for platforms that have some multiprocessing + features but lack semaphores, which ``ParallelCompile`` requires. + `#3043 `_ + +* Fix ``pybind11_INCLUDE_DIR`` in case ``CMAKE_INSTALL_INCLUDEDIR`` is + absolute. + `#3005 `_ + +* Fix bug not respecting ``WITH_SOABI`` or ``WITHOUT_SOABI`` to CMake. + `#2938 `_ + +* Fix the default ``Pybind11Extension`` compilation flags with a Mingw64 python. + `#2921 `_ + +* Clang on Windows: do not pass ``/MP`` (ignored flag). + `#2824 `_ + +* ``pybind11.setup_helpers.intree_extensions`` can be used to generate + ``Pybind11Extension`` instances from cpp files placed in the Python package + source tree. + `#2831 `_ + +Backend and tidying up: + +* Enable clang-tidy performance, readability, and modernization checks + throughout the codebase to enforce best coding practices. + `#3046 `_, + `#3049 `_, + `#3051 `_, + `#3052 `_, + `#3080 `_, and + `#3094 `_ + + +* Checks for common misspellings were added to the pre-commit hooks. + `#3076 `_ + +* Changed ``Werror`` to stricter ``Werror-all`` for Intel compiler and fixed + minor issues. + `#2948 `_ + +* Fixed compilation with GCC < 5 when the user defines ``_GLIBCXX_USE_CXX11_ABI``. + `#2956 `_ + +* Added nox support for easier local testing and linting of contributions. + `#3101 `_ and + `#3121 `_ + +* Avoid RTD style issue with docutils 0.17+. + `#3119 `_ + +* Support pipx run, such as ``pipx run pybind11 --include`` for a quick compile. + `#3117 `_ + + + +v2.6.2 (Jan 26, 2021) +--------------------- + +Minor missing functionality added: + +* enum: add missing Enum.value property. + `#2739 `_ + +* Allow thread termination to be avoided during shutdown for CPython 3.7+ via + ``.disarm`` for ``gil_scoped_acquire``/``gil_scoped_release``. + `#2657 `_ + +Fixed or improved behavior in a few special cases: + +* Fix bug where the constructor of ``object`` subclasses would not throw on + being passed a Python object of the wrong type. + `#2701 `_ + +* The ``type_caster`` for integers does not convert Python objects with + ``__int__`` anymore with ``noconvert`` or during the first round of trying + overloads. + `#2698 `_ + +* When casting to a C++ integer, ``__index__`` is always called and not + considered as conversion, consistent with Python 3.8+. + `#2801 `_ + +Build improvements: + +* Setup helpers: ``extra_compile_args`` and ``extra_link_args`` automatically set by + Pybind11Extension are now prepended, which allows them to be overridden + by user-set ``extra_compile_args`` and ``extra_link_args``. + `#2808 `_ + +* Setup helpers: Don't trigger unused parameter warning. + `#2735 `_ + +* CMake: Support running with ``--warn-uninitialized`` active. + `#2806 `_ + +* CMake: Avoid error if included from two submodule directories. + `#2804 `_ + +* CMake: Fix ``STATIC`` / ``SHARED`` being ignored in FindPython mode. + `#2796 `_ + +* CMake: Respect the setting for ``CMAKE_CXX_VISIBILITY_PRESET`` if defined. + `#2793 `_ + +* CMake: Fix issue with FindPython2/FindPython3 not working with ``pybind11::embed``. + `#2662 `_ + +* CMake: mixing local and installed pybind11's would prioritize the installed + one over the local one (regression in 2.6.0). + `#2716 `_ + + +Bug fixes: + +* Fixed segfault in multithreaded environments when using + ``scoped_ostream_redirect``. + `#2675 `_ + +* Leave docstring unset when all docstring-related options are disabled, rather + than set an empty string. + `#2745 `_ + +* The module key in builtins that pybind11 uses to store its internals changed + from std::string to a python str type (more natural on Python 2, no change on + Python 3). + `#2814 `_ + +* Fixed assertion error related to unhandled (later overwritten) exception in + CPython 3.8 and 3.9 debug builds. + `#2685 `_ + +* Fix ``py::gil_scoped_acquire`` assert with CPython 3.9 debug build. + `#2683 `_ + +* Fix issue with a test failing on pytest 6.2. + `#2741 `_ + +Warning fixes: + +* Fix warning modifying constructor parameter 'flag' that shadows a field of + 'set_flag' ``[-Wshadow-field-in-constructor-modified]``. + `#2780 `_ + +* Suppressed some deprecation warnings about old-style + ``__init__``/``__setstate__`` in the tests. + `#2759 `_ + +Valgrind work: + +* Fix invalid access when calling a pybind11 ``__init__`` on a non-pybind11 + class instance. + `#2755 `_ + +* Fixed various minor memory leaks in pybind11's test suite. + `#2758 `_ + +* Resolved memory leak in cpp_function initialization when exceptions occurred. + `#2756 `_ + +* Added a Valgrind build, checking for leaks and memory-related UB, to CI. + `#2746 `_ + +Compiler support: + +* Intel compiler was not activating C++14 support due to a broken define. + `#2679 `_ + +* Support ICC and NVIDIA HPC SDK in C++17 mode. + `#2729 `_ + +* Support Intel OneAPI compiler (ICC 20.2) and add to CI. + `#2573 `_ + + + +v2.6.1 (Nov 11, 2020) +--------------------- + +* ``py::exec``, ``py::eval``, and ``py::eval_file`` now add the builtins module + as ``"__builtins__"`` to their ``globals`` argument, better matching ``exec`` + and ``eval`` in pure Python. + `#2616 `_ + +* ``setup_helpers`` will no longer set a minimum macOS version higher than the + current version. + `#2622 `_ + +* Allow deleting static properties. + `#2629 `_ + +* Seal a leak in ``def_buffer``, cleaning up the ``capture`` object after the + ``class_`` object goes out of scope. + `#2634 `_ + +* ``pybind11_INCLUDE_DIRS`` was incorrect, potentially causing a regression if + it was expected to include ``PYTHON_INCLUDE_DIRS`` (please use targets + instead). + `#2636 `_ + +* Added parameter names to the ``py::enum_`` constructor and methods, avoiding + ``arg0`` in the generated docstrings. + `#2637 `_ + +* Added ``needs_recompile`` optional function to the ``ParallelCompiler`` + helper, to allow a recompile to be skipped based on a user-defined function. + `#2643 `_ + + +v2.6.0 (Oct 21, 2020) +--------------------- + See :ref:`upgrade-guide-2.6` for help upgrading to the new version. -* Provide an additional spelling of ``py::module`` - ``py::module_`` (with a - trailing underscore), for C++20 compatibility. Only relevant when used - unqualified. - `#2489 `_ - -* ``pybind11_add_module()`` now accepts an optional ``OPT_SIZE`` flag that - switches the binding target to size-based optimization regardless global - CMake build type (except in debug mode, where optimizations remain disabled). - This reduces binary size quite substantially (~25%). - `#2463 `_ +New features: * Keyword-only arguments supported in Python 2 or 3 with ``py::kw_only()``. `#2100 `_ @@ -28,11 +704,17 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version. * Positional-only arguments supported in Python 2 or 3 with ``py::pos_only()``. `#2459 `_ +* ``py::is_final()`` class modifier to block subclassing (CPython only). + `#2151 `_ + +* Added ``py::prepend()``, allowing a function to be placed at the beginning of + the overload chain. + `#1131 `_ + * Access to the type object now provided with ``py::type::of()`` and ``py::type::of(h)``. `#2364 `_ - * Perfect forwarding support for methods. `#2048 `_ @@ -42,11 +724,48 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version. * ``py::hash`` is now public. `#2217 `_ -* ``py::is_final()`` class modifier to block subclassing (CPython only). - `#2151 `_ +* ``py::class_`` is now supported. Note that writing to one data + member of the union and reading another (type punning) is UB in C++. Thus + pybind11-bound enums should never be used for such conversions. + `#2320 `_. -* ``py::memoryview`` update and documentation. - `#2223 `_ +* Classes now check local scope when registering members, allowing a subclass + to have a member with the same name as a parent (such as an enum). + `#2335 `_ + +Code correctness features: + +* Error now thrown when ``__init__`` is forgotten on subclasses. + `#2152 `_ + +* Throw error if conversion to a pybind11 type if the Python object isn't a + valid instance of that type, such as ``py::bytes(o)`` when ``py::object o`` + isn't a bytes instance. + `#2349 `_ + +* Throw if conversion to ``str`` fails. + `#2477 `_ + + +API changes: + +* ``py::module`` was renamed ``py::module_`` to avoid issues with C++20 when + used unqualified, but an alias ``py::module`` is provided for backward + compatibility. + `#2489 `_ + +* Public constructors for ``py::module_`` have been deprecated; please use + ``pybind11::module_::create_extension_module`` if you were using the public + constructor (fairly rare after ``PYBIND11_MODULE`` was introduced). + `#2552 `_ + +* ``PYBIND11_OVERLOAD*`` macros and ``get_overload`` function replaced by + correctly-named ``PYBIND11_OVERRIDE*`` and ``get_override``, fixing + inconsistencies in the presence of a closing ``;`` in these macros. + ``get_type_overload`` is deprecated. + `#2325 `_ + +Packaging / building improvements: * The Python package was reworked to be more powerful and useful. `#2433 `_ @@ -54,7 +773,7 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version. * :ref:`build-setuptools` is easier thanks to a new ``pybind11.setup_helpers`` module, which provides utilities to use setuptools with pybind11. It can be used via PEP 518, ``setup_requires``, - or by directly copying ``setup_helpers.py`` into your project. + or by directly importing or copying ``setup_helpers.py`` into your project. * CMake configuration files are now included in the Python package. Use ``pybind11.get_cmake_dir()`` or ``python -m pybind11 --cmakedir`` to get @@ -62,17 +781,21 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version. site-packages location in your ``CMAKE_MODULE_PATH``. Or you can use the new ``pybind11[global]`` extra when you install ``pybind11``, which installs the CMake files and headers into your base environment in the - standard location + standard location. * ``pybind11-config`` is another way to write ``python -m pybind11`` if you have your PATH set up. + * Added external typing support to the helper module, code from + ``import pybind11`` can now be type checked. + `#2588 `_ + * Minimum CMake required increased to 3.4. `#2338 `_ and `#2370 `_ - * Full integration with CMake’s C++ standard system replaces - ``PYBIND11_CPP_STANDARD``. + * Full integration with CMake’s C++ standard system and compile features + replaces ``PYBIND11_CPP_STANDARD``. * Generated config file is now portable to different Python/compiler/CMake versions. @@ -85,27 +808,36 @@ See :ref:`upgrade-guide-2.6` for help upgrading to the new version. ``CMAKE_INTERPROCEDURAL_OPTIMIZATION``, ``set(CMAKE_CXX_VISIBILITY_PRESET hidden)``. -* Optional :ref:`find-python-mode` and :ref:`nopython-mode` with CMake. - `#2370 `_ + * ``CUDA`` as a language is now supported. + + * Helper functions ``pybind11_strip``, ``pybind11_extension``, + ``pybind11_find_import`` added, see :doc:`cmake/index`. + + * Optional :ref:`find-python-mode` and :ref:`nopython-mode` with CMake. + `#2370 `_ * Uninstall target added. `#2265 `_ and `#2346 `_ -* ``PYBIND11_OVERLOAD*`` macros and ``get_overload`` function replaced by - correctly-named ``PYBIND11_OVERRIDE*`` and ``get_override``, fixing - inconsistencies in the presene of a closing ``;`` in these macros. - ``get_type_overload`` is deprecated. - `#2325 `_ +* ``pybind11_add_module()`` now accepts an optional ``OPT_SIZE`` flag that + switches the binding target to size-based optimization if the global build + type can not always be fixed to ``MinSizeRel`` (except in debug mode, where + optimizations remain disabled). ``MinSizeRel`` or this flag reduces binary + size quite substantially (~25% on some platforms). + `#2463 `_ -Smaller or developer focused features: +Smaller or developer focused features and fixes: -* Moved ``mkdoc.py`` to a new repo, `pybind11-mkdoc`_. +* Moved ``mkdoc.py`` to a new repo, `pybind11-mkdoc`_. There are no longer + submodules in the main repo. -.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc +* ``py::memoryview`` segfault fix and update, with new + ``py::memoryview::from_memory`` in Python 3, and documentation. + `#2223 `_ -* Error now thrown when ``__init__`` is forgotten on subclasses. - `#2152 `_ +* Fix for ``buffer_info`` on Python 2. + `#2503 `_ * If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to ``None``. @@ -114,12 +846,6 @@ Smaller or developer focused features: * ``py::ellipsis`` now also works on Python 2. `#2360 `_ -* Throw if conversion to ``str`` fails. - `#2477 `_ - -* Added missing signature for ``py::array``. - `#2363 `_ - * Pointer to ``std::tuple`` & ``std::pair`` supported in cast. `#2334 `_ @@ -127,7 +853,26 @@ Smaller or developer focused features: argument type. `#2293 `_ -* Bugfixes related to more extensive testing +* Added missing signature for ``py::array``. + `#2363 `_ + +* ``unchecked_mutable_reference`` has access to operator ``()`` and ``[]`` when + const. + `#2514 `_ + +* ``py::vectorize`` is now supported on functions that return void. + `#1969 `_ + +* ``py::capsule`` supports ``get_pointer`` and ``set_pointer``. + `#1131 `_ + +* Fix crash when different instances share the same pointer of the same type. + `#2252 `_ + +* Fix for ``py::len`` not clearing Python's error state when it fails and throws. + `#2575 `_ + +* Bugfixes related to more extensive testing, new GitHub Actions CI. `#2321 `_ * Bug in timezone issue in Eastern hemisphere midnight fixed. @@ -141,16 +886,27 @@ Smaller or developer focused features: requested ordering. `#2484 `_ -* PyPy fixes, including support for PyPy3 and PyPy 7. +* Avoid a segfault on some compilers when types are removed in Python. + `#2564 `_ + +* ``py::arg::none()`` is now also respected when passing keyword arguments. + `#2611 `_ + +* PyPy fixes, PyPy 7.3.x now supported, including PyPy3. (Known issue with + PyPy2 and Windows `#2596 `_). `#2146 `_ -* CPython 3.9 fixes. +* CPython 3.9.0 workaround for undefined behavior (macOS segfault). + `#2576 `_ + +* CPython 3.9 warning fixes. `#2253 `_ -* More C++20 support. +* Improved C++20 support, now tested in CI. `#2489 `_ + `#2599 `_ -* Debug Python interpreter support. +* Improved but still incomplete debug Python interpreter support. `#2025 `_ * NVCC (CUDA 11) now supported and tested in CI. @@ -159,11 +915,20 @@ Smaller or developer focused features: * NVIDIA PGI compilers now supported and tested in CI. `#2475 `_ -* Extensive style checking in CI, with `pre-commit`_ support. +* At least Intel 18 now explicitly required when compiling with Intel. + `#2577 `_ + +* Extensive style checking in CI, with `pre-commit`_ support. Code + modernization, checked by clang-tidy. + +* Expanded docs, including new main page, new installing section, and CMake + helpers page, along with over a dozen new sections on existing pages. + +* In GitHub, new docs for contributing and new issue templates. .. _pre-commit: https://pre-commit.com - +.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc v2.5.0 (Mar 31, 2020) ----------------------------------------------------- @@ -261,7 +1026,7 @@ v2.4.0 (Sep 19, 2019) `#1888 `_. * ``py::details::overload_cast_impl`` is available in C++11 mode, can be used - like ``overload_cast`` with an additional set of parantheses. + like ``overload_cast`` with an additional set of parentheses. `#1581 `_. * Fixed ``get_include()`` on Conda. @@ -520,7 +1285,7 @@ v2.2.2 (February 7, 2018) v2.2.1 (September 14, 2017) ----------------------------------------------------- -* Added ``py::module::reload()`` member function for reloading a module. +* Added ``py::module_::reload()`` member function for reloading a module. `#1040 `_. * Fixed a reference leak in the number converter. @@ -583,6 +1348,7 @@ v2.2.0 (August 31, 2017) from cpp_module import CppBase1, CppBase2 + class PyDerived(CppBase1, CppBase2): def __init__(self): CppBase1.__init__(self) # C++ bases must be initialized explicitly @@ -795,7 +1561,7 @@ v2.2.0 (August 31, 2017) * Intel C++ compiler compatibility fixes. `#937 `_. -* Fixed implicit conversion of `py::enum_` to integer types on Python 2.7. +* Fixed implicit conversion of ``py::enum_`` to integer types on Python 2.7. `#821 `_. * Added ``py::hash`` to fetch the hash value of Python objects, and diff --git a/wrap/pybind11/docs/classes.rst b/wrap/pybind11/docs/classes.rst index f3610ef36..13fa8b538 100644 --- a/wrap/pybind11/docs/classes.rst +++ b/wrap/pybind11/docs/classes.rst @@ -44,12 +44,12 @@ interactive Python session demonstrating this example is shown below: % python >>> import example - >>> p = example.Pet('Molly') + >>> p = example.Pet("Molly") >>> print(p) >>> p.getName() u'Molly' - >>> p.setName('Charly') + >>> p.setName("Charly") >>> p.getName() u'Charly' @@ -122,10 +122,10 @@ This makes it possible to write .. code-block:: pycon - >>> p = example.Pet('Molly') + >>> p = example.Pet("Molly") >>> p.name u'Molly' - >>> p.name = 'Charly' + >>> p.name = "Charly" >>> p.name u'Charly' @@ -174,10 +174,10 @@ Native Python classes can pick up new attributes dynamically: .. code-block:: pycon >>> class Pet: - ... name = 'Molly' + ... name = "Molly" ... >>> p = Pet() - >>> p.name = 'Charly' # overwrite existing + >>> p.name = "Charly" # overwrite existing >>> p.age = 2 # dynamically add a new attribute By default, classes exported from C++ do not support this and the only writable @@ -195,7 +195,7 @@ Trying to set any other attribute results in an error: .. code-block:: pycon >>> p = example.Pet() - >>> p.name = 'Charly' # OK, attribute defined in C++ + >>> p.name = "Charly" # OK, attribute defined in C++ >>> p.age = 2 # fail AttributeError: 'Pet' object has no attribute 'age' @@ -213,7 +213,7 @@ Now everything works as expected: .. code-block:: pycon >>> p = example.Pet() - >>> p.name = 'Charly' # OK, overwrite value in C++ + >>> p.name = "Charly" # OK, overwrite value in C++ >>> p.age = 2 # OK, dynamically add a new attribute >>> p.__dict__ # just like a native Python class {'age': 2} @@ -280,7 +280,7 @@ expose fields and methods of both types: .. code-block:: pycon - >>> p = example.Dog('Molly') + >>> p = example.Dog("Molly") >>> p.name u'Molly' >>> p.bark() @@ -446,8 +446,7 @@ you can use ``py::detail::overload_cast_impl`` with an additional set of parenth Enumerations and internal types =============================== -Let's now suppose that the example class contains an internal enumeration type, -e.g.: +Let's now suppose that the example class contains internal types like enumerations, e.g.: .. code-block:: cpp @@ -457,10 +456,15 @@ e.g.: Cat }; + struct Attributes { + float age = 0; + }; + Pet(const std::string &name, Kind type) : name(name), type(type) { } std::string name; Kind type; + Attributes attr; }; The binding code for this example looks as follows: @@ -471,22 +475,28 @@ The binding code for this example looks as follows: pet.def(py::init()) .def_readwrite("name", &Pet::name) - .def_readwrite("type", &Pet::type); + .def_readwrite("type", &Pet::type) + .def_readwrite("attr", &Pet::attr); py::enum_(pet, "Kind") .value("Dog", Pet::Kind::Dog) .value("Cat", Pet::Kind::Cat) .export_values(); -To ensure that the ``Kind`` type is created within the scope of ``Pet``, the -``pet`` :class:`class_` instance must be supplied to the :class:`enum_`. + py::class_ attributes(pet, "Attributes") + .def(py::init<>()) + .def_readwrite("age", &Pet::Attributes::age); + + +To ensure that the nested types ``Kind`` and ``Attributes`` are created within the scope of ``Pet``, the +``pet`` :class:`class_` instance must be supplied to the :class:`enum_` and :class:`class_` constructor. The :func:`enum_::export_values` function exports the enum entries into the parent scope, which should be skipped for newer C++11-style strongly typed enums. .. code-block:: pycon - >>> p = Pet('Lucy', Pet.Cat) + >>> p = Pet("Lucy", Pet.Cat) >>> p.type Kind.Cat >>> int(p.type) @@ -508,7 +518,7 @@ The ``name`` property returns the name of the enum value as a unicode string. .. code-block:: pycon - >>> p = Pet( "Lucy", Pet.Cat ) + >>> p = Pet("Lucy", Pet.Cat) >>> pet_type = p.type >>> pet_type Pet.Cat diff --git a/wrap/pybind11/docs/cmake/index.rst b/wrap/pybind11/docs/cmake/index.rst new file mode 100644 index 000000000..eaf66d70f --- /dev/null +++ b/wrap/pybind11/docs/cmake/index.rst @@ -0,0 +1,8 @@ +CMake helpers +------------- + +Pybind11 can be used with ``add_subdirectory(extern/pybind11)``, or from an +install with ``find_package(pybind11 CONFIG)``. The interface provided in +either case is functionally identical. + +.. cmake-module:: ../../tools/pybind11Config.cmake.in diff --git a/wrap/pybind11/docs/compiling.rst b/wrap/pybind11/docs/compiling.rst index cbf14a466..75608bd57 100644 --- a/wrap/pybind11/docs/compiling.rst +++ b/wrap/pybind11/docs/compiling.rst @@ -31,20 +31,18 @@ An example of a ``setup.py`` using pybind11's helpers: .. code-block:: python + from glob import glob from setuptools import setup from pybind11.setup_helpers import Pybind11Extension ext_modules = [ Pybind11Extension( "python_example", - ["src/main.cpp"], + sorted(glob("src/*.cpp")), # Sort source files for reproducibility ), ] - setup( - ..., - ext_modules=ext_modules - ) + setup(..., ext_modules=ext_modules) If you want to do an automatic search for the highest supported C++ standard, that is supported via a ``build_ext`` command override; it will only affect @@ -52,21 +50,81 @@ that is supported via a ``build_ext`` command override; it will only affect .. code-block:: python + from glob import glob from setuptools import setup from pybind11.setup_helpers import Pybind11Extension, build_ext ext_modules = [ Pybind11Extension( "python_example", - ["src/main.cpp"], + sorted(glob("src/*.cpp")), ), ] - setup( - ..., - cmdclass={"build_ext": build_ext}, - ext_modules=ext_modules - ) + setup(..., cmdclass={"build_ext": build_ext}, ext_modules=ext_modules) + +If you have single-file extension modules that are directly stored in the +Python source tree (``foo.cpp`` in the same directory as where a ``foo.py`` +would be located), you can also generate ``Pybind11Extensions`` using +``setup_helpers.intree_extensions``: ``intree_extensions(["path/to/foo.cpp", +...])`` returns a list of ``Pybind11Extensions`` which can be passed to +``ext_modules``, possibly after further customizing their attributes +(``libraries``, ``include_dirs``, etc.). By doing so, a ``foo.*.so`` extension +module will be generated and made available upon installation. + +``intree_extension`` will automatically detect if you are using a ``src``-style +layout (as long as no namespace packages are involved), but you can also +explicitly pass ``package_dir`` to it (as in ``setuptools.setup``). + +Since pybind11 does not require NumPy when building, a light-weight replacement +for NumPy's parallel compilation distutils tool is included. Use it like this: + +.. code-block:: python + + from pybind11.setup_helpers import ParallelCompile + + # Optional multithreaded build + ParallelCompile("NPY_NUM_BUILD_JOBS").install() + + setup(...) + +The argument is the name of an environment variable to control the number of +threads, such as ``NPY_NUM_BUILD_JOBS`` (as used by NumPy), though you can set +something different if you want; ``CMAKE_BUILD_PARALLEL_LEVEL`` is another choice +a user might expect. You can also pass ``default=N`` to set the default number +of threads (0 will take the number of threads available) and ``max=N``, the +maximum number of threads; if you have a large extension you may want set this +to a memory dependent number. + +If you are developing rapidly and have a lot of C++ files, you may want to +avoid rebuilding files that have not changed. For simple cases were you are +using ``pip install -e .`` and do not have local headers, you can skip the +rebuild if an object file is newer than its source (headers are not checked!) +with the following: + +.. code-block:: python + + from pybind11.setup_helpers import ParallelCompile, naive_recompile + + ParallelCompile("NPY_NUM_BUILD_JOBS", needs_recompile=naive_recompile).install() + + +If you have a more complex build, you can implement a smarter function and pass +it to ``needs_recompile``, or you can use [Ccache]_ instead. ``CXX="cache g++" +pip install -e .`` would be the way to use it with GCC, for example. Unlike the +simple solution, this even works even when not compiling in editable mode, but +it does require Ccache to be installed. + +Keep in mind that Pip will not even attempt to rebuild if it thinks it has +already built a copy of your code, which it deduces from the version number. +One way to avoid this is to use [setuptools_scm]_, which will generate a +version number that includes the number of commits since your last tag and a +hash for a dirty directory. Another way to force a rebuild is purge your cache +or use Pip's ``--no-cache-dir`` option. + +.. [Ccache] https://ccache.dev + +.. [setuptools_scm] https://github.com/pypa/setuptools_scm .. _setup_helpers-pep518: @@ -85,7 +143,7 @@ Your ``pyproject.toml`` file will likely look something like this: .. code-block:: toml [build-system] - requires = ["setuptools", "wheel", "pybind11==2.6.0"] + requires = ["setuptools>=42", "wheel", "pybind11~=2.6.1"] build-backend = "setuptools.build_meta" .. note:: @@ -96,10 +154,12 @@ Your ``pyproject.toml`` file will likely look something like this: in Python) using something like `cibuildwheel`_, remember that ``setup.py`` and ``pyproject.toml`` are not even contained in the wheel, so this high Pip requirement is only for source builds, and will not affect users of - your binary wheels. + your binary wheels. If you are building SDists and wheels, then + `pypa-build`_ is the recommended official tool. .. _PEP 517: https://www.python.org/dev/peps/pep-0517/ .. _cibuildwheel: https://cibuildwheel.readthedocs.io +.. _pypa-build: https://pypa-build.readthedocs.io/en/latest/ .. _setup_helpers-setup_requires: @@ -140,6 +200,23 @@ this, you will need to import from a local file in ``setup.py`` and ensure the helper file is part of your MANIFEST. +Closely related, if you include pybind11 as a subproject, you can run the +``setup_helpers.py`` inplace. If loaded correctly, this should even pick up +the correct include for pybind11, though you can turn it off as shown above if +you want to input it manually. + +Suggested usage if you have pybind11 as a submodule in ``extern/pybind11``: + +.. code-block:: python + + DIR = os.path.abspath(os.path.dirname(__file__)) + + sys.path.append(os.path.join(DIR, "extern", "pybind11")) + from pybind11.setup_helpers import Pybind11Extension # noqa: E402 + + del sys.path[-1] + + .. versionchanged:: 2.6 Added ``setup_helpers`` file. @@ -184,6 +261,8 @@ PyPI integration, can be found in the [cmake_example]_ repository. .. versionchanged:: 2.6 CMake 3.4+ is required. +Further information can be found at :doc:`cmake/index`. + pybind11_add_module ------------------- @@ -224,8 +303,15 @@ As stated above, LTO is enabled by default. Some newer compilers also support different flavors of LTO such as `ThinLTO`_. Setting ``THIN_LTO`` will cause the function to prefer this flavor if available. The function falls back to regular LTO if ``-flto=thin`` is not available. If -``CMAKE_INTERPROCEDURAL_OPTIMIZATION`` is set (either ON or OFF), then that -will be respected instead of the built-in flag search. +``CMAKE_INTERPROCEDURAL_OPTIMIZATION`` is set (either ``ON`` or ``OFF``), then +that will be respected instead of the built-in flag search. + +.. note:: + + If you want to set the property form on targets or the + ``CMAKE_INTERPROCEDURAL_OPTIMIZATION_`` versions of this, you should + still use ``set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)`` (otherwise a + no-op) to disable pybind11's ipo flags. The ``OPT_SIZE`` flag enables size-based optimization equivalent to the standard ``/Os`` or ``-Os`` compiler flags and the ``MinSizeRel`` build type, @@ -252,10 +338,9 @@ standard explicitly with .. code-block:: cmake - set(CMAKE_CXX_STANDARD 14) # or 11, 14, 17, 20 + set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ version selection") # or 11, 14, 17, 20 set(CMAKE_CXX_STANDARD_REQUIRED ON) # optional, ensure standard is supported - set(CMAKE_CXX_EXTENSIONS OFF) # optional, keep compiler extensionsn off - + set(CMAKE_CXX_EXTENSIONS OFF) # optional, keep compiler extensions off The variables can also be set when calling CMake from the command line using the ``-D=`` flag. You can also manually set ``CXX_STANDARD`` @@ -325,13 +410,14 @@ can refer to the same [cmake_example]_ repository for a full sample project FindPython mode --------------- -CMake 3.12+ (3.15+ recommended) added a new module called FindPython that had a -highly improved search algorithm and modern targets and tools. If you use -FindPython, pybind11 will detect this and use the existing targets instead: +CMake 3.12+ (3.15+ recommended, 3.18.2+ ideal) added a new module called +FindPython that had a highly improved search algorithm and modern targets +and tools. If you use FindPython, pybind11 will detect this and use the +existing targets instead: .. code-block:: cmake - cmake_minumum_required(VERSION 3.15...3.18) + cmake_minimum_required(VERSION 3.15...3.19) project(example LANGUAGES CXX) find_package(Python COMPONENTS Interpreter Development REQUIRED) @@ -357,6 +443,14 @@ setting ``Python_ROOT_DIR`` may be the most common one (though with virtualenv/venv support, and Conda support, this tends to find the correct Python version more often than the old system did). +.. warning:: + + When the Python libraries (i.e. ``libpythonXX.a`` and ``libpythonXX.so`` + on Unix) are not available, as is the case on a manylinux image, the + ``Development`` component will not be resolved by ``FindPython``. When not + using the embedding functionality, CMake 3.18+ allows you to specify + ``Development.Module`` instead of ``Development`` to resolve this issue. + .. versionadded:: 2.6 Advanced: interface library targets @@ -428,7 +522,7 @@ Instead of setting properties, you can set ``CMAKE_*`` variables to initialize t compiler flags are provided to ensure high quality code generation. In contrast to the ``pybind11_add_module()`` command, the CMake interface provides a *composable* set of targets to ensure that you retain flexibility. - It can be expecially important to provide or set these properties; the + It can be especially important to provide or set these properties; the :ref:`FAQ ` contains an explanation on why these are needed. .. versionadded:: 2.6 @@ -481,7 +575,7 @@ On Linux, you can compile an example such as the one given in .. code-block:: bash - $ c++ -O3 -Wall -shared -std=c++11 -fPIC `python3 -m pybind11 --includes` example.cpp -o example`python3-config --extension-suffix` + $ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) example.cpp -o example$(python3-config --extension-suffix) The flags given here assume that you're using Python 3. For Python 2, just change the executable appropriately (to ``python`` or ``python2``). @@ -493,7 +587,7 @@ using ``pip`` or ``conda``. If it hasn't, you can also manually specify ``python3-config --includes``. Note that Python 2.7 modules don't use a special suffix, so you should simply -use ``example.so`` instead of ``example`python3-config --extension-suffix```. +use ``example.so`` instead of ``example$(python3-config --extension-suffix)``. Besides, the ``--extension-suffix`` option may or may not be available, depending on the distribution; in the latter case, the module extension can be manually set to ``.so``. @@ -504,7 +598,7 @@ building the module: .. code-block:: bash - $ c++ -O3 -Wall -shared -std=c++11 -undefined dynamic_lookup `python3 -m pybind11 --includes` example.cpp -o example`python3-config --extension-suffix` + $ c++ -O3 -Wall -shared -std=c++11 -undefined dynamic_lookup $(python3 -m pybind11 --includes) example.cpp -o example$(python3-config --extension-suffix) In general, it is advisable to include several additional build parameters that can considerably reduce the size of the created binary. Refer to section @@ -523,23 +617,11 @@ build system that works on all platforms including Windows. contains one (which will lead to a segfault). -Building with vcpkg +Building with Bazel =================== -You can download and install pybind11 using the Microsoft `vcpkg -`_ dependency manager: -.. code-block:: bash - - git clone https://github.com/Microsoft/vcpkg.git - cd vcpkg - ./bootstrap-vcpkg.sh - ./vcpkg integrate install - vcpkg install pybind11 - -The pybind11 port in vcpkg is kept up to date by Microsoft team members and -community contributors. If the version is out of date, please `create an issue -or pull request `_ on the vcpkg -repository. +You can build with the Bazel build system using the `pybind11_bazel +`_ repository. Generating binding code automatically ===================================== diff --git a/wrap/pybind11/docs/conf.py b/wrap/pybind11/docs/conf.py index 0946f30e2..092e274e0 100644 --- a/wrap/pybind11/docs/conf.py +++ b/wrap/pybind11/docs/conf.py @@ -13,57 +13,68 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os -import shlex +import re import subprocess +import sys +from pathlib import Path + +DIR = Path(__file__).parent.resolve() # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['breathe'] +extensions = [ + "breathe", + "sphinxcontrib.rsvgconverter", + "sphinxcontrib.moderncmakedomain", +] -breathe_projects = {'pybind11': '.build/doxygenxml/'} -breathe_default_project = 'pybind11' -breathe_domain_by_extension = {'h': 'cpp'} +breathe_projects = {"pybind11": ".build/doxygenxml/"} +breathe_default_project = "pybind11" +breathe_domain_by_extension = {"h": "cpp"} # Add any paths that contain templates here, relative to this directory. -templates_path = ['.templates'] +templates_path = [".templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'pybind11' -copyright = '2017, Wenzel Jakob' -author = 'Wenzel Jakob' +project = "pybind11" +copyright = "2017, Wenzel Jakob" +author = "Wenzel Jakob" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. -# -# The short X.Y version. -version = '2.5' + +# Read the listed version +with open("../pybind11/_version.py") as f: + code = compile(f.read(), "../pybind11/_version.py", "exec") +loc = {} +exec(code, loc) + # The full version, including alpha/beta/rc tags. -release = '2.5.dev1' +version = loc["__version__"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -74,37 +85,37 @@ language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['.build', 'release.rst'] +exclude_patterns = [".build", "release.rst"] # The reST default role (used for this markup: `text`) to use for all # documents. -default_role = 'any' +default_role = "any" # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -#pygments_style = 'monokai' +# pygments_style = 'monokai' # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -115,141 +126,150 @@ todo_include_todos = False # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - html_context = { - 'css_files': [ - '_static/theme_overrides.css' - ] - } + html_context = {"css_files": ["_static/theme_overrides.css"]} else: html_context = { - 'css_files': [ - '//media.readthedocs.org/css/sphinx_rtd_theme.css', - '//media.readthedocs.org/css/readthedocs-doc-embed.css', - '_static/theme_overrides.css' + "css_files": [ + "//media.readthedocs.org/css/sphinx_rtd_theme.css", + "//media.readthedocs.org/css/readthedocs-doc-embed.css", + "_static/theme_overrides.css", ] } # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None +# " v documentation". +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'pybind11doc' +htmlhelp_basename = "pybind11doc" # -- Options for LaTeX output --------------------------------------------- +latex_engine = "pdflatex" + latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # + # Additional stuff for the LaTeX preamble. + # remove blank pages (between the title page and the TOC, etc.) + "classoptions": ",openany,oneside", + "preamble": r""" +\usepackage{fontawesome} +\usepackage{textgreek} +\DeclareUnicodeCharacter{00A0}{} +\DeclareUnicodeCharacter{2194}{\faArrowsH} +\DeclareUnicodeCharacter{1F382}{\faBirthdayCake} +\DeclareUnicodeCharacter{1F355}{\faAdjust} +\DeclareUnicodeCharacter{0301}{'} +\DeclareUnicodeCharacter{03C0}{\textpi} -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -'preamble': r'\DeclareUnicodeCharacter{00A0}{}', - -# Latex figure (float) alignment -#'figure_align': 'htbp', +""", + # Latex figure (float) alignment + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'pybind11.tex', 'pybind11 Documentation', - 'Wenzel Jakob', 'manual'), + (master_doc, "pybind11.tex", "pybind11 Documentation", "Wenzel Jakob", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -258,32 +278,29 @@ latex_documents = [ # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pybind11', 'pybind11 Documentation', - [author], 1) -] +man_pages = [(master_doc, "pybind11", "pybind11 Documentation", [author], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -292,41 +309,73 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'pybind11', 'pybind11 Documentation', - author, 'pybind11', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "pybind11", + "pybind11 Documentation", + author, + "pybind11", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False -primary_domain = 'cpp' -highlight_language = 'cpp' +primary_domain = "cpp" +highlight_language = "cpp" def generate_doxygen_xml(app): - build_dir = os.path.join(app.confdir, '.build') + build_dir = os.path.join(app.confdir, ".build") if not os.path.exists(build_dir): os.mkdir(build_dir) try: - subprocess.call(['doxygen', '--version']) - retcode = subprocess.call(['doxygen'], cwd=app.confdir) + subprocess.call(["doxygen", "--version"]) + retcode = subprocess.call(["doxygen"], cwd=app.confdir) if retcode < 0: sys.stderr.write("doxygen error code: {}\n".format(-retcode)) except OSError as e: sys.stderr.write("doxygen execution failed: {}\n".format(e)) +def prepare(app): + with open(DIR.parent / "README.rst") as f: + contents = f.read() + + if app.builder.name == "latex": + # Remove badges and stuff from start + contents = contents[contents.find(r".. start") :] + + # Filter out section titles for index.rst for LaTeX + contents = re.sub(r"^(.*)\n[-~]{3,}$", r"**\1**", contents, flags=re.MULTILINE) + + with open(DIR / "readme.rst", "w") as f: + f.write(contents) + + +def clean_up(app, exception): + (DIR / "readme.rst").unlink() + + def setup(app): - """Add hook for building doxygen xml when needed""" + + # Add hook for building doxygen xml when needed app.connect("builder-inited", generate_doxygen_xml) + + # Copy the readme in + app.connect("builder-inited", prepare) + + # Clean up the generated readme + app.connect("build-finished", clean_up) diff --git a/wrap/pybind11/docs/faq.rst b/wrap/pybind11/docs/faq.rst index 5f7866fa7..e2f477b1f 100644 --- a/wrap/pybind11/docs/faq.rst +++ b/wrap/pybind11/docs/faq.rst @@ -5,7 +5,7 @@ Frequently asked questions =========================================================== 1. Make sure that the name specified in PYBIND11_MODULE is identical to the -filename of the extension library (without suffixes such as .so) +filename of the extension library (without suffixes such as ``.so``). 2. If the above did not fix the issue, you are likely using an incompatible version of Python (for instance, the extension library was compiled against @@ -27,18 +27,6 @@ The Python interpreter immediately crashes when importing my module See the first answer. -CMake doesn't detect the right Python version -============================================= - -The CMake-based build system will try to automatically detect the installed -version of Python and link against that. When this fails, or when there are -multiple versions of Python and it finds the wrong one, delete -``CMakeCache.txt`` and then invoke CMake as follows: - -.. code-block:: bash - - cmake -DPYTHON_EXECUTABLE:FILEPATH= . - .. _faq_reference_arguments: Limitations involving reference arguments @@ -66,7 +54,7 @@ provided by the caller -- in fact, it does nothing at all. .. code-block:: python def increment(i): - i += 1 # nope.. + i += 1 # nope.. pybind11 is also affected by such language-level conventions, which means that binding ``increment`` or ``increment_ptr`` will also create Python functions @@ -100,8 +88,8 @@ following example: .. code-block:: cpp - void init_ex1(py::module &); - void init_ex2(py::module &); + void init_ex1(py::module_ &); + void init_ex2(py::module_ &); /* ... */ PYBIND11_MODULE(example, m) { @@ -114,7 +102,7 @@ following example: .. code-block:: cpp - void init_ex1(py::module &m) { + void init_ex1(py::module_ &m) { m.def("add", [](int a, int b) { return a + b; }); } @@ -122,7 +110,7 @@ following example: .. code-block:: cpp - void init_ex2(py::module &m) { + void init_ex2(py::module_ &m) { m.def("sub", [](int a, int b) { return a - b; }); } @@ -181,8 +169,8 @@ can be changed, but even if it isn't it is not always enough to guarantee complete independence of the symbols involved when not using ``-fvisibility=hidden``. -Additionally, ``-fvisiblity=hidden`` can deliver considerably binary size -savings. (See the following section for more details). +Additionally, ``-fvisibility=hidden`` can deliver considerably binary size +savings. (See the following section for more details.) .. _`faq:symhidden`: @@ -192,7 +180,7 @@ How can I create smaller binaries? To do its job, pybind11 extensively relies on a programming technique known as *template metaprogramming*, which is a way of performing computation at compile -time using type information. Template metaprogamming usually instantiates code +time using type information. Template metaprogramming usually instantiates code involving significant numbers of deeply nested types that are either completely removed or reduced to just a few instructions during the compiler's optimization phase. However, due to the nested nature of these types, the resulting symbol @@ -275,17 +263,34 @@ been received, you must either explicitly interrupt execution by throwing }); } +CMake doesn't detect the right Python version +============================================= + +The CMake-based build system will try to automatically detect the installed +version of Python and link against that. When this fails, or when there are +multiple versions of Python and it finds the wrong one, delete +``CMakeCache.txt`` and then add ``-DPYTHON_EXECUTABLE=$(which python)`` to your +CMake configure line. (Replace ``$(which python)`` with a path to python if +your prefer.) + +You can alternatively try ``-DPYBIND11_FINDPYTHON=ON``, which will activate the +new CMake FindPython support instead of pybind11's custom search. Requires +CMake 3.12+, and 3.15+ or 3.18.2+ are even better. You can set this in your +``CMakeLists.txt`` before adding or finding pybind11, as well. + Inconsistent detection of Python version in CMake and pybind11 ============================================================== -The functions ``find_package(PythonInterp)`` and ``find_package(PythonLibs)`` provided by CMake -for Python version detection are not used by pybind11 due to unreliability and limitations that make -them unsuitable for pybind11's needs. Instead pybind provides its own, more reliable Python detection -CMake code. Conflicts can arise, however, when using pybind11 in a project that *also* uses the CMake -Python detection in a system with several Python versions installed. +The functions ``find_package(PythonInterp)`` and ``find_package(PythonLibs)`` +provided by CMake for Python version detection are modified by pybind11 due to +unreliability and limitations that make them unsuitable for pybind11's needs. +Instead pybind11 provides its own, more reliable Python detection CMake code. +Conflicts can arise, however, when using pybind11 in a project that *also* uses +the CMake Python detection in a system with several Python versions installed. -This difference may cause inconsistencies and errors if *both* mechanisms are used in the same project. Consider the following -CMake code executed in a system with Python 2.7 and 3.x installed: +This difference may cause inconsistencies and errors if *both* mechanisms are +used in the same project. Consider the following CMake code executed in a +system with Python 2.7 and 3.x installed: .. code-block:: cmake @@ -303,10 +308,24 @@ In contrast this code: find_package(PythonInterp) find_package(PythonLibs) -will detect Python 3.x for pybind11 and may crash on ``find_package(PythonLibs)`` afterwards. +will detect Python 3.x for pybind11 and may crash on +``find_package(PythonLibs)`` afterwards. -It is advised to avoid using ``find_package(PythonInterp)`` and ``find_package(PythonLibs)`` from CMake and rely -on pybind11 in detecting Python version. If this is not possible CMake machinery should be called *before* including pybind11. +There are three possible solutions: + +1. Avoid using ``find_package(PythonInterp)`` and ``find_package(PythonLibs)`` + from CMake and rely on pybind11 in detecting Python version. If this is not + possible, the CMake machinery should be called *before* including pybind11. +2. Set ``PYBIND11_FINDPYTHON`` to ``True`` or use ``find_package(Python + COMPONENTS Interpreter Development)`` on modern CMake (3.12+, 3.15+ better, + 3.18.2+ best). Pybind11 in these cases uses the new CMake FindPython instead + of the old, deprecated search tools, and these modules are much better at + finding the correct Python. +3. Set ``PYBIND11_NOPYTHON`` to ``TRUE``. Pybind11 will not search for Python. + However, you will have to use the target-based system, and do more setup + yourself, because it does not know about or include things that depend on + Python, like ``pybind11_add_module``. This might be ideal for integrating + into an existing system, like scikit-build's Python helpers. How to cite this project? ========================= diff --git a/wrap/pybind11/docs/index.rst b/wrap/pybind11/docs/index.rst index d236611b7..4e2e8ca3a 100644 --- a/wrap/pybind11/docs/index.rst +++ b/wrap/pybind11/docs/index.rst @@ -1,18 +1,17 @@ -.. only: not latex +.. only:: latex - .. image:: pybind11-logo.png + Intro + ===== -pybind11 --- Seamless operability between C++11 and Python -========================================================== +.. include:: readme.rst -.. only: not latex +.. only:: not latex Contents: .. toctree:: :maxdepth: 1 - intro changelog upgrade @@ -20,6 +19,7 @@ pybind11 --- Seamless operability between C++11 and Python :caption: The Basics :maxdepth: 2 + installing basics classes compiling @@ -45,3 +45,4 @@ pybind11 --- Seamless operability between C++11 and Python benchmark limitations reference + cmake/index diff --git a/wrap/pybind11/docs/intro.rst b/wrap/pybind11/docs/intro.rst deleted file mode 100644 index 10e1799a1..000000000 --- a/wrap/pybind11/docs/intro.rst +++ /dev/null @@ -1,93 +0,0 @@ -.. image:: pybind11-logo.png - -About this project -================== -**pybind11** is a lightweight header-only library that exposes C++ types in Python -and vice versa, mainly to create Python bindings of existing C++ code. Its -goals and syntax are similar to the excellent `Boost.Python`_ library by David -Abrahams: to minimize boilerplate code in traditional extension modules by -inferring type information using compile-time introspection. - -.. _Boost.Python: http://www.boost.org/doc/libs/release/libs/python/doc/index.html - -The main issue with Boost.Python—and the reason for creating such a similar -project—is Boost. Boost is an enormously large and complex suite of utility -libraries that works with almost every C++ compiler in existence. This -compatibility has its cost: arcane template tricks and workarounds are -necessary to support the oldest and buggiest of compiler specimens. Now that -C++11-compatible compilers are widely available, this heavy machinery has -become an excessively large and unnecessary dependency. -Think of this library as a tiny self-contained version of Boost.Python with -everything stripped away that isn't relevant for binding generation. Without -comments, the core header files only require ~4K lines of code and depend on -Python (2.7 or 3.x, or PyPy2.7 >= 5.7) and the C++ standard library. This -compact implementation was possible thanks to some of the new C++11 language -features (specifically: tuples, lambda functions and variadic templates). Since -its creation, this library has grown beyond Boost.Python in many ways, leading -to dramatically simpler binding code in many common situations. - -Core features -************* -The following core C++ features can be mapped to Python - -- Functions accepting and returning custom data structures per value, reference, or pointer -- Instance methods and static methods -- Overloaded functions -- Instance attributes and static attributes -- Arbitrary exception types -- Enumerations -- Callbacks -- Iterators and ranges -- Custom operators -- Single and multiple inheritance -- STL data structures -- Smart pointers with reference counting like ``std::shared_ptr`` -- Internal references with correct reference counting -- C++ classes with virtual (and pure virtual) methods can be extended in Python - -Goodies -******* -In addition to the core functionality, pybind11 provides some extra goodies: - -- Python 2.7, 3.x, and PyPy (PyPy2.7 >= 5.7) are supported with an - implementation-agnostic interface. - -- It is possible to bind C++11 lambda functions with captured variables. The - lambda capture data is stored inside the resulting Python function object. - -- pybind11 uses C++11 move constructors and move assignment operators whenever - possible to efficiently transfer custom data types. - -- It's easy to expose the internal storage of custom data types through - Pythons' buffer protocols. This is handy e.g. for fast conversion between - C++ matrix classes like Eigen and NumPy without expensive copy operations. - -- pybind11 can automatically vectorize functions so that they are transparently - applied to all entries of one or more NumPy array arguments. - -- Python's slice-based access and assignment operations can be supported with - just a few lines of code. - -- Everything is contained in just a few header files; there is no need to link - against any additional libraries. - -- Binaries are generally smaller by a factor of at least 2 compared to - equivalent bindings generated by Boost.Python. A recent pybind11 conversion - of `PyRosetta`_, an enormous Boost.Python binding project, reported a binary - size reduction of **5.4x** and compile time reduction by **5.8x**. - -- Function signatures are precomputed at compile time (using ``constexpr``), - leading to smaller binaries. - -- With little extra effort, C++ types can be pickled and unpickled similar to - regular Python objects. - -.. _PyRosetta: http://graylab.jhu.edu/RosettaCon2016/PyRosetta-4.pdf - -Supported compilers -******************* - -1. Clang/LLVM (any non-ancient version with C++11 support) -2. GCC 4.8 or newer -3. Microsoft Visual Studio 2015 or newer -4. Intel C++ compiler v17 or newer (v16 with pybind11 v2.0 and v15 with pybind11 v2.0 and a `workaround `_ ) diff --git a/wrap/pybind11/docs/limitations.rst b/wrap/pybind11/docs/limitations.rst index 59474f82f..def5ad659 100644 --- a/wrap/pybind11/docs/limitations.rst +++ b/wrap/pybind11/docs/limitations.rst @@ -1,6 +1,9 @@ Limitations ########### +Design choices +^^^^^^^^^^^^^^ + pybind11 strives to be a general solution to binding generation, but it also has certain limitations: @@ -11,9 +14,59 @@ certain limitations: - The NumPy interface ``pybind11::array`` greatly simplifies accessing numerical data from C++ (and vice versa), but it's not a full-blown array - class like ``Eigen::Array`` or ``boost.multi_array``. + class like ``Eigen::Array`` or ``boost.multi_array``. ``Eigen`` objects are + directly supported, however, with ``pybind11/eigen.h``. -These features could be implemented but would lead to a significant increase in -complexity. I've decided to draw the line here to keep this project simple and -compact. Users who absolutely require these features are encouraged to fork -pybind11. +Large but useful features could be implemented in pybind11 but would lead to a +significant increase in complexity. Pybind11 strives to be simple and compact. +Users who require large new features are encouraged to write an extension to +pybind11; see `pybind11_json `_ for an +example. + + +Known bugs +^^^^^^^^^^ + +These are issues that hopefully will one day be fixed, but currently are +unsolved. If you know how to help with one of these issues, contributions +are welcome! + +- Intel 20.2 is currently having an issue with the test suite. + `#2573 `_ + +- Debug mode Python does not support 1-5 tests in the test suite currently. + `#2422 `_ + +- PyPy3 7.3.1 and 7.3.2 have issues with several tests on 32-bit Windows. + +Known limitations +^^^^^^^^^^^^^^^^^ + +These are issues that are probably solvable, but have not been fixed yet. A +clean, well written patch would likely be accepted to solve them. + +- Type casters are not kept alive recursively. + `#2527 `_ + One consequence is that containers of ``char *`` are currently not supported. + `#2245 `_ + +- The ``cpptest`` does not run on Windows with Python 3.8 or newer, due to DLL + loader changes. User code that is correctly installed should not be affected. + `#2560 `_ + +Python 3.9.0 warning +^^^^^^^^^^^^^^^^^^^^ + +Combining older versions of pybind11 (< 2.6.0) with Python on exactly 3.9.0 +will trigger undefined behavior that typically manifests as crashes during +interpreter shutdown (but could also destroy your data. **You have been +warned**). + +This issue was `fixed in Python `_. +As a mitigation for this bug, pybind11 2.6.0 or newer includes a workaround +specifically when Python 3.9.0 is detected at runtime, leaking about 50 bytes +of memory when a callback function is garbage collected. For reference, the +pybind11 test suite has about 2,000 such callbacks, but only 49 are garbage +collected before the end-of-process. Wheels (even if built with Python 3.9.0) +will correctly avoid the leak when run in Python 3.9.1, and this does not +affect other 3.X versions. diff --git a/wrap/pybind11/docs/reference.rst b/wrap/pybind11/docs/reference.rst index e3a61afb6..e64a03519 100644 --- a/wrap/pybind11/docs/reference.rst +++ b/wrap/pybind11/docs/reference.rst @@ -52,6 +52,20 @@ Convenience classes for specific Python types .. doxygengroup:: pytypes :members: +Convenience functions converting to Python types +================================================ + +.. doxygenfunction:: make_tuple(Args&&...) + +.. doxygenfunction:: make_iterator(Iterator, Sentinel, Extra &&...) +.. doxygenfunction:: make_iterator(Type &, Extra&&...) + +.. doxygenfunction:: make_key_iterator(Iterator, Sentinel, Extra &&...) +.. doxygenfunction:: make_key_iterator(Type &, Extra&&...) + +.. doxygenfunction:: make_value_iterator(Iterator, Sentinel, Extra &&...) +.. doxygenfunction:: make_value_iterator(Type &, Extra&&...) + .. _extras: Passing extra arguments to ``def`` or ``class_`` @@ -110,7 +124,6 @@ Exceptions .. doxygenclass:: builtin_exception :members: - Literals ======== diff --git a/wrap/pybind11/docs/release.rst b/wrap/pybind11/docs/release.rst index 9846f971a..e761cdf7a 100644 --- a/wrap/pybind11/docs/release.rst +++ b/wrap/pybind11/docs/release.rst @@ -1,21 +1,97 @@ -To release a new version of pybind11: +On version numbers +^^^^^^^^^^^^^^^^^^ -- Update the version number and push to pypi - - Update ``pybind11/_version.py`` (set release version, remove 'dev'). - - Update ``PYBIND11_VERSION_MAJOR`` etc. in ``include/pybind11/detail/common.h``. - - Ensure that all the information in ``setup.py`` is up-to-date. - - Update version in ``docs/conf.py``. - - Tag release date in ``docs/changelog.rst``. - - ``git add`` and ``git commit``. - - if new minor version: ``git checkout -b vX.Y``, ``git push -u origin vX.Y`` +The two version numbers (C++ and Python) must match when combined (checked when +you build the PyPI package), and must be a valid `PEP 440 +`_ version when combined. + +For example: + +.. code-block:: C++ + + #define PYBIND11_VERSION_MAJOR X + #define PYBIND11_VERSION_MINOR Y + #define PYBIND11_VERSION_PATCH Z.dev1 + +For beta, ``PYBIND11_VERSION_PATCH`` should be ``Z.b1``. RC's can be ``Z.rc1``. +Always include the dot (even though PEP 440 allows it to be dropped). For a +final release, this must be a simple integer. There is also a HEX version of +the version just below. + + +To release a new version of pybind11: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you don't have nox, you should either use ``pipx run nox`` instead, or use +``pipx install nox`` or ``brew install nox`` (Unix). + +- Update the version number + - Update ``PYBIND11_VERSION_MAJOR`` etc. in + ``include/pybind11/detail/common.h``. PATCH should be a simple integer. + - Update the version HEX just below, as well. + - Update ``pybind11/_version.py`` (match above) + - Run ``nox -s tests_packaging`` to ensure this was done correctly. + - Ensure that all the information in ``setup.cfg`` is up-to-date, like + supported Python versions. + - Add release date in ``docs/changelog.rst``. + - Check to make sure + `needs-changelog `_ + issues are entered in the changelog (clear the label when done). + - ``git add`` and ``git commit``, ``git push``. **Ensure CI passes**. (If it + fails due to a known flake issue, either ignore or restart CI.) +- Add a release branch if this is a new minor version, or update the existing release branch if it is a patch version + - New branch: ``git checkout -b vX.Y``, ``git push -u origin vX.Y`` + - Update branch: ``git checkout vX.Y``, ``git merge ``, ``git push`` +- Update tags (optional; if you skip this, the GitHub release makes a + non-annotated tag for you) - ``git tag -a vX.Y.Z -m 'vX.Y.Z release'``. - - ``git push`` - ``git push --tags``. - - ``python setup.py sdist upload``. - - ``python setup.py bdist_wheel upload``. +- Update stable + - ``git checkout stable`` + - ``git merge master`` + - ``git push`` +- Make a GitHub release (this shows up in the UI, sends new release + notifications to users watching releases, and also uploads PyPI packages). + (Note: if you do not use an existing tag, this creates a new lightweight tag + for you, so you could skip the above step.) + - GUI method: Under `releases `_ + click "Draft a new release" on the far right, fill in the tag name + (if you didn't tag above, it will be made here), fill in a release name + like "Version X.Y.Z", and copy-and-paste the markdown-formatted (!) changelog + into the description (usually ``cat docs/changelog.rst | pandoc -f rst -t gfm``). + Check "pre-release" if this is a beta/RC. + - CLI method: with ``gh`` installed, run ``gh release create vX.Y.Z -t "Version X.Y.Z"`` + If this is a pre-release, add ``-p``. + - Get back to work - - Update ``_version.py`` (add 'dev' and increment minor). - - Update version in ``docs/conf.py`` - - Update version macros in ``include/pybind11/common.h`` - - ``git add`` and ``git commit``. - ``git push`` + - Make sure you are on master, not somewhere else: ``git checkout master`` + - Update version macros in ``include/pybind11/detail/common.h`` (set PATCH to + ``0.dev1`` and increment MINOR). + - Update ``_version.py`` to match + - Run ``nox -s tests_packaging`` to ensure this was done correctly. + - Add a spot for in-development updates in ``docs/changelog.rst``. + - ``git add``, ``git commit``, ``git push`` + +If a version branch is updated, remember to set PATCH to ``1.dev1``. + +If you'd like to bump homebrew, run: + +.. code-block:: console + + brew bump-formula-pr --url https://github.com/pybind/pybind11/archive/vX.Y.Z.tar.gz + +Conda-forge should automatically make a PR in a few hours, and automatically +merge it if there are no issues. + + +Manual packaging +^^^^^^^^^^^^^^^^ + +If you need to manually upload releases, you can download the releases from the job artifacts and upload them with twine. You can also make the files locally (not recommended in general, as your local directory is more likely to be "dirty" and SDists love picking up random unrelated/hidden files); this is the procedure: + +.. code-block:: bash + + nox -s build + twine upload dist/* + +This makes SDists and wheels, and the final line uploads them. diff --git a/wrap/pybind11/docs/requirements.txt b/wrap/pybind11/docs/requirements.txt index f4c3dc2e0..b2801b1f0 100644 --- a/wrap/pybind11/docs/requirements.txt +++ b/wrap/pybind11/docs/requirements.txt @@ -1,5 +1,5 @@ -breathe==4.20.0 -commonmark==0.9.1 -recommonmark==0.6.0 -sphinx==3.2.1 -sphinx_rtd_theme==0.5.0 +breathe==4.31.0 +sphinx==3.5.4 +sphinx_rtd_theme==1.0.0 +sphinxcontrib-moderncmakedomain==3.19 +sphinxcontrib-svg2pdfconverter==1.1.1 diff --git a/wrap/pybind11/docs/upgrade.rst b/wrap/pybind11/docs/upgrade.rst index 62e2312e9..d91d51e6f 100644 --- a/wrap/pybind11/docs/upgrade.rst +++ b/wrap/pybind11/docs/upgrade.rst @@ -8,31 +8,90 @@ to a new version. But it goes into more detail. This includes things like deprecated APIs and their replacements, build system changes, general code modernization and other useful information. +.. _upgrade-guide-2.9: + +v2.9 +==== + +* Any usage of the recently added ``py::make_simple_namespace`` should be + converted to using ``py::module_::import("types").attr("SimpleNamespace")`` + instead. + +* The use of ``_`` in custom type casters can now be replaced with the more + readable ``const_name`` instead. The old ``_`` shortcut has been retained + unless it is being used as a macro (like for gettext). + + +.. _upgrade-guide-2.7: + +v2.7 +==== + +*Before* v2.7, ``py::str`` can hold ``PyUnicodeObject`` or ``PyBytesObject``, +and ``py::isinstance()`` is ``true`` for both ``py::str`` and +``py::bytes``. Starting with v2.7, ``py::str`` exclusively holds +``PyUnicodeObject`` (`#2409 `_), +and ``py::isinstance()`` is ``true`` only for ``py::str``. To help in +the transition of user code, the ``PYBIND11_STR_LEGACY_PERMISSIVE`` macro +is provided as an escape hatch to go back to the legacy behavior. This macro +will be removed in future releases. Two types of required fixes are expected +to be common: + +* Accidental use of ``py::str`` instead of ``py::bytes``, masked by the legacy + behavior. These are probably very easy to fix, by changing from + ``py::str`` to ``py::bytes``. + +* Reliance on py::isinstance(obj) being ``true`` for + ``py::bytes``. This is likely to be easy to fix in most cases by adding + ``|| py::isinstance(obj)``, but a fix may be more involved, e.g. if + ``py::isinstance`` appears in a template. Such situations will require + careful review and custom fixes. + + .. _upgrade-guide-2.6: v2.6 ==== -The ``tools/clang`` submodule and ``tools/mkdoc.py`` have been moved to a -standalone package, `pybind11-mkdoc`_. If you were using those tools, please -use them via a pip install from the new location. +Usage of the ``PYBIND11_OVERLOAD*`` macros and ``get_overload`` function should +be replaced by ``PYBIND11_OVERRIDE*`` and ``get_override``. In the future, the +old macros may be deprecated and removed. -.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc +``py::module`` has been renamed ``py::module_``, but a backward compatible +typedef has been included. This change was to avoid a language change in C++20 +that requires unqualified ``module`` not be placed at the start of a logical +line. Qualified usage is unaffected and the typedef will remain unless the +C++ language rules change again. + +The public constructors of ``py::module_`` have been deprecated. Use +``PYBIND11_MODULE`` or ``module_::create_extension_module`` instead. An error is now thrown when ``__init__`` is forgotten on subclasses. This was incorrect before, but was not checked. Add a call to ``__init__`` if it is missing. +A ``py::type_error`` is now thrown when casting to a subclass (like +``py::bytes`` from ``py::object``) if the conversion is not valid. Make a valid +conversion instead. + The undocumented ``h.get_type()`` method has been deprecated and replaced by ``py::type::of(h)``. +Enums now have a ``__str__`` method pre-defined; if you want to override it, +the simplest fix is to add the new ``py::prepend()`` tag when defining +``"__str__"``. + If ``__eq__`` defined but not ``__hash__``, ``__hash__`` is now set to ``None``, as in normal CPython. You should add ``__hash__`` if you intended the class to be hashable, possibly using the new ``py::hash`` shortcut. -Usage of the ``PYBIND11_OVERLOAD*`` macros and ``get_overload`` function should -be replaced by ``PYBIND11_OVERRIDE*`` and ``get_override``. In the future, the -old macros may be deprecated and removed. +The constructors for ``py::array`` now always take signed integers for size, +for consistency. This may lead to compiler warnings on some systems. Cast to +``py::ssize_t`` instead of ``std::size_t``. + +The ``tools/clang`` submodule and ``tools/mkdoc.py`` have been moved to a +standalone package, `pybind11-mkdoc`_. If you were using those tools, please +use them via a pip install from the new location. The ``pybind11`` package on PyPI no longer fills the wheel "headers" slot - if you were using the headers from this slot, they are available by requesting the @@ -41,6 +100,8 @@ be unaffected, as the ``pybind11/include`` location is reported by ``python -m pybind11 --includes`` and ``pybind11.get_include()`` is still correct and has not changed since 2.5). +.. _pybind11-mkdoc: https://github.com/pybind/pybind11-mkdoc + CMake support: -------------- @@ -54,7 +115,7 @@ something. The changes are: * If you do not request a standard, pybind11 targets will compile with the compiler default, but not less than C++11, instead of forcing C++14 always. - If you depend on the old behavior, please use ``set(CMAKE_CXX_STANDARD 14)`` + If you depend on the old behavior, please use ``set(CMAKE_CXX_STANDARD 14 CACHE STRING "")`` instead. * Direct ``pybind11::module`` usage should always be accompanied by at least @@ -80,7 +141,8 @@ In addition, the following changes may be of interest: * Using ``find_package(Python COMPONENTS Interpreter Development)`` before pybind11 will cause pybind11 to use the new Python mechanisms instead of its own custom search, based on a patched version of classic ``FindPythonInterp`` - / ``FindPythonLibs``. In the future, this may become the default. + / ``FindPythonLibs``. In the future, this may become the default. A recent + (3.15+ or 3.18.2+) version of CMake is recommended. @@ -170,7 +232,7 @@ way to get and set object state. See :ref:`pickling` for details. ... .def(py::pickle( [](const Foo &self) { // __getstate__ - return py::make_tuple(f.value1(), f.value2(), ...); // unchanged + return py::make_tuple(self.value1(), self.value2(), ...); // unchanged }, [](py::tuple t) { // __setstate__, note: no `self` argument return new Foo(t[0].cast(), ...); @@ -234,7 +296,7 @@ Within pybind11's CMake build system, ``pybind11_add_module`` has always been setting the ``-fvisibility=hidden`` flag in release mode. From now on, it's being applied unconditionally, even in debug mode and it can no longer be opted out of with the ``NO_EXTRAS`` option. The ``pybind11::module`` target now also -adds this flag to it's interface. The ``pybind11::embed`` target is unchanged. +adds this flag to its interface. The ``pybind11::embed`` target is unchanged. The most significant change here is for the ``pybind11::module`` target. If you were previously relying on default visibility, i.e. if your Python module was diff --git a/wrap/pybind11/include/pybind11/attr.h b/wrap/pybind11/include/pybind11/attr.h index d0a8b34b8..f1b66fb80 100644 --- a/wrap/pybind11/include/pybind11/attr.h +++ b/wrap/pybind11/include/pybind11/attr.h @@ -12,13 +12,17 @@ #include "cast.h" +#include + PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) /// \addtogroup annotations /// @{ /// Annotation for methods -struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; +struct is_method { handle class_; + explicit is_method(const handle &c) : class_(c) {} +}; /// Annotation for operators struct is_operator { }; @@ -27,16 +31,24 @@ struct is_operator { }; struct is_final { }; /// Annotation for parent scope -struct scope { handle value; scope(const handle &s) : value(s) { } }; +struct scope { handle value; + explicit scope(const handle &s) : value(s) {} +}; /// Annotation for documentation -struct doc { const char *value; doc(const char *value) : value(value) { } }; +struct doc { const char *value; + explicit doc(const char *value) : value(value) {} +}; /// Annotation for function names -struct name { const char *value; name(const char *value) : value(value) { } }; +struct name { const char *value; + explicit name(const char *value) : value(value) {} +}; /// Annotation indicating that a function is an overload associated with a given "sibling" -struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; +struct sibling { handle value; + explicit sibling(const handle &value) : value(value.ptr()) {} +}; /// Annotation indicating that a class derives from another given type template struct base { @@ -62,18 +74,41 @@ struct metaclass { handle value; PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") - metaclass() { } // NOLINT(modernize-use-equals-default): breaks MSVC 2015 when adding an attribute + // NOLINTNEXTLINE(modernize-use-equals-default): breaks MSVC 2015 when adding an attribute + metaclass() {} /// Override pybind11's default metaclass explicit metaclass(handle value) : value(value) { } }; +/// Specifies a custom callback with signature `void (PyHeapTypeObject*)` that +/// may be used to customize the Python type. +/// +/// The callback is invoked immediately before `PyType_Ready`. +/// +/// Note: This is an advanced interface, and uses of it may require changes to +/// work with later versions of pybind11. You may wish to consult the +/// implementation of `make_new_python_type` in `detail/classes.h` to understand +/// the context in which the callback will be run. +struct custom_type_setup { + using callback = std::function; + + explicit custom_type_setup(callback value) : value(std::move(value)) {} + + callback value; +}; + /// Annotation that marks a class as local to the module: -struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; +struct module_local { const bool value; + constexpr explicit module_local(bool v = true) : value(v) {} +}; /// Annotation to mark enums as an arithmetic type struct arithmetic { }; +/// Mark a function for addition at the beginning of the existing overload chain instead of the end +struct prepend { }; + /** \rst A call policy which places one or more guard variables (``Ts...``) around the function call. @@ -120,7 +155,7 @@ enum op_id : int; enum op_type : int; struct undefined_t; template struct op_; -inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); +void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); /// Internal data structure which holds metadata about a keyword argument struct argument_record { @@ -138,8 +173,8 @@ struct argument_record { struct function_record { function_record() : is_constructor(false), is_new_style_constructor(false), is_stateless(false), - is_operator(false), is_method(false), - has_args(false), has_kwargs(false), has_kw_only_args(false) { } + is_operator(false), is_method(false), has_args(false), + has_kwargs(false), prepend(false) { } /// Function name char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ @@ -186,14 +221,15 @@ struct function_record { /// True if the function has a '**kwargs' argument bool has_kwargs : 1; - /// True once a 'py::kw_only' is encountered (any following args are keyword-only) - bool has_kw_only_args : 1; + /// True if this function is to be inserted at the beginning of the overload resolution chain + bool prepend : 1; /// Number of arguments (including py::args and/or py::kwargs, if present) std::uint16_t nargs; - /// Number of trailing arguments (counted in `nargs`) that are keyword-only - std::uint16_t nargs_kw_only = 0; + /// Number of leading positional arguments, which are terminated by a py::args or py::kwargs + /// argument or by a py::kw_only annotation. + std::uint16_t nargs_pos = 0; /// Number of leading arguments (counted in `nargs`) that are positional-only std::uint16_t nargs_pos_only = 0; @@ -253,6 +289,9 @@ struct type_record { /// Custom metaclass (optional) handle metaclass; + /// Custom type setup. + custom_type_setup::callback custom_type_setup_callback; + /// Multiple inheritance marker bool multiple_inheritance : 1; @@ -370,20 +409,23 @@ template <> struct process_attribute : process_attribu static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } }; -inline void process_kw_only_arg(const arg &a, function_record *r) { - if (!a.name || strlen(a.name) == 0) - pybind11_fail("arg(): cannot specify an unnamed argument after an kw_only() annotation"); - ++r->nargs_kw_only; +inline void check_kw_only_arg(const arg &a, function_record *r) { + if (r->args.size() > r->nargs_pos && (!a.name || a.name[0] == '\0')) + pybind11_fail("arg(): cannot specify an unnamed argument after a kw_only() annotation or args() argument"); +} + +inline void append_self_arg_if_needed(function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr, handle(), /*convert=*/ true, /*none=*/ false); } /// Process a keyword argument attribute (*without* a default value) template <> struct process_attribute : process_attribute_default { static void init(const arg &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); + append_self_arg_if_needed(r); r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); - if (r->has_kw_only_args) process_kw_only_arg(a, r); + check_kw_only_arg(a, r); } }; @@ -391,7 +433,7 @@ template <> struct process_attribute : process_attribute_default { template <> struct process_attribute : process_attribute_default { static void init(const arg_v &a, function_record *r) { if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); + r->args.emplace_back("self", /*descr=*/ nullptr, /*parent=*/ handle(), /*convert=*/ true, /*none=*/ false); if (!a.value) { #if !defined(NDEBUG) @@ -416,21 +458,28 @@ template <> struct process_attribute : process_attribute_default { } r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); - if (r->has_kw_only_args) process_kw_only_arg(a, r); + check_kw_only_arg(a, r); } }; /// Process a keyword-only-arguments-follow pseudo argument template <> struct process_attribute : process_attribute_default { static void init(const kw_only &, function_record *r) { - r->has_kw_only_args = true; + append_self_arg_if_needed(r); + if (r->has_args && r->nargs_pos != static_cast(r->args.size())) + pybind11_fail("Mismatched args() and kw_only(): they must occur at the same relative argument location (or omit kw_only() entirely)"); + r->nargs_pos = static_cast(r->args.size()); } }; /// Process a positional-only-argument maker template <> struct process_attribute : process_attribute_default { static void init(const pos_only &, function_record *r) { + append_self_arg_if_needed(r); r->nargs_pos_only = static_cast(r->args.size()); + if (r->nargs_pos_only > r->nargs_pos) + pybind11_fail("pos_only(): cannot follow a py::args() argument"); + // It also can't follow a kw_only, but a static_assert in pybind11.h checks that } }; @@ -457,6 +506,13 @@ struct process_attribute : process_attribute_default static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } }; +template <> +struct process_attribute { + static void init(const custom_type_setup &value, type_record *r) { + r->custom_type_setup_callback = value.value; + } +}; + template <> struct process_attribute : process_attribute_default { static void init(const is_final &, type_record *r) { r->is_final = true; } @@ -477,6 +533,12 @@ struct process_attribute : process_attribute_default static void init(const module_local &l, type_record *r) { r->module_local = l.value; } }; +/// Process a 'prepend' attribute, putting this at the beginning of the overload chain +template <> +struct process_attribute : process_attribute_default { + static void init(const prepend &, function_record *r) { r->prepend = true; } +}; + /// Process an 'arithmetic' attribute for enums (does nothing here) template <> struct process_attribute : process_attribute_default {}; @@ -503,20 +565,31 @@ template struct process_attribute struct process_attributes { static void init(const Args&... args, function_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{ + 0, ((void) process_attribute::type>::init(args, r), 0)...}; } static void init(const Args&... args, type_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::init(args, r), 0)...}; } static void precall(function_call &call) { - int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::precall(call), 0)...}; } static void postcall(function_call &call, handle fn_ret) { - int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; - ignore_unused(unused); + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call, fn_ret); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(fn_ret); + using expander = int[]; + (void) expander{ + 0, (process_attribute::type>::postcall(call, fn_ret), 0)...}; } }; @@ -532,7 +605,8 @@ template ::value...), size_t self = constexpr_sum(std::is_same::value...)> constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { - return named == 0 || (self + named + has_args + has_kwargs) == nargs; + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(nargs, has_args, has_kwargs); + return named == 0 || (self + named + size_t(has_args) + size_t(has_kwargs)) == nargs; } PYBIND11_NAMESPACE_END(detail) diff --git a/wrap/pybind11/include/pybind11/buffer_info.h b/wrap/pybind11/include/pybind11/buffer_info.h index 308be06a3..eba68d1aa 100644 --- a/wrap/pybind11/include/pybind11/buffer_info.h +++ b/wrap/pybind11/include/pybind11/buffer_info.h @@ -13,6 +13,29 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +// Default, C-style strides +inline std::vector c_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + if (ndim > 0) + for (size_t i = ndim - 1; i > 0; --i) + strides[i - 1] = strides[i] * shape[i]; + return strides; +} + +// F-style strides; default when constructing an array_t with `ExtraFlags & f_style` +inline std::vector f_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = 1; i < ndim; ++i) + strides[i] = strides[i - 1] * shape[i - 1]; + return strides; +} + +PYBIND11_NAMESPACE_END(detail) + /// Information record describing a Python buffer object struct buffer_info { void *ptr = nullptr; // Pointer to the underlying storage @@ -53,7 +76,14 @@ struct buffer_info { explicit buffer_info(Py_buffer *view, bool ownview = true) : buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}, view->readonly) { + {view->shape, view->shape + view->ndim}, + /* Though buffer::request() requests PyBUF_STRIDES, ctypes objects + * ignore this flag and return a view with NULL strides. + * When strides are NULL, build them manually. */ + view->strides + ? std::vector(view->strides, view->strides + view->ndim) + : detail::c_strides({view->shape, view->shape + view->ndim}, view->itemsize), + (view->readonly != 0)) { this->m_view = view; this->ownview = ownview; } @@ -61,11 +91,9 @@ struct buffer_info { buffer_info(const buffer_info &) = delete; buffer_info& operator=(const buffer_info &) = delete; - buffer_info(buffer_info &&other) { - (*this) = std::move(other); - } + buffer_info(buffer_info &&other) noexcept { (*this) = std::move(other); } - buffer_info& operator=(buffer_info &&rhs) { + buffer_info &operator=(buffer_info &&rhs) noexcept { ptr = rhs.ptr; itemsize = rhs.itemsize; size = rhs.size; diff --git a/wrap/pybind11/include/pybind11/cast.h b/wrap/pybind11/include/pybind11/cast.h index b071008e6..165102443 100644 --- a/wrap/pybind11/include/pybind11/cast.h +++ b/wrap/pybind11/include/pybind11/cast.h @@ -11,938 +11,25 @@ #pragma once #include "pytypes.h" -#include "detail/typeid.h" +#include "detail/common.h" #include "detail/descr.h" -#include "detail/internals.h" +#include "detail/type_caster_base.h" +#include "detail/typeid.h" #include -#include +#include +#include +#include +#include +#include +#include #include #include - -#if defined(PYBIND11_CPP17) -# if defined(__has_include) -# if __has_include() -# define PYBIND11_HAS_STRING_VIEW -# endif -# elif defined(_MSC_VER) -# define PYBIND11_HAS_STRING_VIEW -# endif -#endif -#ifdef PYBIND11_HAS_STRING_VIEW -#include -#endif - -#if defined(__cpp_lib_char8_t) && __cpp_lib_char8_t >= 201811L -# define PYBIND11_HAS_U8STRING -#endif +#include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) -/// A life support system for temporary objects created by `type_caster::load()`. -/// Adding a patient will keep it alive up until the enclosing function returns. -class loader_life_support { -public: - /// A new patient frame is created when a function is entered - loader_life_support() { - get_internals().loader_patient_stack.push_back(nullptr); - } - - /// ... and destroyed after it returns - ~loader_life_support() { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - pybind11_fail("loader_life_support: internal error"); - - auto ptr = stack.back(); - stack.pop_back(); - Py_CLEAR(ptr); - - // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) - if (stack.capacity() > 16 && !stack.empty() && stack.capacity() / stack.size() > 2) - stack.shrink_to_fit(); - } - - /// This can only be used inside a pybind11-bound function, either by `argument_loader` - /// at argument preparation time or by `py::cast()` at execution time. - PYBIND11_NOINLINE static void add_patient(handle h) { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - throw cast_error("When called outside a bound function, py::cast() cannot " - "do Python -> C++ conversions which require the creation " - "of temporary values"); - - auto &list_ptr = stack.back(); - if (list_ptr == nullptr) { - list_ptr = PyList_New(1); - if (!list_ptr) - pybind11_fail("loader_life_support: error allocating list"); - PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); - } else { - auto result = PyList_Append(list_ptr, h.ptr()); - if (result == -1) - pybind11_fail("loader_life_support: error adding patient"); - } - } -}; - -// Gets the cache entry for the given type, creating it if necessary. The return value is the pair -// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was -// just created. -inline std::pair all_type_info_get_cache(PyTypeObject *type); - -// Populates a just-created cache entry. -PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { - std::vector check; - for (handle parent : reinterpret_borrow(t->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - - auto const &type_dict = get_internals().registered_types_py; - for (size_t i = 0; i < check.size(); i++) { - auto type = check[i]; - // Ignore Python2 old-style class super type: - if (!PyType_Check((PyObject *) type)) continue; - - // Check `type` in the current set of registered python types: - auto it = type_dict.find(type); - if (it != type_dict.end()) { - // We found a cache entry for it, so it's either pybind-registered or has pre-computed - // pybind bases, but we have to make sure we haven't already seen the type(s) before: we - // want to follow Python/virtual C++ rules that there should only be one instance of a - // common base. - for (auto *tinfo : it->second) { - // NB: Could use a second set here, rather than doing a linear search, but since - // having a large number of immediate pybind11-registered types seems fairly - // unlikely, that probably isn't worthwhile. - bool found = false; - for (auto *known : bases) { - if (known == tinfo) { found = true; break; } - } - if (!found) bases.push_back(tinfo); - } - } - else if (type->tp_bases) { - // It's some python type, so keep follow its bases classes to look for one or more - // registered types - if (i + 1 == check.size()) { - // When we're at the end, we can pop off the current element to avoid growing - // `check` when adding just one base (which is typical--i.e. when there is no - // multiple inheritance) - check.pop_back(); - i--; - } - for (handle parent : reinterpret_borrow(type->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - } - } -} - -/** - * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will - * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side - * derived class that uses single inheritance. Will contain as many types as required for a Python - * class that uses multiple inheritance to inherit (directly or indirectly) from multiple - * pybind-registered classes. Will be empty if neither the type nor any base classes are - * pybind-registered. - * - * The value is cached for the lifetime of the Python type. - */ -inline const std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - - return ins.first->second; -} - -/** - * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any - * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use - * `all_type_info` instead if you want to support multiple bases. - */ -PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { - auto &bases = all_type_info(type); - if (bases.empty()) - return nullptr; - if (bases.size() > 1) - pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); - return bases.front(); -} - -inline detail::type_info *get_local_type_info(const std::type_index &tp) { - auto &locals = registered_local_types_cpp(); - auto it = locals.find(tp); - if (it != locals.end()) - return it->second; - return nullptr; -} - -inline detail::type_info *get_global_type_info(const std::type_index &tp) { - auto &types = get_internals().registered_types_cpp; - auto it = types.find(tp); - if (it != types.end()) - return it->second; - return nullptr; -} - -/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. -PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, - bool throw_if_missing = false) { - if (auto ltype = get_local_type_info(tp)) - return ltype; - if (auto gtype = get_global_type_info(tp)) - return gtype; - - if (throw_if_missing) { - std::string tname = tp.name(); - detail::clean_type_id(tname); - pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); - } - return nullptr; -} - -PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { - detail::type_info *type_info = get_type_info(tp, throw_if_missing); - return handle(type_info ? ((PyObject *) type_info->type) : nullptr); -} - -struct value_and_holder { - instance *inst = nullptr; - size_t index = 0u; - const detail::type_info *type = nullptr; - void **vh = nullptr; - - // Main constructor for a found value/holder: - value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : - inst{i}, index{index}, type{type}, - vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} - {} - - // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) - value_and_holder() = default; - - // Used for past-the-end iterator - value_and_holder(size_t index) : index{index} {} - - template V *&value_ptr() const { - return reinterpret_cast(vh[0]); - } - // True if this `value_and_holder` has a non-null value pointer - explicit operator bool() const { return value_ptr(); } - - template H &holder() const { - return reinterpret_cast(vh[1]); - } - bool holder_constructed() const { - return inst->simple_layout - ? inst->simple_holder_constructed - : inst->nonsimple.status[index] & instance::status_holder_constructed; - } - void set_holder_constructed(bool v = true) { - if (inst->simple_layout) - inst->simple_holder_constructed = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_holder_constructed; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; - } - bool instance_registered() const { - return inst->simple_layout - ? inst->simple_instance_registered - : inst->nonsimple.status[index] & instance::status_instance_registered; - } - void set_instance_registered(bool v = true) { - if (inst->simple_layout) - inst->simple_instance_registered = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_instance_registered; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; - } -}; - -// Container for accessing and iterating over an instance's values/holders -struct values_and_holders { -private: - instance *inst; - using type_vec = std::vector; - const type_vec &tinfo; - -public: - values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} - - struct iterator { - private: - instance *inst = nullptr; - const type_vec *types = nullptr; - value_and_holder curr; - friend struct values_and_holders; - iterator(instance *inst, const type_vec *tinfo) - : inst{inst}, types{tinfo}, - curr(inst /* instance */, - types->empty() ? nullptr : (*types)[0] /* type info */, - 0, /* vpos: (non-simple types only): the first vptr comes first */ - 0 /* index */) - {} - // Past-the-end iterator: - iterator(size_t end) : curr(end) {} - public: - bool operator==(const iterator &other) const { return curr.index == other.curr.index; } - bool operator!=(const iterator &other) const { return curr.index != other.curr.index; } - iterator &operator++() { - if (!inst->simple_layout) - curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; - ++curr.index; - curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; - return *this; - } - value_and_holder &operator*() { return curr; } - value_and_holder *operator->() { return &curr; } - }; - - iterator begin() { return iterator(inst, &tinfo); } - iterator end() { return iterator(tinfo.size()); } - - iterator find(const type_info *find_type) { - auto it = begin(), endit = end(); - while (it != endit && it->type != find_type) ++it; - return it; - } - - size_t size() { return tinfo.size(); } -}; - -/** - * Extracts C++ value and holder pointer references from an instance (which may contain multiple - * values/holders for python-side multiple inheritance) that match the given type. Throws an error - * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If - * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, - * regardless of type (and the resulting .type will be nullptr). - * - * The returned object should be short-lived: in particular, it must not outlive the called-upon - * instance. - */ -PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { - // Optimize common case: - if (!find_type || Py_TYPE(this) == find_type->type) - return value_and_holder(this, find_type, 0, 0); - - detail::values_and_holders vhs(this); - auto it = vhs.find(find_type); - if (it != vhs.end()) - return *it; - - if (!throw_if_missing) - return value_and_holder(); - -#if defined(NDEBUG) - pybind11_fail("pybind11::detail::instance::get_value_and_holder: " - "type is not a pybind11 base of the given instance " - "(compile in debug mode for type details)"); -#else - pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + - std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + - std::string(Py_TYPE(this)->tp_name) + "' instance"); -#endif -} - -PYBIND11_NOINLINE inline void instance::allocate_layout() { - auto &tinfo = all_type_info(Py_TYPE(this)); - - const size_t n_types = tinfo.size(); - - if (n_types == 0) - pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); - - simple_layout = - n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); - - // Simple path: no python-side multiple inheritance, and a small-enough holder - if (simple_layout) { - simple_value_holder[0] = nullptr; - simple_holder_constructed = false; - simple_instance_registered = false; - } - else { // multiple base types or a too-large holder - // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, - // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool - // values that tracks whether each associated holder has been initialized. Each [block] is - // padded, if necessary, to an integer multiple of sizeof(void *). - size_t space = 0; - for (auto t : tinfo) { - space += 1; // value pointer - space += t->holder_size_in_ptrs; // holder instance - } - size_t flags_at = space; - space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) - - // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, - // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 - // they default to using pymalloc, which is designed to be efficient for small allocations - // like the one we're doing here; in earlier versions (and for larger allocations) they are - // just wrappers around malloc. -#if PY_VERSION_HEX >= 0x03050000 - nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); -#else - nonsimple.values_and_holders = (void **) PyMem_New(void *, space); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); - std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); -#endif - nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); - } - owned = true; -} - -PYBIND11_NOINLINE inline void instance::deallocate_layout() { - if (!simple_layout) - PyMem_Free(nonsimple.values_and_holders); -} - -PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { - handle type = detail::get_type_handle(tp, false); - if (!type) - return false; - return isinstance(obj, type); -} - -PYBIND11_NOINLINE inline std::string error_string() { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); - return "Unknown internal error occurred"; - } - - error_scope scope; // Preserve error state - - std::string errorString; - if (scope.type) { - errorString += handle(scope.type).attr("__name__").cast(); - errorString += ": "; - } - if (scope.value) - errorString += (std::string) str(scope.value); - - PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); - -#if PY_MAJOR_VERSION >= 3 - if (scope.trace != nullptr) - PyException_SetTraceback(scope.value, scope.trace); -#endif - -#if !defined(PYPY_VERSION) - if (scope.trace) { - auto *trace = (PyTracebackObject *) scope.trace; - - /* Get the deepest trace possible */ - while (trace->tb_next) - trace = trace->tb_next; - - PyFrameObject *frame = trace->tb_frame; - errorString += "\n\nAt:\n"; - while (frame) { - int lineno = PyFrame_GetLineNumber(frame); - errorString += - " " + handle(frame->f_code->co_filename).cast() + - "(" + std::to_string(lineno) + "): " + - handle(frame->f_code->co_name).cast() + "\n"; - frame = frame->f_back; - } - } -#endif - - return errorString; -} - -PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { - auto &instances = get_internals().registered_instances; - auto range = instances.equal_range(ptr); - for (auto it = range.first; it != range.second; ++it) { - for (const auto &vh : values_and_holders(it->second)) { - if (vh.type == type) - return handle((PyObject *) it->second); - } - } - return handle(); -} - -inline PyThreadState *get_thread_state_unchecked() { -#if defined(PYPY_VERSION) - return PyThreadState_GET(); -#elif PY_VERSION_HEX < 0x03000000 - return _PyThreadState_Current; -#elif PY_VERSION_HEX < 0x03050000 - return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); -#elif PY_VERSION_HEX < 0x03050200 - return (PyThreadState*) _PyThreadState_Current.value; -#else - return _PyThreadState_UncheckedGet(); -#endif -} - -// Forward declarations -inline void keep_alive_impl(handle nurse, handle patient); -inline PyObject *make_new_instance(PyTypeObject *type); - -class type_caster_generic { -public: - PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) - : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } - - type_caster_generic(const type_info *typeinfo) - : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } - - bool load(handle src, bool convert) { - return load_impl(src, convert); - } - - PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, - const detail::type_info *tinfo, - void *(*copy_constructor)(const void *), - void *(*move_constructor)(const void *), - const void *existing_holder = nullptr) { - if (!tinfo) // no type info: error will be set already - return handle(); - - void *src = const_cast(_src); - if (src == nullptr) - return none().release(); - - auto it_instances = get_internals().registered_instances.equal_range(src); - for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { - for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { - if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) - return handle((PyObject *) it_i->second).inc_ref(); - } - } - - auto inst = reinterpret_steal(make_new_instance(tinfo->type)); - auto wrapper = reinterpret_cast(inst.ptr()); - wrapper->owned = false; - void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); - - switch (policy) { - case return_value_policy::automatic: - case return_value_policy::take_ownership: - valueptr = src; - wrapper->owned = true; - break; - - case return_value_policy::automatic_reference: - case return_value_policy::reference: - valueptr = src; - wrapper->owned = false; - break; - - case return_value_policy::copy: - if (copy_constructor) - valueptr = copy_constructor(src); - else { -#if defined(NDEBUG) - throw cast_error("return_value_policy = copy, but type is " - "non-copyable! (compile in debug mode for details)"); -#else - std::string type_name(tinfo->cpptype->name()); - detail::clean_type_id(type_name); - throw cast_error("return_value_policy = copy, but type " + - type_name + " is non-copyable!"); -#endif - } - wrapper->owned = true; - break; - - case return_value_policy::move: - if (move_constructor) - valueptr = move_constructor(src); - else if (copy_constructor) - valueptr = copy_constructor(src); - else { -#if defined(NDEBUG) - throw cast_error("return_value_policy = move, but type is neither " - "movable nor copyable! " - "(compile in debug mode for details)"); -#else - std::string type_name(tinfo->cpptype->name()); - detail::clean_type_id(type_name); - throw cast_error("return_value_policy = move, but type " + - type_name + " is neither movable nor copyable!"); -#endif - } - wrapper->owned = true; - break; - - case return_value_policy::reference_internal: - valueptr = src; - wrapper->owned = false; - keep_alive_impl(inst, parent); - break; - - default: - throw cast_error("unhandled return_value_policy: should not happen!"); - } - - tinfo->init_instance(wrapper, existing_holder); - - return inst.release(); - } - - // Base methods for generic caster; there are overridden in copyable_holder_caster - void load_value(value_and_holder &&v_h) { - auto *&vptr = v_h.value_ptr(); - // Lazy allocation for unallocated values: - if (vptr == nullptr) { - auto *type = v_h.type ? v_h.type : typeinfo; - if (type->operator_new) { - vptr = type->operator_new(type->type_size); - } else { - #if defined(__cpp_aligned_new) && (!defined(_MSC_VER) || _MSC_VER >= 1912) - if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) - vptr = ::operator new(type->type_size, - std::align_val_t(type->type_align)); - else - #endif - vptr = ::operator new(type->type_size); - } - } - value = vptr; - } - bool try_implicit_casts(handle src, bool convert) { - for (auto &cast : typeinfo->implicit_casts) { - type_caster_generic sub_caster(*cast.first); - if (sub_caster.load(src, convert)) { - value = cast.second(sub_caster.value); - return true; - } - } - return false; - } - bool try_direct_conversions(handle src) { - for (auto &converter : *typeinfo->direct_conversions) { - if (converter(src.ptr(), value)) - return true; - } - return false; - } - void check_holder_compat() {} - - PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { - auto caster = type_caster_generic(ti); - if (caster.load(src, false)) - return caster.value; - return nullptr; - } - - /// Try to load with foreign typeinfo, if available. Used when there is no - /// native typeinfo, or when the native one wasn't able to produce a value. - PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { - constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; - const auto pytype = type::handle_of(src); - if (!hasattr(pytype, local_key)) - return false; - - type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); - // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type - if (foreign_typeinfo->module_local_load == &local_load - || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) - return false; - - if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { - value = result; - return true; - } - return false; - } - - // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant - // bits of code between here and copyable_holder_caster where the two classes need different - // logic (without having to resort to virtual inheritance). - template - PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { - if (!src) return false; - if (!typeinfo) return try_load_foreign_module_local(src); - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - value = nullptr; - return true; - } - - auto &this_ = static_cast(*this); - this_.check_holder_compat(); - - PyTypeObject *srctype = Py_TYPE(src.ptr()); - - // Case 1: If src is an exact type match for the target type then we can reinterpret_cast - // the instance's value pointer to the target type: - if (srctype == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2: We have a derived class - else if (PyType_IsSubtype(srctype, typeinfo->type)) { - auto &bases = all_type_info(srctype); - bool no_cpp_mi = typeinfo->simple_type; - - // Case 2a: the python type is a Python-inherited derived class that inherits from just - // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of - // the right type and we can use reinterpret_cast. - // (This is essentially the same as case 2b, but because not using multiple inheritance - // is extremely common, we handle it specially to avoid the loop iterator and type - // pointer lookup overhead) - if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if - // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we - // can safely reinterpret_cast to the relevant pointer. - else if (bases.size() > 1) { - for (auto base : bases) { - if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); - return true; - } - } - } - - // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match - // in the registered bases, above, so try implicit casting (needed for proper C++ casting - // when MI is involved). - if (this_.try_implicit_casts(src, convert)) - return true; - } - - // Perform an implicit conversion - if (convert) { - for (auto &converter : typeinfo->implicit_conversions) { - auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); - if (load_impl(temp, false)) { - loader_life_support::add_patient(temp); - return true; - } - } - if (this_.try_direct_conversions(src)) - return true; - } - - // Failed to match local typeinfo. Try again with global. - if (typeinfo->module_local) { - if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { - typeinfo = gtype; - return load(src, false); - } - } - - // Global typeinfo has precedence over foreign module_local - return try_load_foreign_module_local(src); - } - - - // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast - // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair - // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). - PYBIND11_NOINLINE static std::pair src_and_type( - const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { - if (auto *tpi = get_type_info(cast_type)) - return {src, const_cast(tpi)}; - - // Not found, set error: - std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); - detail::clean_type_id(tname); - std::string msg = "Unregistered type : " + tname; - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return {nullptr, nullptr}; - } - - const type_info *typeinfo = nullptr; - const std::type_info *cpptype = nullptr; - void *value = nullptr; -}; - -/** - * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster - * needs to provide `operator T*()` and `operator T&()` operators. - * - * If the type supports moving the value away via an `operator T&&() &&` method, it should use - * `movable_cast_op_type` instead. - */ -template -using cast_op_type = - conditional_t>::value, - typename std::add_pointer>::type, - typename std::add_lvalue_reference>::type>; - -/** - * Determine suitable casting operator for a type caster with a movable value. Such a type caster - * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be - * called in appropriate contexts where the value can be moved rather than copied. - * - * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. - */ -template -using movable_cast_op_type = - conditional_t::type>::value, - typename std::add_pointer>::type, - conditional_t::value, - typename std::add_rvalue_reference>::type, - typename std::add_lvalue_reference>::type>>; - -// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when -// T is non-copyable, but code containing such a copy constructor fails to actually compile. -template struct is_copy_constructible : std::is_copy_constructible {}; - -// Specialization for types that appear to be copy constructible but also look like stl containers -// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if -// so, copy constructability depends on whether the value_type is copy constructible. -template struct is_copy_constructible, - std::is_same, - // Avoid infinite recursion - negation> - >::value>> : is_copy_constructible {}; - -// Likewise for std::pair -// (after C++17 it is mandatory that the copy constructor not exist when the two types aren't themselves -// copy constructible, but this can not be relied upon when T1 or T2 are themselves containers). -template struct is_copy_constructible> - : all_of, is_copy_constructible> {}; - -// The same problems arise with std::is_copy_assignable, so we use the same workaround. -template struct is_copy_assignable : std::is_copy_assignable {}; -template struct is_copy_assignable, - std::is_same - >::value>> : is_copy_assignable {}; -template struct is_copy_assignable> - : all_of, is_copy_assignable> {}; - -PYBIND11_NAMESPACE_END(detail) - -// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed -// to by `src` actually is an instance of some class derived from `itype`. -// If so, it sets `tinfo` to point to the std::type_info representing that derived -// type, and returns a pointer to the start of the most-derived object of that type -// (in which `src` is a subobject; this will be the same address as `src` in most -// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` -// and leaves `tinfo` at its default value of nullptr. -// -// The default polymorphic_type_hook just returns src. A specialization for polymorphic -// types determines the runtime type of the passed object and adjusts the this-pointer -// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear -// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is -// registered with pybind11, and this Animal is in fact a Dog). -// -// You may specialize polymorphic_type_hook yourself for types that want to appear -// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern -// in performance-sensitive applications, used most notably in LLVM.) -// -// polymorphic_type_hook_base allows users to specialize polymorphic_type_hook with -// std::enable_if. User provided specializations will always have higher priority than -// the default implementation and specialization provided in polymorphic_type_hook_base. -template -struct polymorphic_type_hook_base -{ - static const void *get(const itype *src, const std::type_info*&) { return src; } -}; -template -struct polymorphic_type_hook_base::value>> -{ - static const void *get(const itype *src, const std::type_info*& type) { - type = src ? &typeid(*src) : nullptr; - return dynamic_cast(src); - } -}; -template -struct polymorphic_type_hook : public polymorphic_type_hook_base {}; - -PYBIND11_NAMESPACE_BEGIN(detail) - -/// Generic type caster for objects stored on the heap -template class type_caster_base : public type_caster_generic { - using itype = intrinsic_t; - -public: - static constexpr auto name = _(); - - type_caster_base() : type_caster_base(typeid(type)) { } - explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } - - static handle cast(const itype &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast(&src, policy, parent); - } - - static handle cast(itype &&src, return_value_policy, handle parent) { - return cast(&src, return_value_policy::move, parent); - } - - // Returns a (pointer, type_info) pair taking care of necessary type lookup for a - // polymorphic type (using RTTI by default, but can be overridden by specializing - // polymorphic_type_hook). If the instance isn't derived, returns the base version. - static std::pair src_and_type(const itype *src) { - auto &cast_type = typeid(itype); - const std::type_info *instance_type = nullptr; - const void *vsrc = polymorphic_type_hook::get(src, instance_type); - if (instance_type && !same_type(cast_type, *instance_type)) { - // This is a base pointer to a derived type. If the derived type is registered - // with pybind11, we want to make the full derived object available. - // In the typical case where itype is polymorphic, we get the correct - // derived pointer (which may be != base pointer) by a dynamic_cast to - // most derived type. If itype is not polymorphic, we won't get here - // except via a user-provided specialization of polymorphic_type_hook, - // and the user has promised that no this-pointer adjustment is - // required in that case, so it's OK to use static_cast. - if (const auto *tpi = get_type_info(*instance_type)) - return {vsrc, tpi}; - } - // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so - // don't do a cast - return type_caster_generic::src_and_type(src, cast_type, instance_type); - } - - static handle cast(const itype *src, return_value_policy policy, handle parent) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, policy, parent, st.second, - make_copy_constructor(src), make_move_constructor(src)); - } - - static handle cast_holder(const itype *src, const void *holder) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, return_value_policy::take_ownership, {}, st.second, - nullptr, nullptr, holder); - } - - template using cast_op_type = detail::cast_op_type; - - operator itype*() { return (type *) value; } - operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } - -protected: - using Constructor = void *(*)(const void *); - - /* Only enabled when the types are {copy,move}-constructible *and* when the type - does not have a private operator new implementation. */ - template ::value>> - static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { - return [](const void *arg) -> void * { - return new T(*reinterpret_cast(arg)); - }; - } - - template ::value>> - static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { - return [](const void *arg) -> void * { - return new T(std::move(*const_cast(reinterpret_cast(arg)))); - }; - } - - static Constructor make_copy_constructor(...) { return nullptr; } - static Constructor make_move_constructor(...) { return nullptr; } -}; - template class type_caster : public type_caster_base { }; template using make_caster = type_caster>; @@ -960,9 +47,14 @@ template class type_caster> { private: using caster_t = make_caster; caster_t subcaster; - using subcaster_cast_op_type = typename caster_t::template cast_op_type; - static_assert(std::is_same::type &, subcaster_cast_op_type>::value, - "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); + using reference_t = type&; + using subcaster_cast_op_type = + typename caster_t::template cast_op_type; + + static_assert(std::is_same::type &, subcaster_cast_op_type>::value || + std::is_same::value, + "std::reference_wrapper caster requires T to have a caster with an " + "`operator T &()` or `operator const T &()`"); public: bool load(handle src, bool convert) { return subcaster.load(src, convert); } static constexpr auto name = caster_t::name; @@ -973,28 +65,31 @@ public: return caster_t::cast(&src.get(), policy, parent); } template using cast_op_type = std::reference_wrapper; - operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } + explicit operator std::reference_wrapper() { return cast_op(subcaster); } }; -#define PYBIND11_TYPE_CASTER(type, py_name) \ - protected: \ - type value; \ - public: \ - static constexpr auto name = py_name; \ - template >::value, int> = 0> \ - static handle cast(T_ *src, return_value_policy policy, handle parent) { \ - if (!src) return none().release(); \ - if (policy == return_value_policy::take_ownership) { \ - auto h = cast(std::move(*src), policy, parent); delete src; return h; \ - } else { \ - return cast(*src, policy, parent); \ - } \ - } \ - operator type*() { return &value; } \ - operator type&() { return value; } \ - operator type&&() && { return std::move(value); } \ - template using cast_op_type = pybind11::detail::movable_cast_op_type - +#define PYBIND11_TYPE_CASTER(type, py_name) \ +protected: \ + type value; \ + \ +public: \ + static constexpr auto name = py_name; \ + template >::value, int> = 0> \ + static handle cast(T_ *src, return_value_policy policy, handle parent) { \ + if (!src) \ + return none().release(); \ + if (policy == return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); \ + delete src; \ + return h; \ + } \ + return cast(*src, policy, parent); \ + } \ + operator type *() { return &value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &() { return value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &&() && { return std::move(value); } /* NOLINT(bugprone-macro-parentheses) */ \ + template \ + using cast_op_type = pybind11::detail::movable_cast_op_type template using is_std_char_type = any_of< std::is_same, /* std::string */ @@ -1020,19 +115,46 @@ public: if (!src) return false; +#if !defined(PYPY_VERSION) + auto index_check = [](PyObject *o) { return PyIndex_Check(o); }; +#else + // In PyPy 7.3.3, `PyIndex_Check` is implemented by calling `__index__`, + // while CPython only considers the existence of `nb_index`/`__index__`. + auto index_check = [](PyObject *o) { return hasattr(o, "__index__"); }; +#endif + if (std::is_floating_point::value) { if (convert || PyFloat_Check(src.ptr())) py_value = (py_type) PyFloat_AsDouble(src.ptr()); else return false; - } else if (PyFloat_Check(src.ptr())) { + } else if (PyFloat_Check(src.ptr()) + || (!convert && !PYBIND11_LONG_CHECK(src.ptr()) && !index_check(src.ptr()))) { return false; - } else if (std::is_unsigned::value) { - py_value = as_unsigned(src.ptr()); - } else { // signed integer: - py_value = sizeof(T) <= sizeof(long) - ? (py_type) PyLong_AsLong(src.ptr()) - : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); + } else { + handle src_or_index = src; + // PyPy: 7.3.7's 3.8 does not implement PyLong_*'s __index__ calls. +#if PY_VERSION_HEX < 0x03080000 || defined(PYPY_VERSION) + object index; + if (!PYBIND11_LONG_CHECK(src.ptr())) { // So: index_check(src.ptr()) + index = reinterpret_steal(PyNumber_Index(src.ptr())); + if (!index) { + PyErr_Clear(); + if (!convert) + return false; + } + else { + src_or_index = index; + } + } +#endif + if (std::is_unsigned::value) { + py_value = as_unsigned(src_or_index.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type) PyLong_AsLong(src_or_index.ptr()) + : (py_type) PYBIND11_LONG_AS_LONGLONG(src_or_index.ptr()); + } } // Python API reported an error @@ -1041,15 +163,8 @@ public: // Check to see if the conversion is valid (integers should match exactly) // Signed/unsigned checks happen elsewhere if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && py_value != (py_type) (T) py_value)) { - bool type_error = py_err && PyErr_ExceptionMatches( -#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) - PyExc_SystemError -#else - PyExc_TypeError -#endif - ); PyErr_Clear(); - if (type_error && convert && PyNumber_Check(src.ptr())) { + if (py_err && convert && (PyNumber_Check(src.ptr()) != 0)) { auto tmp = reinterpret_steal(std::is_floating_point::value ? PyNumber_Float(src.ptr()) : PyNumber_Long(src.ptr())); @@ -1093,7 +208,7 @@ public: return PyLong_FromUnsignedLongLong((unsigned long long) src); } - PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); + PYBIND11_TYPE_CASTER(T, const_name::value>("int", "float")); }; template struct void_caster { @@ -1106,7 +221,7 @@ public: static handle cast(T, return_value_policy /* policy */, handle /* parent */) { return none().inc_ref(); } - PYBIND11_TYPE_CASTER(T, _("None")); + PYBIND11_TYPE_CASTER(T, const_name("None")); }; template <> class type_caster : public void_caster {}; @@ -1118,7 +233,8 @@ public: bool load(handle h, bool) { if (!h) { return false; - } else if (h.is_none()) { + } + if (h.is_none()) { value = nullptr; return true; } @@ -1143,13 +259,12 @@ public: static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { if (ptr) return capsule(ptr).release(); - else - return none().inc_ref(); + return none().inc_ref(); } template using cast_op_type = void*&; - operator void *&() { return value; } - static constexpr auto name = _("capsule"); + explicit operator void *&() { return value; } + static constexpr auto name = const_name("capsule"); private: void *value = nullptr; }; @@ -1160,9 +275,15 @@ template <> class type_caster { public: bool load(handle src, bool convert) { if (!src) return false; - else if (src.ptr() == Py_True) { value = true; return true; } - else if (src.ptr() == Py_False) { value = false; return true; } - else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + if (src.ptr() == Py_True) { + value = true; + return true; + } + if (src.ptr() == Py_False) { + value = false; + return true; + } + if (convert || (std::strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name) == 0)) { // (allow non-implicit conversion for numpy booleans) Py_ssize_t res = -1; @@ -1184,18 +305,17 @@ public: } #endif if (res == 0 || res == 1) { - value = (bool) res; + value = (res != 0); return true; - } else { - PyErr_Clear(); } + PyErr_Clear(); } return false; } static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { return handle(src ? Py_True : Py_False).inc_ref(); } - PYBIND11_TYPE_CASTER(bool, _("bool")); + PYBIND11_TYPE_CASTER(bool, const_name("bool")); }; // Helper class for UTF-{8,16,32} C++ stl strings: @@ -1222,7 +342,8 @@ template struct string_caster { handle load_src = src; if (!src) { return false; - } else if (!PyUnicode_Check(load_src.ptr())) { + } + if (!PyUnicode_Check(load_src.ptr())) { #if PY_MAJOR_VERSION >= 3 return load_bytes(load_src); #else @@ -1240,13 +361,33 @@ template struct string_caster { #endif } - object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( +#if PY_VERSION_HEX >= 0x03030000 + // On Python >= 3.3, for UTF-8 we avoid the need for a temporary `bytes` + // object by using `PyUnicode_AsUTF8AndSize`. + if (PYBIND11_SILENCE_MSVC_C4127(UTF_N == 8)) { + Py_ssize_t size = -1; + const auto *buffer + = reinterpret_cast(PyUnicode_AsUTF8AndSize(load_src.ptr(), &size)); + if (!buffer) { + PyErr_Clear(); + return false; + } + value = StringType(buffer, static_cast(size)); + return true; + } +#endif + + auto utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); if (!utfNbytes) { PyErr_Clear(); return false; } const auto *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); - if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 + // Skip BOM for UTF-16/32 + if (PYBIND11_SILENCE_MSVC_C4127(UTF_N > 8)) { + buffer++; + length--; + } value = StringType(buffer, length); // If we're loading a string_view we need to keep the encoded Python object alive: @@ -1264,7 +405,7 @@ template struct string_caster { return s; } - PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); + PYBIND11_TYPE_CASTER(StringType, const_name(PYBIND11_STRING_NAME)); private: static handle decode_utfN(const char *buffer, ssize_t nbytes) { @@ -1274,10 +415,8 @@ private: UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); #else - // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version - // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a - // non-const char * arguments, which is also a nuisance, so bypass the whole thing by just - // passing the encoding as a string value, which works properly: + // PyPy segfaults when on PyUnicode_DecodeUTF16 (and possibly on PyUnicode_DecodeUTF32 as well), + // so bypass the whole thing by just passing the encoding as a string value, which works properly: return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); #endif } @@ -1348,8 +487,10 @@ public: return StringCaster::cast(StringType(1, src), policy, parent); } - operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } - operator CharT&() { + explicit operator CharT *() { + return none ? nullptr : const_cast(static_cast(str_caster).c_str()); + } + explicit operator CharT &() { if (none) throw value_error("Cannot convert None to a character"); @@ -1363,12 +504,16 @@ public: // out how long the first encoded character is in bytes to distinguish between these two // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those // can fit into a single char value. - if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { + if (PYBIND11_SILENCE_MSVC_C4127(StringCaster::UTF_N == 8) && str_len > 1 && str_len <= 4) { auto v0 = static_cast(value[0]); - size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 - (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence - (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence - 4; // 0b11110xxx - start of 4-byte sequence + // low bits only: 0-127 + // 0b110xxxxx - start of 2-byte sequence + // 0b1110xxxx - start of 3-byte sequence + // 0b11110xxx - start of 4-byte sequence + size_t char0_bytes = (v0 & 0x80) == 0 ? 1 + : (v0 & 0xE0) == 0xC0 ? 2 + : (v0 & 0xF0) == 0xE0 ? 3 + : 4; if (char0_bytes == str_len) { // If we have a 128-255 value, we can decode it into a single char: @@ -1384,7 +529,7 @@ public: // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a // surrogate pair with total length 2 instantly indicates a range error (but not a "your // string was too long" error). - else if (StringCaster::UTF_N == 16 && str_len == 2) { + else if (PYBIND11_SILENCE_MSVC_C4127(StringCaster::UTF_N == 16) && str_len == 2) { one_char = static_cast(value[0]); if (one_char >= 0xD800 && one_char < 0xE000) throw value_error("Character code point not in range(0x10000)"); @@ -1397,7 +542,7 @@ public: return one_char; } - static constexpr auto name = _(PYBIND11_STRING_NAME); + static constexpr auto name = const_name(PYBIND11_STRING_NAME); template using cast_op_type = pybind11::detail::cast_op_type<_T>; }; @@ -1427,18 +572,19 @@ public: static handle cast(T *src, return_value_policy policy, handle parent) { if (!src) return none().release(); if (policy == return_value_policy::take_ownership) { - auto h = cast(std::move(*src), policy, parent); delete src; return h; - } else { - return cast(*src, policy, parent); + auto h = cast(std::move(*src), policy, parent); + delete src; + return h; } + return cast(*src, policy, parent); } - static constexpr auto name = _("Tuple[") + concat(make_caster::name...) + _("]"); + static constexpr auto name = const_name("Tuple[") + concat(make_caster::name...) + const_name("]"); template using cast_op_type = type; - operator type() & { return implicit_cast(indices{}); } - operator type() && { return std::move(*this).implicit_cast(indices{}); } + explicit operator type() & { return implicit_cast(indices{}); } + explicit operator type() && { return std::move(*this).implicit_cast(indices{}); } protected: template @@ -1464,6 +610,8 @@ protected: /* Implementation: Convert a C++ tuple into a Python tuple */ template static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(src, policy, parent); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(policy, parent); std::array entries{{ reinterpret_steal(make_caster::cast(std::get(std::forward(src)), policy, parent))... }}; @@ -1494,7 +642,11 @@ struct holder_helper { }; /// Type caster for holder types like std::shared_ptr, etc. -template +/// The SFINAE hook is provided to help work around the current lack of support +/// for smart-pointer interoperability. Please consider it an implementation +/// detail that may change in the future, as formal support for smart-pointer +/// interoperability is added into pybind11. +template struct copyable_holder_caster : public type_caster_base { public: using base = type_caster_base; @@ -1514,14 +666,7 @@ public: // see issue #2180 explicit operator type&() { return *(static_cast(this->value)); } explicit operator holder_type*() { return std::addressof(holder); } - - // Workaround for Intel compiler bug - // see pybind11 issue 94 - #if defined(__ICC) || defined(__INTEL_COMPILER) - operator holder_type&() { return holder; } - #else explicit operator holder_type&() { return holder; } - #endif static handle cast(const holder_type &src, return_value_policy, handle) { const auto *ptr = holder_helper::get(src); @@ -1540,14 +685,14 @@ protected: value = v_h.value_ptr(); holder = v_h.template holder(); return true; - } else { - throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " -#if defined(NDEBUG) - "(compile in debug mode for type information)"); -#else - "of type '" + type_id() + "''"); -#endif } + throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " +#if defined(NDEBUG) + "(compile in debug mode for type information)"); +#else + "of type '" + + type_id() + "''"); +#endif } template ::value, int> = 0> @@ -1576,7 +721,10 @@ protected: template class type_caster> : public copyable_holder_caster> { }; -template +/// Type caster for holder types like std::unique_ptr. +/// Please consider the SFINAE hook an implementation detail, as explained +/// in the comment for the copyable_holder_caster. +template struct move_only_holder_caster { static_assert(std::is_base_of, type_caster>::value, "Holder classes are only supported for custom types"); @@ -1616,14 +764,16 @@ template struct is_holder_type : template struct is_holder_type> : std::true_type {}; -template struct handle_type_name { static constexpr auto name = _(); }; -template <> struct handle_type_name { static constexpr auto name = _(PYBIND11_BYTES_NAME); }; -template <> struct handle_type_name { static constexpr auto name = _("int"); }; -template <> struct handle_type_name { static constexpr auto name = _("Iterable"); }; -template <> struct handle_type_name { static constexpr auto name = _("Iterator"); }; -template <> struct handle_type_name { static constexpr auto name = _("None"); }; -template <> struct handle_type_name { static constexpr auto name = _("*args"); }; -template <> struct handle_type_name { static constexpr auto name = _("**kwargs"); }; +template struct handle_type_name { static constexpr auto name = const_name(); }; +template <> struct handle_type_name { static constexpr auto name = const_name("bool"); }; +template <> struct handle_type_name { static constexpr auto name = const_name(PYBIND11_BYTES_NAME); }; +template <> struct handle_type_name { static constexpr auto name = const_name("int"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("Iterable"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("Iterator"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("float"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("None"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("*args"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("**kwargs"); }; template struct pyobject_caster { @@ -1632,6 +782,17 @@ struct pyobject_caster { template ::value, int> = 0> bool load(handle src, bool /* convert */) { +#if PY_MAJOR_VERSION < 3 && !defined(PYBIND11_STR_LEGACY_PERMISSIVE) + // For Python 2, without this implicit conversion, Python code would + // need to be cluttered with six.ensure_text() or similar, only to be + // un-cluttered later after Python 2 support is dropped. + if (PYBIND11_SILENCE_MSVC_C4127(std::is_same::value) && isinstance(src)) { + PyObject *str_from_bytes = PyUnicode_FromEncodedObject(src.ptr(), "utf-8", nullptr); + if (!str_from_bytes) throw error_already_set(); + value = reinterpret_steal(str_from_bytes); + return true; + } +#endif if (!isinstance(src)) return false; value = reinterpret_borrow(src); @@ -1779,8 +940,7 @@ template detail::enable_if_t::value, T> cast template detail::enable_if_t::value, T> cast(object &&object) { if (object.ref_count() > 1) return cast(object); - else - return move(std::move(object)); + return move(std::move(object)); } template detail::enable_if_t::value, T> cast(object &&object) { return cast(object); @@ -1820,6 +980,21 @@ template <> inline void cast_safe(object &&) {} PYBIND11_NAMESPACE_END(detail) +// The overloads could coexist, i.e. the #if is not strictly speaking needed, +// but it is an easy minor optimization. +#if defined(NDEBUG) +inline cast_error cast_error_unable_to_convert_call_arg() { + return cast_error( + "Unable to convert call argument to Python object (compile in debug mode for details)"); +} +#else +inline cast_error cast_error_unable_to_convert_call_arg(const std::string &name, + const std::string &type) { + return cast_error("Unable to convert call argument '" + name + "' of type '" + type + + "' to Python object"); +} +#endif + template tuple make_tuple() { return tuple(0); } @@ -1833,11 +1008,10 @@ template argtypes { {type_id()...} }; - throw cast_error("make_tuple(): unable to convert argument of type '" + - argtypes[i] + "' to Python object"); + throw cast_error_unable_to_convert_call_arg(std::to_string(i), argtypes[i]); #endif } } @@ -1879,7 +1053,14 @@ private: #if !defined(NDEBUG) , type(type_id()) #endif - { } + { + // Workaround! See: + // https://github.com/pybind/pybind11/issues/2336 + // https://github.com/pybind/pybind11/pull/2685#issuecomment-731286700 + if (PyErr_Occurred()) { + PyErr_Clear(); + } + } public: /// Direct construction with name, default, and description @@ -1919,7 +1100,9 @@ struct kw_only {}; struct pos_only {}; template -arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward(value)}; } +arg_v arg::operator=(T &&value) const { + return {*this, std::forward(value)}; +} /// Alias for backward compatibility -- to be removed in version 2.0 template using arg_t = arg_v; @@ -1933,6 +1116,9 @@ constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } PYBIND11_NAMESPACE_BEGIN(detail) +template using is_kw_only = std::is_same, kw_only>; +template using is_pos_only = std::is_same, pos_only>; + // forward declaration (definition in attr.h) struct function_record; @@ -1968,17 +1154,18 @@ class argument_loader { template using argument_is_args = std::is_same, args>; template using argument_is_kwargs = std::is_same, kwargs>; - // Get args/kwargs argument positions relative to the end of the argument list: - static constexpr auto args_pos = constexpr_first() - (int) sizeof...(Args), - kwargs_pos = constexpr_first() - (int) sizeof...(Args); + // Get kwargs argument position, or -1 if not present: + static constexpr auto kwargs_pos = constexpr_last(); - static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1; - - static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function"); + static_assert(kwargs_pos == -1 || kwargs_pos == (int) sizeof...(Args) - 1, "py::kwargs is only permitted as the last argument of a function"); public: - static constexpr bool has_kwargs = kwargs_pos < 0; - static constexpr bool has_args = args_pos < 0; + static constexpr bool has_kwargs = kwargs_pos != -1; + + // py::args argument position; -1 if not present. + static constexpr int args_pos = constexpr_last(); + + static_assert(args_pos == -1 || args_pos == constexpr_first(), "py::args cannot be specified more than once"); static constexpr auto arg_names = concat(type_descr(make_caster::name)...); @@ -1987,13 +1174,14 @@ public: } template + // NOLINTNEXTLINE(readability-const-return-type) enable_if_t::value, Return> call(Func &&f) && { - return std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + return std::move(*this).template call_impl>(std::forward(f), indices{}, Guard{}); } template enable_if_t::value, void_type> call(Func &&f) && { - std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + std::move(*this).template call_impl>(std::forward(f), indices{}, Guard{}); return void_type(); } @@ -2057,8 +1245,8 @@ public: // Tuples aren't (easily) resizable so a list is needed for collection, // but the actual function call strictly requires a tuple. auto args_list = list(); - int _[] = { 0, (process(args_list, std::forward(values)), 0)... }; - ignore_unused(_); + using expander = int[]; + (void) expander{0, (process(args_list, std::forward(values)), 0)...}; m_args = std::move(args_list); } @@ -2083,16 +1271,17 @@ private: auto o = reinterpret_steal(detail::make_caster::cast(std::forward(x), policy, {})); if (!o) { #if defined(NDEBUG) - argument_cast_error(); + throw cast_error_unable_to_convert_call_arg(); #else - argument_cast_error(std::to_string(args_list.size()), type_id()); + throw cast_error_unable_to_convert_call_arg( + std::to_string(args_list.size()), type_id()); #endif } args_list.append(o); } void process(list &args_list, detail::args_proxy ap) { - for (const auto &a : ap) + for (auto a : ap) args_list.append(a); } @@ -2113,9 +1302,9 @@ private: } if (!a.value) { #if defined(NDEBUG) - argument_cast_error(); + throw cast_error_unable_to_convert_call_arg(); #else - argument_cast_error(a.name, a.type); + throw cast_error_unable_to_convert_call_arg(a.name, a.type); #endif } m_kwargs[a.name] = a.value; @@ -2124,7 +1313,7 @@ private: void process(list &/*args_list*/, detail::kwargs_proxy kp) { if (!kp) return; - for (const auto &k : reinterpret_borrow(kp)) { + for (auto k : reinterpret_borrow(kp)) { if (m_kwargs.contains(k.first)) { #if defined(NDEBUG) multiple_values_error(); @@ -2141,7 +1330,7 @@ private: "may be passed via py::arg() to a python function call. " "(compile in debug mode for details)"); } - [[noreturn]] static void nameless_argument_error(std::string type) { + [[noreturn]] static void nameless_argument_error(const std::string &type) { throw type_error("Got kwargs without a name of type '" + type + "'; only named " "arguments may be passed via py::arg() to a python function call. "); } @@ -2150,35 +1339,35 @@ private: "(compile in debug mode for details)"); } - [[noreturn]] static void multiple_values_error(std::string name) { + [[noreturn]] static void multiple_values_error(const std::string &name) { throw type_error("Got multiple values for keyword argument '" + name + "'"); } - [[noreturn]] static void argument_cast_error() { - throw cast_error("Unable to convert call argument to Python object " - "(compile in debug mode for details)"); - } - - [[noreturn]] static void argument_cast_error(std::string name, std::string type) { - throw cast_error("Unable to convert call argument '" + name - + "' of type '" + type + "' to Python object"); - } - private: tuple m_args; dict m_kwargs; }; +// [workaround(intel)] Separate function required here +// We need to put this into a separate function because the Intel compiler +// fails to compile enable_if_t...>::value> +// (tested with ICC 2021.1 Beta 20200827). +template +constexpr bool args_are_all_positional() +{ + return all_of...>::value; +} + /// Collect only positional arguments for a Python function call template ...>::value>> + typename = enable_if_t()>> simple_collector collect_arguments(Args &&...args) { return simple_collector(std::forward(args)...); } /// Collect all arguments, including keywords and unpacking (only instantiated when needed) template ...>::value>> + typename = enable_if_t()>> unpacking_collector collect_arguments(Args &&...args) { // Following argument order rules for generalized unpacking according to PEP 448 static_assert( @@ -2193,6 +1382,11 @@ unpacking_collector collect_arguments(Args &&...args) { template template object object_api::operator()(Args &&...args) const { +#if !defined(NDEBUG) && PY_VERSION_HEX >= 0x03060000 + if (!PyGILState_Check()) { + pybind11_fail("pybind11::object_api<>::operator() PyGILState_Check() failure."); + } +#endif return detail::collect_arguments(std::forward(args)...).call(derived().ptr()); } diff --git a/wrap/pybind11/include/pybind11/chrono.h b/wrap/pybind11/include/pybind11/chrono.h index cbe9acec3..460a28fa5 100644 --- a/wrap/pybind11/include/pybind11/chrono.h +++ b/wrap/pybind11/include/pybind11/chrono.h @@ -11,9 +11,12 @@ #pragma once #include "pybind11.h" + +#include #include #include -#include +#include + #include // Backport the PyDateTime_DELTA functions from Python3.3 if required @@ -32,10 +35,10 @@ PYBIND11_NAMESPACE_BEGIN(detail) template class duration_caster { public: - typedef typename type::rep rep; + using rep = typename type::rep; using period = typename type::period; - using days = std::chrono::duration>; + using days = std::chrono::duration>; // signed 25 bits required by the standard. bool load(handle src, bool) { using namespace std::chrono; @@ -53,11 +56,11 @@ public: return true; } // If invoked with a float we assume it is seconds and convert - else if (PyFloat_Check(src.ptr())) { + if (PyFloat_Check(src.ptr())) { value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); return true; } - else return false; + return false; } // If this is a duration just return it back @@ -92,9 +95,25 @@ public: return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); } - PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); + PYBIND11_TYPE_CASTER(type, const_name("datetime.timedelta")); }; +inline std::tm *localtime_thread_safe(const std::time_t *time, std::tm *buf) { +#if (defined(__STDC_LIB_EXT1__) && defined(__STDC_WANT_LIB_EXT1__)) || defined(_MSC_VER) + if (localtime_s(buf, time)) + return nullptr; + return buf; +#else + static std::mutex mtx; + std::lock_guard lock(mtx); + std::tm *tm_ptr = std::localtime(time); + if (tm_ptr != nullptr) { + *buf = *tm_ptr; + } + return tm_ptr; +#endif +} + // This is for casting times on the system clock into datetime.datetime instances template class type_caster> { public: @@ -161,10 +180,11 @@ public: // > If std::time_t has lower precision, it is implementation-defined whether the value is rounded or truncated. // (https://en.cppreference.com/w/cpp/chrono/system_clock/to_time_t) std::time_t tt = system_clock::to_time_t(time_point_cast(src - us)); - // this function uses static memory so it's best to copy it out asap just in case - // otherwise other code that is using localtime may break this (not just python code) - std::tm localtime = *std::localtime(&tt); + std::tm localtime; + std::tm *localtime_ptr = localtime_thread_safe(&tt, &localtime); + if (!localtime_ptr) + throw cast_error("Unable to represent system_clock in local time"); return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, localtime.tm_mon + 1, localtime.tm_mday, @@ -173,7 +193,7 @@ public: localtime.tm_sec, us.count()); } - PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); + PYBIND11_TYPE_CASTER(type, const_name("datetime.datetime")); }; // Other clocks that are not the system clock are not measured as datetime.datetime objects diff --git a/wrap/pybind11/include/pybind11/complex.h b/wrap/pybind11/include/pybind11/complex.h index f8327eb37..e1ecf4358 100644 --- a/wrap/pybind11/include/pybind11/complex.h +++ b/wrap/pybind11/include/pybind11/complex.h @@ -59,7 +59,7 @@ public: return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); } - PYBIND11_TYPE_CASTER(std::complex, _("complex")); + PYBIND11_TYPE_CASTER(std::complex, const_name("complex")); }; PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/detail/class.h b/wrap/pybind11/include/pybind11/detail/class.h index b4a11c0a0..cc1e40ce7 100644 --- a/wrap/pybind11/include/pybind11/detail/class.h +++ b/wrap/pybind11/include/pybind11/detail/class.h @@ -24,6 +24,18 @@ PYBIND11_NAMESPACE_BEGIN(detail) # define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj) #endif +inline std::string get_fully_qualified_tp_name(PyTypeObject *type) { +#if !defined(PYPY_VERSION) + return type->tp_name; +#else + auto module_name = handle((PyObject *) type).attr("__module__").cast(); + if (module_name == PYBIND11_BUILTINS_MODULE) + return type->tp_name; + else + return std::move(module_name) + "." + type->tp_name; +#endif +} + inline PyTypeObject *type_incref(PyTypeObject *type) { Py_INCREF(type); return type; @@ -117,8 +129,9 @@ extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyOb // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment const auto static_prop = (PyObject *) get_internals().static_property_type; - const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) - && !PyObject_IsInstance(value, static_prop); + const auto call_descr_set = (descr != nullptr) && (value != nullptr) + && (PyObject_IsInstance(descr, static_prop) != 0) + && (PyObject_IsInstance(value, static_prop) == 0); if (call_descr_set) { // Call `static_property.__set__()` instead of replacing the `static_property`. #if !defined(PYPY_VERSION) @@ -150,9 +163,7 @@ extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name Py_INCREF(descr); return descr; } - else { - return PyType_Type.tp_getattro(obj, name); - } + return PyType_Type.tp_getattro(obj, name); } #endif @@ -172,7 +183,7 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P for (const auto &vh : values_and_holders(instance)) { if (!vh.holder_constructed()) { PyErr_Format(PyExc_TypeError, "%.200s.__init__() must be called when overriding __init__", - vh.type->type->tp_name); + get_fully_qualified_tp_name(vh.type->type).c_str()); Py_DECREF(self); return nullptr; } @@ -181,6 +192,44 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P return self; } +/// Cleanup the type-info for a pybind11-registered type. +extern "C" inline void pybind11_meta_dealloc(PyObject *obj) { + auto *type = (PyTypeObject *) obj; + auto &internals = get_internals(); + + // A pybind11-registered type will: + // 1) be found in internals.registered_types_py + // 2) have exactly one associated `detail::type_info` + auto found_type = internals.registered_types_py.find(type); + if (found_type != internals.registered_types_py.end() && + found_type->second.size() == 1 && + found_type->second[0]->type == type) { + + auto *tinfo = found_type->second[0]; + auto tindex = std::type_index(*tinfo->cpptype); + internals.direct_conversions.erase(tindex); + + if (tinfo->module_local) + get_local_internals().registered_types_cpp.erase(tindex); + else + internals.registered_types_cpp.erase(tindex); + internals.registered_types_py.erase(tinfo->type); + + // Actually just `std::erase_if`, but that's only available in C++20 + auto &cache = internals.inactive_override_cache; + for (auto it = cache.begin(), last = cache.end(); it != last; ) { + if (it->first == (PyObject *) tinfo->type) + it = cache.erase(it); + else + ++it; + } + + delete tinfo; + } + + PyType_Type.tp_dealloc(obj); +} + /** This metaclass is assigned by default to all pybind11 types and is required in order for static properties to function correctly. Users may override this using `py::metaclass`. Return value: New reference. */ @@ -213,6 +262,8 @@ inline PyTypeObject* make_default_metaclass() { type->tp_getattro = pybind11_meta_getattro; #endif + type->tp_dealloc = pybind11_meta_dealloc; + if (PyType_Ready(type) < 0) pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); @@ -250,7 +301,7 @@ inline bool deregister_instance_impl(void *ptr, instance *self) { auto ®istered_instances = get_internals().registered_instances; auto range = registered_instances.equal_range(ptr); for (auto it = range.first; it != range.second; ++it) { - if (Py_TYPE(self) == Py_TYPE(it->second)) { + if (self == it->second) { registered_instances.erase(it); return true; } @@ -277,7 +328,7 @@ inline bool deregister_instance(instance *self, void *valptr, const type_info *t inline PyObject *make_new_instance(PyTypeObject *type) { #if defined(PYPY_VERSION) // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited - // object is a a plain Python type (i.e. not derived from an extension type). Fix it. + // object is a plain Python type (i.e. not derived from an extension type). Fix it. ssize_t instance_size = static_cast(sizeof(instance)); if (type->tp_basicsize < instance_size) { type->tp_basicsize = instance_size; @@ -288,8 +339,6 @@ inline PyObject *make_new_instance(PyTypeObject *type) { // Allocate the value/holder internals: inst->allocate_layout(); - inst->owned = true; - return self; } @@ -304,12 +353,7 @@ extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, /// following default function will be used which simply throws an exception. extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { PyTypeObject *type = Py_TYPE(self); - std::string msg; -#if defined(PYPY_VERSION) - msg += handle((PyObject *) type).attr("__module__").cast() + "."; -#endif - msg += type->tp_name; - msg += ": No constructor defined!"; + std::string msg = get_fully_qualified_tp_name(type) + ": No constructor defined!"; PyErr_SetString(PyExc_TypeError, msg.c_str()); return -1; } @@ -448,7 +492,7 @@ extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { if (!PyDict_Check(new_dict)) { PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", - Py_TYPE(new_dict)->tp_name); + get_fully_qualified_tp_name(Py_TYPE(new_dict)).c_str()); return -1; } PyObject *&dict = *_PyObject_GetDictPtr(self); @@ -475,11 +519,6 @@ extern "C" inline int pybind11_clear(PyObject *self) { /// Give instances of this type a `__dict__` and opt into garbage collection. inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { auto type = &heap_type->ht_type; -#if defined(PYPY_VERSION) && (PYPY_VERSION_NUM < 0x06000000) - pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " - "currently not supported in " - "conjunction with PyPy!"); -#endif type->tp_flags |= Py_TPFLAGS_HAVE_GC; type->tp_dictoffset = type->tp_basicsize; // place dict at the end type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it @@ -510,6 +549,12 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla } std::memset(view, 0, sizeof(Py_buffer)); buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); + if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE && info->readonly) { + delete info; + // view->obj = nullptr; // Was just memset to 0, so not necessary + PyErr_SetString(PyExc_BufferError, "Writable buffer requested for readonly storage"); + return -1; + } view->obj = obj; view->ndim = 1; view->internal = info; @@ -518,13 +563,7 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla view->len = view->itemsize; for (auto s : info->shape) view->len *= s; - view->readonly = info->readonly; - if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE && info->readonly) { - if (view) - view->obj = nullptr; - PyErr_SetString(PyExc_BufferError, "Writable buffer requested for readonly storage"); - return -1; - } + view->readonly = static_cast(info->readonly); if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) view->format = const_cast(info->format.c_str()); if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { @@ -567,17 +606,17 @@ inline PyObject* make_new_python_type(const type_record &rec) { #endif } - object module; + object module_; if (rec.scope) { if (hasattr(rec.scope, "__module__")) - module = rec.scope.attr("__module__"); + module_ = rec.scope.attr("__module__"); else if (hasattr(rec.scope, "__name__")) - module = rec.scope.attr("__name__"); + module_ = rec.scope.attr("__name__"); } auto full_name = c_str( #if !defined(PYPY_VERSION) - module ? str(module).cast() + "." + rec.name : + module_ ? str(module_).cast() + "." + rec.name : #endif rec.name); @@ -585,9 +624,9 @@ inline PyObject* make_new_python_type(const type_record &rec) { if (rec.doc && options::show_user_defined_docstrings()) { /* Allocate memory for docstring (using PyObject_MALLOC, since Python will free this later on) */ - size_t size = strlen(rec.doc) + 1; + size_t size = std::strlen(rec.doc) + 1; tp_doc = (char *) PyObject_MALLOC(size); - memcpy((void *) tp_doc, rec.doc, size); + std::memcpy((void *) tp_doc, rec.doc, size); } auto &internals = get_internals(); @@ -644,11 +683,13 @@ inline PyObject* make_new_python_type(const type_record &rec) { if (rec.buffer_protocol) enable_buffer_protocol(heap_type); + if (rec.custom_type_setup_callback) + rec.custom_type_setup_callback(heap_type); + if (PyType_Ready(type) < 0) pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); - assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) - : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + assert(!rec.dynamic_attr || PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); /* Register type with the parent scope */ if (rec.scope) @@ -656,8 +697,8 @@ inline PyObject* make_new_python_type(const type_record &rec) { else Py_INCREF(type); // Keep it alive forever (reference leak) - if (module) // Needed by pydoc - setattr((PyObject *) type, "__module__", module); + if (module_) // Needed by pydoc + setattr((PyObject *) type, "__module__", module_); PYBIND11_SET_OLDPY_QUALNAME(type, qualname); diff --git a/wrap/pybind11/include/pybind11/detail/common.h b/wrap/pybind11/include/pybind11/detail/common.h index 1f8390fba..5c59b4141 100644 --- a/wrap/pybind11/include/pybind11/detail/common.h +++ b/wrap/pybind11/include/pybind11/detail/common.h @@ -10,8 +10,12 @@ #pragma once #define PYBIND11_VERSION_MAJOR 2 -#define PYBIND11_VERSION_MINOR 6 -#define PYBIND11_VERSION_PATCH 0.dev1 +#define PYBIND11_VERSION_MINOR 9 +#define PYBIND11_VERSION_PATCH 1 + +// Similar to Python's convention: https://docs.python.org/3/c-api/apiabiversion.html +// Additional convention: 0xD = dev +#define PYBIND11_VERSION_HEX 0x02090100 #define PYBIND11_NAMESPACE_BEGIN(name) namespace name { #define PYBIND11_NAMESPACE_END(name) } @@ -27,11 +31,14 @@ # endif #endif -#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) +#if !(defined(_MSC_VER) && __cplusplus == 199711L) # if __cplusplus >= 201402L # define PYBIND11_CPP14 # if __cplusplus >= 201703L # define PYBIND11_CPP17 +# if __cplusplus >= 202002L +# define PYBIND11_CPP20 +# endif # endif # endif #elif defined(_MSC_VER) && __cplusplus == 199711L @@ -41,15 +48,23 @@ # define PYBIND11_CPP14 # if _MSVC_LANG > 201402L && _MSC_VER >= 1910 # define PYBIND11_CPP17 +# if _MSVC_LANG >= 202002L +# define PYBIND11_CPP20 +# endif # endif # endif #endif // Compiler version assertions #if defined(__INTEL_COMPILER) -# if __INTEL_COMPILER < 1700 -# error pybind11 requires Intel C++ compiler v17 or newer +# if __INTEL_COMPILER < 1800 +# error pybind11 requires Intel C++ compiler v18 or newer +# elif __INTEL_COMPILER < 1900 && defined(PYBIND11_CPP14) +# error pybind11 supports only C++11 with Intel C++ compiler v18. Use v19 or newer for C++14. # endif +/* The following pragma cannot be pop'ed: + https://community.intel.com/t5/Intel-C-Compiler/Inline-and-no-inline-warning/td-p/1216764 */ +# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" #elif defined(__clang__) && !defined(__apple_build_version__) # if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) # error pybind11 requires clang 3.3 or newer @@ -80,13 +95,43 @@ # endif #endif -#if defined(_MSC_VER) -# define PYBIND11_NOINLINE __declspec(noinline) -#else -# define PYBIND11_NOINLINE __attribute__ ((noinline)) +#if !defined(PYBIND11_EXPORT_EXCEPTION) +# ifdef __MINGW32__ +// workaround for: +// error: 'dllexport' implies default visibility, but xxx has already been declared with a different visibility +# define PYBIND11_EXPORT_EXCEPTION +# else +# define PYBIND11_EXPORT_EXCEPTION PYBIND11_EXPORT +# endif #endif -#if defined(PYBIND11_CPP14) +// For CUDA, GCC7, GCC8: +// PYBIND11_NOINLINE_FORCED is incompatible with `-Wattributes -Werror`. +// When defining PYBIND11_NOINLINE_FORCED, it is best to also use `-Wno-attributes`. +// However, the measured shared-library size saving when using noinline are only +// 1.7% for CUDA, -0.2% for GCC7, and 0.0% for GCC8 (using -DCMAKE_BUILD_TYPE=MinSizeRel, +// the default under pybind11/tests). +#if !defined(PYBIND11_NOINLINE_FORCED) && \ + (defined(__CUDACC__) || (defined(__GNUC__) && (__GNUC__ == 7 || __GNUC__ == 8))) +# define PYBIND11_NOINLINE_DISABLED +#endif + +// The PYBIND11_NOINLINE macro is for function DEFINITIONS. +// In contrast, FORWARD DECLARATIONS should never use this macro: +// https://stackoverflow.com/questions/9317473/forward-declaration-of-inline-functions +#if defined(PYBIND11_NOINLINE_DISABLED) // Option for maximum portability and experimentation. +# define PYBIND11_NOINLINE inline +#elif defined(_MSC_VER) +# define PYBIND11_NOINLINE __declspec(noinline) inline +#else +# define PYBIND11_NOINLINE __attribute__ ((noinline)) inline +#endif + +#if defined(__MINGW32__) +// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared +// whether it is used or not +# define PYBIND11_DEPRECATED(reason) +#elif defined(PYBIND11_CPP14) # define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] #else # define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) @@ -112,13 +157,61 @@ # define HAVE_ROUND 1 # endif # pragma warning(push) -# pragma warning(disable: 4510 4610 4512 4005) +// C4505: 'PySlice_GetIndicesEx': unreferenced local function has been removed (PyPy only) +# pragma warning(disable: 4505) # if defined(_DEBUG) && !defined(Py_DEBUG) +// Workaround for a VS 2022 issue. +// NOTE: This workaround knowingly violates the Python.h include order requirement: +// https://docs.python.org/3/c-api/intro.html#include-files +// See https://github.com/pybind/pybind11/pull/3497 for full context. +# include +# if _MSVC_STL_VERSION >= 143 +# include +# endif # define PYBIND11_DEBUG_MARKER # undef _DEBUG # endif #endif +// https://en.cppreference.com/w/c/chrono/localtime +#if defined(__STDC_LIB_EXT1__) && !defined(__STDC_WANT_LIB_EXT1__) +# define __STDC_WANT_LIB_EXT1__ +#endif + +#ifdef __has_include +// std::optional (but including it in c++14 mode isn't allowed) +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_OPTIONAL 1 +# endif +// std::experimental::optional (but not allowed in c++11 mode) +# if defined(PYBIND11_CPP14) && (__has_include() && \ + !__has_include()) +# define PYBIND11_HAS_EXP_OPTIONAL 1 +# endif +// std::variant +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_VARIANT 1 +# endif +#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) +# define PYBIND11_HAS_OPTIONAL 1 +# define PYBIND11_HAS_VARIANT 1 +#endif + +#if defined(PYBIND11_CPP17) +# if defined(__has_include) +# if __has_include() +# define PYBIND11_HAS_STRING_VIEW +# endif +# elif defined(_MSC_VER) +# define PYBIND11_HAS_STRING_VIEW +# endif +#endif + +#if defined(__cpp_lib_char8_t) && __cpp_lib_char8_t >= 201811L +# define PYBIND11_HAS_U8STRING +#endif + + #include #include #include @@ -160,6 +253,24 @@ #include #include #include +#if defined(__has_include) +# if __has_include() +# include +# endif +#endif + +// #define PYBIND11_STR_LEGACY_PERMISSIVE +// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject +// (probably surprising and never documented, but this was the +// legacy behavior until and including v2.6.x). As a side-effect, +// pybind11::isinstance() is true for both pybind11::str and +// pybind11::bytes. +// If UNDEFINED, pybind11::str can only hold PyUnicodeObject, and +// pybind11::isinstance() is true only for pybind11::str. +// However, for Python 2 only (!), the pybind11::str caster +// implicitly decodes bytes to PyUnicodeObject. This is to ease +// the transition from the legacy behavior to the non-permissive +// behavior. #if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions #define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) @@ -173,8 +284,8 @@ #define PYBIND11_BYTES_SIZE PyBytes_Size #define PYBIND11_LONG_CHECK(o) PyLong_Check(o) #define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) -#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o) -#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o) +#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) (o)) +#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) (o)) #define PYBIND11_BYTES_NAME "bytes" #define PYBIND11_STRING_NAME "str" #define PYBIND11_SLICE_OBJECT PyObject @@ -182,6 +293,7 @@ #define PYBIND11_STR_TYPE ::pybind11::str #define PYBIND11_BOOL_ATTR "__bool__" #define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) +#define PYBIND11_BUILTINS_MODULE "builtins" // Providing a separate declaration to make Clang's -Wmissing-prototypes happy. // See comment for PYBIND11_MODULE below for why this is marked "maybe unused". #define PYBIND11_PLUGIN_IMPL(name) \ @@ -209,6 +321,7 @@ #define PYBIND11_STR_TYPE ::pybind11::bytes #define PYBIND11_BOOL_ATTR "__nonzero__" #define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero) +#define PYBIND11_BUILTINS_MODULE "__builtin__" // Providing a separate PyInit decl to make Clang's -Wmissing-prototypes happy. // See comment for PYBIND11_MODULE below for why this is marked "maybe unused". #define PYBIND11_PLUGIN_IMPL(name) \ @@ -250,6 +363,19 @@ extern "C" { } \ } +#if PY_VERSION_HEX >= 0x03030000 + +#define PYBIND11_CATCH_INIT_EXCEPTIONS \ + catch (pybind11::error_already_set &e) { \ + pybind11::raise_from(e, PyExc_ImportError, "initialization failed"); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + +#else + #define PYBIND11_CATCH_INIT_EXCEPTIONS \ catch (pybind11::error_already_set &e) { \ PyErr_SetString(PyExc_ImportError, e.what()); \ @@ -259,17 +385,19 @@ extern "C" { return nullptr; \ } \ +#endif + /** \rst ***Deprecated in favor of PYBIND11_MODULE*** This macro creates the entry point that will be invoked when the Python interpreter - imports a plugin library. Please create a `module` in the function body and return + imports a plugin library. Please create a `module_` in the function body and return the pointer to its underlying Python object at the end. .. code-block:: cpp PYBIND11_PLUGIN(example) { - pybind11::module m("example", "pybind11 example plugin"); + pybind11::module_ m("example", "pybind11 example plugin"); /// Set up bindings here return m.ptr(); } @@ -290,7 +418,7 @@ extern "C" { This macro creates the entry point that will be invoked when the Python interpreter imports an extension module. The module name is given as the fist argument and it should not be in quotes. The second macro argument defines a variable of type - `py::module` which can be used to initialize the module. + `py::module_` which can be used to initialize the module. The entry point is marked as "maybe unused" to aid dead-code detection analysis: since the entry point is typically only looked up at runtime and not referenced @@ -307,26 +435,35 @@ extern "C" { }); } \endrst */ -#define PYBIND11_MODULE(name, variable) \ - PYBIND11_MAYBE_UNUSED \ - static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ - PYBIND11_PLUGIN_IMPL(name) { \ - PYBIND11_CHECK_PYTHON_VERSION \ - PYBIND11_ENSURE_INTERNALS_READY \ - auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ - try { \ - PYBIND11_CONCAT(pybind11_init_, name)(m); \ - return m.ptr(); \ - } PYBIND11_CATCH_INIT_EXCEPTIONS \ - } \ - void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) - +#define PYBIND11_MODULE(name, variable) \ + static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name) \ + PYBIND11_MAYBE_UNUSED; \ + PYBIND11_MAYBE_UNUSED \ + static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + PYBIND11_ENSURE_INTERNALS_READY \ + auto m = ::pybind11::module_::create_extension_module( \ + PYBIND11_TOSTRING(name), nullptr, &PYBIND11_CONCAT(pybind11_module_def_, name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ & (variable)) PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) using ssize_t = Py_ssize_t; using size_t = std::size_t; +template +inline ssize_t ssize_t_cast(const IntType &val) { + static_assert(sizeof(IntType) <= sizeof(ssize_t), "Implicit narrowing is not permitted."); + return static_cast(val); +} + /// Approach used to cast a previously unknown C++ instance into a Python object enum class return_value_policy : uint8_t { /** This is the default return value policy, which falls back to the policy @@ -481,6 +618,18 @@ template using remove_cv_t = typename std::remove_cv::type; template using remove_reference_t = typename std::remove_reference::type; #endif +#if defined(PYBIND11_CPP20) +using std::remove_cvref; +using std::remove_cvref_t; +#else +template +struct remove_cvref { + using type = remove_cv_t>; +}; +template +using remove_cvref_t = typename remove_cvref::type; +#endif + /// Index sequences #if defined(PYBIND11_CPP14) using std::index_sequence; @@ -488,7 +637,7 @@ using std::make_index_sequence; #else template struct index_sequence { }; template struct make_index_sequence_impl : make_index_sequence_impl { }; -template struct make_index_sequence_impl <0, S...> { typedef index_sequence type; }; +template struct make_index_sequence_impl <0, S...> { using type = index_sequence; }; template using make_index_sequence = typename make_index_sequence_impl::type; #endif @@ -502,10 +651,10 @@ template using select_indices = typename select_indices_impl using bool_constant = std::integral_constant; template struct negation : bool_constant { }; -// PGI cannot detect operator delete with the "compatible" void_t impl, so +// PGI/Intel cannot detect operator delete with the "compatible" void_t impl, so // using the new one (C++14 defect, so generally works on newer compilers, even // if not in C++17 mode) -#if defined(__PGIC__) +#if defined(__PGIC__) || defined(__INTEL_COMPILER) template using void_t = void; #else template struct void_t_impl { using type = void; }; @@ -618,8 +767,9 @@ template using is_strict_base_of = bool_consta /// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer /// can be converted to a Base pointer) +/// For unions, `is_base_of::value` is False, so we need to check `is_same` as well. template using is_accessible_base_of = bool_constant< - std::is_base_of::value && std::is_convertible::value>; + (std::is_same::value || std::is_base_of::value) && std::is_convertible::value>; template class Base> struct is_template_base_of_impl { @@ -656,6 +806,10 @@ template using is_function_pointer = bool_constant< std::is_pointer::value && std::is_function::type>::value>; template struct strip_function_object { + // If you are encountering an + // 'error: name followed by "::" must be a class or namespace name' + // with the Intel compiler and a noexcept function here, + // try to use noexcept(true) instead of plain noexcept. using type = typename remove_class::type; }; @@ -677,11 +831,10 @@ using function_signature_t = conditional_t< template using is_lambda = satisfies_none_of, std::is_function, std::is_pointer, std::is_member_pointer>; -/// Ignore that a variable is unused in compiler warnings -inline void ignore_unused(const int *) { } - +// [workaround(intel)] Internal error on fold expression /// Apply a function over each element of a parameter pack -#ifdef __cpp_fold_expressions +#if defined(__cpp_fold_expressions) && !defined(__INTEL_COMPILER) +// Intel compiler produces an internal error on this fold expression (tested with ICC 19.0.2) #define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...) #else using expand_side_effects = bool[]; @@ -690,16 +843,23 @@ using expand_side_effects = bool[]; PYBIND11_NAMESPACE_END(detail) +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4275) // warning C4275: An exported class was derived from a class that wasn't exported. Can be ignored when derived from a STL class. +#endif /// C++ bindings of builtin Python exceptions -class builtin_exception : public std::runtime_error { +class PYBIND11_EXPORT_EXCEPTION builtin_exception : public std::runtime_error { public: using std::runtime_error::runtime_error; /// Set the error using the Python C API virtual void set_error() const = 0; }; +#if defined(_MSC_VER) +# pragma warning(pop) +#endif #define PYBIND11_RUNTIME_EXCEPTION(name, type) \ - class name : public builtin_exception { public: \ + class PYBIND11_EXPORT_EXCEPTION name : public builtin_exception { public: \ using builtin_exception::builtin_exception; \ name() : name("") { } \ void set_error() const override { PyErr_SetString(type, what()); } \ @@ -712,11 +872,12 @@ PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError) PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError) PYBIND11_RUNTIME_EXCEPTION(buffer_error, PyExc_BufferError) PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError) +PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally -[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); } -[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); } +[[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const char *reason) { throw std::runtime_error(reason); } +[[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); } template struct format_descriptor { }; @@ -761,7 +922,8 @@ struct nodelete { template void operator()(T*) { } }; PYBIND11_NAMESPACE_BEGIN(detail) template struct overload_cast_impl { - constexpr overload_cast_impl() {}; // NOLINT(modernize-use-equals-default): MSVC 2015 needs this + // NOLINTNEXTLINE(modernize-use-equals-default): MSVC 2015 needs this + constexpr overload_cast_impl() {} template constexpr auto operator()(Return (*pf)(Args...)) const noexcept @@ -817,6 +979,7 @@ public: // Implicit conversion constructor from any arbitrary container type with values convertible to T template ())), T>::value>> + // NOLINTNEXTLINE(google-explicit-constructor) any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { } // initializer_list's aren't deducible, so don't get matched by the above template; we need this @@ -825,9 +988,11 @@ public: any_container(const std::initializer_list &c) : any_container(c.begin(), c.end()) { } // Avoid copying if given an rvalue vector of the correct type. + // NOLINTNEXTLINE(google-explicit-constructor) any_container(std::vector &&v) : v(std::move(v)) { } // Moves the vector out of an rvalue any_container + // NOLINTNEXTLINE(google-explicit-constructor) operator std::vector &&() && { return std::move(v); } // Dereferencing obtains a reference to the underlying vector @@ -839,8 +1004,60 @@ public: const std::vector *operator->() const { return &v; } }; +// Forward-declaration; see detail/class.h +std::string get_fully_qualified_tp_name(PyTypeObject*); + +template +inline static std::shared_ptr try_get_shared_from_this(std::enable_shared_from_this *holder_value_ptr) { +// Pre C++17, this code path exploits undefined behavior, but is known to work on many platforms. +// Use at your own risk! +// See also https://en.cppreference.com/w/cpp/memory/enable_shared_from_this, and in particular +// the `std::shared_ptr gp1 = not_so_good.getptr();` and `try`-`catch` parts of the example. +#if defined(__cpp_lib_enable_shared_from_this) && (!defined(_MSC_VER) || _MSC_VER >= 1912) + return holder_value_ptr->weak_from_this().lock(); +#else + try { + return holder_value_ptr->shared_from_this(); + } + catch (const std::bad_weak_ptr &) { + return nullptr; + } +#endif +} + +// For silencing "unused" compiler warnings in special situations. +template +#if defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER < 1920 // MSVC 2017 +constexpr +#endif +inline void silence_unused_warnings(Args &&...) {} + +// MSVC warning C4100: Unreferenced formal parameter +#if defined(_MSC_VER) && _MSC_VER <= 1916 +# define PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(...) \ + detail::silence_unused_warnings(__VA_ARGS__) +#else +# define PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(...) +#endif + +// GCC -Wunused-but-set-parameter All GCC versions (as of July 2021). +#if defined(__GNUG__) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(...) \ + detail::silence_unused_warnings(__VA_ARGS__) +#else +# define PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(...) +#endif + +#if defined(_MSC_VER) // All versions (as of July 2021). + +// warning C4127: Conditional expression is constant +constexpr inline bool silence_msvc_c4127(bool cond) { return cond; } + +# define PYBIND11_SILENCE_MSVC_C4127(...) ::pybind11::detail::silence_msvc_c4127(__VA_ARGS__) + +#else +# define PYBIND11_SILENCE_MSVC_C4127(...) __VA_ARGS__ +#endif + PYBIND11_NAMESPACE_END(detail) - - - PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/detail/descr.h b/wrap/pybind11/include/pybind11/detail/descr.h index 92720cd56..0f93e06b2 100644 --- a/wrap/pybind11/include/pybind11/detail/descr.h +++ b/wrap/pybind11/include/pybind11/detail/descr.h @@ -23,15 +23,17 @@ PYBIND11_NAMESPACE_BEGIN(detail) /* Concatenate type signatures at compile time */ template struct descr { - char text[N + 1]; + char text[N + 1]{'\0'}; - constexpr descr() : text{'\0'} { } + constexpr descr() = default; + // NOLINTNEXTLINE(google-explicit-constructor) constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } template constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } template + // NOLINTNEXTLINE(google-explicit-constructor) constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } static constexpr std::array types() { @@ -42,6 +44,7 @@ struct descr { template constexpr descr plus_impl(const descr &a, const descr &b, index_sequence, index_sequence) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(b); return {a.text[Is1]..., b.text[Is2]...}; } @@ -51,34 +54,64 @@ constexpr descr operator+(const descr &a, c } template -constexpr descr _(char const(&text)[N]) { return descr(text); } -constexpr descr<0> _(char const(&)[1]) { return {}; } +constexpr descr const_name(char const(&text)[N]) { return descr(text); } +constexpr descr<0> const_name(char const(&)[1]) { return {}; } template struct int_to_str : int_to_str { }; template struct int_to_str<0, Digits...> { + // WARNING: This only works with C++17 or higher. static constexpr auto digits = descr(('0' + Digits)...); }; // Ternary description (like std::conditional) template -constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { - return _(text1); +constexpr enable_if_t> const_name(char const(&text1)[N1], char const(&)[N2]) { + return const_name(text1); } template -constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { - return _(text2); +constexpr enable_if_t> const_name(char const(&)[N1], char const(&text2)[N2]) { + return const_name(text2); } template -constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } +constexpr enable_if_t const_name(const T1 &d, const T2 &) { return d; } template -constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } +constexpr enable_if_t const_name(const T1 &, const T2 &d) { return d; } -template auto constexpr _() -> decltype(int_to_str::digits) { +template +auto constexpr const_name() -> remove_cv_t::digits)> { return int_to_str::digits; } -template constexpr descr<1, Type> _() { return {'%'}; } +template constexpr descr<1, Type> const_name() { return {'%'}; } + +// If "_" is defined as a macro, py::detail::_ cannot be provided. +// It is therefore best to use py::detail::const_name universally. +// This block is for backward compatibility only. +// (The const_name code is repeated to avoid introducing a "_" #define ourselves.) +#ifndef _ +#define PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY +template +constexpr descr _(char const(&text)[N]) { return const_name(text); } +template +constexpr enable_if_t> _(char const(&text1)[N1], char const(&text2)[N2]) { + return const_name(text1, text2); +} +template +constexpr enable_if_t> _(char const(&text1)[N1], char const(&text2)[N2]) { + return const_name(text1, text2); +} +template +constexpr enable_if_t _(const T1 &d1, const T2 &d2) { return const_name(d1, d2); } +template +constexpr enable_if_t _(const T1 &d1, const T2 &d2) { return const_name(d1, d2); } + +template +auto constexpr _() -> remove_cv_t::digits)> { + return const_name(); +} +template constexpr descr<1, Type> _() { return const_name(); } +#endif // #ifndef _ constexpr descr<0> concat() { return {}; } @@ -88,12 +121,12 @@ constexpr descr concat(const descr &descr) { return descr; } template constexpr auto concat(const descr &d, const Args &...args) -> decltype(std::declval>() + concat(args...)) { - return d + _(", ") + concat(args...); + return d + const_name(", ") + concat(args...); } template constexpr descr type_descr(const descr &descr) { - return _("{") + descr + _("}"); + return const_name("{") + descr + const_name("}"); } PYBIND11_NAMESPACE_END(detail) diff --git a/wrap/pybind11/include/pybind11/detail/init.h b/wrap/pybind11/include/pybind11/detail/init.h index 3ef78c117..eaaad5a07 100644 --- a/wrap/pybind11/include/pybind11/detail/init.h +++ b/wrap/pybind11/include/pybind11/detail/init.h @@ -23,8 +23,8 @@ public: } template using cast_op_type = value_and_holder &; - operator value_and_holder &() { return *value; } - static constexpr auto name = _(); + explicit operator value_and_holder &() { return *value; } + static constexpr auto name = const_name(); private: value_and_holder *value = nullptr; @@ -94,8 +94,9 @@ void construct(...) { // construct an Alias from the returned base instance. template void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias); no_nullptr(ptr); - if (Class::has_alias && need_alias && !is_alias(ptr)) { + if (PYBIND11_SILENCE_MSVC_C4127(Class::has_alias) && need_alias && !is_alias(ptr)) { // We're going to try to construct an alias by moving the cpp type. Whether or not // that succeeds, we still need to destroy the original cpp pointer (either the // moved away leftover, if the alias construction works, or the value itself if we @@ -131,10 +132,11 @@ void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { // derived type (through those holder's implicit conversion from derived class holder constructors). template void construct(value_and_holder &v_h, Holder holder, bool need_alias) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias); auto *ptr = holder_helper>::get(holder); no_nullptr(ptr); // If we need an alias, check that the held pointer is actually an alias instance - if (Class::has_alias && need_alias && !is_alias(ptr)) + if (PYBIND11_SILENCE_MSVC_C4127(Class::has_alias) && need_alias && !is_alias(ptr)) throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " "is not an alias instance"); @@ -148,9 +150,10 @@ void construct(value_and_holder &v_h, Holder holder, bool need_alias) { // need it, we simply move-construct the cpp value into a new instance. template void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias); static_assert(std::is_move_constructible>::value, "pybind11::init() return-by-value factory function requires a movable class"); - if (Class::has_alias && need_alias) + if (PYBIND11_SILENCE_MSVC_C4127(Class::has_alias) && need_alias) construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); else v_h.value_ptr() = new Cpp(std::move(result)); @@ -219,7 +222,8 @@ template struct factory { remove_reference_t class_factory; - factory(Func &&f) : class_factory(std::forward(f)) { } + // NOLINTNEXTLINE(google-explicit-constructor) + factory(Func &&f) : class_factory(std::forward(f)) {} // The given class either has no alias or has no separate alias factory; // this always constructs the class itself. If the class is registered with an alias @@ -293,7 +297,13 @@ template ::value, int> = 0> void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { construct(v_h, std::move(result.first), need_alias); - setattr((PyObject *) v_h.inst, "__dict__", result.second); + auto d = handle(result.second); + if (PyDict_Check(d.ptr()) && PyDict_Size(d.ptr()) == 0) { + // Skipping setattr below, to not force use of py::dynamic_attr() for Class unnecessarily. + // See PR #2972 for details. + return; + } + setattr((PyObject *) v_h.inst, "__dict__", d); } /// Implementation for py::pickle(GetState, SetState) diff --git a/wrap/pybind11/include/pybind11/detail/internals.h b/wrap/pybind11/include/pybind11/detail/internals.h index 133d2f4c8..9edb9492e 100644 --- a/wrap/pybind11/include/pybind11/detail/internals.h +++ b/wrap/pybind11/include/pybind11/detail/internals.h @@ -10,9 +10,32 @@ #pragma once #include "../pytypes.h" +#include + +/// Tracks the `internals` and `type_info` ABI version independent of the main library version. +/// +/// Some portions of the code use an ABI that is conditional depending on this +/// version number. That allows ABI-breaking changes to be "pre-implemented". +/// Once the default version number is incremented, the conditional logic that +/// no longer applies can be removed. Additionally, users that need not +/// maintain ABI compatibility can increase the version number in order to take +/// advantage of any functionality/efficiency improvements that depend on the +/// newer ABI. +/// +/// WARNING: If you choose to manually increase the ABI version, note that +/// pybind11 may not be tested as thoroughly with a non-default ABI version, and +/// further ABI-incompatible changes may be made before the ABI is officially +/// changed to the new version. +#ifndef PYBIND11_INTERNALS_VERSION +# define PYBIND11_INTERNALS_VERSION 4 +#endif PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +using ExceptionTranslator = void (*)(std::exception_ptr); + PYBIND11_NAMESPACE_BEGIN(detail) + // Forward declarations inline PyTypeObject *make_static_property_type(); inline PyTypeObject *make_default_metaclass(); @@ -21,30 +44,59 @@ inline PyObject *make_object_base_type(PyTypeObject *metaclass); // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new // Thread Specific Storage (TSS) API. #if PY_VERSION_HEX >= 0x03070000 -# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr -# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value)) -# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) -# define PYBIND11_TLS_FREE(key) PyThread_tss_free(key) -#else - // Usually an int but a long on Cygwin64 with Python 3.x -# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 -# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) -# if PY_MAJOR_VERSION < 3 -# define PYBIND11_TLS_DELETE_VALUE(key) \ - PyThread_delete_key_value(key) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ - do { \ - PyThread_delete_key_value((key)); \ - PyThread_set_key_value((key), (value)); \ - } while (false) +// Avoid unnecessary allocation of `Py_tss_t`, since we cannot use +// `Py_LIMITED_API` anyway. +# if PYBIND11_INTERNALS_VERSION > 4 +# define PYBIND11_TLS_KEY_REF Py_tss_t & +# ifdef __GNUC__ +// Clang on macOS warns due to `Py_tss_NEEDS_INIT` not specifying an initializer +// for every field. +# define PYBIND11_TLS_KEY_INIT(var) \ + _Pragma("GCC diagnostic push") /**/ \ + _Pragma("GCC diagnostic ignored \"-Wmissing-field-initializers\"") /**/ \ + Py_tss_t var \ + = Py_tss_NEEDS_INIT; \ + _Pragma("GCC diagnostic pop") +# else +# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t var = Py_tss_NEEDS_INIT; +# endif +# define PYBIND11_TLS_KEY_CREATE(var) (PyThread_tss_create(&(var)) == 0) +# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get(&(key)) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set(&(key), (value)) +# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set(&(key), nullptr) +# define PYBIND11_TLS_FREE(key) PyThread_tss_delete(&(key)) # else -# define PYBIND11_TLS_DELETE_VALUE(key) \ - PyThread_set_key_value((key), nullptr) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ - PyThread_set_key_value((key), (value)) +# define PYBIND11_TLS_KEY_REF Py_tss_t * +# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr; +# define PYBIND11_TLS_KEY_CREATE(var) \ + (((var) = PyThread_tss_alloc()) != nullptr && (PyThread_tss_create((var)) == 0)) +# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value)) +# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) +# define PYBIND11_TLS_FREE(key) PyThread_tss_free(key) # endif -# define PYBIND11_TLS_FREE(key) (void)key +#else +// Usually an int but a long on Cygwin64 with Python 3.x +# define PYBIND11_TLS_KEY_REF decltype(PyThread_create_key()) +# define PYBIND11_TLS_KEY_INIT(var) PYBIND11_TLS_KEY_REF var = 0; +# define PYBIND11_TLS_KEY_CREATE(var) (((var) = PyThread_create_key()) != -1) +# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) +# if PY_MAJOR_VERSION < 3 || defined(PYPY_VERSION) +// On CPython < 3.4 and on PyPy, `PyThread_set_key_value` strangely does not set +// the value if it has already been set. Instead, it must first be deleted and +// then set again. +inline void tls_replace_value(PYBIND11_TLS_KEY_REF key, void *value) { + PyThread_delete_key_value(key); + PyThread_set_key_value(key, value); +} +# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_delete_key_value(key) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ + ::pybind11::detail::tls_replace_value((key), (value)) +# else +# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_set_key_value((key), nullptr) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_set_key_value((key), (value)) +# endif +# define PYBIND11_TLS_FREE(key) (void) key #endif // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly @@ -100,24 +152,33 @@ struct internals { std::unordered_set, override_hash> inactive_override_cache; type_map> direct_conversions; std::unordered_map> patients; - std::forward_list registered_exception_translators; + std::forward_list registered_exception_translators; std::unordered_map shared_data; // Custom data to be shared across extensions - std::vector loader_patient_stack; // Used by `loader_life_support` +#if PYBIND11_INTERNALS_VERSION == 4 + std::vector unused_loader_patient_stack_remove_at_v5; +#endif std::forward_list static_strings; // Stores the std::strings backing detail::c_str() PyTypeObject *static_property_type; PyTypeObject *default_metaclass; PyObject *instance_base; #if defined(WITH_THREAD) - PYBIND11_TLS_KEY_INIT(tstate); + PYBIND11_TLS_KEY_INIT(tstate) +# if PYBIND11_INTERNALS_VERSION > 4 + PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key) +# endif // PYBIND11_INTERNALS_VERSION > 4 PyInterpreterState *istate = nullptr; ~internals() { +# if PYBIND11_INTERNALS_VERSION > 4 + PYBIND11_TLS_FREE(loader_life_support_tls_key); +# endif // PYBIND11_INTERNALS_VERSION > 4 + // This destructor is called *after* Py_Finalize() in finalize_interpreter(). - // That *SHOULD BE* fine. The following details what happens whe PyThread_tss_free is called. - // PYBIND11_TLS_FREE is PyThread_tss_free on python 3.7+. On older python, it does nothing. - // PyThread_tss_free calls PyThread_tss_delete and PyMem_RawFree. - // PyThread_tss_delete just calls TlsFree (on Windows) or pthread_key_delete (on *NIX). Neither - // of those have anything to do with CPython internals. - // PyMem_RawFree *requires* that the `tstate` be allocated with the CPython allocator. + // That *SHOULD BE* fine. The following details what happens when PyThread_tss_free is + // called. PYBIND11_TLS_FREE is PyThread_tss_free on python 3.7+. On older python, it does + // nothing. PyThread_tss_free calls PyThread_tss_delete and PyMem_RawFree. + // PyThread_tss_delete just calls TlsFree (on Windows) or pthread_key_delete (on *NIX). + // Neither of those have anything to do with CPython internals. PyMem_RawFree *requires* + // that the `tstate` be allocated with the CPython allocator. PYBIND11_TLS_FREE(tstate); } #endif @@ -139,7 +200,9 @@ struct type_info { void *get_buffer_data = nullptr; void *(*module_local_load)(PyObject *, const type_info *) = nullptr; /* A simple type never occurs as a (direct or indirect) parent - * of a class that makes use of multiple inheritance */ + * of a class that makes use of multiple inheritance. + * A type can be simple even if it has non-simple ancestors as long as it has no descendants. + */ bool simple_type : 1; /* True if there is no multiple inheritance in this type's inheritance tree */ bool simple_ancestors : 1; @@ -149,54 +212,62 @@ struct type_info { bool module_local : 1; }; -/// Tracks the `internals` and `type_info` ABI version independent of the main library version -#define PYBIND11_INTERNALS_VERSION 4 - /// On MSVC, debug and release builds are not ABI-compatible! #if defined(_MSC_VER) && defined(_DEBUG) -# define PYBIND11_BUILD_TYPE "_debug" +# define PYBIND11_BUILD_TYPE "_debug" #else -# define PYBIND11_BUILD_TYPE "" +# define PYBIND11_BUILD_TYPE "" #endif /// Let's assume that different compilers are ABI-incompatible. -#if defined(_MSC_VER) -# define PYBIND11_COMPILER_TYPE "_msvc" -#elif defined(__INTEL_COMPILER) -# define PYBIND11_COMPILER_TYPE "_icc" -#elif defined(__clang__) -# define PYBIND11_COMPILER_TYPE "_clang" -#elif defined(__PGI) -# define PYBIND11_COMPILER_TYPE "_pgi" -#elif defined(__MINGW32__) -# define PYBIND11_COMPILER_TYPE "_mingw" -#elif defined(__CYGWIN__) -# define PYBIND11_COMPILER_TYPE "_gcc_cygwin" -#elif defined(__GNUC__) -# define PYBIND11_COMPILER_TYPE "_gcc" -#else -# define PYBIND11_COMPILER_TYPE "_unknown" +/// A user can manually set this string if they know their +/// compiler is compatible. +#ifndef PYBIND11_COMPILER_TYPE +# if defined(_MSC_VER) +# define PYBIND11_COMPILER_TYPE "_msvc" +# elif defined(__INTEL_COMPILER) +# define PYBIND11_COMPILER_TYPE "_icc" +# elif defined(__clang__) +# define PYBIND11_COMPILER_TYPE "_clang" +# elif defined(__PGI) +# define PYBIND11_COMPILER_TYPE "_pgi" +# elif defined(__MINGW32__) +# define PYBIND11_COMPILER_TYPE "_mingw" +# elif defined(__CYGWIN__) +# define PYBIND11_COMPILER_TYPE "_gcc_cygwin" +# elif defined(__GNUC__) +# define PYBIND11_COMPILER_TYPE "_gcc" +# else +# define PYBIND11_COMPILER_TYPE "_unknown" +# endif #endif -#if defined(_LIBCPP_VERSION) -# define PYBIND11_STDLIB "_libcpp" -#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) -# define PYBIND11_STDLIB "_libstdcpp" -#else -# define PYBIND11_STDLIB "" +/// Also standard libs +#ifndef PYBIND11_STDLIB +# if defined(_LIBCPP_VERSION) +# define PYBIND11_STDLIB "_libcpp" +# elif defined(__GLIBCXX__) || defined(__GLIBCPP__) +# define PYBIND11_STDLIB "_libstdcpp" +# else +# define PYBIND11_STDLIB "" +# endif #endif /// On Linux/OSX, changes in __GXX_ABI_VERSION__ indicate ABI incompatibility. -#if defined(__GXX_ABI_VERSION) -# define PYBIND11_BUILD_ABI "_cxxabi" PYBIND11_TOSTRING(__GXX_ABI_VERSION) -#else -# define PYBIND11_BUILD_ABI "" +#ifndef PYBIND11_BUILD_ABI +# if defined(__GXX_ABI_VERSION) +# define PYBIND11_BUILD_ABI "_cxxabi" PYBIND11_TOSTRING(__GXX_ABI_VERSION) +# else +# define PYBIND11_BUILD_ABI "" +# endif #endif -#if defined(WITH_THREAD) -# define PYBIND11_INTERNALS_KIND "" -#else -# define PYBIND11_INTERNALS_KIND "_without_thread" +#ifndef PYBIND11_INTERNALS_KIND +# if defined(WITH_THREAD) +# define PYBIND11_INTERNALS_KIND "" +# else +# define PYBIND11_INTERNALS_KIND "_without_thread" +# endif #endif #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ @@ -212,21 +283,104 @@ inline internals **&get_internals_pp() { return internals_pp; } +#if PY_VERSION_HEX >= 0x03030000 +// forward decl +inline void translate_exception(std::exception_ptr); + +template >::value, int> = 0> +bool handle_nested_exception(const T &exc, const std::exception_ptr &p) { + std::exception_ptr nested = exc.nested_ptr(); + if (nested != nullptr && nested != p) { + translate_exception(nested); + return true; + } + return false; +} + +template >::value, int> = 0> +bool handle_nested_exception(const T &exc, const std::exception_ptr &p) { + if (auto *nep = dynamic_cast(std::addressof(exc))) { + return handle_nested_exception(*nep, p); + } + return false; +} + +#else + +template +bool handle_nested_exception(const T &, std::exception_ptr &) { + return false; +} +#endif + +inline bool raise_err(PyObject *exc_type, const char *msg) { +#if PY_VERSION_HEX >= 0x03030000 + if (PyErr_Occurred()) { + raise_from(exc_type, msg); + return true; + } +#endif + PyErr_SetString(exc_type, msg); + return false; +} + inline void translate_exception(std::exception_ptr p) { + if (!p) { + return; + } try { - if (p) std::rethrow_exception(p); - } catch (error_already_set &e) { e.restore(); return; - } catch (const builtin_exception &e) { e.set_error(); return; - } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; - } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; - } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::overflow_error &e) { PyErr_SetString(PyExc_OverflowError, e.what()); return; - } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; + std::rethrow_exception(p); + } catch (error_already_set &e) { + handle_nested_exception(e, p); + e.restore(); + return; + } catch (const builtin_exception &e) { + // Could not use template since it's an abstract class. + if (auto *nep = dynamic_cast(std::addressof(e))) { + handle_nested_exception(*nep, p); + } + e.set_error(); + return; + } catch (const std::bad_alloc &e) { + handle_nested_exception(e, p); + raise_err(PyExc_MemoryError, e.what()); + return; + } catch (const std::domain_error &e) { + handle_nested_exception(e, p); + raise_err(PyExc_ValueError, e.what()); + return; + } catch (const std::invalid_argument &e) { + handle_nested_exception(e, p); + raise_err(PyExc_ValueError, e.what()); + return; + } catch (const std::length_error &e) { + handle_nested_exception(e, p); + raise_err(PyExc_ValueError, e.what()); + return; + } catch (const std::out_of_range &e) { + handle_nested_exception(e, p); + raise_err(PyExc_IndexError, e.what()); + return; + } catch (const std::range_error &e) { + handle_nested_exception(e, p); + raise_err(PyExc_ValueError, e.what()); + return; + } catch (const std::overflow_error &e) { + handle_nested_exception(e, p); + raise_err(PyExc_OverflowError, e.what()); + return; + } catch (const std::exception &e) { + handle_nested_exception(e, p); + raise_err(PyExc_RuntimeError, e.what()); + return; + } catch (const std::nested_exception &e) { + handle_nested_exception(e, p); + raise_err(PyExc_RuntimeError, "Caught an unknown nested exception!"); + return; } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); + raise_err(PyExc_RuntimeError, "Caught an unknown exception!"); return; } } @@ -242,7 +396,7 @@ inline void translate_local_exception(std::exception_ptr p) { #endif /// Return a reference to the current `internals` data -PYBIND11_NOINLINE inline internals &get_internals() { +PYBIND11_NOINLINE internals &get_internals() { auto **&internals_pp = get_internals_pp(); if (internals_pp && *internals_pp) return **internals_pp; @@ -255,7 +409,7 @@ PYBIND11_NOINLINE inline internals &get_internals() { const PyGILState_STATE state; } gil; - constexpr auto *id = PYBIND11_INTERNALS_ID; + PYBIND11_STR_TYPE id(PYBIND11_INTERNALS_ID); auto builtins = handle(PyEval_GetBuiltins()); if (builtins.contains(id) && isinstance(builtins[id])) { internals_pp = static_cast(capsule(builtins[id])); @@ -265,6 +419,8 @@ PYBIND11_NOINLINE inline internals &get_internals() { // initial exception translator, below, so add another for our local exception classes. // // libstdc++ doesn't require this (types there are identified only by name) + // libc++ with CPython doesn't require this (types are explicitly exported) + // libc++ with PyPy still need it, awaiting further investigation #if !defined(__GLIBCXX__) (*internals_pp)->registered_exception_translators.push_front(&translate_local_exception); #endif @@ -274,21 +430,21 @@ PYBIND11_NOINLINE inline internals &get_internals() { internals_ptr = new internals(); #if defined(WITH_THREAD) - #if PY_VERSION_HEX < 0x03090000 - PyEval_InitThreads(); - #endif +# if PY_VERSION_HEX < 0x03090000 + PyEval_InitThreads(); +# endif PyThreadState *tstate = PyThreadState_Get(); - #if PY_VERSION_HEX >= 0x03070000 - internals_ptr->tstate = PyThread_tss_alloc(); - if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) - pybind11_fail("get_internals: could not successfully initialize the TSS key!"); - PyThread_tss_set(internals_ptr->tstate, tstate); - #else - internals_ptr->tstate = PyThread_create_key(); - if (internals_ptr->tstate == -1) - pybind11_fail("get_internals: could not successfully initialize the TLS key!"); - PyThread_set_key_value(internals_ptr->tstate, tstate); - #endif + if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->tstate)) { + pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!"); + } + PYBIND11_TLS_REPLACE_VALUE(internals_ptr->tstate, tstate); + +# if PYBIND11_INTERNALS_VERSION > 4 + if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->loader_life_support_tls_key)) { + pybind11_fail("get_internals: could not successfully initialize the " + "loader_life_support TSS key!"); + } +# endif internals_ptr->istate = tstate->interp; #endif builtins[id] = capsule(internals_pp); @@ -300,12 +456,57 @@ PYBIND11_NOINLINE inline internals &get_internals() { return **internals_pp; } -/// Works like `internals.registered_types_cpp`, but for module-local registered types: -inline type_map ®istered_local_types_cpp() { - static type_map locals{}; - return locals; +// the internals struct (above) is shared between all the modules. local_internals are only +// for a single module. Any changes made to internals may require an update to +// PYBIND11_INTERNALS_VERSION, breaking backwards compatibility. local_internals is, by design, +// restricted to a single module. Whether a module has local internals or not should not +// impact any other modules, because the only things accessing the local internals is the +// module that contains them. +struct local_internals { + type_map registered_types_cpp; + std::forward_list registered_exception_translators; +#if defined(WITH_THREAD) && PYBIND11_INTERNALS_VERSION == 4 + + // For ABI compatibility, we can't store the loader_life_support TLS key in + // the `internals` struct directly. Instead, we store it in `shared_data` and + // cache a copy in `local_internals`. If we allocated a separate TLS key for + // each instance of `local_internals`, we could end up allocating hundreds of + // TLS keys if hundreds of different pybind11 modules are loaded (which is a + // plausible number). + PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key) + + // Holds the shared TLS key for the loader_life_support stack. + struct shared_loader_life_support_data { + PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key) + shared_loader_life_support_data() { + if (!PYBIND11_TLS_KEY_CREATE(loader_life_support_tls_key)) { + pybind11_fail("local_internals: could not successfully initialize the " + "loader_life_support TLS key!"); + } + } + // We can't help but leak the TLS key, because Python never unloads extension modules. + }; + + local_internals() { + auto &internals = get_internals(); + // Get or create the `loader_life_support_stack_key`. + auto &ptr = internals.shared_data["_life_support"]; + if (!ptr) { + ptr = new shared_loader_life_support_data; + } + loader_life_support_tls_key + = static_cast(ptr)->loader_life_support_tls_key; + } +#endif // defined(WITH_THREAD) && PYBIND11_INTERNALS_VERSION == 4 +}; + +/// Works like `get_internals`, but for things which are locally registered. +inline local_internals &get_local_internals() { + static local_internals locals; + return locals; } + /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only /// cleared when the program exits or after interpreter shutdown (when embedding), and so are @@ -322,14 +523,14 @@ PYBIND11_NAMESPACE_END(detail) /// Returns a named pointer that is shared among all extension modules (using the same /// pybind11 version) running in the current interpreter. Names starting with underscores /// are reserved for internal usage. Returns `nullptr` if no matching entry was found. -inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { +PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { auto &internals = detail::get_internals(); auto it = internals.shared_data.find(name); return it != internals.shared_data.end() ? it->second : nullptr; } /// Set the shared data that can be later recovered by `get_shared_data()`. -inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { +PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { detail::get_internals().shared_data[name] = data; return data; } diff --git a/wrap/pybind11/include/pybind11/detail/type_caster_base.h b/wrap/pybind11/include/pybind11/detail/type_caster_base.h new file mode 100644 index 000000000..48e218b2f --- /dev/null +++ b/wrap/pybind11/include/pybind11/detail/type_caster_base.h @@ -0,0 +1,985 @@ +/* + pybind11/detail/type_caster_base.h (originally first part of pybind11/cast.h) + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "../pytypes.h" +#include "common.h" +#include "descr.h" +#include "internals.h" +#include "typeid.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +/// A life support system for temporary objects created by `type_caster::load()`. +/// Adding a patient will keep it alive up until the enclosing function returns. +class loader_life_support { +private: + loader_life_support* parent = nullptr; + std::unordered_set keep_alive; + +#if defined(WITH_THREAD) + // Store stack pointer in thread-local storage. + static PYBIND11_TLS_KEY_REF get_stack_tls_key() { +# if PYBIND11_INTERNALS_VERSION == 4 + return get_local_internals().loader_life_support_tls_key; +# else + return get_internals().loader_life_support_tls_key; +# endif + } + static loader_life_support *get_stack_top() { + return static_cast(PYBIND11_TLS_GET_VALUE(get_stack_tls_key())); + } + static void set_stack_top(loader_life_support *value) { + PYBIND11_TLS_REPLACE_VALUE(get_stack_tls_key(), value); + } +#else + // Use single global variable for stack. + static loader_life_support **get_stack_pp() { + static loader_life_support *global_stack = nullptr; + return global_stack; + } + static loader_life_support *get_stack_top() { return *get_stack_pp(); } + static void set_stack_top(loader_life_support *value) { *get_stack_pp() = value; } +#endif + +public: + /// A new patient frame is created when a function is entered + loader_life_support() { + parent = get_stack_top(); + set_stack_top(this); + } + + /// ... and destroyed after it returns + ~loader_life_support() { + if (get_stack_top() != this) + pybind11_fail("loader_life_support: internal error"); + set_stack_top(parent); + for (auto* item : keep_alive) + Py_DECREF(item); + } + + /// This can only be used inside a pybind11-bound function, either by `argument_loader` + /// at argument preparation time or by `py::cast()` at execution time. + PYBIND11_NOINLINE static void add_patient(handle h) { + loader_life_support *frame = get_stack_top(); + if (!frame) { + // NOTE: It would be nice to include the stack frames here, as this indicates + // use of pybind11::cast<> outside the normal call framework, finding such + // a location is challenging. Developers could consider printing out + // stack frame addresses here using something like __builtin_frame_address(0) + throw cast_error("When called outside a bound function, py::cast() cannot " + "do Python -> C++ conversions which require the creation " + "of temporary values"); + } + + if (frame->keep_alive.insert(h.ptr()).second) + Py_INCREF(h.ptr()); + } +}; + +// Gets the cache entry for the given type, creating it if necessary. The return value is the pair +// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was +// just created. +inline std::pair all_type_info_get_cache(PyTypeObject *type); + +// Populates a just-created cache entry. +PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector &bases) { + std::vector check; + for (handle parent : reinterpret_borrow(t->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + + auto const &type_dict = get_internals().registered_types_py; + for (size_t i = 0; i < check.size(); i++) { + auto type = check[i]; + // Ignore Python2 old-style class super type: + if (!PyType_Check((PyObject *) type)) continue; + + // Check `type` in the current set of registered python types: + auto it = type_dict.find(type); + if (it != type_dict.end()) { + // We found a cache entry for it, so it's either pybind-registered or has pre-computed + // pybind bases, but we have to make sure we haven't already seen the type(s) before: we + // want to follow Python/virtual C++ rules that there should only be one instance of a + // common base. + for (auto *tinfo : it->second) { + // NB: Could use a second set here, rather than doing a linear search, but since + // having a large number of immediate pybind11-registered types seems fairly + // unlikely, that probably isn't worthwhile. + bool found = false; + for (auto *known : bases) { + if (known == tinfo) { found = true; break; } + } + if (!found) bases.push_back(tinfo); + } + } + else if (type->tp_bases) { + // It's some python type, so keep follow its bases classes to look for one or more + // registered types + if (i + 1 == check.size()) { + // When we're at the end, we can pop off the current element to avoid growing + // `check` when adding just one base (which is typical--i.e. when there is no + // multiple inheritance) + check.pop_back(); + i--; + } + for (handle parent : reinterpret_borrow(type->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + } + } +} + +/** + * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will + * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side + * derived class that uses single inheritance. Will contain as many types as required for a Python + * class that uses multiple inheritance to inherit (directly or indirectly) from multiple + * pybind-registered classes. Will be empty if neither the type nor any base classes are + * pybind-registered. + * + * The value is cached for the lifetime of the Python type. + */ +inline const std::vector &all_type_info(PyTypeObject *type) { + auto ins = all_type_info_get_cache(type); + if (ins.second) + // New cache entry: populate it + all_type_info_populate(type, ins.first->second); + + return ins.first->second; +} + +/** + * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any + * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use + * `all_type_info` instead if you want to support multiple bases. + */ +PYBIND11_NOINLINE detail::type_info* get_type_info(PyTypeObject *type) { + auto &bases = all_type_info(type); + if (bases.empty()) + return nullptr; + if (bases.size() > 1) + pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); + return bases.front(); +} + +inline detail::type_info *get_local_type_info(const std::type_index &tp) { + auto &locals = get_local_internals().registered_types_cpp; + auto it = locals.find(tp); + if (it != locals.end()) + return it->second; + return nullptr; +} + +inline detail::type_info *get_global_type_info(const std::type_index &tp) { + auto &types = get_internals().registered_types_cpp; + auto it = types.find(tp); + if (it != types.end()) + return it->second; + return nullptr; +} + +/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. +PYBIND11_NOINLINE detail::type_info *get_type_info(const std::type_index &tp, + bool throw_if_missing = false) { + if (auto ltype = get_local_type_info(tp)) + return ltype; + if (auto gtype = get_global_type_info(tp)) + return gtype; + + if (throw_if_missing) { + std::string tname = tp.name(); + detail::clean_type_id(tname); + pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); + } + return nullptr; +} + +PYBIND11_NOINLINE handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { + detail::type_info *type_info = get_type_info(tp, throw_if_missing); + return handle(type_info ? ((PyObject *) type_info->type) : nullptr); +} + +// Searches the inheritance graph for a registered Python instance, using all_type_info(). +PYBIND11_NOINLINE handle find_registered_python_instance(void *src, + const detail::type_info *tinfo) { + auto it_instances = get_internals().registered_instances.equal_range(src); + for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { + for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { + if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) + return handle((PyObject *) it_i->second).inc_ref(); + } + } + return handle(); +} + +struct value_and_holder { + instance *inst = nullptr; + size_t index = 0u; + const detail::type_info *type = nullptr; + void **vh = nullptr; + + // Main constructor for a found value/holder: + value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : + inst{i}, index{index}, type{type}, + vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} + {} + + // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) + value_and_holder() = default; + + // Used for past-the-end iterator + explicit value_and_holder(size_t index) : index{index} {} + + template V *&value_ptr() const { + return reinterpret_cast(vh[0]); + } + // True if this `value_and_holder` has a non-null value pointer + explicit operator bool() const { return value_ptr() != nullptr; } + + template H &holder() const { + return reinterpret_cast(vh[1]); + } + bool holder_constructed() const { + return inst->simple_layout + ? inst->simple_holder_constructed + : (inst->nonsimple.status[index] & instance::status_holder_constructed) != 0u; + } + // NOLINTNEXTLINE(readability-make-member-function-const) + void set_holder_constructed(bool v = true) { + if (inst->simple_layout) + inst->simple_holder_constructed = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_holder_constructed; + else + inst->nonsimple.status[index] &= (std::uint8_t) ~instance::status_holder_constructed; + } + bool instance_registered() const { + return inst->simple_layout + ? inst->simple_instance_registered + : ((inst->nonsimple.status[index] & instance::status_instance_registered) != 0); + } + // NOLINTNEXTLINE(readability-make-member-function-const) + void set_instance_registered(bool v = true) { + if (inst->simple_layout) + inst->simple_instance_registered = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_instance_registered; + else + inst->nonsimple.status[index] &= (std::uint8_t) ~instance::status_instance_registered; + } +}; + +// Container for accessing and iterating over an instance's values/holders +struct values_and_holders { +private: + instance *inst; + using type_vec = std::vector; + const type_vec &tinfo; + +public: + explicit values_and_holders(instance *inst) + : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} + + struct iterator { + private: + instance *inst = nullptr; + const type_vec *types = nullptr; + value_and_holder curr; + friend struct values_and_holders; + iterator(instance *inst, const type_vec *tinfo) + : inst{inst}, types{tinfo}, + curr(inst /* instance */, + types->empty() ? nullptr : (*types)[0] /* type info */, + 0, /* vpos: (non-simple types only): the first vptr comes first */ + 0 /* index */) + {} + // Past-the-end iterator: + explicit iterator(size_t end) : curr(end) {} + + public: + bool operator==(const iterator &other) const { return curr.index == other.curr.index; } + bool operator!=(const iterator &other) const { return curr.index != other.curr.index; } + iterator &operator++() { + if (!inst->simple_layout) + curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; + ++curr.index; + curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; + return *this; + } + value_and_holder &operator*() { return curr; } + value_and_holder *operator->() { return &curr; } + }; + + iterator begin() { return iterator(inst, &tinfo); } + iterator end() { return iterator(tinfo.size()); } + + iterator find(const type_info *find_type) { + auto it = begin(), endit = end(); + while (it != endit && it->type != find_type) ++it; + return it; + } + + size_t size() { return tinfo.size(); } +}; + +/** + * Extracts C++ value and holder pointer references from an instance (which may contain multiple + * values/holders for python-side multiple inheritance) that match the given type. Throws an error + * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If + * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, + * regardless of type (and the resulting .type will be nullptr). + * + * The returned object should be short-lived: in particular, it must not outlive the called-upon + * instance. + */ +PYBIND11_NOINLINE value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { + // Optimize common case: + if (!find_type || Py_TYPE(this) == find_type->type) + return value_and_holder(this, find_type, 0, 0); + + detail::values_and_holders vhs(this); + auto it = vhs.find(find_type); + if (it != vhs.end()) + return *it; + + if (!throw_if_missing) + return value_and_holder(); + +#if defined(NDEBUG) + pybind11_fail("pybind11::detail::instance::get_value_and_holder: " + "type is not a pybind11 base of the given instance " + "(compile in debug mode for type details)"); +#else + pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + + get_fully_qualified_tp_name(find_type->type) + "' is not a pybind11 base of the given `" + + get_fully_qualified_tp_name(Py_TYPE(this)) + "' instance"); +#endif +} + +PYBIND11_NOINLINE void instance::allocate_layout() { + auto &tinfo = all_type_info(Py_TYPE(this)); + + const size_t n_types = tinfo.size(); + + if (n_types == 0) + pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); + + simple_layout = + n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); + + // Simple path: no python-side multiple inheritance, and a small-enough holder + if (simple_layout) { + simple_value_holder[0] = nullptr; + simple_holder_constructed = false; + simple_instance_registered = false; + } + else { // multiple base types or a too-large holder + // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, + // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool + // values that tracks whether each associated holder has been initialized. Each [block] is + // padded, if necessary, to an integer multiple of sizeof(void *). + size_t space = 0; + for (auto t : tinfo) { + space += 1; // value pointer + space += t->holder_size_in_ptrs; // holder instance + } + size_t flags_at = space; + space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) + + // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, + // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 + // they default to using pymalloc, which is designed to be efficient for small allocations + // like the one we're doing here; in earlier versions (and for larger allocations) they are + // just wrappers around malloc. +#if PY_VERSION_HEX >= 0x03050000 + nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); +#else + nonsimple.values_and_holders = (void **) PyMem_New(void *, space); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); + std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); +#endif + nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); + } + owned = true; +} + +// NOLINTNEXTLINE(readability-make-member-function-const) +PYBIND11_NOINLINE void instance::deallocate_layout() { + if (!simple_layout) + PyMem_Free(nonsimple.values_and_holders); +} + +PYBIND11_NOINLINE bool isinstance_generic(handle obj, const std::type_info &tp) { + handle type = detail::get_type_handle(tp, false); + if (!type) + return false; + return isinstance(obj, type); +} + +PYBIND11_NOINLINE std::string error_string() { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); + return "Unknown internal error occurred"; + } + + error_scope scope; // Preserve error state + + std::string errorString; + if (scope.type) { + errorString += handle(scope.type).attr("__name__").cast(); + errorString += ": "; + } + if (scope.value) + errorString += (std::string) str(scope.value); + + PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); + +#if PY_MAJOR_VERSION >= 3 + if (scope.trace != nullptr) + PyException_SetTraceback(scope.value, scope.trace); +#endif + +#if !defined(PYPY_VERSION) + if (scope.trace) { + auto *trace = (PyTracebackObject *) scope.trace; + + /* Get the deepest trace possible */ + while (trace->tb_next) + trace = trace->tb_next; + + PyFrameObject *frame = trace->tb_frame; + errorString += "\n\nAt:\n"; + while (frame) { +#if PY_VERSION_HEX >= 0x03090000 + PyCodeObject *f_code = PyFrame_GetCode(frame); +#else + PyCodeObject *f_code = frame->f_code; + Py_INCREF(f_code); +#endif + int lineno = PyFrame_GetLineNumber(frame); + errorString += + " " + handle(f_code->co_filename).cast() + + "(" + std::to_string(lineno) + "): " + + handle(f_code->co_name).cast() + "\n"; + frame = frame->f_back; + Py_DECREF(f_code); + } + } +#endif + + return errorString; +} + +PYBIND11_NOINLINE handle get_object_handle(const void *ptr, const detail::type_info *type ) { + auto &instances = get_internals().registered_instances; + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + for (const auto &vh : values_and_holders(it->second)) { + if (vh.type == type) + return handle((PyObject *) it->second); + } + } + return handle(); +} + +inline PyThreadState *get_thread_state_unchecked() { +#if defined(PYPY_VERSION) + return PyThreadState_GET(); +#elif PY_VERSION_HEX < 0x03000000 + return _PyThreadState_Current; +#elif PY_VERSION_HEX < 0x03050000 + return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); +#elif PY_VERSION_HEX < 0x03050200 + return (PyThreadState*) _PyThreadState_Current.value; +#else + return _PyThreadState_UncheckedGet(); +#endif +} + +// Forward declarations +void keep_alive_impl(handle nurse, handle patient); +inline PyObject *make_new_instance(PyTypeObject *type); + +class type_caster_generic { +public: + PYBIND11_NOINLINE explicit type_caster_generic(const std::type_info &type_info) + : typeinfo(get_type_info(type_info)), cpptype(&type_info) {} + + explicit type_caster_generic(const type_info *typeinfo) + : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) {} + + bool load(handle src, bool convert) { + return load_impl(src, convert); + } + + PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, + const detail::type_info *tinfo, + void *(*copy_constructor)(const void *), + void *(*move_constructor)(const void *), + const void *existing_holder = nullptr) { + if (!tinfo) // no type info: error will be set already + return handle(); + + void *src = const_cast(_src); + if (src == nullptr) + return none().release(); + + if (handle registered_inst = find_registered_python_instance(src, tinfo)) + return registered_inst; + + auto inst = reinterpret_steal(make_new_instance(tinfo->type)); + auto wrapper = reinterpret_cast(inst.ptr()); + wrapper->owned = false; + void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); + + switch (policy) { + case return_value_policy::automatic: + case return_value_policy::take_ownership: + valueptr = src; + wrapper->owned = true; + break; + + case return_value_policy::automatic_reference: + case return_value_policy::reference: + valueptr = src; + wrapper->owned = false; + break; + + case return_value_policy::copy: + if (copy_constructor) + valueptr = copy_constructor(src); + else { +#if defined(NDEBUG) + throw cast_error("return_value_policy = copy, but type is " + "non-copyable! (compile in debug mode for details)"); +#else + std::string type_name(tinfo->cpptype->name()); + detail::clean_type_id(type_name); + throw cast_error("return_value_policy = copy, but type " + + type_name + " is non-copyable!"); +#endif + } + wrapper->owned = true; + break; + + case return_value_policy::move: + if (move_constructor) + valueptr = move_constructor(src); + else if (copy_constructor) + valueptr = copy_constructor(src); + else { +#if defined(NDEBUG) + throw cast_error("return_value_policy = move, but type is neither " + "movable nor copyable! " + "(compile in debug mode for details)"); +#else + std::string type_name(tinfo->cpptype->name()); + detail::clean_type_id(type_name); + throw cast_error("return_value_policy = move, but type " + + type_name + " is neither movable nor copyable!"); +#endif + } + wrapper->owned = true; + break; + + case return_value_policy::reference_internal: + valueptr = src; + wrapper->owned = false; + keep_alive_impl(inst, parent); + break; + + default: + throw cast_error("unhandled return_value_policy: should not happen!"); + } + + tinfo->init_instance(wrapper, existing_holder); + + return inst.release(); + } + + // Base methods for generic caster; there are overridden in copyable_holder_caster + void load_value(value_and_holder &&v_h) { + auto *&vptr = v_h.value_ptr(); + // Lazy allocation for unallocated values: + if (vptr == nullptr) { + auto *type = v_h.type ? v_h.type : typeinfo; + if (type->operator_new) { + vptr = type->operator_new(type->type_size); + } else { + #if defined(__cpp_aligned_new) && (!defined(_MSC_VER) || _MSC_VER >= 1912) + if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) + vptr = ::operator new(type->type_size, + std::align_val_t(type->type_align)); + else + #endif + vptr = ::operator new(type->type_size); + } + } + value = vptr; + } + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + type_caster_generic sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + return true; + } + } + return false; + } + bool try_direct_conversions(handle src) { + for (auto &converter : *typeinfo->direct_conversions) { + if (converter(src.ptr(), value)) + return true; + } + return false; + } + void check_holder_compat() {} + + PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { + auto caster = type_caster_generic(ti); + if (caster.load(src, false)) + return caster.value; + return nullptr; + } + + /// Try to load with foreign typeinfo, if available. Used when there is no + /// native typeinfo, or when the native one wasn't able to produce a value. + PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { + constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; + const auto pytype = type::handle_of(src); + if (!hasattr(pytype, local_key)) + return false; + + type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); + // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type + if (foreign_typeinfo->module_local_load == &local_load + || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) + return false; + + if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { + value = result; + return true; + } + return false; + } + + // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant + // bits of code between here and copyable_holder_caster where the two classes need different + // logic (without having to resort to virtual inheritance). + template + PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { + if (!src) return false; + if (!typeinfo) return try_load_foreign_module_local(src); + + auto &this_ = static_cast(*this); + this_.check_holder_compat(); + + PyTypeObject *srctype = Py_TYPE(src.ptr()); + + // Case 1: If src is an exact type match for the target type then we can reinterpret_cast + // the instance's value pointer to the target type: + if (srctype == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2: We have a derived class + if (PyType_IsSubtype(srctype, typeinfo->type)) { + auto &bases = all_type_info(srctype); + bool no_cpp_mi = typeinfo->simple_type; + + // Case 2a: the python type is a Python-inherited derived class that inherits from just + // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of + // the right type and we can use reinterpret_cast. + // (This is essentially the same as case 2b, but because not using multiple inheritance + // is extremely common, we handle it specially to avoid the loop iterator and type + // pointer lookup overhead) + if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if + // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we + // can safely reinterpret_cast to the relevant pointer. + if (bases.size() > 1) { + for (auto base : bases) { + if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); + return true; + } + } + } + + // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match + // in the registered bases, above, so try implicit casting (needed for proper C++ casting + // when MI is involved). + if (this_.try_implicit_casts(src, convert)) + return true; + } + + // Perform an implicit conversion + if (convert) { + for (auto &converter : typeinfo->implicit_conversions) { + auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); + if (load_impl(temp, false)) { + loader_life_support::add_patient(temp); + return true; + } + } + if (this_.try_direct_conversions(src)) + return true; + } + + // Failed to match local typeinfo. Try again with global. + if (typeinfo->module_local) { + if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { + typeinfo = gtype; + return load(src, false); + } + } + + // Global typeinfo has precedence over foreign module_local + if (try_load_foreign_module_local(src)) { + return true; + } + + // Custom converters didn't take None, now we convert None to nullptr. + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + value = nullptr; + return true; + } + + return false; + } + + + // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast + // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair + // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). + PYBIND11_NOINLINE static std::pair src_and_type( + const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { + if (auto *tpi = get_type_info(cast_type)) + return {src, const_cast(tpi)}; + + // Not found, set error: + std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); + detail::clean_type_id(tname); + std::string msg = "Unregistered type : " + tname; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return {nullptr, nullptr}; + } + + const type_info *typeinfo = nullptr; + const std::type_info *cpptype = nullptr; + void *value = nullptr; +}; + +/** + * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster + * needs to provide `operator T*()` and `operator T&()` operators. + * + * If the type supports moving the value away via an `operator T&&() &&` method, it should use + * `movable_cast_op_type` instead. + */ +template +using cast_op_type = + conditional_t>::value, + typename std::add_pointer>::type, + typename std::add_lvalue_reference>::type>; + +/** + * Determine suitable casting operator for a type caster with a movable value. Such a type caster + * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be + * called in appropriate contexts where the value can be moved rather than copied. + * + * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. + */ +template +using movable_cast_op_type = + conditional_t::type>::value, + typename std::add_pointer>::type, + conditional_t::value, + typename std::add_rvalue_reference>::type, + typename std::add_lvalue_reference>::type>>; + +// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when +// T is non-copyable, but code containing such a copy constructor fails to actually compile. +template struct is_copy_constructible : std::is_copy_constructible {}; + +// Specialization for types that appear to be copy constructible but also look like stl containers +// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if +// so, copy constructability depends on whether the value_type is copy constructible. +template struct is_copy_constructible, + std::is_same, + // Avoid infinite recursion + negation> + >::value>> : is_copy_constructible {}; + +// Likewise for std::pair +// (after C++17 it is mandatory that the copy constructor not exist when the two types aren't themselves +// copy constructible, but this can not be relied upon when T1 or T2 are themselves containers). +template struct is_copy_constructible> + : all_of, is_copy_constructible> {}; + +// The same problems arise with std::is_copy_assignable, so we use the same workaround. +template struct is_copy_assignable : std::is_copy_assignable {}; +template struct is_copy_assignable, + std::is_same + >::value>> : is_copy_assignable {}; +template struct is_copy_assignable> + : all_of, is_copy_assignable> {}; + +PYBIND11_NAMESPACE_END(detail) + +// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed +// to by `src` actually is an instance of some class derived from `itype`. +// If so, it sets `tinfo` to point to the std::type_info representing that derived +// type, and returns a pointer to the start of the most-derived object of that type +// (in which `src` is a subobject; this will be the same address as `src` in most +// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` +// and leaves `tinfo` at its default value of nullptr. +// +// The default polymorphic_type_hook just returns src. A specialization for polymorphic +// types determines the runtime type of the passed object and adjusts the this-pointer +// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear +// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is +// registered with pybind11, and this Animal is in fact a Dog). +// +// You may specialize polymorphic_type_hook yourself for types that want to appear +// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern +// in performance-sensitive applications, used most notably in LLVM.) +// +// polymorphic_type_hook_base allows users to specialize polymorphic_type_hook with +// std::enable_if. User provided specializations will always have higher priority than +// the default implementation and specialization provided in polymorphic_type_hook_base. +template +struct polymorphic_type_hook_base +{ + static const void *get(const itype *src, const std::type_info*&) { return src; } +}; +template +struct polymorphic_type_hook_base::value>> +{ + static const void *get(const itype *src, const std::type_info*& type) { + type = src ? &typeid(*src) : nullptr; + return dynamic_cast(src); + } +}; +template +struct polymorphic_type_hook : public polymorphic_type_hook_base {}; + +PYBIND11_NAMESPACE_BEGIN(detail) + +/// Generic type caster for objects stored on the heap +template class type_caster_base : public type_caster_generic { + using itype = intrinsic_t; + +public: + static constexpr auto name = const_name(); + + type_caster_base() : type_caster_base(typeid(type)) { } + explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } + + static handle cast(const itype &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + + static handle cast(itype &&src, return_value_policy, handle parent) { + return cast(&src, return_value_policy::move, parent); + } + + // Returns a (pointer, type_info) pair taking care of necessary type lookup for a + // polymorphic type (using RTTI by default, but can be overridden by specializing + // polymorphic_type_hook). If the instance isn't derived, returns the base version. + static std::pair src_and_type(const itype *src) { + auto &cast_type = typeid(itype); + const std::type_info *instance_type = nullptr; + const void *vsrc = polymorphic_type_hook::get(src, instance_type); + if (instance_type && !same_type(cast_type, *instance_type)) { + // This is a base pointer to a derived type. If the derived type is registered + // with pybind11, we want to make the full derived object available. + // In the typical case where itype is polymorphic, we get the correct + // derived pointer (which may be != base pointer) by a dynamic_cast to + // most derived type. If itype is not polymorphic, we won't get here + // except via a user-provided specialization of polymorphic_type_hook, + // and the user has promised that no this-pointer adjustment is + // required in that case, so it's OK to use static_cast. + if (const auto *tpi = get_type_info(*instance_type)) + return {vsrc, tpi}; + } + // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so + // don't do a cast + return type_caster_generic::src_and_type(src, cast_type, instance_type); + } + + static handle cast(const itype *src, return_value_policy policy, handle parent) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, policy, parent, st.second, + make_copy_constructor(src), make_move_constructor(src)); + } + + static handle cast_holder(const itype *src, const void *holder) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, return_value_policy::take_ownership, {}, st.second, + nullptr, nullptr, holder); + } + + template using cast_op_type = detail::cast_op_type; + + // NOLINTNEXTLINE(google-explicit-constructor) + operator itype*() { return (type *) value; } + // NOLINTNEXTLINE(google-explicit-constructor) + operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } + +protected: + using Constructor = void *(*)(const void *); + + /* Only enabled when the types are {copy,move}-constructible *and* when the type + does not have a private operator new implementation. A comma operator is used in the decltype + argument to apply SFINAE to the public copy/move constructors.*/ + template ::value>> + static auto make_copy_constructor(const T *) -> decltype(new T(std::declval()), Constructor{}) { + return [](const void *arg) -> void * { + return new T(*reinterpret_cast(arg)); + }; + } + + template ::value>> + static auto make_move_constructor(const T *) -> decltype(new T(std::declval()), Constructor{}) { + return [](const void *arg) -> void * { + return new T(std::move(*const_cast(reinterpret_cast(arg)))); + }; + } + + static Constructor make_copy_constructor(...) { return nullptr; } + static Constructor make_move_constructor(...) { return nullptr; } +}; + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/detail/typeid.h b/wrap/pybind11/include/pybind11/detail/typeid.h index 148889ffe..39ba8ce0f 100644 --- a/wrap/pybind11/include/pybind11/detail/typeid.h +++ b/wrap/pybind11/include/pybind11/detail/typeid.h @@ -29,7 +29,7 @@ inline void erase_all(std::string &string, const std::string &search) { } } -PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { +PYBIND11_NOINLINE void clean_type_id(std::string &name) { #if defined(__GNUG__) int status = 0; std::unique_ptr res { diff --git a/wrap/pybind11/include/pybind11/eigen.h b/wrap/pybind11/include/pybind11/eigen.h index 12ce9bd3e..696099fa6 100644 --- a/wrap/pybind11/include/pybind11/eigen.h +++ b/wrap/pybind11/include/pybind11/eigen.h @@ -9,33 +9,31 @@ #pragma once +/* HINT: To suppress warnings originating from the Eigen headers, use -isystem. + See also: + https://stackoverflow.com/questions/2579576/i-dir-vs-isystem-dir + https://stackoverflow.com/questions/1741816/isystem-for-ms-visual-studio-c-compiler +*/ + #include "numpy.h" -#if defined(__INTEL_COMPILER) -# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) -#elif defined(__GNUG__) || defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wconversion" -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -# ifdef __clang__ -// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated -// under Clang, so disable that warning here: -# pragma GCC diagnostic ignored "-Wdeprecated" -# endif -# if __GNUC__ >= 7 -# pragma GCC diagnostic ignored "-Wint-in-bool-context" -# endif -#endif - +// The C4127 suppression was introduced for Eigen 3.4.0. In theory we could +// make it version specific, or even remove it later, but considering that +// 1. C4127 is generally far more distracting than useful for modern template code, and +// 2. we definitely want to ignore any MSVC warnings originating from Eigen code, +// it is probably best to keep this around indefinitely. #if defined(_MSC_VER) # pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17 +# pragma warning(disable: 4127) // C4127: conditional expression is constant #endif #include #include +#if defined(_MSC_VER) +# pragma warning(pop) +#endif + // Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit // move constructors that break things. We could detect this an explicitly copy, but an extra copy // of matrices seems highly undesirable. @@ -52,8 +50,12 @@ PYBIND11_NAMESPACE_BEGIN(detail) #if EIGEN_VERSION_AT_LEAST(3,3,0) using EigenIndex = Eigen::Index; +template +using EigenMapSparseMatrix = Eigen::Map>; #else using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE; +template +using EigenMapSparseMatrix = Eigen::MappedSparseMatrix; #endif // Matches Eigen::Map, Eigen::Ref, blocks, etc: @@ -77,18 +79,17 @@ template struct EigenConformable { EigenDStride stride{0, 0}; // Only valid if negativestrides is false! bool negativestrides = false; // If true, do not use stride! + // NOLINTNEXTLINE(google-explicit-constructor) EigenConformable(bool fits = false) : conformable{fits} {} // Matrix type: EigenConformable(EigenIndex r, EigenIndex c, EigenIndex rstride, EigenIndex cstride) : - conformable{true}, rows{r}, cols{c} { - // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747 - if (rstride < 0 || cstride < 0) { - negativestrides = true; - } else { - stride = {EigenRowMajor ? rstride : cstride /* outer stride */, - EigenRowMajor ? cstride : rstride /* inner stride */ }; - } + conformable{true}, rows{r}, cols{c}, + //TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747 + stride{EigenRowMajor ? (rstride > 0 ? rstride : 0) : (cstride > 0 ? cstride : 0) /* outer stride */, + EigenRowMajor ? (cstride > 0 ? cstride : 0) : (rstride > 0 ? rstride : 0) /* inner stride */ }, + negativestrides{rstride < 0 || cstride < 0} { + } // Vector type: EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride) @@ -104,6 +105,7 @@ template struct EigenConformable { (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() || (EigenRowMajor ? rows : cols) == 1); } + // NOLINTNEXTLINE(google-explicit-constructor) operator bool() const { return conformable; } }; @@ -153,7 +155,8 @@ template struct EigenProps { np_cols = a.shape(1), np_rstride = a.strides(0) / static_cast(sizeof(Scalar)), np_cstride = a.strides(1) / static_cast(sizeof(Scalar)); - if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) + if ((PYBIND11_SILENCE_MSVC_C4127(fixed_rows) && np_rows != rows) || + (PYBIND11_SILENCE_MSVC_C4127(fixed_cols) && np_cols != cols)) return false; return {np_rows, np_cols, np_rstride, np_cstride}; @@ -165,25 +168,22 @@ template struct EigenProps { stride = a.strides(0) / static_cast(sizeof(Scalar)); if (vector) { // Eigen type is a compile-time vector - if (fixed && size != n) + if (PYBIND11_SILENCE_MSVC_C4127(fixed) && size != n) return false; // Vector size mismatch return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride}; } - else if (fixed) { + if (fixed) { // The type has a fixed size, but is not a vector: abort return false; } - else if (fixed_cols) { + if (fixed_cols) { // Since this isn't a vector, cols must be != 1. We allow this only if it exactly // equals the number of elements (rows is Dynamic, and so 1 row is allowed). if (cols != n) return false; return {1, n, stride}; - } - else { - // Otherwise it's either fully dynamic, or column dynamic; both become a column vector - if (fixed_rows && rows != n) return false; + } // Otherwise it's either fully dynamic, or column dynamic; both become a column vector + if (PYBIND11_SILENCE_MSVC_C4127(fixed_rows) && rows != n) return false; return {n, 1, stride}; - } } static constexpr bool show_writeable = is_eigen_dense_map::value && is_eigen_mutable_map::value; @@ -192,20 +192,20 @@ template struct EigenProps { static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major; static constexpr auto descriptor = - _("numpy.ndarray[") + npy_format_descriptor::name + - _("[") + _(_<(size_t) rows>(), _("m")) + - _(", ") + _(_<(size_t) cols>(), _("n")) + - _("]") + + const_name("numpy.ndarray[") + npy_format_descriptor::name + + const_name("[") + const_name(const_name<(size_t) rows>(), const_name("m")) + + const_name(", ") + const_name(const_name<(size_t) cols>(), const_name("n")) + + const_name("]") + // For a reference type (e.g. Ref) we have other constraints that might need to be // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you // *gave* a numpy.ndarray of the right type and dimensions. - _(", flags.writeable", "") + - _(", flags.c_contiguous", "") + - _(", flags.f_contiguous", "") + - _("]"); + const_name(", flags.writeable", "") + + const_name(", flags.c_contiguous", "") + + const_name(", flags.f_contiguous", "") + + const_name("]"); }; // Casts an Eigen type to numpy array. If given a base, the numpy array references the src data, @@ -344,8 +344,11 @@ public: static constexpr auto name = props::descriptor; + // NOLINTNEXTLINE(google-explicit-constructor) operator Type*() { return &value; } + // NOLINTNEXTLINE(google-explicit-constructor) operator Type&() { return value; } + // NOLINTNEXTLINE(google-explicit-constructor) operator Type&&() && { return std::move(value); } template using cast_op_type = movable_cast_op_type; @@ -432,7 +435,7 @@ public: if (!need_copy) { // We don't need a converting copy, but we also need to check whether the strides are // compatible with the Ref's stride requirements - Array aref = reinterpret_borrow(src); + auto aref = reinterpret_borrow(src); if (aref && (!need_writeable || aref.writeable())) { fits = props::conformable(aref); @@ -469,7 +472,9 @@ public: return true; } + // NOLINTNEXTLINE(google-explicit-constructor) operator Type*() { return ref.get(); } + // NOLINTNEXTLINE(google-explicit-constructor) operator Type&() { return *ref; } template using cast_op_type = pybind11::detail::cast_op_type<_T>; @@ -539,9 +544,9 @@ public: template struct type_caster::value>> { - typedef typename Type::Scalar Scalar; - typedef remove_reference_t().outerIndexPtr())> StorageIndex; - typedef typename Type::Index Index; + using Scalar = typename Type::Scalar; + using StorageIndex = remove_reference_t().outerIndexPtr())>; + using Index = typename Type::Index; static constexpr bool rowMajor = Type::IsRowMajor; bool load(handle src, bool) { @@ -549,7 +554,7 @@ struct type_caster::value>> { return false; auto obj = reinterpret_borrow(src); - object sparse_module = module::import("scipy.sparse"); + object sparse_module = module_::import("scipy.sparse"); object matrix_type = sparse_module.attr( rowMajor ? "csr_matrix" : "csc_matrix"); @@ -570,7 +575,9 @@ struct type_caster::value>> { if (!values || !innerIndices || !outerIndices) return false; - value = Eigen::MappedSparseMatrix( + value = EigenMapSparseMatrix( shape[0].cast(), shape[1].cast(), nnz, outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); @@ -580,7 +587,7 @@ struct type_caster::value>> { static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { const_cast(src).makeCompressed(); - object matrix_type = module::import("scipy.sparse").attr( + object matrix_type = module_::import("scipy.sparse").attr( rowMajor ? "csr_matrix" : "csc_matrix"); array data(src.nonZeros(), src.valuePtr()); @@ -593,15 +600,9 @@ struct type_caster::value>> { ).release(); } - PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") - + npy_format_descriptor::name + _("]")); + PYBIND11_TYPE_CASTER(Type, const_name<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") + + npy_format_descriptor::name + const_name("]")); }; PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(__GNUG__) || defined(__clang__) -# pragma GCC diagnostic pop -#elif defined(_MSC_VER) -# pragma warning(pop) -#endif diff --git a/wrap/pybind11/include/pybind11/embed.h b/wrap/pybind11/include/pybind11/embed.h index eae86c714..9ab1ce9c0 100644 --- a/wrap/pybind11/include/pybind11/embed.h +++ b/wrap/pybind11/include/pybind11/embed.h @@ -12,6 +12,9 @@ #include "pybind11.h" #include "eval.h" +#include +#include + #if defined(PYPY_VERSION) # error Embedding the interpreter is not supported with PyPy #endif @@ -45,27 +48,23 @@ }); } \endrst */ -#define PYBIND11_EMBEDDED_MODULE(name, variable) \ - static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ - static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ - auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ - try { \ - PYBIND11_CONCAT(pybind11_init_, name)(m); \ - return m.ptr(); \ - } catch (pybind11::error_already_set &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } catch (const std::exception &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } \ - } \ - PYBIND11_EMBEDDED_MODULE_IMPL(name) \ - pybind11::detail::embedded_module PYBIND11_CONCAT(pybind11_module_, name) \ - (PYBIND11_TOSTRING(name), \ - PYBIND11_CONCAT(pybind11_init_impl_, name)); \ - void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) - +#define PYBIND11_EMBEDDED_MODULE(name, variable) \ + static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name); \ + static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \ + static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ + auto m = ::pybind11::module_::create_extension_module( \ + PYBIND11_TOSTRING(name), nullptr, &PYBIND11_CONCAT(pybind11_module_def_, name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + ::pybind11::detail::embedded_module PYBIND11_CONCAT(pybind11_module_, name)( \ + PYBIND11_TOSTRING(name), PYBIND11_CONCAT(pybind11_init_impl_, name)); \ + void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ \ + & variable) // NOLINT(bugprone-macro-parentheses) PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -78,7 +77,7 @@ struct embedded_module { using init_t = void (*)(); #endif embedded_module(const char *name, init_t init) { - if (Py_IsInitialized()) + if (Py_IsInitialized() != 0) pybind11_fail("Can't add new modules after the interpreter has been initialized"); auto result = PyImport_AppendInittab(name, init); @@ -87,29 +86,118 @@ struct embedded_module { } }; +struct wide_char_arg_deleter { + void operator()(wchar_t *ptr) const { +#if PY_VERSION_HEX >= 0x030500f0 + // API docs: https://docs.python.org/3/c-api/sys.html#c.Py_DecodeLocale + PyMem_RawFree(ptr); +#else + delete[] ptr; +#endif + } +}; + +inline wchar_t *widen_chars(const char *safe_arg) { +#if PY_VERSION_HEX >= 0x030500f0 + wchar_t *widened_arg = Py_DecodeLocale(safe_arg, nullptr); +#else + wchar_t *widened_arg = nullptr; + +// warning C4996: 'mbstowcs': This function or variable may be unsafe. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4996) +#endif + +# if defined(HAVE_BROKEN_MBSTOWCS) && HAVE_BROKEN_MBSTOWCS + size_t count = std::strlen(safe_arg); +# else + size_t count = std::mbstowcs(nullptr, safe_arg, 0); +# endif + if (count != static_cast(-1)) { + widened_arg = new wchar_t[count + 1]; + std::mbstowcs(widened_arg, safe_arg, count + 1); + } + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#endif + return widened_arg; +} + +/// Python 2.x/3.x-compatible version of `PySys_SetArgv` +inline void set_interpreter_argv(int argc, const char *const *argv, bool add_program_dir_to_path) { + // Before it was special-cased in python 3.8, passing an empty or null argv + // caused a segfault, so we have to reimplement the special case ourselves. + bool special_case = (argv == nullptr || argc <= 0); + + const char *const empty_argv[]{"\0"}; + const char *const *safe_argv = special_case ? empty_argv : argv; + if (special_case) + argc = 1; + + auto argv_size = static_cast(argc); +#if PY_MAJOR_VERSION >= 3 + // SetArgv* on python 3 takes wchar_t, so we have to convert. + std::unique_ptr widened_argv(new wchar_t *[argv_size]); + std::vector> widened_argv_entries; + widened_argv_entries.reserve(argv_size); + for (size_t ii = 0; ii < argv_size; ++ii) { + widened_argv_entries.emplace_back(widen_chars(safe_argv[ii])); + if (!widened_argv_entries.back()) { + // A null here indicates a character-encoding failure or the python + // interpreter out of memory. Give up. + return; + } + widened_argv[ii] = widened_argv_entries.back().get(); + } + + auto pysys_argv = widened_argv.get(); +#else + // python 2.x + std::vector strings{safe_argv, safe_argv + argv_size}; + std::vector char_strings{argv_size}; + for (std::size_t i = 0; i < argv_size; ++i) + char_strings[i] = &strings[i][0]; + char **pysys_argv = char_strings.data(); +#endif + + PySys_SetArgvEx(argc, pysys_argv, static_cast(add_program_dir_to_path)); +} + PYBIND11_NAMESPACE_END(detail) /** \rst Initialize the Python interpreter. No other pybind11 or CPython API functions can be called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The - optional parameter can be used to skip the registration of signal handlers (see the - `Python documentation`_ for details). Calling this function again after the interpreter - has already been initialized is a fatal error. + optional `init_signal_handlers` parameter can be used to skip the registration of + signal handlers (see the `Python documentation`_ for details). Calling this function + again after the interpreter has already been initialized is a fatal error. If initializing the Python interpreter fails, then the program is terminated. (This is controlled by the CPython runtime and is an exception to pybind11's normal behavior of throwing exceptions on errors.) + The remaining optional parameters, `argc`, `argv`, and `add_program_dir_to_path` are + used to populate ``sys.argv`` and ``sys.path``. + See the |PySys_SetArgvEx documentation|_ for details. + .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx + .. |PySys_SetArgvEx documentation| replace:: ``PySys_SetArgvEx`` documentation + .. _PySys_SetArgvEx documentation: https://docs.python.org/3/c-api/init.html#c.PySys_SetArgvEx \endrst */ -inline void initialize_interpreter(bool init_signal_handlers = true) { - if (Py_IsInitialized()) +inline void initialize_interpreter(bool init_signal_handlers = true, + int argc = 0, + const char *const *argv = nullptr, + bool add_program_dir_to_path = true) { + if (Py_IsInitialized() != 0) pybind11_fail("The interpreter is already running"); Py_InitializeEx(init_signal_handlers ? 1 : 0); - // Make .py files in the working directory available by default - module::import("sys").attr("path").cast().append("."); + detail::set_interpreter_argv(argc, argv, add_program_dir_to_path); } /** \rst @@ -171,6 +259,8 @@ inline void finalize_interpreter() { Scope guard version of `initialize_interpreter` and `finalize_interpreter`. This a move-only guard and only a single instance can exist. + See `initialize_interpreter` for a discussion of its constructor arguments. + .. code-block:: cpp #include @@ -182,8 +272,11 @@ inline void finalize_interpreter() { \endrst */ class scoped_interpreter { public: - scoped_interpreter(bool init_signal_handlers = true) { - initialize_interpreter(init_signal_handlers); + explicit scoped_interpreter(bool init_signal_handlers = true, + int argc = 0, + const char *const *argv = nullptr, + bool add_program_dir_to_path = true) { + initialize_interpreter(init_signal_handlers, argc, argv, add_program_dir_to_path); } scoped_interpreter(const scoped_interpreter &) = delete; diff --git a/wrap/pybind11/include/pybind11/eval.h b/wrap/pybind11/include/pybind11/eval.h index ba82cf42a..4248551e9 100644 --- a/wrap/pybind11/include/pybind11/eval.h +++ b/wrap/pybind11/include/pybind11/eval.h @@ -1,5 +1,5 @@ /* - pybind11/exec.h: Support for evaluating Python expressions and statements + pybind11/eval.h: Support for evaluating Python expressions and statements from strings and files Copyright (c) 2016 Klemens Morgenstern and @@ -11,9 +11,27 @@ #pragma once +#include + #include "pybind11.h" PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +inline void ensure_builtins_in_globals(object &global) { + #if defined(PYPY_VERSION) || PY_VERSION_HEX < 0x03080000 + // Running exec and eval on Python 2 and 3 adds `builtins` module under + // `__builtins__` key to globals if not yet present. + // Python 3.8 made PyRun_String behave similarly. Let's also do that for + // older versions, for consistency. This was missing from PyPy3.8 7.3.7. + if (!global.contains("__builtins__")) + global["__builtins__"] = module_::import(PYBIND11_BUILTINS_MODULE); + #else + (void) global; + #endif +} + +PYBIND11_NAMESPACE_END(detail) enum eval_mode { /// Evaluate a string containing an isolated expression @@ -27,15 +45,17 @@ enum eval_mode { }; template -object eval(str expr, object global = globals(), object local = object()) { +object eval(const str &expr, object global = globals(), object local = object()) { if (!local) local = global; + detail::ensure_builtins_in_globals(global); + /* PyRun_String does not accept a PyObject / encoding specifier, this seems to be the only alternative */ std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; - int start; + int start = 0; switch (mode) { case eval_expr: start = Py_eval_input; break; case eval_single_statement: start = Py_single_input; break; @@ -52,13 +72,13 @@ object eval(str expr, object global = globals(), object local = object()) { template object eval(const char (&s)[N], object global = globals(), object local = object()) { /* Support raw string literals by removing common leading whitespace */ - auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) + auto expr = (s[0] == '\n') ? str(module_::import("textwrap").attr("dedent")(s)) : str(s); return eval(expr, global, local); } -inline void exec(str expr, object global = globals(), object local = object()) { - eval(expr, global, local); +inline void exec(const str &expr, object global = globals(), object local = object()) { + eval(expr, std::move(global), std::move(local)); } template @@ -66,7 +86,7 @@ void exec(const char (&s)[N], object global = globals(), object local = object() eval(s, global, local); } -#if defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x3000000 +#if defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03000000 template object eval_file(str, object, object) { pybind11_fail("eval_file not supported in PyPy3. Use eval"); @@ -85,7 +105,9 @@ object eval_file(str fname, object global = globals(), object local = object()) if (!local) local = global; - int start; + detail::ensure_builtins_in_globals(global); + + int start = 0; switch (mode) { case eval_expr: start = Py_eval_input; break; case eval_single_statement: start = Py_single_input; break; @@ -114,6 +136,15 @@ object eval_file(str fname, object global = globals(), object local = object()) pybind11_fail("File \"" + fname_str + "\" could not be opened!"); } + // In Python2, this should be encoded by getfilesystemencoding. + // We don't boher setting it since Python2 is past EOL anyway. + // See PR#3233 +#if PY_VERSION_HEX >= 0x03000000 + if (!global.contains("__file__")) { + global["__file__"] = std::move(fname); + } +#endif + #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), local.ptr()); diff --git a/wrap/pybind11/include/pybind11/functional.h b/wrap/pybind11/include/pybind11/functional.h index 57b6cd210..7912aef17 100644 --- a/wrap/pybind11/include/pybind11/functional.h +++ b/wrap/pybind11/include/pybind11/functional.h @@ -43,22 +43,43 @@ public: captured variables), in which case the roundtrip can be avoided. */ if (auto cfunc = func.cpp_function()) { - auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); - auto rec = (function_record *) c; + auto cfunc_self = PyCFunction_GET_SELF(cfunc.ptr()); + if (isinstance(cfunc_self)) { + auto c = reinterpret_borrow(cfunc_self); + auto rec = (function_record *) c; - if (rec && rec->is_stateless && - same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { - struct capture { function_type f; }; - value = ((capture *) &rec->data)->f; - return true; + while (rec != nullptr) { + if (rec->is_stateless + && same_type(typeid(function_type), + *reinterpret_cast(rec->data[1]))) { + struct capture { + function_type f; + }; + value = ((capture *) &rec->data)->f; + return true; + } + rec = rec->next; + } } + // PYPY segfaults here when passing builtin function like sum. + // Raising an fail exception here works to prevent the segfault, but only on gcc. + // See PR #1413 for full details } // ensure GIL is held during functor destruction struct func_handle { function f; - func_handle(function&& f_) : f(std::move(f_)) {} - func_handle(const func_handle&) = default; +#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17)) + // This triggers a syntax error under very special conditions (very weird indeed). + explicit +#endif + func_handle(function &&f_) noexcept : f(std::move(f_)) {} + func_handle(const func_handle &f_) { operator=(f_); } + func_handle &operator=(const func_handle &f_) { + gil_scoped_acquire acq; + f = f_.f; + return *this; + } ~func_handle() { gil_scoped_acquire acq; function kill_f(std::move(f)); @@ -68,7 +89,7 @@ public: // to emulate 'move initialization capture' in C++11 struct func_wrapper { func_handle hfunc; - func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {} + explicit func_wrapper(func_handle &&hf) noexcept : hfunc(std::move(hf)) {} Return operator()(Args... args) const { gil_scoped_acquire acq; object retval(hfunc.f(std::forward(args)...)); @@ -89,12 +110,11 @@ public: auto result = f_.template target(); if (result) return cpp_function(*result, policy).release(); - else - return cpp_function(std::forward(f_), policy).release(); + return cpp_function(std::forward(f_), policy).release(); } - PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") - + make_caster::name + _("]")); + PYBIND11_TYPE_CASTER(type, const_name("Callable[[") + concat(make_caster::name...) + const_name("], ") + + make_caster::name + const_name("]")); }; PYBIND11_NAMESPACE_END(detail) diff --git a/wrap/pybind11/include/pybind11/gil.h b/wrap/pybind11/include/pybind11/gil.h new file mode 100644 index 000000000..b73aaa3f5 --- /dev/null +++ b/wrap/pybind11/include/pybind11/gil.h @@ -0,0 +1,193 @@ +/* + pybind11/gil.h: RAII helpers for managing the GIL + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "detail/internals.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + + +PYBIND11_NAMESPACE_BEGIN(detail) + +// forward declarations +PyThreadState *get_thread_state_unchecked(); + +PYBIND11_NAMESPACE_END(detail) + + +#if defined(WITH_THREAD) && !defined(PYPY_VERSION) + +/* The functions below essentially reproduce the PyGILState_* API using a RAII + * pattern, but there are a few important differences: + * + * 1. When acquiring the GIL from an non-main thread during the finalization + * phase, the GILState API blindly terminates the calling thread, which + * is often not what is wanted. This API does not do this. + * + * 2. The gil_scoped_release function can optionally cut the relationship + * of a PyThreadState and its associated thread, which allows moving it to + * another thread (this is a fairly rare/advanced use case). + * + * 3. The reference count of an acquired thread state can be controlled. This + * can be handy to prevent cases where callbacks issued from an external + * thread would otherwise constantly construct and destroy thread state data + * structures. + * + * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an + * example which uses features 2 and 3 to migrate the Python thread of + * execution to another thread (to run the event loop on the original thread, + * in this case). + */ + +class gil_scoped_acquire { +public: + PYBIND11_NOINLINE gil_scoped_acquire() { + auto &internals = detail::get_internals(); + tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate); + + if (!tstate) { + /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if + calling from a Python thread). Since we use a different key, this ensures + we don't create a new thread state and deadlock in PyEval_AcquireThread + below. Note we don't save this state with internals.tstate, since we don't + create it we would fail to clear it (its reference count should be > 0). */ + tstate = PyGILState_GetThisThreadState(); + } + + if (!tstate) { + tstate = PyThreadState_New(internals.istate); + #if !defined(NDEBUG) + if (!tstate) + pybind11_fail("scoped_acquire: could not create thread state!"); + #endif + tstate->gilstate_counter = 0; + PYBIND11_TLS_REPLACE_VALUE(internals.tstate, tstate); + } else { + release = detail::get_thread_state_unchecked() != tstate; + } + + if (release) { + PyEval_AcquireThread(tstate); + } + + inc_ref(); + } + + void inc_ref() { + ++tstate->gilstate_counter; + } + + PYBIND11_NOINLINE void dec_ref() { + --tstate->gilstate_counter; + #if !defined(NDEBUG) + if (detail::get_thread_state_unchecked() != tstate) + pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!"); + if (tstate->gilstate_counter < 0) + pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!"); + #endif + if (tstate->gilstate_counter == 0) { + #if !defined(NDEBUG) + if (!release) + pybind11_fail("scoped_acquire::dec_ref(): internal error!"); + #endif + PyThreadState_Clear(tstate); + if (active) + PyThreadState_DeleteCurrent(); + PYBIND11_TLS_DELETE_VALUE(detail::get_internals().tstate); + release = false; + } + } + + /// This method will disable the PyThreadState_DeleteCurrent call and the + /// GIL won't be acquired. This method should be used if the interpreter + /// could be shutting down when this is called, as thread deletion is not + /// allowed during shutdown. Check _Py_IsFinalizing() on Python 3.7+, and + /// protect subsequent code. + PYBIND11_NOINLINE void disarm() { + active = false; + } + + PYBIND11_NOINLINE ~gil_scoped_acquire() { + dec_ref(); + if (release) + PyEval_SaveThread(); + } +private: + PyThreadState *tstate = nullptr; + bool release = true; + bool active = true; +}; + +class gil_scoped_release { +public: + explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { + // `get_internals()` must be called here unconditionally in order to initialize + // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an + // initialization race could occur as multiple threads try `gil_scoped_acquire`. + auto &internals = detail::get_internals(); + tstate = PyEval_SaveThread(); + if (disassoc) { + auto key = internals.tstate; + PYBIND11_TLS_DELETE_VALUE(key); + } + } + + /// This method will disable the PyThreadState_DeleteCurrent call and the + /// GIL won't be acquired. This method should be used if the interpreter + /// could be shutting down when this is called, as thread deletion is not + /// allowed during shutdown. Check _Py_IsFinalizing() on Python 3.7+, and + /// protect subsequent code. + PYBIND11_NOINLINE void disarm() { + active = false; + } + + ~gil_scoped_release() { + if (!tstate) + return; + // `PyEval_RestoreThread()` should not be called if runtime is finalizing + if (active) + PyEval_RestoreThread(tstate); + if (disassoc) { + auto key = detail::get_internals().tstate; + PYBIND11_TLS_REPLACE_VALUE(key, tstate); + } + } +private: + PyThreadState *tstate; + bool disassoc; + bool active = true; +}; +#elif defined(PYPY_VERSION) +class gil_scoped_acquire { + PyGILState_STATE state; +public: + gil_scoped_acquire() { state = PyGILState_Ensure(); } + ~gil_scoped_acquire() { PyGILState_Release(state); } + void disarm() {} +}; + +class gil_scoped_release { + PyThreadState *state; +public: + gil_scoped_release() { state = PyEval_SaveThread(); } + ~gil_scoped_release() { PyEval_RestoreThread(state); } + void disarm() {} +}; +#else +class gil_scoped_acquire { + void disarm() {} +}; +class gil_scoped_release { + void disarm() {} +}; +#endif + +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/iostream.h b/wrap/pybind11/include/pybind11/iostream.h index 48479f2d1..95449a07b 100644 --- a/wrap/pybind11/include/pybind11/iostream.h +++ b/wrap/pybind11/include/pybind11/iostream.h @@ -5,17 +5,31 @@ All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. + + WARNING: The implementation in this file is NOT thread safe. Multiple + threads writing to a redirected ostream concurrently cause data races + and potentially buffer overflows. Therefore it is currently a requirement + that all (possibly) concurrent redirected ostream writes are protected by + a mutex. + #HelpAppreciated: Work on iostream.h thread safety. + For more background see the discussions under + https://github.com/pybind/pybind11/pull/2982 and + https://github.com/pybind/pybind11/pull/2995. */ #pragma once #include "pybind11.h" -#include -#include -#include -#include +#include +#include #include +#include +#include +#include +#include +#include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -38,21 +52,68 @@ private: return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); } - // This function must be non-virtual to be called in a destructor. If the - // rare MSVC test failure shows up with this version, then this should be - // simplified to a fully qualified call. - int _sync() { - if (pbase() != pptr()) { - // This subtraction cannot be negative, so dropping the sign - str line(pbase(), static_cast(pptr() - pbase())); + // Computes how many bytes at the end of the buffer are part of an + // incomplete sequence of UTF-8 bytes. + // Precondition: pbase() < pptr() + size_t utf8_remainder() const { + const auto rbase = std::reverse_iterator(pbase()); + const auto rpptr = std::reverse_iterator(pptr()); + auto is_ascii = [](char c) { + return (static_cast(c) & 0x80) == 0x00; + }; + auto is_leading = [](char c) { + return (static_cast(c) & 0xC0) == 0xC0; + }; + auto is_leading_2b = [](char c) { + return static_cast(c) <= 0xDF; + }; + auto is_leading_3b = [](char c) { + return static_cast(c) <= 0xEF; + }; + // If the last character is ASCII, there are no incomplete code points + if (is_ascii(*rpptr)) + return 0; + // Otherwise, work back from the end of the buffer and find the first + // UTF-8 leading byte + const auto rpend = rbase - rpptr >= 3 ? rpptr + 3 : rbase; + const auto leading = std::find_if(rpptr, rpend, is_leading); + if (leading == rbase) + return 0; + const auto dist = static_cast(leading - rpptr); + size_t remainder = 0; - { - gil_scoped_acquire tmp; + if (dist == 0) + remainder = 1; // 1-byte code point is impossible + else if (dist == 1) + remainder = is_leading_2b(*leading) ? 0 : dist + 1; + else if (dist == 2) + remainder = is_leading_3b(*leading) ? 0 : dist + 1; + // else if (dist >= 3), at least 4 bytes before encountering an UTF-8 + // leading byte, either no remainder or invalid UTF-8. + // Invalid UTF-8 will cause an exception later when converting + // to a Python string, so that's not handled here. + return remainder; + } + + // This function must be non-virtual to be called in a destructor. + int _sync() { + if (pbase() != pptr()) { // If buffer is not empty + gil_scoped_acquire tmp; + // This subtraction cannot be negative, so dropping the sign. + auto size = static_cast(pptr() - pbase()); + size_t remainder = utf8_remainder(); + + if (size > remainder) { + str line(pbase(), size - remainder); pywrite(line); pyflush(); } + // Copy the remainder at the end of the buffer to the beginning: + if (remainder > 0) + std::memmove(pbase(), pptr() - remainder, remainder); setp(pbase(), epptr()); + pbump(static_cast(remainder)); } return 0; } @@ -62,11 +123,8 @@ private: } public: - - pythonbuf(object pyostream, size_t buffer_size = 1024) - : buf_size(buffer_size), - d_buffer(new char[buf_size]), - pywrite(pyostream.attr("write")), + explicit pythonbuf(const object &pyostream, size_t buffer_size = 1024) + : buf_size(buffer_size), d_buffer(new char[buf_size]), pywrite(pyostream.attr("write")), pyflush(pyostream.attr("flush")) { setp(d_buffer.get(), d_buffer.get() + buf_size - 1); } @@ -103,7 +161,7 @@ PYBIND11_NAMESPACE_END(detail) { py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; - std::cerr << "Hello, World!"; + std::cout << "Hello, World!"; } \endrst */ class scoped_ostream_redirect { @@ -113,9 +171,9 @@ protected: detail::pythonbuf buffer; public: - scoped_ostream_redirect( - std::ostream &costream = std::cout, - object pyostream = module::import("sys").attr("stdout")) + explicit scoped_ostream_redirect(std::ostream &costream = std::cout, + const object &pyostream + = module_::import("sys").attr("stdout")) : costream(costream), buffer(pyostream) { old = costream.rdbuf(&buffer); } @@ -144,10 +202,10 @@ public: \endrst */ class scoped_estream_redirect : public scoped_ostream_redirect { public: - scoped_estream_redirect( - std::ostream &costream = std::cerr, - object pyostream = module::import("sys").attr("stderr")) - : scoped_ostream_redirect(costream,pyostream) {} + explicit scoped_estream_redirect(std::ostream &costream = std::cerr, + const object &pyostream + = module_::import("sys").attr("stderr")) + : scoped_ostream_redirect(costream, pyostream) {} }; @@ -161,7 +219,7 @@ class OstreamRedirect { std::unique_ptr redirect_stderr; public: - OstreamRedirect(bool do_stdout = true, bool do_stderr = true) + explicit OstreamRedirect(bool do_stdout = true, bool do_stderr = true) : do_stdout_(do_stdout), do_stderr_(do_stderr) {} void enter() { @@ -206,11 +264,12 @@ PYBIND11_NAMESPACE_END(detail) m.noisy_function_with_error_printing() \endrst */ -inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { - return class_(m, name.c_str(), module_local()) - .def(init(), arg("stdout")=true, arg("stderr")=true) +inline class_ +add_ostream_redirect(module_ m, const std::string &name = "ostream_redirect") { + return class_(std::move(m), name.c_str(), module_local()) + .def(init(), arg("stdout") = true, arg("stderr") = true) .def("__enter__", &detail::OstreamRedirect::enter) - .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); + .def("__exit__", [](detail::OstreamRedirect &self_, const args &) { self_.exit(); }); } PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/numpy.h b/wrap/pybind11/include/pybind11/numpy.h index 03e1ed61e..95a743ace 100644 --- a/wrap/pybind11/include/pybind11/numpy.h +++ b/wrap/pybind11/include/pybind11/numpy.h @@ -20,20 +20,18 @@ #include #include #include +#include #include #include #include -#if defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - /* This will be true on all flat address space platforms and allows us to reduce the whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size and dimension types (e.g. shape, strides, indexing), instead of inflicting this upon the library user. */ -static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t"); +static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t"); +static_assert(std::is_signed::value, "Py_intptr_t must be signed"); +// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares) PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) @@ -41,7 +39,7 @@ class array; // Forward declaration PYBIND11_NAMESPACE_BEGIN(detail) -template <> struct handle_type_name { static constexpr auto name = _("numpy.ndarray"); }; +template <> struct handle_type_name { static constexpr auto name = const_name("numpy.ndarray"); }; template struct npy_format_descriptor; @@ -101,7 +99,7 @@ struct numpy_internals { } }; -inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { +PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { ptr = &get_or_create_shared_data("_numpy_internals"); } @@ -161,10 +159,10 @@ struct npy_api { NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_), }; - typedef struct { + struct PyArray_Dims { Py_intptr_t *ptr; int len; - } PyArray_Dims; + }; static npy_api& get() { static npy_api api = lookup(); @@ -172,10 +170,10 @@ struct npy_api { } bool PyArray_Check_(PyObject *obj) const { - return (bool) PyObject_TypeCheck(obj, PyArray_Type_); + return PyObject_TypeCheck(obj, PyArray_Type_) != 0; } bool PyArrayDescr_Check_(PyObject *obj) const { - return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); + return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0; } unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); @@ -200,6 +198,9 @@ struct npy_api { // Unused. Not removed because that affects ABI of the class. int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); + PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int); + PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*); + private: enum functions { API_PyArray_GetNDArrayCFeatureVersion = 211, @@ -214,15 +215,17 @@ private: API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, API_PyArray_DescrNewFromType = 96, + API_PyArray_Newshape = 135, + API_PyArray_Squeeze = 136, + API_PyArray_View = 137, API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, API_PyArray_GetArrayParamsFromObject = 278, - API_PyArray_Squeeze = 136, API_PyArray_SetBaseObject = 282 }; static npy_api lookup() { - module_ m = module::import("numpy.core.multiarray"); + module_ m = module_::import("numpy.core.multiarray"); auto c = m.attr("_ARRAY_API"); #if PY_MAJOR_VERSION >= 3 void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); @@ -245,11 +248,14 @@ private: DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewFromDescr); DECL_NPY_API(PyArray_DescrNewFromType); + DECL_NPY_API(PyArray_Newshape); + DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_View); DECL_NPY_API(PyArray_DescrConverter); DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_GetArrayParamsFromObject); - DECL_NPY_API(PyArray_Squeeze); DECL_NPY_API(PyArray_SetBaseObject); + #undef DECL_NPY_API return api; } @@ -284,7 +290,7 @@ template struct array_info_scalar { using type = T; static constexpr bool is_array = false; static constexpr bool is_empty = false; - static constexpr auto extents = _(""); + static constexpr auto extents = const_name(""); static void append_extents(list& /* shape */) { } }; // Computes underlying type and a comma-separated list of extents for array @@ -303,8 +309,8 @@ template struct array_info> { array_info::append_extents(shape); } - static constexpr auto extents = _::is_array>( - concat(_(), array_info::extents), _() + static constexpr auto extents = const_name::is_array>( + concat(const_name(), array_info::extents), const_name() ); }; // For numpy we have special handling for arrays of characters, so we don't include @@ -316,18 +322,23 @@ template using remove_all_extents_t = typename array_info::type; template using is_pod_struct = all_of< std::is_standard_layout, // since we're accessing directly in memory we need a standard layout type -#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI) - // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent - // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4). - std::is_trivially_copyable, -#else - // GCC 4 doesn't implement is_trivially_copyable, so approximate it +#if defined(__GLIBCXX__) && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803) + // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after 5) + // don't implement is_trivially_copyable, so approximate it std::is_trivially_destructible, satisfies_any_of, +#else + std::is_trivially_copyable, #endif satisfies_none_of >; +// Replacement for std::is_pod (deprecated in C++20) +template using is_pod = all_of< + std::is_standard_layout, + std::is_trivial +>; + template ssize_t byte_offset_unsafe(const Strides &) { return 0; } template ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) { @@ -419,6 +430,10 @@ class unchecked_mutable_reference : public unchecked_reference { using ConstBase::ConstBase; using ConstBase::Dynamic; public: + // Bring in const-qualified versions from base class + using ConstBase::operator(); + using ConstBase::operator[]; + /// Mutable, unchecked access to data at the given indices. template T& operator()(Ix... index) { static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, @@ -453,28 +468,30 @@ public: explicit dtype(const buffer_info &info) { dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); // If info.itemsize == 0, use the value calculated from the format string - m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr(); + m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize()) + .release() + .ptr(); } explicit dtype(const std::string &format) { m_ptr = from_args(pybind11::str(format)).release().ptr(); } - dtype(const char *format) : dtype(std::string(format)) { } + explicit dtype(const char *format) : dtype(std::string(format)) {} dtype(list names, list formats, list offsets, ssize_t itemsize) { dict args; - args["names"] = names; - args["formats"] = formats; - args["offsets"] = offsets; + args["names"] = std::move(names); + args["formats"] = std::move(formats); + args["offsets"] = std::move(offsets); args["itemsize"] = pybind11::int_(itemsize); - m_ptr = from_args(args).release().ptr(); + m_ptr = from_args(std::move(args)).release().ptr(); } /// This is essentially the same as calling numpy.dtype(args) in Python. static dtype from_args(object args) { PyObject *ptr = nullptr; - if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr) + if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr) throw error_already_set(); return reinterpret_steal(ptr); } @@ -494,14 +511,24 @@ public: return detail::array_descriptor_proxy(m_ptr)->names != nullptr; } - /// Single-character type code. + /// Single-character code for dtype's kind. + /// For example, floating point types are 'f' and integral types are 'i'. char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; } + /// Single-character for dtype's type. + /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'l'. + char char_() const { + // Note: The signature, `dtype::char_` follows the naming of NumPy's + // public Python API (i.e., ``dtype.char``), rather than its internal + // C API (``PyArray_Descr::type``). + return detail::array_descriptor_proxy(m_ptr)->type; + } + private: static object _dtype_from_pep3118() { - static PyObject *obj = module::import("numpy.core._internal") + static PyObject *obj = module_::import("numpy.core._internal") .attr("_dtype_from_pep3118").cast().release().ptr(); return reinterpret_borrow(obj); } @@ -520,7 +547,7 @@ private: auto name = spec[0].cast(); auto format = spec[1].cast()[0].cast(); auto offset = spec[1].cast()[1].cast(); - if (!len(name) && format.kind() == 'V') + if ((len(name) == 0u) && format.kind() == 'V') continue; field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); } @@ -536,7 +563,7 @@ private: formats.append(descr.format); offsets.append(descr.offset); } - return dtype(names, formats, offsets, itemsize); + return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize); } }; @@ -560,7 +587,7 @@ public: const void *ptr = nullptr, handle base = handle()) { if (strides->empty()) - *strides = c_strides(*shape, dt.itemsize()); + *strides = detail::c_strides(*shape, dt.itemsize()); auto ndim = shape->size(); if (ndim != strides->size()) @@ -579,7 +606,10 @@ public: auto &api = detail::npy_api::get(); auto tmp = reinterpret_steal(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(), + api.PyArray_Type_, descr.release().ptr(), (int) ndim, + // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1) + reinterpret_cast(shape->data()), + reinterpret_cast(strides->data()), const_cast(ptr), flags, nullptr)); if (!tmp) throw error_already_set(); @@ -720,7 +750,7 @@ public: * and the caller must take care not to access invalid dimensions or dimension indices. */ template detail::unchecked_mutable_reference mutable_unchecked() & { - if (Dims >= 0 && ndim() != Dims) + if (PYBIND11_SILENCE_MSVC_C4127(Dims >= 0) && ndim() != Dims) throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + "; expected " + std::to_string(Dims)); return detail::unchecked_mutable_reference(mutable_data(), shape(), strides(), ndim()); @@ -734,7 +764,7 @@ public: * invalid dimensions or dimension indices. */ template detail::unchecked_reference unchecked() const & { - if (Dims >= 0 && ndim() != Dims) + if (PYBIND11_SILENCE_MSVC_C4127(Dims >= 0) && ndim() != Dims) throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + "; expected " + std::to_string(Dims)); return detail::unchecked_reference(data(), shape(), strides(), ndim()); @@ -751,16 +781,45 @@ public: /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change void resize(ShapeContainer new_shape, bool refcheck = true) { detail::npy_api::PyArray_Dims d = { - new_shape->data(), int(new_shape->size()) + // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1) + reinterpret_cast(new_shape->data()), + int(new_shape->size()) }; // try to resize, set ordering param to -1 cause it's not used anyway - object new_array = reinterpret_steal( + auto new_array = reinterpret_steal( detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) ); if (!new_array) throw error_already_set(); if (isinstance(new_array)) { *this = std::move(new_array); } } + /// Optional `order` parameter omitted, to be added as needed. + array reshape(ShapeContainer new_shape) { + detail::npy_api::PyArray_Dims d + = {reinterpret_cast(new_shape->data()), int(new_shape->size())}; + auto new_array + = reinterpret_steal(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0)); + if (!new_array) { + throw error_already_set(); + } + return new_array; + } + + /// Create a view of an array in a different data type. + /// This function may fundamentally reinterpret the data in the array. + /// It is the responsibility of the caller to ensure that this is safe. + /// Only supports the `dtype` argument, the `type` argument is omitted, + /// to be added as needed. + array view(const std::string &dtype) { + auto &api = detail::npy_api::get(); + auto new_view = reinterpret_steal(api.PyArray_View_( + m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr)); + if (!new_view) { + throw error_already_set(); + } + return new_view; + } + /// Ensure that the argument is a NumPy array /// In case of an error, nullptr is returned and the Python error is cleared. static array ensure(handle h, int ExtraFlags = 0) { @@ -788,25 +847,6 @@ protected: throw std::domain_error("array is not writeable"); } - // Default, C-style strides - static std::vector c_strides(const std::vector &shape, ssize_t itemsize) { - auto ndim = shape.size(); - std::vector strides(ndim, itemsize); - if (ndim > 0) - for (size_t i = ndim - 1; i > 0; --i) - strides[i - 1] = strides[i] * shape[i]; - return strides; - } - - // F-style strides; default when constructing an array_t with `ExtraFlags & f_style` - static std::vector f_strides(const std::vector &shape, ssize_t itemsize) { - auto ndim = shape.size(); - std::vector strides(ndim, itemsize); - for (size_t i = 1; i < ndim; ++i) - strides[i] = strides[i - 1] * shape[i - 1]; - return strides; - } - template void check_dimensions(Ix... index) const { check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...); } @@ -854,6 +894,7 @@ public: if (!is_borrowed) Py_XDECREF(h.ptr()); } + // NOLINTNEXTLINE(google-explicit-constructor) array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) { if (!m_ptr) throw error_already_set(); } @@ -864,9 +905,12 @@ public: : array(std::move(shape), std::move(strides), ptr, base) { } explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle()) - : array_t(private_ctor{}, std::move(shape), - ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()), - ptr, base) { } + : array_t(private_ctor{}, + std::move(shape), + (ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize()) + : detail::c_strides(*shape, itemsize()), + ptr, + base) {} explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle()) : array({count}, {}, ptr, base) { } @@ -977,7 +1021,7 @@ template struct format_descriptor::is_array>> { static std::string format() { using namespace detail; - static constexpr auto extents = _("(") + array_info::extents + _(")"); + static constexpr auto extents = const_name("(") + array_info::extents + const_name(")"); return extents.text + format_descriptor>::format(); } }; @@ -1012,23 +1056,28 @@ struct npy_format_descriptor_name; template struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value>( - _("bool"), _::value>("numpy.int", "numpy.uint") + _() + static constexpr auto name = const_name::value>( + const_name("bool"), const_name::value>("numpy.int", "numpy.uint") + const_name() ); }; template struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value || std::is_same::value>( - _("numpy.float") + _(), _("numpy.longdouble") + static constexpr auto name = const_name::value + || std::is_same::value + || std::is_same::value + || std::is_same::value>( + const_name("numpy.float") + const_name(), const_name("numpy.longdouble") ); }; template struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value - || std::is_same::value>( - _("numpy.complex") + _(), _("numpy.longcomplex") + static constexpr auto name = const_name::value + || std::is_same::value + || std::is_same::value + || std::is_same::value>( + const_name("numpy.complex") + const_name(), const_name("numpy.longcomplex") ); }; @@ -1056,7 +1105,7 @@ public: }; #define PYBIND11_DECL_CHAR_FMT \ - static constexpr auto name = _("S") + _(); \ + static constexpr auto name = const_name("S") + const_name(); \ static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } template struct npy_format_descriptor { PYBIND11_DECL_CHAR_FMT }; template struct npy_format_descriptor> { PYBIND11_DECL_CHAR_FMT }; @@ -1068,7 +1117,7 @@ private: public: static_assert(!array_info::is_empty, "Zero-sized arrays are not supported"); - static constexpr auto name = _("(") + array_info::extents + _(")") + base_descr::name; + static constexpr auto name = const_name("(") + array_info::extents + const_name(")") + base_descr::name; static pybind11::dtype dtype() { list shape; array_info::append_extents(shape); @@ -1092,7 +1141,7 @@ struct field_descriptor { dtype descr; }; -inline PYBIND11_NOINLINE void register_structured_dtype( +PYBIND11_NOINLINE void register_structured_dtype( any_container fields, const std::type_info& tinfo, ssize_t itemsize, bool (*direct_converter)(PyObject *, void *&)) { @@ -1116,7 +1165,10 @@ inline PYBIND11_NOINLINE void register_structured_dtype( formats.append(field.descr); offsets.append(pybind11::int_(field.offset)); } - auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + auto dtype_ptr + = pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize) + .release() + .ptr(); // There is an existing bug in NumPy (as of v1.11): trailing bytes are // not encoded explicitly into the format string. This will supposedly @@ -1270,26 +1322,13 @@ private: #endif // __CLION_IDE__ -template -using array_iterator = typename std::add_pointer::type; - -template -array_iterator array_begin(const buffer_info& buffer) { - return array_iterator(reinterpret_cast(buffer.ptr)); -} - -template -array_iterator array_end(const buffer_info& buffer) { - return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); -} - class common_iterator { public: using container_type = std::vector; using value_type = container_type::value_type; using size_type = container_type::size_type; - common_iterator() : p_ptr(0), m_strides() {} + common_iterator() : m_strides() {} common_iterator(void* ptr, const container_type& strides, const container_type& shape) : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) { @@ -1310,7 +1349,7 @@ public: } private: - char* p_ptr; + char *p_ptr{0}; container_type m_strides; }; @@ -1338,9 +1377,8 @@ public: if (++m_index[i] != m_shape[i]) { increment_common_iterator(i); break; - } else { - m_index[i] = 0; } + m_index[i] = 0; } return *this; } @@ -1474,7 +1512,7 @@ struct vectorize_arg { using call_type = remove_reference_t; // Is this a vectorized argument? static constexpr bool vectorize = - satisfies_any_of::value && + satisfies_any_of::value && satisfies_none_of::value && (!std::is_reference::value || (std::is_lvalue_reference::value && std::is_const::value)); @@ -1482,6 +1520,55 @@ struct vectorize_arg { using type = conditional_t, array::forcecast>, T>; }; + +// py::vectorize when a return type is present +template +struct vectorize_returned_array { + using Type = array_t; + + static Type create(broadcast_trivial trivial, const std::vector &shape) { + if (trivial == broadcast_trivial::f_trivial) + return array_t(shape); + return array_t(shape); + } + + static Return *mutable_data(Type &array) { + return array.mutable_data(); + } + + static Return call(Func &f, Args &... args) { + return f(args...); + } + + static void call(Return *out, size_t i, Func &f, Args &... args) { + out[i] = f(args...); + } +}; + +// py::vectorize when a return type is not present +template +struct vectorize_returned_array { + using Type = none; + + static Type create(broadcast_trivial, const std::vector &) { + return none(); + } + + static void *mutable_data(Type &) { + return nullptr; + } + + static detail::void_type call(Func &f, Args &... args) { + f(args...); + return {}; + } + + static void call(void *, size_t, Func &f, Args &... args) { + f(args...); + } +}; + + template struct vectorize_helper { @@ -1498,8 +1585,11 @@ private: "pybind11::vectorize(...) requires a function with at least one vectorizable argument"); public: - template - explicit vectorize_helper(T &&f) : f(std::forward(f)) { } + template ::type>::value>> + explicit vectorize_helper(T &&f) : f(std::forward(f)) {} object operator()(typename vectorize_arg::type... args) { return run(args..., @@ -1516,6 +1606,8 @@ private: using arg_call_types = std::tuple::call_type...>; template using param_n_t = typename std::tuple_element::type; + using returned_array = vectorize_returned_array; + // Runs a vectorized function given arguments tuple and three index sequences: // - Index is the full set of 0 ... (N-1) argument indices; // - VIndex is the subset of argument indices with vectorized parameters, letting us access @@ -1547,20 +1639,19 @@ private: // not wrapped in an array). if (size == 1 && ndim == 0) { PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); - return cast(f(*reinterpret_cast *>(params[Index])...)); + return cast(returned_array::call(f, *reinterpret_cast *>(params[Index])...)); } - array_t result; - if (trivial == broadcast_trivial::f_trivial) result = array_t(shape); - else result = array_t(shape); + auto result = returned_array::create(trivial, shape); if (size == 0) return std::move(result); /* Call the function */ + auto mutable_data = returned_array::mutable_data(result); if (trivial == broadcast_trivial::non_trivial) - apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq); + apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq); else - apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq); + apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq); return std::move(result); } @@ -1583,7 +1674,7 @@ private: }}; for (size_t i = 0; i < size; ++i) { - out[i] = f(*reinterpret_cast *>(params[Index])...); + returned_array::call(out, i, f, *reinterpret_cast *>(params[Index])...); for (auto &x : vecparams) x.first += x.second; } } @@ -1591,19 +1682,18 @@ private: template void apply_broadcast(std::array &buffers, std::array ¶ms, - array_t &output_array, + Return *out, + size_t size, + const std::vector &output_shape, index_sequence, index_sequence, index_sequence) { - buffer_info output = output_array.request(); - multi_array_iterator input_iter(buffers, output.shape); + multi_array_iterator input_iter(buffers, output_shape); - for (array_iterator iter = array_begin(output), end = array_end(output); - iter != end; - ++iter, ++input_iter) { + for (size_t i = 0; i < size; ++i, ++input_iter) { PYBIND11_EXPAND_SIDE_EFFECTS(( params[VIndex] = input_iter.template data() )); - *iter = f(*reinterpret_cast *>(std::get(params))...); + returned_array::call(out, i, f, *reinterpret_cast *>(std::get(params))...); } } }; @@ -1615,7 +1705,7 @@ vectorize_extractor(const Func &f, Return (*) (Args ...)) { } template struct handle_type_name> { - static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor::name + _("]"); + static constexpr auto name = const_name("numpy.ndarray[") + npy_format_descriptor::name + const_name("]"); }; PYBIND11_NAMESPACE_END(detail) @@ -1649,7 +1739,3 @@ Helper vectorize(Return (Class::*f)(Args...) const) { } PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/wrap/pybind11/include/pybind11/operators.h b/wrap/pybind11/include/pybind11/operators.h index 086cb4cfd..2a6153158 100644 --- a/wrap/pybind11/include/pybind11/operators.h +++ b/wrap/pybind11/include/pybind11/operators.h @@ -11,13 +11,6 @@ #include "pybind11.h" -#if defined(__clang__) && !defined(__INTEL_COMPILER) -# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) -#elif defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -58,7 +51,8 @@ template struct op_ { using op = op_impl; cl.def(op::name(), &op::execute, is_operator(), extra...); #if PY_MAJOR_VERSION < 3 - if (id == op_truediv || id == op_itruediv) + if (PYBIND11_SILENCE_MSVC_C4127(id == op_truediv) || + PYBIND11_SILENCE_MSVC_C4127(id == op_itruediv)) cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", &op::execute, is_operator(), extra...); #endif @@ -167,7 +161,3 @@ using detail::self; using detail::hash; PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -# pragma warning(pop) -#endif diff --git a/wrap/pybind11/include/pybind11/pybind11.h b/wrap/pybind11/include/pybind11/pybind11.h index f6dba4ed2..7aa93bb5a 100644 --- a/wrap/pybind11/include/pybind11/pybind11.h +++ b/wrap/pybind11/include/pybind11/pybind11.h @@ -10,56 +10,84 @@ #pragma once -#if defined(__INTEL_COMPILER) -# pragma warning push -# pragma warning disable 68 // integer conversion resulted in a change of sign -# pragma warning disable 186 // pointless comparison of unsigned integer with zero -# pragma warning disable 878 // incompatible exception specifications -# pragma warning disable 1334 // the "template" keyword used for syntactic disambiguation may only be used within a template -# pragma warning disable 1682 // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) -# pragma warning disable 1786 // function "strdup" was declared deprecated -# pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard -# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" -#elif defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -# pragma warning(disable: 4512) // warning C4512: Assignment operator was implicitly defined as deleted -# pragma warning(disable: 4800) // warning C4800: 'int': forcing value to bool 'true' or 'false' (performance warning) -# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name -# pragma warning(disable: 4702) // warning C4702: unreachable code -# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified -#elif defined(__GNUG__) && !defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wunused-but-set-parameter" -# pragma GCC diagnostic ignored "-Wunused-but-set-variable" -# pragma GCC diagnostic ignored "-Wmissing-field-initializers" -# pragma GCC diagnostic ignored "-Wstrict-aliasing" -# pragma GCC diagnostic ignored "-Wattributes" -# if __GNUC__ >= 7 -# pragma GCC diagnostic ignored "-Wnoexcept-type" -# endif -#endif - #include "attr.h" +#include "gil.h" #include "options.h" #include "detail/class.h" #include "detail/init.h" +#include +#include +#include +#include +#include +#include + +#include + +#if defined(__cpp_lib_launder) && !(defined(_MSC_VER) && (_MSC_VER < 1914)) +# define PYBIND11_STD_LAUNDER std::launder +# define PYBIND11_HAS_STD_LAUNDER 1 +#else +# define PYBIND11_STD_LAUNDER +# define PYBIND11_HAS_STD_LAUNDER 0 +#endif #if defined(__GNUG__) && !defined(__clang__) # include #endif +/* https://stackoverflow.com/questions/46798456/handling-gccs-noexcept-type-warning + This warning is about ABI compatibility, not code health. + It is only actually needed in a couple places, but apparently GCC 7 "generates this warning if + and only if the first template instantiation ... involves noexcept" [stackoverflow], therefore + it could get triggered from seemingly random places, depending on user code. + No other GCC version generates this warning. + */ +#if defined(__GNUC__) && __GNUC__ == 7 +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wnoexcept-type" +#endif + PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +// Apply all the extensions translators from a list +// Return true if one of the translators completed without raising an exception +// itself. Return of false indicates that if there are other translators +// available, they should be tried. +inline bool apply_exception_translators(std::forward_list& translators) { + auto last_exception = std::current_exception(); + + for (auto &translator : translators) { + try { + translator(last_exception); + return true; + } catch (...) { + last_exception = std::current_exception(); + } + } + return false; +} + +#if defined(_MSC_VER) +# define PYBIND11_COMPAT_STRDUP _strdup +#else +# define PYBIND11_COMPAT_STRDUP strdup +#endif + +PYBIND11_NAMESPACE_END(detail) + /// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object class cpp_function : public function { public: cpp_function() = default; + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(std::nullptr_t) { } /// Construct a cpp_function from a vanilla function pointer template + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Return (*f)(Args...), const Extra&... extra) { initialize(f, f, extra...); } @@ -67,6 +95,7 @@ public: /// Construct a cpp_function from a lambda function (possibly with internal state) template ::value>> + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Func &&f, const Extra&... extra) { initialize(std::forward(f), (detail::function_signature_t *) nullptr, extra...); @@ -74,6 +103,7 @@ public: /// Construct a cpp_function from a class method (non-const, no ref-qualifier) template + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) { initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(std::forward(args)...); }, (Return (*) (Class *, Arg...)) nullptr, extra...); @@ -83,13 +113,15 @@ public: /// A copy of the overload for non-const functions without explicit ref-qualifier /// but with an added `&`. template + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Return (Class::*f)(Arg...)&, const Extra&... extra) { - initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(std::forward(args)...); }, (Return (*) (Class *, Arg...)) nullptr, extra...); } /// Construct a cpp_function from a class method (const, no ref-qualifier) template + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) { initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(std::forward(args)...); }, (Return (*)(const Class *, Arg ...)) nullptr, extra...); @@ -99,8 +131,9 @@ public: /// A copy of the overload for const functions without explicit ref-qualifier /// but with an added `&`. template + // NOLINTNEXTLINE(google-explicit-constructor) cpp_function(Return (Class::*f)(Arg...) const&, const Extra&... extra) { - initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(std::forward(args)...); }, (Return (*)(const Class *, Arg ...)) nullptr, extra...); } @@ -108,9 +141,16 @@ public: object name() const { return attr("__name__"); } protected: + struct InitializingFunctionRecordDeleter { + // `destruct(function_record, false)`: `initialize_generic` copies strings and + // takes care of cleaning up in case of exceptions. So pass `false` to `free_strings`. + void operator()(detail::function_record * rec) { destruct(rec, false); } + }; + using unique_function_record = std::unique_ptr; + /// Space optimization: don't inline this frequently instantiated fragment - PYBIND11_NOINLINE detail::function_record *make_function_record() { - return new detail::function_record(); + PYBIND11_NOINLINE unique_function_record make_function_record() { + return unique_function_record(new detail::function_record()); } /// Special internal constructor for functors, lambda functions, etc. @@ -120,23 +160,38 @@ protected: struct capture { remove_reference_t f; }; /* Store the function including any extra state it might have (e.g. a lambda capture object) */ - auto rec = make_function_record(); + // The unique_ptr makes sure nothing is leaked in case of an exception. + auto unique_rec = make_function_record(); + auto rec = unique_rec.get(); /* Store the capture object directly in the function record if there is enough space */ - if (sizeof(capture) <= sizeof(rec->data)) { + if (PYBIND11_SILENCE_MSVC_C4127(sizeof(capture) <= sizeof(rec->data))) { /* Without these pragmas, GCC warns that there might not be enough space to use the placement new operator. However, the 'if' statement above ensures that this is the case. */ -#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +#if defined(__GNUG__) && __GNUC__ >= 6 && !defined(__clang__) && !defined(__INTEL_COMPILER) # pragma GCC diagnostic push # pragma GCC diagnostic ignored "-Wplacement-new" #endif new ((capture *) &rec->data) capture { std::forward(f) }; -#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +#if defined(__GNUG__) && __GNUC__ >= 6 && !defined(__clang__) && !defined(__INTEL_COMPILER) +# pragma GCC diagnostic pop +#endif +#if defined(__GNUG__) && !PYBIND11_HAS_STD_LAUNDER && !defined(__INTEL_COMPILER) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + // UB without std::launder, but without breaking ABI and/or + // a significant refactoring it's "impossible" to solve. + if (!std::is_trivially_destructible::value) + rec->free_data = [](function_record *r) { + auto data = PYBIND11_STD_LAUNDER((capture *) &r->data); + (void) data; + data->~capture(); + }; +#if defined(__GNUG__) && !PYBIND11_HAS_STD_LAUNDER && !defined(__INTEL_COMPILER) # pragma GCC diagnostic pop #endif - if (!std::is_trivially_destructible::value) - rec->free_data = [](function_record *r) { ((capture *) &r->data)->~capture(); }; } else { rec->data[0] = new capture { std::forward(f) }; rec->free_data = [](function_record *r) { delete ((capture *) r->data[0]); }; @@ -148,7 +203,7 @@ protected: conditional_t::value, void_type, Return> >; - static_assert(expected_num_args(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs), + static_assert(expected_num_args(sizeof...(Args), cast_in::args_pos >= 0, cast_in::has_kwargs), "The number of argument annotations does not match the number of function arguments"); /* Dispatch code which converts function arguments and performs the actual function call */ @@ -183,28 +238,36 @@ protected: return result; }; + rec->nargs_pos = cast_in::args_pos >= 0 + ? static_cast(cast_in::args_pos) + : sizeof...(Args) - cast_in::has_kwargs; // Will get reduced more if we have a kw_only + rec->has_args = cast_in::args_pos >= 0; + rec->has_kwargs = cast_in::has_kwargs; + /* Process any user-provided function attributes */ process_attributes::init(extra..., rec); { constexpr bool has_kw_only_args = any_of...>::value, has_pos_only_args = any_of...>::value, - has_args = any_of...>::value, has_arg_annotations = any_of...>::value; static_assert(has_arg_annotations || !has_kw_only_args, "py::kw_only requires the use of argument annotations"); static_assert(has_arg_annotations || !has_pos_only_args, "py::pos_only requires the use of argument annotations (for docstrings and aligning the annotations to the argument)"); - static_assert(!(has_args && has_kw_only_args), "py::kw_only cannot be combined with a py::args argument"); + + static_assert(constexpr_sum(is_kw_only::value...) <= 1, "py::kw_only may be specified only once"); + static_assert(constexpr_sum(is_pos_only::value...) <= 1, "py::pos_only may be specified only once"); + constexpr auto kw_only_pos = constexpr_first(); + constexpr auto pos_only_pos = constexpr_first(); + static_assert(!(has_kw_only_args && has_pos_only_args) || pos_only_pos < kw_only_pos, "py::pos_only must come before py::kw_only"); } /* Generate a readable signature describing the function's arguments and return value types */ - static constexpr auto signature = _("(") + cast_in::arg_names + _(") -> ") + cast_out::name; + static constexpr auto signature = const_name("(") + cast_in::arg_names + const_name(") -> ") + cast_out::name; PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types(); /* Register the function with Python from generic (non-templated) code */ - initialize_generic(rec, signature.text, types.data(), sizeof...(Args)); - - if (cast_in::has_args) rec->has_args = true; - if (cast_in::has_kwargs) rec->has_kwargs = true; + // Pass on the ownership over the `unique_rec` to `initialize_generic`. `rec` stays valid. + initialize_generic(std::move(unique_rec), signature.text, types.data(), sizeof...(Args)); /* Stash some additional information used by an important optimization in 'functional.h' */ using FunctionType = Return (*)(Args...); @@ -217,27 +280,59 @@ protected: } } + // Utility class that keeps track of all duplicated strings, and cleans them up in its destructor, + // unless they are released. Basically a RAII-solution to deal with exceptions along the way. + class strdup_guard { + public: + ~strdup_guard() { + for (auto s : strings) + std::free(s); + } + char *operator()(const char *s) { + auto t = PYBIND11_COMPAT_STRDUP(s); + strings.push_back(t); + return t; + } + void release() { + strings.clear(); + } + private: + std::vector strings; + }; + /// Register a function call with Python (generic non-templated code goes here) - void initialize_generic(detail::function_record *rec, const char *text, + void initialize_generic(unique_function_record &&unique_rec, const char *text, const std::type_info *const *types, size_t args) { + // Do NOT receive `unique_rec` by value. If this function fails to move out the unique_ptr, + // we do not want this to destuct the pointer. `initialize` (the caller) still relies on the + // pointee being alive after this call. Only move out if a `capsule` is going to keep it alive. + auto rec = unique_rec.get(); + + // Keep track of strdup'ed strings, and clean them up as long as the function's capsule + // has not taken ownership yet (when `unique_rec.release()` is called). + // Note: This cannot easily be fixed by a `unique_ptr` with custom deleter, because the strings + // are only referenced before strdup'ing. So only *after* the following block could `destruct` + // safely be called, but even then, `repr` could still throw in the middle of copying all strings. + strdup_guard guarded_strdup; /* Create copies of all referenced C-style strings */ - rec->name = strdup(rec->name ? rec->name : ""); - if (rec->doc) rec->doc = strdup(rec->doc); + rec->name = guarded_strdup(rec->name ? rec->name : ""); + if (rec->doc) rec->doc = guarded_strdup(rec->doc); for (auto &a: rec->args) { if (a.name) - a.name = strdup(a.name); + a.name = guarded_strdup(a.name); if (a.descr) - a.descr = strdup(a.descr); + a.descr = guarded_strdup(a.descr); else if (a.value) - a.descr = strdup(repr(a.value).cast().c_str()); + a.descr = guarded_strdup(repr(a.value).cast().c_str()); } - rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); + rec->is_constructor = (std::strcmp(rec->name, "__init__") == 0) + || (std::strcmp(rec->name, "__setstate__") == 0); #if !defined(NDEBUG) && !defined(PYBIND11_DISABLE_NEW_STYLE_INIT_WARNING) if (rec->is_constructor && !rec->is_new_style_constructor) { - const auto class_name = std::string(((PyTypeObject *) rec->scope.ptr())->tp_name); + const auto class_name = detail::get_fully_qualified_tp_name((PyTypeObject *) rec->scope.ptr()); const auto func_name = std::string(rec->name); PyErr_WarnEx( PyExc_FutureWarning, @@ -252,16 +347,18 @@ protected: /* Generate a proper function signature */ std::string signature; size_t type_index = 0, arg_index = 0; + bool is_starred = false; for (auto *pc = text; *pc != '\0'; ++pc) { const auto c = *pc; if (c == '{') { // Write arg name for everything except *args and **kwargs. - if (*(pc + 1) == '*') + is_starred = *(pc + 1) == '*'; + if (is_starred) continue; // Separator for keyword-only arguments, placed before the kw - // arguments start - if (rec->nargs_kw_only > 0 && arg_index + rec->nargs_kw_only == args) + // arguments start (unless we are already putting an *args) + if (!rec->has_args && arg_index == rec->nargs_pos) signature += "*, "; if (arg_index < rec->args.size() && rec->args[arg_index].name) { signature += rec->args[arg_index].name; @@ -273,7 +370,7 @@ protected: signature += ": "; } else if (c == '}') { // Write default value if available. - if (arg_index < rec->args.size() && rec->args[arg_index].descr) { + if (!is_starred && arg_index < rec->args.size() && rec->args[arg_index].descr) { signature += " = "; signature += rec->args[arg_index].descr; } @@ -281,7 +378,8 @@ protected: // argument, rather than before like * if (rec->nargs_pos_only > 0 && (arg_index + 1) == rec->nargs_pos_only) signature += ", /"; - arg_index++; + if (!is_starred) + arg_index++; } else if (c == '%') { const std::type_info *t = types[type_index++]; if (!t) @@ -307,19 +405,19 @@ protected: } } - if (arg_index != args || types[type_index] != nullptr) + if (arg_index != args - rec->has_args - rec->has_kwargs || types[type_index] != nullptr) pybind11_fail("Internal error while parsing type signature (2)"); #if PY_MAJOR_VERSION < 3 - if (strcmp(rec->name, "__next__") == 0) { + if (std::strcmp(rec->name, "__next__") == 0) { std::free(rec->name); - rec->name = strdup("next"); - } else if (strcmp(rec->name, "__bool__") == 0) { + rec->name = guarded_strdup("next"); + } else if (std::strcmp(rec->name, "__bool__") == 0) { std::free(rec->name); - rec->name = strdup("__nonzero__"); + rec->name = guarded_strdup("__nonzero__"); } #endif - rec->signature = strdup(signature.c_str()); + rec->signature = guarded_strdup(signature.c_str()); rec->args.shrink_to_fit(); rec->nargs = (std::uint16_t) args; @@ -329,7 +427,8 @@ protected: detail::function_record *chain = nullptr, *chain_start = rec; if (rec->sibling) { if (PyCFunction_Check(rec->sibling.ptr())) { - auto rec_capsule = reinterpret_borrow(PyCFunction_GET_SELF(rec->sibling.ptr())); + auto *self = PyCFunction_GET_SELF(rec->sibling.ptr()); + capsule rec_capsule = isinstance(self) ? reinterpret_borrow(self) : capsule(self); chain = (detail::function_record *) rec_capsule; /* Never append a method to an overload chain of a parent class; instead, hide the parent's overloads in this case */ @@ -347,12 +446,14 @@ protected: rec->def = new PyMethodDef(); std::memset(rec->def, 0, sizeof(PyMethodDef)); rec->def->ml_name = rec->name; - rec->def->ml_meth = reinterpret_cast(reinterpret_cast(*dispatcher)); + rec->def->ml_meth + = reinterpret_cast(reinterpret_cast(dispatcher)); rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS; - capsule rec_capsule(rec, [](void *ptr) { + capsule rec_capsule(unique_rec.release(), [](void *ptr) { destruct((detail::function_record *) ptr); }); + guarded_strdup.release(); object scope_module; if (rec->scope) { @@ -367,10 +468,9 @@ protected: if (!m_ptr) pybind11_fail("cpp_function::cpp_function(): Could not allocate function object"); } else { - /* Append at the end of the overload chain */ + /* Append at the beginning or end of the overload chain */ m_ptr = rec->sibling.ptr(); inc_ref(); - chain_start = chain; if (chain->is_method != rec->is_method) pybind11_fail("overloading a method with both static and instance methods is not supported; " #if defined(NDEBUG) @@ -380,9 +480,24 @@ protected: std::string(pybind11::str(rec->scope.attr("__name__"))) + "." + std::string(rec->name) + signature #endif ); - while (chain->next) - chain = chain->next; - chain->next = rec; + + if (rec->prepend) { + // Beginning of chain; we need to replace the capsule's current head-of-the-chain + // pointer with this one, then make this one point to the previous head of the + // chain. + chain_start = rec; + rec->next = chain; + auto rec_capsule = reinterpret_borrow(((PyCFunctionObject *) m_ptr)->m_self); + rec_capsule.set_pointer(unique_rec.release()); + guarded_strdup.release(); + } else { + // Or end of chain (normal behavior) + chain_start = chain; + while (chain->next) + chain = chain->next; + chain->next = unique_rec.release(); + guarded_strdup.release(); + } } std::string signatures; @@ -406,7 +521,7 @@ protected: signatures += it->signature; signatures += "\n"; } - if (it->doc && strlen(it->doc) > 0 && options::show_user_defined_docstrings()) { + if (it->doc && it->doc[0] != '\0' && options::show_user_defined_docstrings()) { // If we're appending another docstring, and aren't printing function signatures, we // need to append a newline first: if (!options::show_function_signatures()) { @@ -421,9 +536,10 @@ protected: /* Install docstring */ auto *func = (PyCFunctionObject *) m_ptr; - if (func->m_ml->ml_doc) - std::free(const_cast(func->m_ml->ml_doc)); - func->m_ml->ml_doc = strdup(signatures.c_str()); + std::free(const_cast(func->m_ml->ml_doc)); + // Install docstring if it's non-empty (when at least one option is enabled) + func->m_ml->ml_doc + = signatures.empty() ? nullptr : PYBIND11_COMPAT_STRDUP(signatures.c_str()); if (rec->is_method) { m_ptr = PYBIND11_INSTANCE_METHOD_NEW(m_ptr, rec->scope.ptr()); @@ -434,28 +550,49 @@ protected: } /// When a cpp_function is GCed, release any memory allocated by pybind11 - static void destruct(detail::function_record *rec) { + static void destruct(detail::function_record *rec, bool free_strings = true) { + // If on Python 3.9, check the interpreter "MICRO" (patch) version. + // If this is running on 3.9.0, we have to work around a bug. + #if !defined(PYPY_VERSION) && PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 9 + static bool is_zero = Py_GetVersion()[4] == '0'; + #endif + while (rec) { detail::function_record *next = rec->next; if (rec->free_data) rec->free_data(rec); - std::free((char *) rec->name); - std::free((char *) rec->doc); - std::free((char *) rec->signature); - for (auto &arg: rec->args) { - std::free(const_cast(arg.name)); - std::free(const_cast(arg.descr)); - arg.value.dec_ref(); + // During initialization, these strings might not have been copied yet, + // so they cannot be freed. Once the function has been created, they can. + // Check `make_function_record` for more details. + if (free_strings) { + std::free((char *) rec->name); + std::free((char *) rec->doc); + std::free((char *) rec->signature); + for (auto &arg: rec->args) { + std::free(const_cast(arg.name)); + std::free(const_cast(arg.descr)); + } } + for (auto &arg: rec->args) + arg.value.dec_ref(); if (rec->def) { std::free(const_cast(rec->def->ml_doc)); - delete rec->def; + // Python 3.9.0 decref's these in the wrong order; rec->def + // If loaded on 3.9.0, let these leak (use Python 3.9.1 at runtime to fix) + // See https://github.com/python/cpython/pull/22670 + #if !defined(PYPY_VERSION) && PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 9 + if (!is_zero) + delete rec->def; + #else + delete rec->def; + #endif } delete rec; rec = next; } } + /// Main dispatch logic for calls to functions bound using pybind11 static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) { using namespace detail; @@ -472,15 +609,15 @@ protected: auto self_value_and_holder = value_and_holder(); if (overloads->is_constructor) { - const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr()); - const auto pi = reinterpret_cast(parent.ptr()); - self_value_and_holder = pi->get_value_and_holder(tinfo, false); - - if (!self_value_and_holder.type || !self_value_and_holder.inst) { - PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid `self` argument"); + if (!parent || !PyObject_TypeCheck(parent.ptr(), (PyTypeObject *) overloads->scope.ptr())) { + PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid or missing `self` argument"); return nullptr; } + const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr()); + const auto pi = reinterpret_cast(parent.ptr()); + self_value_and_holder = pi->get_value_and_holder(tinfo, true); + // If this value is already registered it must mean __init__ is invoked multiple times; // we really can't support that in C++, so just ignore the second __init__. if (self_value_and_holder.instance_registered()) @@ -504,7 +641,7 @@ protected: named positional arguments weren't *also* specified via kwarg. 2. If we weren't given enough, try to make up the omitted ones by checking whether they were provided by a kwarg matching the `py::arg("name")` name. If - so, use it (and remove it from kwargs; if not, see if the function binding + so, use it (and remove it from kwargs); if not, see if the function binding provided a default that we can use. 3. Ensure that either all keyword arguments were "consumed", or that the function takes a kwargs argument to accept unconsumed kwargs. @@ -522,7 +659,7 @@ protected: size_t num_args = func.nargs; // Number of positional arguments that we need if (func.has_args) --num_args; // (but don't count py::args if (func.has_kwargs) --num_args; // or py::kwargs) - size_t pos_args = num_args - func.nargs_kw_only; + size_t pos_args = func.nargs_pos; if (!func.has_args && n_args_in > pos_args) continue; // Too many positional arguments for this overload @@ -552,7 +689,7 @@ protected: bool bad_arg = false; for (; args_copied < args_to_copy; ++args_copied) { const argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr; - if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) { + if (kwargs_in && arg_rec && arg_rec->name && dict_getitemstring(kwargs_in, arg_rec->name)) { bad_arg = true; break; } @@ -568,21 +705,25 @@ protected: if (bad_arg) continue; // Maybe it was meant for another overload (issue #688) + // Keep track of how many position args we copied out in case we need to come back + // to copy the rest into a py::args argument. + size_t positional_args_copied = args_copied; + // We'll need to copy this if we steal some kwargs for defaults dict kwargs = reinterpret_borrow(kwargs_in); // 1.5. Fill in any missing pos_only args from defaults if they exist if (args_copied < func.nargs_pos_only) { for (; args_copied < func.nargs_pos_only; ++args_copied) { - const auto &arg = func.args[args_copied]; + const auto &arg_rec = func.args[args_copied]; handle value; - if (arg.value) { - value = arg.value; + if (arg_rec.value) { + value = arg_rec.value; } if (value) { call.args.push_back(value); - call.args_convert.push_back(arg.convert); + call.args_convert.push_back(arg_rec.convert); } else break; } @@ -596,11 +737,11 @@ protected: bool copied_kwargs = false; for (; args_copied < num_args; ++args_copied) { - const auto &arg = func.args[args_copied]; + const auto &arg_rec = func.args[args_copied]; handle value; - if (kwargs_in && arg.name) - value = PyDict_GetItemString(kwargs.ptr(), arg.name); + if (kwargs_in && arg_rec.name) + value = dict_getitemstring(kwargs.ptr(), arg_rec.name); if (value) { // Consume a kwargs value @@ -608,14 +749,24 @@ protected: kwargs = reinterpret_steal(PyDict_Copy(kwargs.ptr())); copied_kwargs = true; } - PyDict_DelItemString(kwargs.ptr(), arg.name); - } else if (arg.value) { - value = arg.value; + if (PyDict_DelItemString(kwargs.ptr(), arg_rec.name) == -1) { + throw error_already_set(); + } + } else if (arg_rec.value) { + value = arg_rec.value; + } + + if (!arg_rec.none && value.is_none()) { + break; } if (value) { + // If we're at the py::args index then first insert a stub for it to be replaced later + if (func.has_args && call.args.size() == func.nargs_pos) + call.args.push_back(none()); + call.args.push_back(value); - call.args_convert.push_back(arg.convert); + call.args_convert.push_back(arg_rec.convert); } else break; @@ -636,16 +787,19 @@ protected: // We didn't copy out any position arguments from the args_in tuple, so we // can reuse it directly without copying: extra_args = reinterpret_borrow(args_in); - } else if (args_copied >= n_args_in) { + } else if (positional_args_copied >= n_args_in) { extra_args = tuple(0); } else { - size_t args_size = n_args_in - args_copied; + size_t args_size = n_args_in - positional_args_copied; extra_args = tuple(args_size); for (size_t i = 0; i < args_size; ++i) { - extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i); + extra_args[i] = PyTuple_GET_ITEM(args_in, positional_args_copied + i); } } - call.args.push_back(extra_args); + if (call.args.size() <= func.nargs_pos) + call.args.push_back(extra_args); + else + call.args[func.nargs_pos] = extra_args; call.args_convert.push_back(false); call.args_ref = std::move(extra_args); } @@ -724,14 +878,18 @@ protected: } catch (error_already_set &e) { e.restore(); return nullptr; -#if defined(__GNUG__) && !defined(__clang__) +#ifdef __GLIBCXX__ } catch ( abi::__forced_unwind& ) { throw; #endif } catch (...) { /* When an exception is caught, give each registered exception - translator a chance to translate it to a Python exception - in reverse order of registration. + translator a chance to translate it to a Python exception. First + all module-local translators will be tried in reverse order of + registration. If none of the module-locale translators handle + the exception (or there are no module-locale translators) then + the global translators will be tried, also in reverse order of + registration. A translator may choose to do one of the following: @@ -740,17 +898,15 @@ protected: - do nothing and let the exception fall through to the next translator, or - delegate translation to the next translator by throwing a new type of exception. */ - auto last_exception = std::current_exception(); - auto ®istered_exception_translators = get_internals().registered_exception_translators; - for (auto& translator : registered_exception_translators) { - try { - translator(last_exception); - } catch (...) { - last_exception = std::current_exception(); - continue; - } + auto &local_exception_translators = get_local_internals().registered_exception_translators; + if (detail::apply_exception_translators(local_exception_translators)) { return nullptr; } + auto &exception_translators = get_internals().registered_exception_translators; + if (detail::apply_exception_translators(exception_translators)) { + return nullptr; + } + PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!"); return nullptr; } @@ -832,47 +988,54 @@ protected: } append_note_if_missing_header_is_suspected(msg); +#if PY_VERSION_HEX >= 0x03030000 + // Attach additional error info to the exception if supported + if (PyErr_Occurred()) { + // #HelpAppreciated: unit test coverage for this branch. + raise_from(PyExc_TypeError, msg.c_str()); + return nullptr; + } +#endif PyErr_SetString(PyExc_TypeError, msg.c_str()); return nullptr; - } else if (!result) { + } + if (!result) { std::string msg = "Unable to convert function return value to a " "Python type! The signature was\n\t"; msg += it->signature; append_note_if_missing_header_is_suspected(msg); +#if PY_VERSION_HEX >= 0x03030000 + // Attach additional error info to the exception if supported + if (PyErr_Occurred()) { + raise_from(PyExc_TypeError, msg.c_str()); + return nullptr; + } +#endif PyErr_SetString(PyExc_TypeError, msg.c_str()); return nullptr; - } else { - if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) { - auto *pi = reinterpret_cast(parent.ptr()); - self_value_and_holder.type->init_instance(pi, nullptr); - } - return result.ptr(); } + if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) { + auto *pi = reinterpret_cast(parent.ptr()); + self_value_and_holder.type->init_instance(pi, nullptr); + } + return result.ptr(); } }; + /// Wrapper for Python extension modules class module_ : public object { public: PYBIND11_OBJECT_DEFAULT(module_, object, PyModule_Check) /// Create a new top-level Python module with the given name and docstring + PYBIND11_DEPRECATED("Use PYBIND11_MODULE or module_::create_extension_module instead") explicit module_(const char *name, const char *doc = nullptr) { - if (!options::show_user_defined_docstrings()) doc = nullptr; #if PY_MAJOR_VERSION >= 3 - auto *def = new PyModuleDef(); - std::memset(def, 0, sizeof(PyModuleDef)); - def->m_name = name; - def->m_doc = doc; - def->m_size = -1; - Py_INCREF(def); - m_ptr = PyModule_Create(def); + *this = create_extension_module(name, doc, new PyModuleDef()); #else - m_ptr = Py_InitModule3(name, nullptr, doc); + *this = create_extension_module(name, doc, nullptr); #endif - if (m_ptr == nullptr) - pybind11_fail("Internal error in module_::module_()"); - inc_ref(); } /** \rst @@ -896,9 +1059,9 @@ public: .. code-block:: cpp - py::module m("example", "pybind11 example plugin"); - py::module m2 = m.def_submodule("sub", "A submodule of 'example'"); - py::module m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'"); + py::module_ m("example", "pybind11 example plugin"); + py::module_ m2 = m.def_submodule("sub", "A submodule of 'example'"); + py::module_ m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'"); \endrst */ module_ def_submodule(const char *name, const char *doc = nullptr) { std::string full_name = std::string(PyModule_GetName(m_ptr)) @@ -926,11 +1089,13 @@ public: *this = reinterpret_steal(obj); } - // Adds an object to the module using the given name. Throws if an object with the given name - // already exists. - // - // overwrite should almost always be false: attempting to overwrite objects that pybind11 has - // established will, in most cases, break things. + /** \rst + Adds an object to the module using the given name. Throws if an object with the given name + already exists. + + ``overwrite`` should almost always be false: attempting to overwrite objects that pybind11 has + established will, in most cases, break things. + \endrst */ PYBIND11_NOINLINE void add_object(const char *name, handle obj, bool overwrite = false) { if (!overwrite && hasattr(*this, name)) pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" + @@ -938,8 +1103,53 @@ public: PyModule_AddObject(ptr(), name, obj.inc_ref().ptr() /* steals a reference */); } + +#if PY_MAJOR_VERSION >= 3 + using module_def = PyModuleDef; +#else + struct module_def {}; +#endif + + /** \rst + Create a new top-level module that can be used as the main module of a C extension. + + For Python 3, ``def`` should point to a statically allocated module_def. + For Python 2, ``def`` can be a nullptr and is completely ignored. + \endrst */ + static module_ create_extension_module(const char *name, const char *doc, module_def *def) { +#if PY_MAJOR_VERSION >= 3 + // module_def is PyModuleDef + def = new (def) PyModuleDef { // Placement new (not an allocation). + /* m_base */ PyModuleDef_HEAD_INIT, + /* m_name */ name, + /* m_doc */ options::show_user_defined_docstrings() ? doc : nullptr, + /* m_size */ -1, + /* m_methods */ nullptr, + /* m_slots */ nullptr, + /* m_traverse */ nullptr, + /* m_clear */ nullptr, + /* m_free */ nullptr + }; + auto m = PyModule_Create(def); +#else + // Ignore module_def *def; only necessary for Python 3 + (void) def; + auto m = Py_InitModule3(name, nullptr, options::show_user_defined_docstrings() ? doc : nullptr); +#endif + if (m == nullptr) { + if (PyErr_Occurred()) + throw error_already_set(); + pybind11_fail("Internal error in module_::create_extension_module()"); + } + // TODO: Should be reinterpret_steal for Python 3, but Python also steals it again when returned from PyInit_... + // For Python 2, reinterpret_borrow is correct. + return reinterpret_borrow(m); + } }; +// When inside a namespace (or anywhere as long as it's not the first item on a line), +// C++20 allows "module" to be used. This is provided for backward compatibility, and for +// simplicity, if someone wants to use py::module for example, that is perfectly safe. using module = module_; /// \ingroup python_builtins @@ -947,22 +1157,31 @@ using module = module_; /// or ``__main__.__dict__`` if there is no frame (usually when the interpreter is embedded). inline dict globals() { PyObject *p = PyEval_GetGlobals(); - return reinterpret_borrow(p ? p : module::import("__main__").attr("__dict__").ptr()); + return reinterpret_borrow(p ? p : module_::import("__main__").attr("__dict__").ptr()); } +#if PY_VERSION_HEX >= 0x03030000 +template ()>> +PYBIND11_DEPRECATED("make_simple_namespace should be replaced with py::module_::import(\"types\").attr(\"SimpleNamespace\") ") +object make_simple_namespace(Args&&... args_) { + return module_::import("types").attr("SimpleNamespace")(std::forward(args_)...); +} +#endif + PYBIND11_NAMESPACE_BEGIN(detail) /// Generic support for creating new Python heap types class generic_type : public object { - template friend class class_; public: PYBIND11_OBJECT_DEFAULT(generic_type, object, PyType_Check) protected: void initialize(const type_record &rec) { - if (rec.scope && hasattr(rec.scope, rec.name)) + if (rec.scope && hasattr(rec.scope, "__dict__") && rec.scope.attr("__dict__").contains(rec.name)) pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) + "\": an object with that name is already defined"); - if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type)) + if ((rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type)) + != nullptr) pybind11_fail("generic_type: type \"" + std::string(rec.name) + "\" is already registered!"); @@ -987,7 +1206,7 @@ protected: auto tindex = std::type_index(*rec.type); tinfo->direct_conversions = &internals.direct_conversions[tindex]; if (rec.module_local) - registered_local_types_cpp()[tindex] = tinfo; + get_local_internals().registered_types_cpp[tindex] = tinfo; else internals.registered_types_cpp[tindex] = tinfo; internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo }; @@ -997,8 +1216,12 @@ protected: tinfo->simple_ancestors = false; } else if (rec.bases.size() == 1) { - auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); - tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + auto *parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); + assert(parent_tinfo != nullptr); + bool parent_simple_ancestors = parent_tinfo->simple_ancestors; + tinfo->simple_ancestors = parent_simple_ancestors; + // The parent can no longer be a simple type if it has MI and has a child + parent_tinfo->simple_type = parent_tinfo->simple_type && parent_simple_ancestors; } if (rec.module_local) { @@ -1028,7 +1251,7 @@ protected: if (!type->ht_type.tp_as_buffer) pybind11_fail( "To be able to register buffer protocol support for the type '" + - std::string(tinfo->type->tp_name) + + get_fully_qualified_tp_name(tinfo->type) + "' the associated class<>(..) invocation must " "include the pybind11::buffer_protocol() annotation!"); @@ -1040,8 +1263,9 @@ protected: void def_property_static_impl(const char *name, handle fget, handle fset, detail::function_record *rec_func) { - const auto is_static = rec_func && !(rec_func->is_method && rec_func->scope); - const auto has_doc = rec_func && rec_func->doc && pybind11::options::show_user_defined_docstrings(); + const auto is_static = (rec_func != nullptr) && !(rec_func->is_method && rec_func->scope); + const auto has_doc = (rec_func != nullptr) && (rec_func->doc != nullptr) + && pybind11::options::show_user_defined_docstrings(); auto property = handle((PyObject *) (is_static ? get_internals().static_property_type : &PyProperty_Type)); attr(name) = property(fget.ptr() ? fget : none(), @@ -1090,8 +1314,8 @@ inline void call_operator_delete(void *p, size_t s, size_t a) { inline void add_class_method(object& cls, const char *name_, const cpp_function &cf) { cls.attr(cf.name()) = cf; - if (strcmp(name_, "__eq__") == 0 && !cls.attr("__dict__").contains("__hash__")) { - cls.attr("__hash__") = none(); + if (std::strcmp(name_, "__eq__") == 0 && !cls.attr("__dict__").contains("__hash__")) { + cls.attr("__hash__") = none(); } } @@ -1173,7 +1397,7 @@ public: generic_type::initialize(record); if (has_alias) { - auto &instances = record.module_local ? registered_local_types_cpp() : get_internals().registered_types_cpp; + auto &instances = record.module_local ? get_local_internals().registered_types_cpp : get_internals().registered_types_cpp; instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))]; } } @@ -1220,12 +1444,14 @@ public: template class_ &def(const detail::initimpl::constructor &init, const Extra&... extra) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(init); init.execute(*this, extra...); return *this; } template class_ &def(const detail::initimpl::alias_constructor &init, const Extra&... extra) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(init); init.execute(*this, extra...); return *this; } @@ -1242,7 +1468,8 @@ public: return *this; } - template class_& def_buffer(Func &&func) { + template + class_& def_buffer(Func &&func) { struct capture { Func func; }; auto *ptr = new capture { std::forward(func) }; install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* { @@ -1251,6 +1478,10 @@ public: return nullptr; return new buffer_info(((capture *) ptr)->func(caster)); }, ptr); + weakref(m_ptr, cpp_function([ptr](handle wr) { + delete ptr; + wr.dec_ref(); + })).release(); return *this; } @@ -1283,15 +1514,15 @@ public: template class_ &def_readwrite_static(const char *name, D *pm, const Extra& ...extra) { - cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)), - fset([pm](object, const D &value) { *pm = value; }, scope(*this)); + cpp_function fget([pm](const object &) -> const D & { return *pm; }, scope(*this)), + fset([pm](const object &, const D &value) { *pm = value; }, scope(*this)); def_property_static(name, fget, fset, return_value_policy::reference, extra...); return *this; } template class_ &def_readonly_static(const char *name, const D *pm, const Extra& ...extra) { - cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)); + cpp_function fget([pm](const object &) -> const D & { return *pm; }, scope(*this)); def_property_readonly_static(name, fget, return_value_policy::reference, extra...); return *this; } @@ -1355,16 +1586,16 @@ public: char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */ detail::process_attributes::init(extra..., rec_fget); if (rec_fget->doc && rec_fget->doc != doc_prev) { - free(doc_prev); - rec_fget->doc = strdup(rec_fget->doc); + std::free(doc_prev); + rec_fget->doc = PYBIND11_COMPAT_STRDUP(rec_fget->doc); } } if (rec_fset) { char *doc_prev = rec_fset->doc; detail::process_attributes::init(extra..., rec_fset); if (rec_fset->doc && rec_fset->doc != doc_prev) { - free(doc_prev); - rec_fset->doc = strdup(rec_fset->doc); + std::free(doc_prev); + rec_fset->doc = PYBIND11_COMPAT_STRDUP(rec_fset->doc); } if (! rec_active) rec_active = rec_fset; } @@ -1377,14 +1608,13 @@ private: template static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, const holder_type * /* unused */, const std::enable_shared_from_this * /* dummy */) { - try { - auto sh = std::dynamic_pointer_cast( - v_h.value_ptr()->shared_from_this()); - if (sh) { - new (std::addressof(v_h.holder())) holder_type(std::move(sh)); - v_h.set_holder_constructed(); - } - } catch (const std::bad_weak_ptr &) {} + + auto sh = std::dynamic_pointer_cast( + detail::try_get_shared_from_this(v_h.value_ptr())); + if (sh) { + new (std::addressof(v_h.holder())) holder_type(std::move(sh)); + v_h.set_holder_constructed(); + } if (!v_h.holder_constructed() && inst->owned) { new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); @@ -1481,8 +1711,18 @@ detail::initimpl::pickle_factory pickle(GetState &&g, SetSta } PYBIND11_NAMESPACE_BEGIN(detail) + +inline str enum_name(handle arg) { + dict entries = arg.get_type().attr("__entries"); + for (auto kv : entries) { + if (handle(kv.second[int_(0)]).equal(arg)) + return pybind11::str(kv.first); + } + return "???"; +} + struct enum_base { - enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { } + enum_base(const handle &base, const handle &parent) : m_base(base), m_parent(parent) { } PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) { m_base.attr("__entries") = dict(); @@ -1490,29 +1730,22 @@ struct enum_base { auto static_property = handle((PyObject *) get_internals().static_property_type); m_base.attr("__repr__") = cpp_function( - [](handle arg) -> str { + [](const object &arg) -> str { handle type = type::handle_of(arg); object type_name = type.attr("__name__"); - dict entries = type.attr("__entries"); - for (const auto &kv : entries) { - object other = kv.second[int_(0)]; - if (other.equal(arg)) - return pybind11::str("{}.{}").format(type_name, kv.first); - } - return pybind11::str("{}.???").format(type_name); - }, name("__repr__"), is_method(m_base) - ); + return pybind11::str("<{}.{}: {}>").format(type_name, enum_name(arg), int_(arg)); + }, + name("__repr__"), + is_method(m_base)); - m_base.attr("name") = property(cpp_function( + m_base.attr("name") = property(cpp_function(&enum_name, name("name"), is_method(m_base))); + + m_base.attr("__str__") = cpp_function( [](handle arg) -> str { - dict entries = type::handle_of(arg).attr("__entries"); - for (const auto &kv : entries) { - if (handle(kv.second[int_(0)]).equal(arg)) - return pybind11::str(kv.first); - } - return "???"; + object type_name = type::handle_of(arg).attr("__name__"); + return pybind11::str("{}.{}").format(type_name, enum_name(arg)); }, name("name"), is_method(m_base) - )); + ); m_base.attr("__doc__") = static_property(cpp_function( [](handle arg) -> std::string { @@ -1521,7 +1754,7 @@ struct enum_base { if (((PyTypeObject *) arg.ptr())->tp_doc) docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n"; docstring += "Members:"; - for (const auto &kv : entries) { + for (auto kv : entries) { auto key = std::string(pybind11::str(kv.first)); auto comment = kv.second[int_(1)]; docstring += "\n\n " + key; @@ -1535,36 +1768,42 @@ struct enum_base { m_base.attr("__members__") = static_property(cpp_function( [](handle arg) -> dict { dict entries = arg.attr("__entries"), m; - for (const auto &kv : entries) + for (auto kv : entries) m[kv.first] = kv.second[int_(0)]; return m; }, name("__members__")), none(), none(), "" ); - #define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \ - m_base.attr(op) = cpp_function( \ - [](object a, object b) { \ - if (!type::handle_of(a).is(type::handle_of(b))) \ - strict_behavior; \ - return expr; \ - }, \ - name(op), is_method(m_base)) +#define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \ + m_base.attr(op) = cpp_function( \ + [](const object &a, const object &b) { \ + if (!type::handle_of(a).is(type::handle_of(b))) \ + strict_behavior; /* NOLINT(bugprone-macro-parentheses) */ \ + return expr; \ + }, \ + name(op), \ + is_method(m_base), \ + arg("other")) - #define PYBIND11_ENUM_OP_CONV(op, expr) \ - m_base.attr(op) = cpp_function( \ - [](object a_, object b_) { \ - int_ a(a_), b(b_); \ - return expr; \ - }, \ - name(op), is_method(m_base)) +#define PYBIND11_ENUM_OP_CONV(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](const object &a_, const object &b_) { \ + int_ a(a_), b(b_); \ + return expr; \ + }, \ + name(op), \ + is_method(m_base), \ + arg("other")) - #define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \ - m_base.attr(op) = cpp_function( \ - [](object a_, object b) { \ - int_ a(a_); \ - return expr; \ - }, \ - name(op), is_method(m_base)) +#define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](const object &a_, const object &b) { \ + int_ a(a_); \ + return expr; \ + }, \ + name(op), \ + is_method(m_base), \ + arg("other")) if (is_convertible) { PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b)); @@ -1581,8 +1820,10 @@ struct enum_base { PYBIND11_ENUM_OP_CONV("__ror__", a | b); PYBIND11_ENUM_OP_CONV("__xor__", a ^ b); PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); - m_base.attr("__invert__") = cpp_function( - [](object arg) { return ~(int_(arg)); }, name("__invert__"), is_method(m_base)); + m_base.attr("__invert__") + = cpp_function([](const object &arg) { return ~(int_(arg)); }, + name("__invert__"), + is_method(m_base)); } } else { PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false); @@ -1603,10 +1844,10 @@ struct enum_base { #undef PYBIND11_ENUM_OP_STRICT m_base.attr("__getstate__") = cpp_function( - [](object arg) { return int_(arg); }, name("__getstate__"), is_method(m_base)); + [](const object &arg) { return int_(arg); }, name("__getstate__"), is_method(m_base)); m_base.attr("__hash__") = cpp_function( - [](object arg) { return int_(arg); }, name("__hash__"), is_method(m_base)); + [](const object &arg) { return int_(arg); }, name("__hash__"), is_method(m_base)); } PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) { @@ -1623,7 +1864,7 @@ struct enum_base { PYBIND11_NOINLINE void export_values() { dict entries = m_base.attr("__entries"); - for (const auto &kv : entries) + for (auto kv : entries) m_parent.attr(kv.first) = kv.second[int_(0)]; } @@ -1631,6 +1872,19 @@ struct enum_base { handle m_parent; }; +template struct equivalent_integer {}; +template <> struct equivalent_integer { using type = int8_t; }; +template <> struct equivalent_integer { using type = uint8_t; }; +template <> struct equivalent_integer { using type = int16_t; }; +template <> struct equivalent_integer { using type = uint16_t; }; +template <> struct equivalent_integer { using type = int32_t; }; +template <> struct equivalent_integer { using type = uint32_t; }; +template <> struct equivalent_integer { using type = int64_t; }; +template <> struct equivalent_integer { using type = uint64_t; }; + +template +using equivalent_integer_t = typename equivalent_integer::value, sizeof(IntLike)>::type; + PYBIND11_NAMESPACE_END(detail) /// Binds C++ enumerations and enumeration classes to Python @@ -1641,16 +1895,21 @@ public: using Base::attr; using Base::def_property_readonly; using Base::def_property_readonly_static; - using Scalar = typename std::underlying_type::type; + using Underlying = typename std::underlying_type::type; + // Scalar is the integer representation of underlying type + using Scalar = detail::conditional_t, std::is_same + >::value, detail::equivalent_integer_t, Underlying>; template enum_(const handle &scope, const char *name, const Extra&... extra) : class_(scope, name, extra...), m_base(*this, scope) { constexpr bool is_arithmetic = detail::any_of...>::value; - constexpr bool is_convertible = std::is_convertible::value; + constexpr bool is_convertible = std::is_convertible::value; m_base.init(is_arithmetic, is_convertible); - def(init([](Scalar i) { return static_cast(i); })); + def(init([](Scalar i) { return static_cast(i); }), arg("value")); + def_property_readonly("value", [](Type value) { return (Scalar) value; }); def("__int__", [](Type value) { return (Scalar) value; }); #if PY_MAJOR_VERSION < 3 def("__long__", [](Type value) { return (Scalar) value; }); @@ -1664,7 +1923,7 @@ public: detail::initimpl::setstate(v_h, static_cast(arg), Py_TYPE(v_h.inst) != v_h.type->type); }, detail::is_new_style_constructor(), - pybind11::name("__setstate__"), is_method(*this)); + pybind11::name("__setstate__"), is_method(*this), arg("state")); } /// Export enumeration entries into the parent scope @@ -1686,7 +1945,7 @@ private: PYBIND11_NAMESPACE_BEGIN(detail) -inline void keep_alive_impl(handle nurse, handle patient) { +PYBIND11_NOINLINE void keep_alive_impl(handle nurse, handle patient) { if (!nurse || !patient) pybind11_fail("Could not activate keep_alive!"); @@ -1713,13 +1972,13 @@ inline void keep_alive_impl(handle nurse, handle patient) { } } -PYBIND11_NOINLINE inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { +PYBIND11_NOINLINE void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { auto get_arg = [&](size_t n) { if (n == 0) return ret; - else if (n == 1 && call.init_self) + if (n == 1 && call.init_self) return call.init_self; - else if (n <= call.args.size()) + if (n <= call.args.size()) return call.args[n - 1]; return handle(); }; @@ -1739,6 +1998,16 @@ inline std::pair all_t // gets destroyed: weakref((PyObject *) type, cpp_function([type](handle wr) { get_internals().registered_types_py.erase(type); + + // TODO consolidate the erasure code in pybind11_meta_dealloc() in class.h + auto &cache = get_internals().inactive_override_cache; + for (auto it = cache.begin(), last = cache.end(); it != last; ) { + if (it->first == reinterpret_cast(type)) + it = cache.erase(it); + else + ++it; + } + wr.dec_ref(); })).release(); } @@ -1746,23 +2015,79 @@ inline std::pair all_t return res; } -template +/* There are a large number of apparently unused template arguments because + * each combination requires a separate py::class_ registration. + */ +template struct iterator_state { Iterator it; Sentinel end; bool first_or_done; }; -PYBIND11_NAMESPACE_END(detail) +// Note: these helpers take the iterator by non-const reference because some +// iterators in the wild can't be dereferenced when const. The & after Iterator +// is required for MSVC < 16.9. SFINAE cannot be reused for result_type due to +// bugs in ICC, NVCC, and PGI compilers. See PR #3293. +template ())> +struct iterator_access { + using result_type = decltype(*std::declval()); + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 + result_type operator()(Iterator &it) const { + return *it; + } +}; -/// Makes a python iterator from a first and past-the-end C++ InputIterator. -template ()).first) > +class iterator_key_access { +private: + using pair_type = decltype(*std::declval()); + +public: + /* If either the pair itself or the element of the pair is a reference, we + * want to return a reference, otherwise a value. When the decltype + * expression is parenthesized it is based on the value category of the + * expression; otherwise it is the declared type of the pair member. + * The use of declval in the second branch rather than directly + * using *std::declval() is a workaround for nvcc + * (it's not used in the first branch because going via decltype and back + * through declval does not perfectly preserve references). + */ + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).first)), + decltype(std::declval().first) + >; + result_type operator()(Iterator &it) const { + return (*it).first; + } +}; + +template ()).second)> +class iterator_value_access { +private: + using pair_type = decltype(*std::declval()); + +public: + using result_type = conditional_t< + std::is_reference())>::value, + decltype(((*std::declval()).second)), + decltype(std::declval().second) + >; + result_type operator()(Iterator &it) const { + return (*it).second; + } +}; + +template ()), + typename ValueType, typename... Extra> -iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { - typedef detail::iterator_state state; +iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&... extra) { + using state = detail::iterator_state; + // TODO: state captures only the types of Extra, not the values if (!detail::get_type_info(typeid(state), false)) { class_(handle(), "iterator", pybind11::module_local()) @@ -1776,40 +2101,63 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { s.first_or_done = true; throw stop_iteration(); } - return *s.it; + return Access()(s.it); + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 }, std::forward(extra)..., Policy); } return cast(state{first, last, true}); } -/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a +PYBIND11_NAMESPACE_END(detail) + +/// Makes a python iterator from a first and past-the-end C++ InputIterator. +template ::result_type, + typename... Extra> +iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { + return detail::make_iterator_impl< + detail::iterator_access, + Policy, + Iterator, + Sentinel, + ValueType, + Extra...>(first, last, std::forward(extra)...); +} + +/// Makes a python iterator over the keys (`.first`) of a iterator over pairs from a /// first and past-the-end InputIterator. template ()).first), + typename KeyType = typename detail::iterator_key_access::result_type, typename... Extra> -iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { - using state = detail::iterator_state; +iterator make_key_iterator(Iterator first, Sentinel last, Extra &&...extra) { + return detail::make_iterator_impl< + detail::iterator_key_access, + Policy, + Iterator, + Sentinel, + KeyType, + Extra...>(first, last, std::forward(extra)...); +} - if (!detail::get_type_info(typeid(state), false)) { - class_(handle(), "iterator", pybind11::module_local()) - .def("__iter__", [](state &s) -> state& { return s; }) - .def("__next__", [](state &s) -> KeyType { - if (!s.first_or_done) - ++s.it; - else - s.first_or_done = false; - if (s.it == s.end) { - s.first_or_done = true; - throw stop_iteration(); - } - return (*s.it).first; - }, std::forward(extra)..., Policy); - } - - return cast(state{first, last, true}); +/// Makes a python iterator over the values (`.second`) of a iterator over pairs from a +/// first and past-the-end InputIterator. +template ::result_type, + typename... Extra> +iterator make_value_iterator(Iterator first, Sentinel last, Extra &&...extra) { + return detail::make_iterator_impl< + detail::iterator_value_access, + Policy, Iterator, + Sentinel, + ValueType, + Extra...>(first, last, std::forward(extra)...); } /// Makes an iterator over values of an stl container or other container supporting @@ -1826,10 +2174,17 @@ template (std::begin(value), std::end(value), extra...); } +/// Makes an iterator over the values (`.second`) of a stl map-like container supporting +/// `std::begin()`/`std::end()` +template iterator make_value_iterator(Type &value, Extra&&... extra) { + return make_value_iterator(std::begin(value), std::end(value), extra...); +} + template void implicitly_convertible() { struct set_flag { bool &flag; - set_flag(bool &flag) : flag(flag) { flag = true; } + explicit set_flag(bool &flag_) : flag(flag_) { flag_ = true; } ~set_flag() { flag = false; } }; auto implicit_caster = [](PyObject *obj, PyTypeObject *type) -> PyObject * { @@ -1853,12 +2208,24 @@ template void implicitly_convertible() pybind11_fail("implicitly_convertible: Unable to find type " + type_id()); } -template -void register_exception_translator(ExceptionTranslator&& translator) { + +inline void register_exception_translator(ExceptionTranslator &&translator) { detail::get_internals().registered_exception_translators.push_front( std::forward(translator)); } + +/** + * Add a new module-local exception translator. Locally registered functions + * will be tried before any globally registered exception translators, which + * will only be invoked if the module-local handlers do not deal with + * the exception. + */ +inline void register_local_exception_translator(ExceptionTranslator &&translator) { + detail::get_local_internals().registered_exception_translators.push_front( + std::forward(translator)); +} + /** * Wrapper to generate a new Python exception type. * @@ -1874,7 +2241,7 @@ public: std::string full_name = scope.attr("__name__").cast() + std::string(".") + name; m_ptr = PyErr_NewException(const_cast(full_name.c_str()), base.ptr(), NULL); - if (hasattr(scope, name)) + if (hasattr(scope, "__dict__") && scope.attr("__dict__").contains(name)) pybind11_fail("Error during initialization: multiple incompatible " "definitions with name \"" + std::string(name) + "\""); scope.attr(name) = *this; @@ -1892,22 +2259,20 @@ PYBIND11_NAMESPACE_BEGIN(detail) // directly in register_exception, but that makes clang <3.5 segfault - issue #1349). template exception &get_exception_object() { static exception ex; return ex; } -PYBIND11_NAMESPACE_END(detail) -/** - * Registers a Python exception in `m` of the given `name` and installs an exception translator to - * translate the C++ exception to the created Python exception using the exceptions what() method. - * This is intended for simple exception translations; for more complex translation, register the - * exception object and translator directly. - */ +// Helper function for register_exception and register_local_exception template -exception ®ister_exception(handle scope, - const char *name, - handle base = PyExc_Exception) { +exception ®ister_exception_impl(handle scope, + const char *name, + handle base, + bool isLocal) { auto &ex = detail::get_exception_object(); if (!ex) ex = exception(scope, name, base); - register_exception_translator([](std::exception_ptr p) { + auto register_func = isLocal ? ®ister_local_exception_translator + : ®ister_exception_translator; + + register_func([](std::exception_ptr p) { if (!p) return; try { std::rethrow_exception(p); @@ -1918,8 +2283,38 @@ exception ®ister_exception(handle scope, return ex; } +PYBIND11_NAMESPACE_END(detail) + +/** + * Registers a Python exception in `m` of the given `name` and installs a translator to + * translate the C++ exception to the created Python exception using the what() method. + * This is intended for simple exception translations; for more complex translation, register the + * exception object and translator directly. + */ +template +exception ®ister_exception(handle scope, + const char *name, + handle base = PyExc_Exception) { + return detail::register_exception_impl(scope, name, base, false /* isLocal */); +} + +/** + * Registers a Python exception in `m` of the given `name` and installs a translator to + * translate the C++ exception to the created Python exception using the what() method. + * This translator will only be used for exceptions that are thrown in this module and will be + * tried before global exception translators, including those registered with register_exception. + * This is intended for simple exception translations; for more complex translation, register the + * exception object and translator directly. + */ +template +exception ®ister_local_exception(handle scope, + const char *name, + handle base = PyExc_Exception) { + return detail::register_exception_impl(scope, name, base, true /* isLocal */); +} + PYBIND11_NAMESPACE_BEGIN(detail) -PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { +PYBIND11_NOINLINE void print(const tuple &args, const dict &kwargs) { auto strings = tuple(args.size()); for (size_t i = 0; i < args.size(); ++i) { strings[i] = str(args[i]); @@ -1932,7 +2327,7 @@ PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { file = kwargs["file"].cast(); } else { try { - file = module::import("sys").attr("stdout"); + file = module_::import("sys").attr("stdout"); } catch (const error_already_set &) { /* If print() is called from code that is executed as part of garbage collection during interpreter shutdown, @@ -1957,151 +2352,6 @@ void print(Args &&...args) { detail::print(c.args(), c.kwargs()); } -#if defined(WITH_THREAD) && !defined(PYPY_VERSION) - -/* The functions below essentially reproduce the PyGILState_* API using a RAII - * pattern, but there are a few important differences: - * - * 1. When acquiring the GIL from an non-main thread during the finalization - * phase, the GILState API blindly terminates the calling thread, which - * is often not what is wanted. This API does not do this. - * - * 2. The gil_scoped_release function can optionally cut the relationship - * of a PyThreadState and its associated thread, which allows moving it to - * another thread (this is a fairly rare/advanced use case). - * - * 3. The reference count of an acquired thread state can be controlled. This - * can be handy to prevent cases where callbacks issued from an external - * thread would otherwise constantly construct and destroy thread state data - * structures. - * - * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an - * example which uses features 2 and 3 to migrate the Python thread of - * execution to another thread (to run the event loop on the original thread, - * in this case). - */ - -class gil_scoped_acquire { -public: - PYBIND11_NOINLINE gil_scoped_acquire() { - auto const &internals = detail::get_internals(); - tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate); - - if (!tstate) { - /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if - calling from a Python thread). Since we use a different key, this ensures - we don't create a new thread state and deadlock in PyEval_AcquireThread - below. Note we don't save this state with internals.tstate, since we don't - create it we would fail to clear it (its reference count should be > 0). */ - tstate = PyGILState_GetThisThreadState(); - } - - if (!tstate) { - tstate = PyThreadState_New(internals.istate); - #if !defined(NDEBUG) - if (!tstate) - pybind11_fail("scoped_acquire: could not create thread state!"); - #endif - tstate->gilstate_counter = 0; - PYBIND11_TLS_REPLACE_VALUE(internals.tstate, tstate); - } else { - release = detail::get_thread_state_unchecked() != tstate; - } - - if (release) { - /* Work around an annoying assertion in PyThreadState_Swap */ - #if defined(Py_DEBUG) - PyInterpreterState *interp = tstate->interp; - tstate->interp = nullptr; - #endif - PyEval_AcquireThread(tstate); - #if defined(Py_DEBUG) - tstate->interp = interp; - #endif - } - - inc_ref(); - } - - void inc_ref() { - ++tstate->gilstate_counter; - } - - PYBIND11_NOINLINE void dec_ref() { - --tstate->gilstate_counter; - #if !defined(NDEBUG) - if (detail::get_thread_state_unchecked() != tstate) - pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!"); - if (tstate->gilstate_counter < 0) - pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!"); - #endif - if (tstate->gilstate_counter == 0) { - #if !defined(NDEBUG) - if (!release) - pybind11_fail("scoped_acquire::dec_ref(): internal error!"); - #endif - PyThreadState_Clear(tstate); - PyThreadState_DeleteCurrent(); - PYBIND11_TLS_DELETE_VALUE(detail::get_internals().tstate); - release = false; - } - } - - PYBIND11_NOINLINE ~gil_scoped_acquire() { - dec_ref(); - if (release) - PyEval_SaveThread(); - } -private: - PyThreadState *tstate = nullptr; - bool release = true; -}; - -class gil_scoped_release { -public: - explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { - // `get_internals()` must be called here unconditionally in order to initialize - // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an - // initialization race could occur as multiple threads try `gil_scoped_acquire`. - const auto &internals = detail::get_internals(); - tstate = PyEval_SaveThread(); - if (disassoc) { - auto key = internals.tstate; - PYBIND11_TLS_DELETE_VALUE(key); - } - } - ~gil_scoped_release() { - if (!tstate) - return; - PyEval_RestoreThread(tstate); - if (disassoc) { - auto key = detail::get_internals().tstate; - PYBIND11_TLS_REPLACE_VALUE(key, tstate); - } - } -private: - PyThreadState *tstate; - bool disassoc; -}; -#elif defined(PYPY_VERSION) -class gil_scoped_acquire { - PyGILState_STATE state; -public: - gil_scoped_acquire() { state = PyGILState_Ensure(); } - ~gil_scoped_acquire() { PyGILState_Release(state); } -}; - -class gil_scoped_release { - PyThreadState *state; -public: - gil_scoped_release() { state = PyEval_SaveThread(); } - ~gil_scoped_release() { PyEval_RestoreThread(state); } -}; -#else -class gil_scoped_acquire { }; -class gil_scoped_release { }; -#endif - error_already_set::~error_already_set() { if (m_type) { gil_scoped_acquire gil; @@ -2134,16 +2384,42 @@ inline function get_type_override(const void *this_ptr, const type_info *this_ty /* Don't call dispatch code if invoked from overridden function. Unfortunately this doesn't work on PyPy. */ -#if !defined(PYPY_VERSION) +#if !defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030B0000 + // TODO: Remove PyPy workaround for Python 3.11. + // Current API fails on 3.11 since co_varnames can be null. +#if PY_VERSION_HEX >= 0x03090000 + PyFrameObject *frame = PyThreadState_GetFrame(PyThreadState_Get()); + if (frame != nullptr) { + PyCodeObject *f_code = PyFrame_GetCode(frame); + // f_code is guaranteed to not be NULL + if ((std::string) str(f_code->co_name) == name && f_code->co_argcount > 0) { + PyObject* locals = PyEval_GetLocals(); + if (locals != nullptr && f_code->co_varnames != nullptr) { + PyObject *self_caller = dict_getitem( + locals, PyTuple_GET_ITEM(f_code->co_varnames, 0) + ); + if (self_caller == self.ptr()) { + Py_DECREF(f_code); + Py_DECREF(frame); + return function(); + } + } + } + Py_DECREF(f_code); + Py_DECREF(frame); + } +#else PyFrameObject *frame = PyThreadState_Get()->frame; - if (frame && (std::string) str(frame->f_code->co_name) == name && - frame->f_code->co_argcount > 0) { + if (frame != nullptr && (std::string) str(frame->f_code->co_name) == name + && frame->f_code->co_argcount > 0) { PyFrame_FastToLocals(frame); - PyObject *self_caller = PyDict_GetItem( + PyObject *self_caller = dict_getitem( frame->f_locals, PyTuple_GET_ITEM(frame->f_code->co_varnames, 0)); if (self_caller == self.ptr()) return function(); } +#endif + #else /* PyPy currently doesn't provide a detailed cpyext emulation of frame objects, so we have to emulate this using Python. This @@ -2174,7 +2450,7 @@ PYBIND11_NAMESPACE_END(detail) /** \rst Try to retrieve a python method by the provided name from the instance pointed to by the this_ptr. - :this_ptr: The pointer to the object the overriden method should be retrieved for. This should be + :this_ptr: The pointer to the object the overridden method should be retrieved for. This should be the first non-trampoline class encountered in the inheritance chain. :name: The name of the overridden Python method to retrieve. :return: The Python method by this name from the object or an empty function wrapper. @@ -2184,18 +2460,19 @@ template function get_override(const T *this_ptr, const char *name) { return tinfo ? detail::get_type_override(this_ptr, tinfo, name) : function(); } -#define PYBIND11_OVERRIDE_IMPL(ret_type, cname, name, ...) \ - do { \ - pybind11::gil_scoped_acquire gil; \ - pybind11::function override = pybind11::get_override(static_cast(this), name); \ - if (override) { \ - auto o = override(__VA_ARGS__); \ - if (pybind11::detail::cast_is_temporary_value_reference::value) { \ - static pybind11::detail::override_caster_t caster; \ - return pybind11::detail::cast_ref(std::move(o), caster); \ - } \ - else return pybind11::detail::cast_safe(std::move(o)); \ - } \ +#define PYBIND11_OVERRIDE_IMPL(ret_type, cname, name, ...) \ + do { \ + pybind11::gil_scoped_acquire gil; \ + pybind11::function override \ + = pybind11::get_override(static_cast(this), name); \ + if (override) { \ + auto o = override(__VA_ARGS__); \ + if (pybind11::detail::cast_is_temporary_value_reference::value) { \ + static pybind11::detail::override_caster_t caster; \ + return pybind11::detail::cast_ref(std::move(o), caster); \ + } \ + return pybind11::detail::cast_safe(std::move(o)); \ + } \ } while (false) /** \rst @@ -2291,8 +2568,6 @@ inline function get_overload(const T *this_ptr, const char *name) { PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) -#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) -# pragma warning(pop) -#elif defined(__GNUG__) && !defined(__clang__) -# pragma GCC diagnostic pop +#if defined(__GNUC__) && __GNUC__ == 7 +# pragma GCC diagnostic pop // -Wnoexcept-type #endif diff --git a/wrap/pybind11/include/pybind11/pytypes.h b/wrap/pybind11/include/pybind11/pytypes.h index a2f7cec48..902fb1f07 100644 --- a/wrap/pybind11/include/pybind11/pytypes.h +++ b/wrap/pybind11/include/pybind11/pytypes.h @@ -14,6 +14,14 @@ #include #include +#if defined(PYBIND11_HAS_OPTIONAL) +# include +#endif + +#ifdef PYBIND11_HAS_STRING_VIEW +# include +#endif + PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) /* A few forward declarations */ @@ -24,7 +32,7 @@ struct arg; struct arg_v; PYBIND11_NAMESPACE_BEGIN(detail) class args_proxy; -inline bool isinstance_generic(handle obj, const std::type_info &tp); +bool isinstance_generic(handle obj, const std::type_info &tp); // Accessor forward declarations template class accessor; @@ -153,7 +161,7 @@ public: /// Return the object's current reference count int ref_count() const { return static_cast(Py_REFCNT(derived().ptr())); } - PYBIND11_DEPRECATED("Call py::type::handle_of(h) or py::type::of(h) instead of h.get_type()") + // TODO PYBIND11_DEPRECATED("Call py::type::handle_of(h) or py::type::of(h) instead of h.get_type()") handle get_type() const; private: @@ -178,6 +186,7 @@ public: /// The default constructor creates a handle with a ``nullptr``-valued pointer handle() = default; /// Creates a ``handle`` from the given raw Python object pointer + // NOLINTNEXTLINE(google-explicit-constructor) handle(PyObject *ptr) : m_ptr(ptr) { } // Allow implicit conversion from PyObject* /// Return the underlying ``PyObject *`` pointer @@ -254,8 +263,11 @@ public: object& operator=(const object &other) { other.inc_ref(); - dec_ref(); + // Use temporary variable to ensure `*this` remains valid while + // `Py_XDECREF` executes, in case `*this` is accessible from Python. + handle temp(m_ptr); m_ptr = other.m_ptr; + temp.dec_ref(); return *this; } @@ -279,8 +291,10 @@ protected: struct borrowed_t { }; struct stolen_t { }; + /// @cond BROKEN template friend T reinterpret_borrow(handle); template friend T reinterpret_steal(handle); + /// @endcond public: // Only accessible from derived classes and the reinterpret_* functions @@ -314,14 +328,18 @@ template T reinterpret_borrow(handle h) { return {h, object::borrow template T reinterpret_steal(handle h) { return {h, object::stolen_t{}}; } PYBIND11_NAMESPACE_BEGIN(detail) -inline std::string error_string(); +std::string error_string(); PYBIND11_NAMESPACE_END(detail) +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4275 4251) // warning C4275: An exported class was derived from a class that wasn't exported. Can be ignored when derived from a STL class. +#endif /// Fetch and hold an error which was already set in Python. An instance of this is typically /// thrown to propagate python-side errors back through C++ which can either be caught manually or /// else falls back to the function dispatcher (which then raises the captured error back to /// python). -class error_already_set : public std::runtime_error { +class PYBIND11_EXPORT_EXCEPTION error_already_set : public std::runtime_error { public: /// Constructs a new exception from the current Python error indicator, if any. The current /// Python error indicator will be cleared. @@ -339,16 +357,17 @@ public: /// error variables (but the `.what()` string is still available). void restore() { PyErr_Restore(m_type.release().ptr(), m_value.release().ptr(), m_trace.release().ptr()); } - /// If it is impossible to raise the currently-held error, such as in destructor, we can write - /// it out using Python's unraisable hook (sys.unraisablehook). The error context should be - /// some object whose repr() helps identify the location of the error. Python already knows the - /// type and value of the error, so there is no need to repeat that. For example, __func__ could - /// be helpful. After this call, the current object no longer stores the error variables, - /// and neither does Python. + /// If it is impossible to raise the currently-held error, such as in a destructor, we can write + /// it out using Python's unraisable hook (`sys.unraisablehook`). The error context should be + /// some object whose `repr()` helps identify the location of the error. Python already knows the + /// type and value of the error, so there is no need to repeat that. After this call, the current + /// object no longer stores the error variables, and neither does Python. void discard_as_unraisable(object err_context) { restore(); PyErr_WriteUnraisable(err_context.ptr()); } + /// An alternate version of `discard_as_unraisable()`, where a string provides information on the + /// location of the error. For example, `__func__` could be helpful. void discard_as_unraisable(const char *err_context) { discard_as_unraisable(reinterpret_steal(PYBIND11_FROM_STRING(err_context))); } @@ -360,7 +379,9 @@ public: /// Check if the currently trapped error type matches the given Python exception class (or a /// subclass thereof). May also be passed a tuple to search for any exception class matches in /// the given tuple. - bool matches(handle exc) const { return PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()); } + bool matches(handle exc) const { + return (PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()) != 0); + } const object& type() const { return m_type; } const object& value() const { return m_value; } @@ -369,8 +390,52 @@ public: private: object m_type, m_value, m_trace; }; +#if defined(_MSC_VER) +# pragma warning(pop) +#endif -/** \defgroup python_builtins _ +#if PY_VERSION_HEX >= 0x03030000 + +/// Replaces the current Python error indicator with the chosen error, performing a +/// 'raise from' to indicate that the chosen error was caused by the original error. +inline void raise_from(PyObject *type, const char *message) { + // Based on _PyErr_FormatVFromCause: + // https://github.com/python/cpython/blob/467ab194fc6189d9f7310c89937c51abeac56839/Python/errors.c#L405 + // See https://github.com/pybind/pybind11/pull/2112 for details. + PyObject *exc = nullptr, *val = nullptr, *val2 = nullptr, *tb = nullptr; + + assert(PyErr_Occurred()); + PyErr_Fetch(&exc, &val, &tb); + PyErr_NormalizeException(&exc, &val, &tb); + if (tb != nullptr) { + PyException_SetTraceback(val, tb); + Py_DECREF(tb); + } + Py_DECREF(exc); + assert(!PyErr_Occurred()); + + PyErr_SetString(type, message); + + PyErr_Fetch(&exc, &val2, &tb); + PyErr_NormalizeException(&exc, &val2, &tb); + Py_INCREF(val); + PyException_SetCause(val2, val); + PyException_SetContext(val2, val); + PyErr_Restore(exc, val2, tb); +} + +/// Sets the current Python error indicator with the chosen error, performing a 'raise from' +/// from the error contained in error_already_set to indicate that the chosen error was +/// caused by the original error. After this function is called error_already_set will +/// no longer contain an error. +inline void raise_from(error_already_set& err, PyObject *type, const char *message) { + err.restore(); + raise_from(type, message); +} + +#endif + +/** \defgroup python_builtins const_name Unless stated otherwise, the following C++ functions behave the same as their Python counterparts. */ @@ -431,19 +496,17 @@ inline object getattr(handle obj, const char *name) { inline object getattr(handle obj, handle name, handle default_) { if (PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr())) { return reinterpret_steal(result); - } else { - PyErr_Clear(); - return reinterpret_borrow(default_); } + PyErr_Clear(); + return reinterpret_borrow(default_); } inline object getattr(handle obj, const char *name, handle default_) { if (PyObject *result = PyObject_GetAttrString(obj.ptr(), name)) { return reinterpret_steal(result); - } else { - PyErr_Clear(); - return reinterpret_borrow(default_); } + PyErr_Clear(); + return reinterpret_borrow(default_); } inline void setattr(handle obj, handle name, handle value) { @@ -476,6 +539,43 @@ inline handle get_function(handle value) { return value; } +// Reimplementation of python's dict helper functions to ensure that exceptions +// aren't swallowed (see #2862) + +// copied from cpython _PyDict_GetItemStringWithError +inline PyObject * dict_getitemstring(PyObject *v, const char *key) +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *kv = nullptr, *rv = nullptr; + kv = PyUnicode_FromString(key); + if (kv == NULL) { + throw error_already_set(); + } + + rv = PyDict_GetItemWithError(v, kv); + Py_DECREF(kv); + if (rv == NULL && PyErr_Occurred()) { + throw error_already_set(); + } + return rv; +#else + return PyDict_GetItemString(v, key); +#endif +} + +inline PyObject * dict_getitem(PyObject *v, PyObject *key) +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *rv = PyDict_GetItemWithError(v, key); + if (rv == NULL && PyErr_Occurred()) { + throw error_already_set(); + } + return rv; +#else + return PyDict_GetItem(v, key); +#endif +} + // Helper aliases/functions to support implicit casting of values given to python accessors/methods. // When given a pyobject, this simply returns the pyobject as-is; for other C++ type, the value goes // through pybind11::cast(obj) to convert it to an `object`. @@ -487,6 +587,10 @@ object object_or_cast(T &&o); // Match a PyObject*, which we want to convert directly to handle via its converting constructor inline handle object_or_cast(PyObject *ptr) { return ptr; } +#if defined(_MSC_VER) && _MSC_VER < 1920 +# pragma warning(push) +# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified +#endif template class accessor : public object_api> { using key_type = typename Policy::key_type; @@ -494,7 +598,7 @@ class accessor : public object_api> { public: accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { } accessor(const accessor &) = default; - accessor(accessor &&) = default; + accessor(accessor &&) noexcept = default; // accessor overload required to override default assignment operator (templates are not allowed // to replace default compiler-generated assignments). @@ -520,6 +624,7 @@ public: return obj.contains(key); } + // NOLINTNEXTLINE(google-explicit-constructor) operator object() const { return get_cache(); } PyObject *ptr() const { return get_cache().ptr(); } template T cast() const { return get_cache().template cast(); } @@ -535,6 +640,9 @@ private: key_type key; mutable object cache; }; +#if defined(_MSC_VER) && _MSC_VER < 1920 +# pragma warning(pop) +#endif PYBIND11_NAMESPACE_BEGIN(accessor_policies) struct obj_attr { @@ -566,15 +674,17 @@ struct generic_item { struct sequence_item { using key_type = size_t; - static object get(handle obj, size_t index) { - PyObject *result = PySequence_GetItem(obj.ptr(), static_cast(index)); + template ::value, int> = 0> + static object get(handle obj, const IdxType &index) { + PyObject *result = PySequence_GetItem(obj.ptr(), ssize_t_cast(index)); if (!result) { throw error_already_set(); } return reinterpret_steal(result); } - static void set(handle obj, size_t index, handle val) { + template ::value, int> = 0> + static void set(handle obj, const IdxType &index, handle val) { // PySequence_SetItem does not steal a reference to 'val' - if (PySequence_SetItem(obj.ptr(), static_cast(index), val.ptr()) != 0) { + if (PySequence_SetItem(obj.ptr(), ssize_t_cast(index), val.ptr()) != 0) { throw error_already_set(); } } @@ -583,15 +693,17 @@ struct sequence_item { struct list_item { using key_type = size_t; - static object get(handle obj, size_t index) { - PyObject *result = PyList_GetItem(obj.ptr(), static_cast(index)); + template ::value, int> = 0> + static object get(handle obj, const IdxType &index) { + PyObject *result = PyList_GetItem(obj.ptr(), ssize_t_cast(index)); if (!result) { throw error_already_set(); } return reinterpret_borrow(result); } - static void set(handle obj, size_t index, handle val) { + template ::value, int> = 0> + static void set(handle obj, const IdxType &index, handle val) { // PyList_SetItem steals a reference to 'val' - if (PyList_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + if (PyList_SetItem(obj.ptr(), ssize_t_cast(index), val.inc_ref().ptr()) != 0) { throw error_already_set(); } } @@ -600,15 +712,17 @@ struct list_item { struct tuple_item { using key_type = size_t; - static object get(handle obj, size_t index) { - PyObject *result = PyTuple_GetItem(obj.ptr(), static_cast(index)); + template ::value, int> = 0> + static object get(handle obj, const IdxType &index) { + PyObject *result = PyTuple_GetItem(obj.ptr(), ssize_t_cast(index)); if (!result) { throw error_already_set(); } return reinterpret_borrow(result); } - static void set(handle obj, size_t index, handle val) { + template ::value, int> = 0> + static void set(handle obj, const IdxType &index, handle val) { // PyTuple_SetItem steals a reference to 'val' - if (PyTuple_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + if (PyTuple_SetItem(obj.ptr(), ssize_t_cast(index), val.inc_ref().ptr()) != 0) { throw error_already_set(); } } @@ -630,7 +744,9 @@ public: generic_iterator() = default; generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { } + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference operator*() const { return Policy::dereference(); } + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference operator[](difference_type n) const { return *(*this + n); } pointer operator->() const { return **this; } @@ -660,7 +776,8 @@ template struct arrow_proxy { T value; - arrow_proxy(T &&value) : value(std::move(value)) { } + // NOLINTNEXTLINE(google-explicit-constructor) + arrow_proxy(T &&value) noexcept : value(std::move(value)) { } T *operator->() const { return &value; } }; @@ -669,11 +786,12 @@ class sequence_fast_readonly { protected: using iterator_category = std::random_access_iterator_tag; using value_type = handle; - using reference = const handle; + using reference = const handle; // PR #3263 using pointer = arrow_proxy; sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { } + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference dereference() const { return *ptr; } void increment() { ++ptr; } void decrement() { --ptr; } @@ -712,14 +830,19 @@ class dict_readonly { protected: using iterator_category = std::forward_iterator_tag; using value_type = std::pair; - using reference = const value_type; + using reference = const value_type; // PR #3263 using pointer = arrow_proxy; dict_readonly() = default; dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); } + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference dereference() const { return {key, value}; } - void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } } + void increment() { + if (PyDict_Next(obj.ptr(), &pos, &key, &value) == 0) { + pos = -1; + } + } bool equal(const dict_readonly &b) const { return pos == b.pos; } private: @@ -745,16 +868,20 @@ inline bool PyIterable_Check(PyObject *obj) { if (iter) { Py_DECREF(iter); return true; - } else { - PyErr_Clear(); - return false; } + PyErr_Clear(); + return false; } inline bool PyNone_Check(PyObject *o) { return o == Py_None; } inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; } +#ifdef PYBIND11_STR_LEGACY_PERMISSIVE inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); } +#define PYBIND11_STR_CHECK_FUN detail::PyUnicode_Check_Permissive +#else +#define PYBIND11_STR_CHECK_FUN PyUnicode_Check +#endif inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; } @@ -797,26 +924,42 @@ PYBIND11_NAMESPACE_END(detail) Name(handle h, borrowed_t) : Parent(h, borrowed_t{}) { } \ Name(handle h, stolen_t) : Parent(h, stolen_t{}) { } \ PYBIND11_DEPRECATED("Use py::isinstance(obj) instead") \ - bool check() const { return m_ptr != nullptr && (bool) CheckFun(m_ptr); } \ + bool check() const { return m_ptr != nullptr && (CheckFun(m_ptr) != 0); } \ static bool check_(handle h) { return h.ptr() != nullptr && CheckFun(h.ptr()); } \ template \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ Name(const ::pybind11::detail::accessor &a) : Name(object(a)) { } #define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ Name(const object &o) \ : Parent(check_(o) ? o.inc_ref().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ { if (!m_ptr) throw error_already_set(); } \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ Name(object &&o) \ : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ { if (!m_ptr) throw error_already_set(); } +#define PYBIND11_OBJECT_CVT_DEFAULT(Name, Parent, CheckFun, ConvertFun) \ + PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ + Name() : Parent() { } + +#define PYBIND11_OBJECT_CHECK_FAILED(Name, o_ptr) \ + ::pybind11::type_error("Object of type '" + \ + ::pybind11::detail::get_fully_qualified_tp_name(Py_TYPE(o_ptr)) + \ + "' is not an instance of '" #Name "'") + #define PYBIND11_OBJECT(Name, Parent, CheckFun) \ PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ - Name(const object &o) : Parent(o) { } \ - Name(object &&o) : Parent(std::move(o)) { } + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ + Name(const object &o) : Parent(o) \ + { if (m_ptr && !check_(m_ptr)) throw PYBIND11_OBJECT_CHECK_FAILED(Name, m_ptr); } \ + /* NOLINTNEXTLINE(google-explicit-constructor) */ \ + Name(object &&o) : Parent(std::move(o)) \ + { if (m_ptr && !check_(m_ptr)) throw PYBIND11_OBJECT_CHECK_FAILED(Name, m_ptr); } #define PYBIND11_OBJECT_DEFAULT(Name, Parent, CheckFun) \ PYBIND11_OBJECT(Name, Parent, CheckFun) \ @@ -838,7 +981,7 @@ public: using iterator_category = std::input_iterator_tag; using difference_type = ssize_t; using value_type = handle; - using reference = const handle; + using reference = const handle; // PR #3263 using pointer = const handle *; PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) @@ -854,6 +997,7 @@ public: return rv; } + // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference operator*() const { if (m_ptr && !value.ptr()) { auto& self = const_cast(*this); @@ -927,21 +1071,38 @@ class bytes; class str : public object { public: - PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str) + PYBIND11_OBJECT_CVT(str, object, PYBIND11_STR_CHECK_FUN, raw_str) - str(const char *c, size_t n) - : object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) { + template ::value, int> = 0> + str(const char *c, const SzType &n) + : object(PyUnicode_FromStringAndSize(c, ssize_t_cast(n)), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate string object!"); } // 'explicit' is explicitly omitted from the following constructors to allow implicit conversion to py::str from C++ string-like objects + // NOLINTNEXTLINE(google-explicit-constructor) str(const char *c = "") : object(PyUnicode_FromString(c), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate string object!"); } + // NOLINTNEXTLINE(google-explicit-constructor) str(const std::string &s) : str(s.data(), s.size()) { } +#ifdef PYBIND11_HAS_STRING_VIEW + // enable_if is needed to avoid "ambiguous conversion" errors (see PR #3521). + template ::value, int> = 0> + // NOLINTNEXTLINE(google-explicit-constructor) + str(T s) : str(s.data(), s.size()) { } + +# ifdef PYBIND11_HAS_U8STRING + // reinterpret_cast here is safe (C++20 guarantees char8_t has the same size/alignment as char) + // NOLINTNEXTLINE(google-explicit-constructor) + str(std::u8string_view s) : str(reinterpret_cast(s.data()), s.size()) { } +# endif + +#endif + explicit str(const bytes &b); /** \rst @@ -950,15 +1111,16 @@ public: \endrst */ explicit str(handle h) : object(raw_str(h.ptr()), stolen_t{}) { if (!m_ptr) throw error_already_set(); } + // NOLINTNEXTLINE(google-explicit-constructor) operator std::string() const { object temp = *this; if (PyUnicode_Check(m_ptr)) { temp = reinterpret_steal(PyUnicode_AsUTF8String(m_ptr)); if (!temp) - pybind11_fail("Unable to extract string contents! (encoding issue)"); + throw error_already_set(); } - char *buffer; - ssize_t length; + char *buffer = nullptr; + ssize_t length = 0; if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) pybind11_fail("Unable to extract string contents! (invalid type)"); return std::string(buffer, (size_t) length); @@ -997,28 +1159,52 @@ public: PYBIND11_OBJECT(bytes, object, PYBIND11_BYTES_CHECK) // Allow implicit conversion: + // NOLINTNEXTLINE(google-explicit-constructor) bytes(const char *c = "") : object(PYBIND11_BYTES_FROM_STRING(c), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); } - bytes(const char *c, size_t n) - : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, (ssize_t) n), stolen_t{}) { + template ::value, int> = 0> + bytes(const char *c, const SzType &n) + : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, ssize_t_cast(n)), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); } // Allow implicit conversion: + // NOLINTNEXTLINE(google-explicit-constructor) bytes(const std::string &s) : bytes(s.data(), s.size()) { } explicit bytes(const pybind11::str &s); + // NOLINTNEXTLINE(google-explicit-constructor) operator std::string() const { - char *buffer; - ssize_t length; + char *buffer = nullptr; + ssize_t length = 0; if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length)) pybind11_fail("Unable to extract bytes contents!"); return std::string(buffer, (size_t) length); } + +#ifdef PYBIND11_HAS_STRING_VIEW + // enable_if is needed to avoid "ambiguous conversion" errors (see PR #3521). + template ::value, int> = 0> + // NOLINTNEXTLINE(google-explicit-constructor) + bytes(T s) : bytes(s.data(), s.size()) { } + + // Obtain a string view that views the current `bytes` buffer value. Note that this is only + // valid so long as the `bytes` instance remains alive and so generally should not outlive the + // lifetime of the `bytes` instance. + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::string_view() const { + char *buffer = nullptr; + ssize_t length = 0; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length)) + pybind11_fail("Unable to extract bytes contents!"); + return {buffer, static_cast(length)}; + } +#endif + }; // Note: breathe >= 4.17.0 will fail to build docs if the below two constructors // are included in the doxygen group; close here and reopen after as a workaround @@ -1031,8 +1217,8 @@ inline bytes::bytes(const pybind11::str &s) { if (!temp) pybind11_fail("Unable to extract string contents! (encoding issue)"); } - char *buffer; - ssize_t length; + char *buffer = nullptr; + ssize_t length = 0; if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) pybind11_fail("Unable to extract string contents! (invalid type)"); auto obj = reinterpret_steal(PYBIND11_BYTES_FROM_STRING_AND_SIZE(buffer, length)); @@ -1042,16 +1228,45 @@ inline bytes::bytes(const pybind11::str &s) { } inline str::str(const bytes& b) { - char *buffer; - ssize_t length; + char *buffer = nullptr; + ssize_t length = 0; if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &buffer, &length)) pybind11_fail("Unable to extract bytes contents!"); - auto obj = reinterpret_steal(PyUnicode_FromStringAndSize(buffer, (ssize_t) length)); + auto obj = reinterpret_steal(PyUnicode_FromStringAndSize(buffer, length)); if (!obj) pybind11_fail("Could not allocate string object!"); m_ptr = obj.release().ptr(); } +/// \addtogroup pytypes +/// @{ +class bytearray : public object { +public: + PYBIND11_OBJECT_CVT(bytearray, object, PyByteArray_Check, PyByteArray_FromObject) + + template ::value, int> = 0> + bytearray(const char *c, const SzType &n) + : object(PyByteArray_FromStringAndSize(c, ssize_t_cast(n)), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate bytearray object!"); + } + + bytearray() + : bytearray("", 0) {} + + explicit bytearray(const std::string &s) : bytearray(s.data(), s.size()) { } + + size_t size() const { return static_cast(PyByteArray_Size(m_ptr)); } + + explicit operator std::string() const { + char *buffer = PyByteArray_AS_STRING(m_ptr); + ssize_t size = PyByteArray_GET_SIZE(m_ptr); + return std::string(buffer, static_cast(size)); + } +}; +// Note: breathe >= 4.17.0 will fail to build docs if the below two constructors +// are included in the doxygen group; close here and reopen after as a workaround +/// @} pytypes + /// \addtogroup pytypes /// @{ class none : public object { @@ -1071,15 +1286,17 @@ public: PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool) bool_() : object(Py_False, borrowed_t{}) { } // Allow implicit conversion from and to `bool`: + // NOLINTNEXTLINE(google-explicit-constructor) bool_(bool value) : object(value ? Py_True : Py_False, borrowed_t{}) { } - operator bool() const { return m_ptr && PyLong_AsLong(m_ptr) != 0; } + // NOLINTNEXTLINE(google-explicit-constructor) + operator bool() const { return (m_ptr != nullptr) && PyLong_AsLong(m_ptr) != 0; } private: /// Return the truth value of an object -- always returns a new reference static PyObject *raw_bool(PyObject *op) { const auto value = PyObject_IsTrue(op); if (value == -1) return nullptr; - return handle(value ? Py_True : Py_False).inc_ref().ptr(); + return handle(value != 0 ? Py_True : Py_False).inc_ref().ptr(); } }; @@ -1090,18 +1307,16 @@ PYBIND11_NAMESPACE_BEGIN(detail) // unsigned type: (A)-1 != (B)-1 when A and B are unsigned types of different sizes). template Unsigned as_unsigned(PyObject *o) { - if (sizeof(Unsigned) <= sizeof(unsigned long) + if (PYBIND11_SILENCE_MSVC_C4127(sizeof(Unsigned) <= sizeof(unsigned long)) #if PY_VERSION_HEX < 0x03000000 - || PyInt_Check(o) + || PyInt_Check(o) #endif ) { unsigned long v = PyLong_AsUnsignedLong(o); return v == (unsigned long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; } - else { - unsigned long long v = PyLong_AsUnsignedLongLong(o); - return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; - } + unsigned long long v = PyLong_AsUnsignedLongLong(o); + return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; } PYBIND11_NAMESPACE_END(detail) @@ -1112,8 +1327,9 @@ public: // Allow implicit conversion from C++ integral types: template ::value, int> = 0> + // NOLINTNEXTLINE(google-explicit-constructor) int_(T value) { - if (sizeof(T) <= sizeof(long)) { + if (PYBIND11_SILENCE_MSVC_C4127(sizeof(T) <= sizeof(long))) { if (std::is_signed::value) m_ptr = PyLong_FromLong((long) value); else @@ -1129,6 +1345,7 @@ public: template ::value, int> = 0> + // NOLINTNEXTLINE(google-explicit-constructor) operator T() const { return std::is_unsigned::value ? detail::as_unsigned(m_ptr) @@ -1142,33 +1359,51 @@ class float_ : public object { public: PYBIND11_OBJECT_CVT(float_, object, PyFloat_Check, PyNumber_Float) // Allow implicit conversion from float/double: + // NOLINTNEXTLINE(google-explicit-constructor) float_(float value) : object(PyFloat_FromDouble((double) value), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate float object!"); } + // NOLINTNEXTLINE(google-explicit-constructor) float_(double value = .0) : object(PyFloat_FromDouble((double) value), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate float object!"); } + // NOLINTNEXTLINE(google-explicit-constructor) operator float() const { return (float) PyFloat_AsDouble(m_ptr); } + // NOLINTNEXTLINE(google-explicit-constructor) operator double() const { return (double) PyFloat_AsDouble(m_ptr); } }; class weakref : public object { public: - PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check) + PYBIND11_OBJECT_CVT_DEFAULT(weakref, object, PyWeakref_Check, raw_weakref) explicit weakref(handle obj, handle callback = {}) : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate weak reference!"); } + +private: + static PyObject *raw_weakref(PyObject *o) { + return PyWeakref_NewRef(o, nullptr); + } }; class slice : public object { public: PYBIND11_OBJECT_DEFAULT(slice, object, PySlice_Check) - slice(ssize_t start_, ssize_t stop_, ssize_t step_) { - int_ start(start_), stop(stop_), step(step_); + slice(handle start, handle stop, handle step) { m_ptr = PySlice_New(start.ptr(), stop.ptr(), step.ptr()); - if (!m_ptr) pybind11_fail("Could not allocate slice object!"); + if (!m_ptr) + pybind11_fail("Could not allocate slice object!"); } + +#ifdef PYBIND11_HAS_OPTIONAL + slice(std::optional start, std::optional stop, std::optional step) + : slice(index_to_object(start), index_to_object(stop), index_to_object(step)) {} +#else + slice(ssize_t start_, ssize_t stop_, ssize_t step_) + : slice(int_(start_), int_(stop_), int_(step_)) {} +#endif + bool compute(size_t length, size_t *start, size_t *stop, size_t *step, size_t *slicelength) const { return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, @@ -1183,6 +1418,12 @@ public: stop, step, slicelength) == 0; } + +private: + template + static object index_to_object(T index) { + return index ? object(int_(*index)) : object(none()); + } }; class capsule : public object { @@ -1218,7 +1459,7 @@ public: pybind11_fail("Could not set capsule context!"); } - capsule(void (*destructor)()) { + explicit capsule(void (*destructor)()) { m_ptr = PyCapsule_New(reinterpret_cast(destructor), nullptr, [](PyObject *o) { auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, nullptr)); destructor(); @@ -1228,20 +1469,41 @@ public: pybind11_fail("Could not allocate capsule object!"); } + // NOLINTNEXTLINE(google-explicit-constructor) template operator T *() const { + return get_pointer(); + } + + /// Get the pointer the capsule holds. + template + T* get_pointer() const { auto name = this->name(); - T * result = static_cast(PyCapsule_GetPointer(m_ptr, name)); - if (!result) pybind11_fail("Unable to extract capsule contents!"); + T *result = static_cast(PyCapsule_GetPointer(m_ptr, name)); + if (!result) { + PyErr_Clear(); + pybind11_fail("Unable to extract capsule contents!"); + } return result; } + /// Replaces a capsule's pointer *without* calling the destructor on the existing one. + void set_pointer(const void *value) { + if (PyCapsule_SetPointer(m_ptr, const_cast(value)) != 0) { + PyErr_Clear(); + pybind11_fail("Could not set capsule pointer"); + } + } + const char *name() const { return PyCapsule_GetName(m_ptr); } }; class tuple : public object { public: PYBIND11_OBJECT_CVT(tuple, object, PyTuple_Check, PySequence_Tuple) - explicit tuple(size_t size = 0) : object(PyTuple_New((ssize_t) size), stolen_t{}) { + template ::value, int> = 0> + // Some compilers generate link errors when using `const SzType &` here: + explicit tuple(SzType size = 0) : object(PyTuple_New(ssize_t_cast(size)), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate tuple object!"); } size_t size() const { return (size_t) PyTuple_Size(m_ptr); } @@ -1252,6 +1514,15 @@ public: detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } }; +// We need to put this into a separate function because the Intel compiler +// fails to compile enable_if_t...>::value> part below +// (tested with ICC 2021.1 Beta 20200827). +template +constexpr bool args_are_all_keyword_or_ds() +{ + return detail::all_of...>::value; +} + class dict : public object { public: PYBIND11_OBJECT_CVT(dict, object, PyDict_Check, raw_dict) @@ -1259,7 +1530,7 @@ public: if (!m_ptr) pybind11_fail("Could not allocate dict object!"); } template ...>::value>, + typename = detail::enable_if_t()>, // MSVC workaround: it can't compile an out-of-line definition, so defer the collector typename collector = detail::deferred_t, Args...>> explicit dict(Args &&...args) : dict(collector(std::forward(args)...).kwargs()) { } @@ -1268,7 +1539,7 @@ public: bool empty() const { return size() == 0; } detail::dict_iterator begin() const { return {*this, 0}; } detail::dict_iterator end() const { return {}; } - void clear() const { PyDict_Clear(ptr()); } + void clear() /* py-non-const */ { PyDict_Clear(ptr()); } template bool contains(T &&key) const { return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward(key)).ptr()) == 1; } @@ -1301,7 +1572,10 @@ public: class list : public object { public: PYBIND11_OBJECT_CVT(list, object, PyList_Check, PySequence_List) - explicit list(size_t size = 0) : object(PyList_New((ssize_t) size), stolen_t{}) { + template ::value, int> = 0> + // Some compilers generate link errors when using `const SzType &` here: + explicit list(SzType size = 0) : object(PyList_New(ssize_t_cast(size)), stolen_t{}) { if (!m_ptr) pybind11_fail("Could not allocate list object!"); } size_t size() const { return (size_t) PyList_Size(m_ptr); } @@ -1310,12 +1584,15 @@ public: detail::item_accessor operator[](handle h) const { return object::operator[](h); } detail::list_iterator begin() const { return {*this, 0}; } detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } - template void append(T &&val) const { + template void append(T &&val) /* py-non-const */ { PyList_Append(m_ptr, detail::object_or_cast(std::forward(val)).ptr()); } - template void insert(size_t index, T &&val) const { - PyList_Insert(m_ptr, static_cast(index), - detail::object_or_cast(std::forward(val)).ptr()); + template ::value, int> = 0> + void insert(const IdxType &index, ValType &&val) /* py-non-const */ { + PyList_Insert( + m_ptr, ssize_t_cast(index), detail::object_or_cast(std::forward(val)).ptr()); } }; @@ -1330,10 +1607,10 @@ public: } size_t size() const { return (size_t) PySet_Size(m_ptr); } bool empty() const { return size() == 0; } - template bool add(T &&val) const { + template bool add(T &&val) /* py-non-const */ { return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; } - void clear() const { PySet_Clear(m_ptr); } + void clear() /* py-non-const */ { PySet_Clear(m_ptr); } template bool contains(T &&val) const { return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; } @@ -1429,7 +1706,7 @@ public: detail::any_container shape, detail::any_container strides) { return memoryview::from_buffer( - const_cast(ptr), itemsize, format, shape, strides, true); + const_cast(ptr), itemsize, format, std::move(shape), std::move(strides), true); } template @@ -1475,10 +1752,17 @@ public: static memoryview from_memory(const void *mem, ssize_t size) { return memoryview::from_memory(const_cast(mem), size, true); } + +#ifdef PYBIND11_HAS_STRING_VIEW + static memoryview from_memory(std::string_view mem) { + return from_memory(const_cast(mem.data()), static_cast(mem.size()), true); + } +#endif + #endif }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS +/// @cond DUPLICATE inline memoryview memoryview::from_buffer( void *ptr, ssize_t itemsize, const char* format, detail::any_container shape, @@ -1486,7 +1770,7 @@ inline memoryview memoryview::from_buffer( size_t ndim = shape->size(); if (ndim != strides->size()) pybind11_fail("memoryview: shape length doesn't match strides length"); - ssize_t size = ndim ? 1 : 0; + ssize_t size = ndim != 0u ? 1 : 0; for (size_t i = 0; i < ndim; ++i) size *= (*shape)[i]; Py_buffer view; @@ -1506,18 +1790,22 @@ inline memoryview memoryview::from_buffer( throw error_already_set(); return memoryview(object(obj, stolen_t{})); } -#endif // DOXYGEN_SHOULD_SKIP_THIS +/// @endcond /// @} pytypes /// \addtogroup python_builtins /// @{ + +/// Get the length of a Python object. inline size_t len(handle h) { ssize_t result = PyObject_Length(h.ptr()); if (result < 0) - pybind11_fail("Unable to compute length of object"); + throw error_already_set(); return (size_t) result; } +/// Get the length hint of a Python object. +/// Returns 0 when this cannot be determined. inline size_t len_hint(handle h) { #if PY_VERSION_HEX >= 0x03040000 ssize_t result = PyObject_LengthHint(h.ptr(), 0); @@ -1580,8 +1868,7 @@ template str_attr_accessor object_api::doc() const { return attr("__doc__"); } template -PYBIND11_DEPRECATED("Use py::type::of(h) instead of h.get_type()") -handle object_api::get_type() const { return type::handle_of(*this); } +handle object_api::get_type() const { return type::handle_of(derived()); } template bool object_api::rich_compare(object_api const &other, int value) const { diff --git a/wrap/pybind11/include/pybind11/stl.h b/wrap/pybind11/include/pybind11/stl.h index 721bb669f..430349482 100644 --- a/wrap/pybind11/include/pybind11/stl.h +++ b/wrap/pybind11/include/pybind11/stl.h @@ -9,6 +9,7 @@ #pragma once +#include "detail/common.h" #include "pybind11.h" #include #include @@ -19,33 +20,15 @@ #include #include -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +// See `detail/common.h` for implementation of these guards. +#if defined(PYBIND11_HAS_OPTIONAL) +# include +#elif defined(PYBIND11_HAS_EXP_OPTIONAL) +# include #endif -#ifdef __has_include -// std::optional (but including it in c++14 mode isn't allowed) -# if defined(PYBIND11_CPP17) && __has_include() -# include -# define PYBIND11_HAS_OPTIONAL 1 -# endif -// std::experimental::optional (but not allowed in c++11 mode) -# if defined(PYBIND11_CPP14) && (__has_include() && \ - !__has_include()) -# include -# define PYBIND11_HAS_EXP_OPTIONAL 1 -# endif -// std::variant -# if defined(PYBIND11_CPP17) && __has_include() -# include -# define PYBIND11_HAS_VARIANT 1 -# endif -#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) -# include +#if defined(PYBIND11_HAS_VARIANT) # include -# define PYBIND11_HAS_OPTIONAL 1 -# define PYBIND11_HAS_VARIANT 1 #endif PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) @@ -95,7 +78,7 @@ template struct set_caster { return s.release(); } - PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); + PYBIND11_TYPE_CASTER(type, const_name("Set[") + key_conv::name + const_name("]")); }; template struct map_caster { @@ -137,14 +120,14 @@ template struct map_caster { return d.release(); } - PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); + PYBIND11_TYPE_CASTER(Type, const_name("Dict[") + key_conv::name + const_name(", ") + value_conv::name + const_name("]")); }; template struct list_caster { using value_conv = make_caster; bool load(handle src, bool convert) { - if (!isinstance(src) || isinstance(src)) + if (!isinstance(src) || isinstance(src) || isinstance(src)) return false; auto s = reinterpret_borrow(src); value.clear(); @@ -159,10 +142,13 @@ template struct list_caster { } private: - template ().reserve(0)), void>::value, int> = 0> - void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } - void reserve_maybe(sequence, void *) { } + template < + typename T = Type, + enable_if_t().reserve(0)), void>::value, int> = 0> + void reserve_maybe(const sequence &s, Type *) { + value.reserve(s.size()); + } + void reserve_maybe(const sequence &, void *) {} public: template @@ -170,17 +156,17 @@ public: if (!std::is_lvalue_reference::value) policy = return_value_policy_override::policy(policy); list l(src.size()); - size_t index = 0; + ssize_t index = 0; for (auto &&value : src) { auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); if (!value_) return handle(); - PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference } return l.release(); } - PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); + PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]")); }; template struct type_caster> @@ -227,17 +213,17 @@ public: template static handle cast(T &&src, return_value_policy policy, handle parent) { list l(src.size()); - size_t index = 0; + ssize_t index = 0; for (auto &&value : src) { auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); if (!value_) return handle(); - PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference } return l.release(); } - PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); + PYBIND11_TYPE_CASTER(ArrayType, const_name("List[") + value_conv::name + const_name(const_name(""), const_name("[") + const_name() + const_name("]")) + const_name("]")); }; template struct type_caster> @@ -259,34 +245,35 @@ template , Key, Value> { }; // This type caster is intended to be used for std::optional and std::experimental::optional -template struct optional_caster { - using value_conv = make_caster; +template struct optional_caster { + using value_conv = make_caster; - template - static handle cast(T_ &&src, return_value_policy policy, handle parent) { + template + static handle cast(T &&src, return_value_policy policy, handle parent) { if (!src) return none().inc_ref(); if (!std::is_lvalue_reference::value) { - policy = return_value_policy_override::policy(policy); + policy = return_value_policy_override::policy(policy); } - return value_conv::cast(*std::forward(src), policy, parent); + return value_conv::cast(*std::forward(src), policy, parent); } bool load(handle src, bool convert) { if (!src) { return false; - } else if (src.is_none()) { + } + if (src.is_none()) { return true; // default-constructed value is already empty } value_conv inner_caster; if (!inner_caster.load(src, convert)) return false; - value.emplace(cast_op(std::move(inner_caster))); + value.emplace(cast_op(std::move(inner_caster))); return true; } - PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); + PYBIND11_TYPE_CASTER(Type, const_name("Optional[") + value_conv::name + const_name("]")); }; #if defined(PYBIND11_HAS_OPTIONAL) @@ -366,7 +353,7 @@ struct variant_caster> { } using Type = V; - PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); + PYBIND11_TYPE_CASTER(Type, const_name("Union[") + detail::concat(make_caster::name...) + const_name("]")); }; #if defined(PYBIND11_HAS_VARIANT) @@ -377,12 +364,12 @@ struct type_caster> : variant_caster> { PYBIND11_NAMESPACE_END(detail) inline std::ostream &operator<<(std::ostream &os, const handle &obj) { +#ifdef PYBIND11_HAS_STRING_VIEW + os << str(obj).cast(); +#else os << (std::string) str(obj); +#endif return os; } PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/wrap/pybind11/include/pybind11/stl/filesystem.h b/wrap/pybind11/include/pybind11/stl/filesystem.h new file mode 100644 index 000000000..a9a6c8512 --- /dev/null +++ b/wrap/pybind11/include/pybind11/stl/filesystem.h @@ -0,0 +1,103 @@ +// Copyright (c) 2021 The Pybind Development Team. +// All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +#pragma once + +#include "../cast.h" +#include "../pybind11.h" +#include "../pytypes.h" + +#include "../detail/common.h" +#include "../detail/descr.h" + +#include + +#ifdef __has_include +# if defined(PYBIND11_CPP17) && __has_include() && \ + PY_VERSION_HEX >= 0x03060000 +# include +# define PYBIND11_HAS_FILESYSTEM 1 +# endif +#endif + +#if !defined(PYBIND11_HAS_FILESYSTEM) && !defined(PYBIND11_HAS_FILESYSTEM_IS_OPTIONAL) +# error \ + "#include is not available. (Use -DPYBIND11_HAS_FILESYSTEM_IS_OPTIONAL to ignore.)" +#endif + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +#if defined(PYBIND11_HAS_FILESYSTEM) +template struct path_caster { + +private: + static PyObject* unicode_from_fs_native(const std::string& w) { +#if !defined(PYPY_VERSION) + return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size())); +#else + // PyPy mistakenly declares the first parameter as non-const. + return PyUnicode_DecodeFSDefaultAndSize( + const_cast(w.c_str()), ssize_t(w.size())); +#endif + } + + static PyObject* unicode_from_fs_native(const std::wstring& w) { + return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size())); + } + +public: + static handle cast(const T& path, return_value_policy, handle) { + if (auto py_str = unicode_from_fs_native(path.native())) { + return module_::import("pathlib").attr("Path")(reinterpret_steal(py_str)) + .release(); + } + return nullptr; + } + + bool load(handle handle, bool) { + // PyUnicode_FSConverter and PyUnicode_FSDecoder normally take care of + // calling PyOS_FSPath themselves, but that's broken on PyPy (PyPy + // issue #3168) so we do it ourselves instead. + PyObject* buf = PyOS_FSPath(handle.ptr()); + if (!buf) { + PyErr_Clear(); + return false; + } + PyObject* native = nullptr; + if constexpr (std::is_same_v) { + if (PyUnicode_FSConverter(buf, &native) != 0) { + if (auto c_str = PyBytes_AsString(native)) { + // AsString returns a pointer to the internal buffer, which + // must not be free'd. + value = c_str; + } + } + } else if constexpr (std::is_same_v) { + if (PyUnicode_FSDecoder(buf, &native) != 0) { + if (auto c_str = PyUnicode_AsWideCharString(native, nullptr)) { + // AsWideCharString returns a new string that must be free'd. + value = c_str; // Copies the string. + PyMem_Free(c_str); + } + } + } + Py_XDECREF(native); + Py_DECREF(buf); + if (PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + return true; + } + + PYBIND11_TYPE_CASTER(T, const_name("os.PathLike")); +}; + +template<> struct type_caster + : public path_caster {}; +#endif // PYBIND11_HAS_FILESYSTEM + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/wrap/pybind11/include/pybind11/stl_bind.h b/wrap/pybind11/include/pybind11/stl_bind.h index 9d8ed0c82..050be83cc 100644 --- a/wrap/pybind11/include/pybind11/stl_bind.h +++ b/wrap/pybind11/include/pybind11/stl_bind.h @@ -128,11 +128,11 @@ void vector_modifiers(enable_if_t(new Vector()); v->reserve(len_hint(it)); for (handle h : it) - v->push_back(h.cast()); + v->push_back(h.cast()); return v.release(); })); @@ -151,27 +151,28 @@ void vector_modifiers(enable_if_t()); - } - } catch (const cast_error &) { - v.erase(v.begin() + static_cast(old_size), v.end()); - try { - v.shrink_to_fit(); - } catch (const std::exception &) { - // Do nothing - } - throw; - } - }, - arg("L"), - "Extend the list by appending all the items in the given list" - ); + cl.def( + "extend", + [](Vector &v, const iterable &it) { + const size_t old_size = v.size(); + v.reserve(old_size + len_hint(it)); + try { + for (handle h : it) { + v.push_back(h.cast()); + } + } catch (const cast_error &) { + v.erase(v.begin() + static_cast(old_size), + v.end()); + try { + v.shrink_to_fit(); + } catch (const std::exception &) { + // Do nothing + } + throw; + } + }, + arg("L"), + "Extend the list by appending all the items in the given list"); cl.def("insert", [](Vector &v, DiffType i, const T &x) { @@ -190,7 +191,7 @@ void vector_modifiers(enable_if_t Vector * { - size_t start, stop, step, slicelength; + size_t start = 0, stop = 0, step = 0, slicelength = 0; if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) throw error_already_set(); @@ -233,12 +235,12 @@ void vector_modifiers(enable_if_t), @@ -375,10 +375,20 @@ struct vector_has_data_and_format : std::false_type {}; template struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; +// [workaround(intel)] Separate function required here +// Workaround as the Intel compiler does not compile the enable_if_t part below +// (tested with icc (ICC) 2021.1 Beta 20200827) +template +constexpr bool args_any_are_buffer() { + return detail::any_of...>::value; +} + +// [workaround(intel)] Separate function required here +// [workaround(msvc)] Can't use constexpr bool in return type + // Add the buffer interface to a vector template -enable_if_t...>::value> -vector_buffer(Class_& cl) { +void vector_buffer_impl(Class_& cl, std::true_type) { using T = typename Vector::value_type; static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); @@ -390,7 +400,7 @@ vector_buffer(Class_& cl) { return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); }); - cl.def(init([](buffer buf) { + cl.def(init([](const buffer &buf) { auto info = buf.request(); if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) throw type_error("Only valid 1D buffers can be copied to a vector"); @@ -403,20 +413,24 @@ vector_buffer(Class_& cl) { if (step == 1) { return Vector(p, end); } - else { - Vector vec; - vec.reserve((size_t) info.shape[0]); - for (; p != end; p += step) - vec.push_back(*p); - return vec; - } + Vector vec; + vec.reserve((size_t) info.shape[0]); + for (; p != end; p += step) + vec.push_back(*p); + return vec; + })); return; } template -enable_if_t...>::value> vector_buffer(Class_&) {} +void vector_buffer_impl(Class_&, std::false_type) {} + +template +void vector_buffer(Class_& cl) { + vector_buffer_impl(cl, detail::any_of...>{}); +} PYBIND11_NAMESPACE_END(detail) @@ -581,6 +595,23 @@ template auto map_if_insertion_operator(Class_ & ); } +template +struct keys_view +{ + Map ↦ +}; + +template +struct values_view +{ + Map ↦ +}; + +template +struct items_view +{ + Map ↦ +}; PYBIND11_NAMESPACE_END(detail) @@ -588,6 +619,9 @@ template , typename... class_ bind_map(handle scope, const std::string &name, Args&&... args) { using KeyType = typename Map::key_type; using MappedType = typename Map::mapped_type; + using KeysView = detail::keys_view; + using ValuesView = detail::values_view; + using ItemsView = detail::items_view; using Class_ = class_; // If either type is a non-module-local bound type then make the map binding non-local as well; @@ -601,6 +635,12 @@ class_ bind_map(handle scope, const std::string &name, Args&&. } Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + class_ keys_view( + scope, ("KeysView[" + name + "]").c_str(), pybind11::module_local(local)); + class_ values_view( + scope, ("ValuesView[" + name + "]").c_str(), pybind11::module_local(local)); + class_ items_view( + scope, ("ItemsView[" + name + "]").c_str(), pybind11::module_local(local)); cl.def(init<>()); @@ -614,12 +654,22 @@ class_ bind_map(handle scope, const std::string &name, Args&&. cl.def("__iter__", [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + keep_alive<0, 1>() /* Essential: keep map alive while iterator exists */ + ); + + cl.def("keys", + [](Map &m) { return KeysView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ + ); + + cl.def("values", + [](Map &m) { return ValuesView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def("items", - [](Map &m) { return make_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + [](Map &m) { return ItemsView{m}; }, + keep_alive<0, 1>() /* Essential: keep map alive while view exists */ ); cl.def("__getitem__", @@ -640,6 +690,8 @@ class_ bind_map(handle scope, const std::string &name, Args&&. return true; } ); + // Fallback for when the object is not of the key type + cl.def("__contains__", [](Map &, const object &) -> bool { return false; }); // Assignment provided only if the type is copyable detail::map_assignment(cl); @@ -655,6 +707,40 @@ class_ bind_map(handle scope, const std::string &name, Args&&. cl.def("__len__", &Map::size); + keys_view.def("__len__", [](KeysView &view) { return view.map.size(); }); + keys_view.def("__iter__", + [](KeysView &view) { + return make_key_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + keys_view.def("__contains__", + [](KeysView &view, const KeyType &k) -> bool { + auto it = view.map.find(k); + if (it == view.map.end()) + return false; + return true; + } + ); + // Fallback for when the object is not of the key type + keys_view.def("__contains__", [](KeysView &, const object &) -> bool { return false; }); + + values_view.def("__len__", [](ValuesView &view) { return view.map.size(); }); + values_view.def("__iter__", + [](ValuesView &view) { + return make_value_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + + items_view.def("__len__", [](ItemsView &view) { return view.map.size(); }); + items_view.def("__iter__", + [](ItemsView &view) { + return make_iterator(view.map.begin(), view.map.end()); + }, + keep_alive<0, 1>() /* Essential: keep view alive while iterator exists */ + ); + return cl; } diff --git a/wrap/pybind11/noxfile.py b/wrap/pybind11/noxfile.py new file mode 100644 index 000000000..4adffac2e --- /dev/null +++ b/wrap/pybind11/noxfile.py @@ -0,0 +1,93 @@ +import nox + +nox.options.sessions = ["lint", "tests", "tests_packaging"] + +PYTHON_VERISONS = ["2.7", "3.5", "3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] + + +@nox.session(reuse_venv=True) +def lint(session: nox.Session) -> None: + """ + Lint the codebase (except for clang-format/tidy). + """ + session.install("pre-commit") + session.run("pre-commit", "run", "-a") + + +@nox.session(python=PYTHON_VERISONS) +def tests(session: nox.Session) -> None: + """ + Run the tests (requires a compiler). + """ + tmpdir = session.create_tmp() + session.install("cmake") + session.install("-r", "tests/requirements.txt") + session.run( + "cmake", + "-S", + ".", + "-B", + tmpdir, + "-DPYBIND11_WERROR=ON", + "-DDOWNLOAD_CATCH=ON", + "-DDOWNLOAD_EIGEN=ON", + *session.posargs + ) + session.run("cmake", "--build", tmpdir) + session.run("cmake", "--build", tmpdir, "--config=Release", "--target", "check") + + +@nox.session +def tests_packaging(session: nox.Session) -> None: + """ + Run the packaging tests. + """ + + session.install("-r", "tests/requirements.txt", "--prefer-binary") + session.run("pytest", "tests/extra_python_package") + + +@nox.session(reuse_venv=True) +def docs(session: nox.Session) -> None: + """ + Build the docs. Pass "serve" to serve. + """ + + session.install("-r", "docs/requirements.txt") + session.chdir("docs") + + if "pdf" in session.posargs: + session.run("sphinx-build", "-b", "latexpdf", ".", "_build") + return + + session.run("sphinx-build", "-b", "html", ".", "_build") + + if "serve" in session.posargs: + session.log("Launching docs at http://localhost:8000/ - use Ctrl-C to quit") + session.run("python", "-m", "http.server", "8000", "-d", "_build/html") + elif session.posargs: + session.error("Unsupported argument to docs") + + +@nox.session(reuse_venv=True) +def make_changelog(session: nox.Session) -> None: + """ + Inspect the closed issues and make entries for a changelog. + """ + session.install("ghapi", "rich") + session.run("python", "tools/make_changelog.py") + + +@nox.session(reuse_venv=True) +def build(session: nox.Session) -> None: + """ + Build SDists and wheels. + """ + + session.install("build") + session.log("Building normal files") + session.run("python", "-m", "build", *session.posargs) + session.log("Building pybind11-global files (PYBIND11_GLOBAL_SDIST=1)") + session.run( + "python", "-m", "build", *session.posargs, env={"PYBIND11_GLOBAL_SDIST": "1"} + ) diff --git a/wrap/pybind11/pybind11/__init__.py b/wrap/pybind11/pybind11/__init__.py index ad6542089..64e999ba0 100644 --- a/wrap/pybind11/pybind11/__init__.py +++ b/wrap/pybind11/pybind11/__init__.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -from ._version import version_info, __version__ -from .commands import get_include, get_cmake_dir - +from ._version import __version__, version_info +from .commands import get_cmake_dir, get_include __all__ = ( "version_info", diff --git a/wrap/pybind11/pybind11/__main__.py b/wrap/pybind11/pybind11/__main__.py index f4d543783..3235747be 100644 --- a/wrap/pybind11/pybind11/__main__.py +++ b/wrap/pybind11/pybind11/__main__.py @@ -5,10 +5,11 @@ import argparse import sys import sysconfig -from .commands import get_include, get_cmake_dir +from .commands import get_cmake_dir, get_include def print_includes(): + # type: () -> None dirs = [ sysconfig.get_path("include"), sysconfig.get_path("platinclude"), @@ -18,13 +19,15 @@ def print_includes(): # Make unique but preserve order unique_dirs = [] for d in dirs: - if d not in unique_dirs: + if d and d not in unique_dirs: unique_dirs.append(d) print(" ".join("-I" + d for d in unique_dirs)) def main(): + # type: () -> None + parser = argparse.ArgumentParser() parser.add_argument( "--includes", diff --git a/wrap/pybind11/pybind11/_version.py b/wrap/pybind11/pybind11/_version.py index ca84c262c..9d39b77a4 100644 --- a/wrap/pybind11/pybind11/_version.py +++ b/wrap/pybind11/pybind11/_version.py @@ -8,5 +8,5 @@ def _to_int(s): return s -__version__ = "2.6.0.dev1" +__version__ = "2.9.1" version_info = tuple(_to_int(s) for s in __version__.split(".")) diff --git a/wrap/pybind11/pybind11/_version.pyi b/wrap/pybind11/pybind11/_version.pyi new file mode 100644 index 000000000..d45e5dc90 --- /dev/null +++ b/wrap/pybind11/pybind11/_version.pyi @@ -0,0 +1,6 @@ +from typing import Tuple, Union + +def _to_int(s: str) -> Union[int, str]: ... + +__version__: str +version_info: Tuple[Union[int, str], ...] diff --git a/wrap/pybind11/pybind11/commands.py b/wrap/pybind11/pybind11/commands.py index fa7eac3cc..11f81d2d6 100644 --- a/wrap/pybind11/pybind11/commands.py +++ b/wrap/pybind11/pybind11/commands.py @@ -1,17 +1,18 @@ # -*- coding: utf-8 -*- import os - DIR = os.path.abspath(os.path.dirname(__file__)) def get_include(user=False): + # type: (bool) -> str installed_path = os.path.join(DIR, "include") source_path = os.path.join(os.path.dirname(DIR), "include") return installed_path if os.path.exists(installed_path) else source_path def get_cmake_dir(): + # type: () -> str cmake_installed_path = os.path.join(DIR, "share", "cmake", "pybind11") if os.path.exists(cmake_installed_path): return cmake_installed_path diff --git a/wrap/pybind11/pybind11/py.typed b/wrap/pybind11/pybind11/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/wrap/pybind11/pybind11/setup_helpers.py b/wrap/pybind11/pybind11/setup_helpers.py index 041e22689..5b7c9aab1 100644 --- a/wrap/pybind11/pybind11/setup_helpers.py +++ b/wrap/pybind11/pybind11/setup_helpers.py @@ -33,25 +33,34 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +# IMPORTANT: If you change this file in the pybind11 repo, also review +# setup_helpers.pyi for matching changes. +# +# If you copy this file in, you don't +# need the .pyi file; it's just an interface file for static type checkers. + import contextlib import os +import platform +import shlex import shutil import sys +import sysconfig import tempfile import threading import warnings try: - from setuptools.command.build_ext import build_ext as _build_ext from setuptools import Extension as _Extension + from setuptools.command.build_ext import build_ext as _build_ext except ImportError: from distutils.command.build_ext import build_ext as _build_ext from distutils.extension import Extension as _Extension +import distutils.ccompiler import distutils.errors - -WIN = sys.platform.startswith("win32") +WIN = sys.platform.startswith("win32") and "mingw" not in sysconfig.get_platform() PY2 = sys.version_info[0] < 3 MACOS = sys.platform.startswith("darwin") STD_TMPL = "/std:c++{}" if WIN else "-std=c++{}" @@ -76,7 +85,7 @@ class Pybind11Extension(_Extension): * ``stdlib=libc++`` on macOS * ``visibility=hidden`` and ``-g0`` on Unix - Finally, you can set ``cxx_std`` via constructor or afterwords to enable + Finally, you can set ``cxx_std`` via constructor or afterwards to enable flags for C++ std, and a few extra helper flags related to the C++ standard level. It is _highly_ recommended you either set this, or use the provided ``build_ext``, which will search for the highest supported extension for @@ -91,15 +100,14 @@ class Pybind11Extension(_Extension): this is an ugly old-style class due to Distutils. """ - def _add_cflags(self, *flags): - for flag in flags: - if flag not in self.extra_compile_args: - self.extra_compile_args.append(flag) + # flags are prepended, so that they can be further overridden, e.g. by + # ``extra_compile_args=["-g"]``. - def _add_lflags(self, *flags): - for flag in flags: - if flag not in self.extra_compile_args: - self.extra_link_args.append(flag) + def _add_cflags(self, flags): + self.extra_compile_args[:0] = flags + + def _add_ldflags(self, flags): + self.extra_link_args[:0] = flags def __init__(self, *args, **kwargs): @@ -131,13 +139,22 @@ class Pybind11Extension(_Extension): # Have to use the accessor manually to support Python 2 distutils Pybind11Extension.cxx_std.__set__(self, cxx_std) + cflags = [] + ldflags = [] if WIN: - self._add_cflags("/EHsc", "/bigobj") + cflags += ["/EHsc", "/bigobj"] else: - self._add_cflags("-fvisibility=hidden", "-g0") + cflags += ["-fvisibility=hidden"] + env_cflags = os.environ.get("CFLAGS", "") + env_cppflags = os.environ.get("CPPFLAGS", "") + c_cpp_flags = shlex.split(env_cflags) + shlex.split(env_cppflags) + if not any(opt.startswith("-g") for opt in c_cpp_flags): + cflags += ["-g0"] if MACOS: - self._add_cflags("-stdlib=libc++") - self._add_lflags("-stdlib=libc++") + cflags += ["-stdlib=libc++"] + ldflags += ["-stdlib=libc++"] + self._add_cflags(cflags) + self._add_ldflags(ldflags) @property def cxx_std(self): @@ -156,7 +173,8 @@ class Pybind11Extension(_Extension): if self._cxx_level: warnings.warn("You cannot safely change the cxx_level after setting it!") - # MSVC 2015 Update 3 and later only have 14 (and later 17) modes + # MSVC 2015 Update 3 and later only have 14 (and later 17) modes, so + # force a valid flag here. if WIN and level == 11: level = 14 @@ -165,19 +183,34 @@ class Pybind11Extension(_Extension): if not level: return - self.extra_compile_args.append(STD_TMPL.format(level)) + cflags = [STD_TMPL.format(level)] + ldflags = [] if MACOS and "MACOSX_DEPLOYMENT_TARGET" not in os.environ: - # C++17 requires a higher min version of macOS - macosx_min = "-mmacosx-version-min=" + ("10.9" if level < 17 else "10.14") - self.extra_compile_args.append(macosx_min) - self.extra_link_args.append(macosx_min) + # C++17 requires a higher min version of macOS. An earlier version + # (10.12 or 10.13) can be set manually via environment variable if + # you are careful in your feature usage, but 10.14 is the safest + # setting for general use. However, never set higher than the + # current macOS version! + current_macos = tuple(int(x) for x in platform.mac_ver()[0].split(".")[:2]) + desired_macos = (10, 9) if level < 17 else (10, 14) + macos_string = ".".join(str(x) for x in min(current_macos, desired_macos)) + macosx_min = "-mmacosx-version-min=" + macos_string + cflags += [macosx_min] + ldflags += [macosx_min] if PY2: - if level >= 17: - self.extra_compile_args.append("/wd503" if WIN else "-Wno-register") - elif not WIN and level >= 14: - self.extra_compile_args.append("-Wno-deprecated-register") + if WIN: + # Will be ignored on MSVC 2015, where C++17 is not supported so + # this flag is not valid. + cflags += ["/wd5033"] + elif level >= 17: + cflags += ["-Wno-register"] + elif level >= 14: + cflags += ["-Wno-deprecated-register"] + + self._add_cflags(cflags) + self._add_ldflags(ldflags) # Just in case someone clever tries to multithread @@ -212,7 +245,8 @@ def has_flag(compiler, flag): with tmp_chdir(): fname = "flagcheck.cpp" with open(fname, "w") as f: - f.write("int main (int argc, char **argv) { return 0; }") + # Don't trigger -Wunused-parameter. + f.write("int main (int, char **) { return 0; }") try: compiler.compile([fname], extra_postargs=[flag]) @@ -227,9 +261,12 @@ cpp_flag_cache = None def auto_cpp_level(compiler): """ - Return the max supported C++ std level (17, 14, or 11). + Return the max supported C++ std level (17, 14, or 11). Returns latest on Windows. """ + if WIN: + return "latest" + global cpp_flag_cache # If this has been previously calculated with the same args, return that @@ -237,7 +274,7 @@ def auto_cpp_level(compiler): if cpp_flag_cache: return cpp_flag_cache - levels = [17, 14] + ([] if WIN else [11]) + levels = [17, 14, 11] for level in levels: if has_flag(compiler, STD_TMPL.format(level)): @@ -252,7 +289,8 @@ def auto_cpp_level(compiler): class build_ext(_build_ext): # noqa: N801 """ Customized build_ext that allows an auto-search for the highest supported - C++ level for Pybind11Extension. + C++ level for Pybind11Extension. This is only needed for the auto-search + for now, and is completely optional otherwise. """ def build_extensions(self): @@ -268,3 +306,189 @@ class build_ext(_build_ext): # noqa: N801 # Python 2 doesn't allow super here, since distutils uses old-style # classes! _build_ext.build_extensions(self) + + +def intree_extensions(paths, package_dir=None): + """ + Generate Pybind11Extensions from source files directly located in a Python + source tree. + + ``package_dir`` behaves as in ``setuptools.setup``. If unset, the Python + package root parent is determined as the first parent directory that does + not contain an ``__init__.py`` file. + """ + exts = [] + for path in paths: + if package_dir is None: + parent, _ = os.path.split(path) + while os.path.exists(os.path.join(parent, "__init__.py")): + parent, _ = os.path.split(parent) + relname, _ = os.path.splitext(os.path.relpath(path, parent)) + qualified_name = relname.replace(os.path.sep, ".") + exts.append(Pybind11Extension(qualified_name, [path])) + else: + found = False + for prefix, parent in package_dir.items(): + if path.startswith(parent): + found = True + relname, _ = os.path.splitext(os.path.relpath(path, parent)) + qualified_name = relname.replace(os.path.sep, ".") + if prefix: + qualified_name = prefix + "." + qualified_name + exts.append(Pybind11Extension(qualified_name, [path])) + if not found: + raise ValueError( + "path {} is not a child of any of the directories listed " + "in 'package_dir' ({})".format(path, package_dir) + ) + return exts + + +def naive_recompile(obj, src): + """ + This will recompile only if the source file changes. It does not check + header files, so a more advanced function or Ccache is better if you have + editable header files in your package. + """ + return os.stat(obj).st_mtime < os.stat(src).st_mtime + + +def no_recompile(obg, src): + """ + This is the safest but slowest choice (and is the default) - will always + recompile sources. + """ + return True + + +# Optional parallel compile utility +# inspired by: http://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils +# and: https://github.com/tbenthompson/cppimport/blob/stable/cppimport/build_module.py +# and NumPy's parallel distutils module: +# https://github.com/numpy/numpy/blob/master/numpy/distutils/ccompiler.py +class ParallelCompile(object): + """ + Make a parallel compile function. Inspired by + numpy.distutils.ccompiler.CCompiler_compile and cppimport. + + This takes several arguments that allow you to customize the compile + function created: + + envvar: + Set an environment variable to control the compilation threads, like + NPY_NUM_BUILD_JOBS + default: + 0 will automatically multithread, or 1 will only multithread if the + envvar is set. + max: + The limit for automatic multithreading if non-zero + needs_recompile: + A function of (obj, src) that returns True when recompile is needed. No + effect in isolated mode; use ccache instead, see + https://github.com/matplotlib/matplotlib/issues/1507/ + + To use:: + + ParallelCompile("NPY_NUM_BUILD_JOBS").install() + + or:: + + with ParallelCompile("NPY_NUM_BUILD_JOBS"): + setup(...) + + By default, this assumes all files need to be recompiled. A smarter + function can be provided via needs_recompile. If the output has not yet + been generated, the compile will always run, and this function is not + called. + """ + + __slots__ = ("envvar", "default", "max", "_old", "needs_recompile") + + def __init__(self, envvar=None, default=0, max=0, needs_recompile=no_recompile): + self.envvar = envvar + self.default = default + self.max = max + self.needs_recompile = needs_recompile + self._old = [] + + def function(self): + """ + Builds a function object usable as distutils.ccompiler.CCompiler.compile. + """ + + def compile_function( + compiler, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, + ): + + # These lines are directly from distutils.ccompiler.CCompiler + macros, objects, extra_postargs, pp_opts, build = compiler._setup_compile( + output_dir, macros, include_dirs, sources, depends, extra_postargs + ) + cc_args = compiler._get_cc_args(pp_opts, debug, extra_preargs) + + # The number of threads; start with default. + threads = self.default + + # Determine the number of compilation threads, unless set by an environment variable. + if self.envvar is not None: + threads = int(os.environ.get(self.envvar, self.default)) + + def _single_compile(obj): + try: + src, ext = build[obj] + except KeyError: + return + + if not os.path.exists(obj) or self.needs_recompile(obj, src): + compiler._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) + + try: + # Importing .synchronize checks for platforms that have some multiprocessing + # capabilities but lack semaphores, such as AWS Lambda and Android Termux. + import multiprocessing.synchronize + from multiprocessing.pool import ThreadPool + except ImportError: + threads = 1 + + if threads == 0: + try: + threads = multiprocessing.cpu_count() + threads = self.max if self.max and self.max < threads else threads + except NotImplementedError: + threads = 1 + + if threads > 1: + pool = ThreadPool(threads) + # In Python 2, ThreadPool can't be used as a context manager. + # Once we are no longer supporting it, this can be 'with pool:' + try: + for _ in pool.imap_unordered(_single_compile, objects): + pass + finally: + pool.terminate() + else: + for ob in objects: + _single_compile(ob) + + return objects + + return compile_function + + def install(self): + distutils.ccompiler.CCompiler.compile = self.function() + return self + + def __enter__(self): + self._old.append(distutils.ccompiler.CCompiler.compile) + return self.install() + + def __exit__(self, *args): + distutils.ccompiler.CCompiler.compile = self._old.pop() diff --git a/wrap/pybind11/pybind11/setup_helpers.pyi b/wrap/pybind11/pybind11/setup_helpers.pyi new file mode 100644 index 000000000..074744eb8 --- /dev/null +++ b/wrap/pybind11/pybind11/setup_helpers.pyi @@ -0,0 +1,63 @@ +# IMPORTANT: Should stay in sync with setup_helpers.py (mostly checked by CI / +# pre-commit). + +import contextlib +import distutils.ccompiler +from distutils.command.build_ext import build_ext as _build_ext # type: ignore +from distutils.extension import Extension as _Extension +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union + +WIN: bool +PY2: bool +MACOS: bool +STD_TMPL: str + +class Pybind11Extension(_Extension): + def _add_cflags(self, *flags: str) -> None: ... + def _add_lflags(self, *flags: str) -> None: ... + def __init__( + self, *args: Any, cxx_std: int = 0, language: str = "c++", **kwargs: Any + ) -> None: ... + @property + def cxx_std(self) -> int: ... + @cxx_std.setter + def cxx_std(self, level: int) -> None: ... + +@contextlib.contextmanager +def tmp_chdir() -> Iterator[str]: ... +def has_flag(compiler: distutils.ccompiler.CCompiler, flag: str) -> bool: ... +def auto_cpp_level(compiler: distutils.ccompiler.CCompiler) -> Union[int, str]: ... + +class build_ext(_build_ext): # type: ignore + def build_extensions(self) -> None: ... + +def intree_extensions( + paths: Iterator[str], package_dir: Optional[Dict[str, str]] = None +) -> List[Pybind11Extension]: ... +def no_recompile(obj: str, src: str) -> bool: ... +def naive_recompile(obj: str, src: str) -> bool: ... + +T = TypeVar("T", bound="ParallelCompile") + +class ParallelCompile: + envvar: Optional[str] + default: int + max: int + needs_recompile: Callable[[str, str], bool] + def __init__( + self, + envvar: Optional[str] = None, + default: int = 0, + max: int = 0, + needs_recompile: Callable[[str, str], bool] = no_recompile, + ) -> None: ... + def function(self) -> Any: ... + def install(self: T) -> T: ... + def __enter__(self: T) -> T: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: ... diff --git a/wrap/pybind11/pyproject.toml b/wrap/pybind11/pyproject.toml index 3bab1c1a2..7d7a1c821 100644 --- a/wrap/pybind11/pyproject.toml +++ b/wrap/pybind11/pyproject.toml @@ -1,3 +1,41 @@ [build-system] -requires = ["setuptools", "wheel", "cmake==3.18.0", "ninja"] +requires = ["setuptools>=42", "wheel", "cmake>=3.18", "ninja"] build-backend = "setuptools.build_meta" + +[tool.check-manifest] +ignore = [ + "tests/**", + "docs/**", + "tools/**", + "include/**", + ".*", + "pybind11/include/**", + "pybind11/share/**", + "CMakeLists.txt", + "noxfile.py", +] + +[tool.isort] +# Needs the compiled .so modules and env.py from tests +known_first_party = "env,pybind11_cross_module_tests,pybind11_tests," +# For black compatibility +profile = "black" + +[tool.mypy] +files = "pybind11" +python_version = "2.7" +warn_unused_configs = true + +disallow_any_generics = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_return_any = true +no_implicit_reexport = true +strict_equality = true diff --git a/wrap/pybind11/setup.cfg b/wrap/pybind11/setup.cfg index ca0d59a4d..317c44bbf 100644 --- a/wrap/pybind11/setup.cfg +++ b/wrap/pybind11/setup.cfg @@ -1,10 +1,10 @@ [metadata] -long_description = file: README.md -long_description_content_type = text/markdown +long_description = file: README.rst +long_description_content_type = text/x-rst description = Seamless operability between C++11 and Python author = Wenzel Jakob -author_email = "wenzel.jakob@epfl.ch" -url = "https://github.com/pybind/pybind11" +author_email = wenzel.jakob@epfl.ch +url = https://github.com/pybind/pybind11 license = BSD classifiers = @@ -19,6 +19,8 @@ classifiers = Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 License :: OSI Approved :: BSD License Programming Language :: Python :: Implementation :: PyPy Programming Language :: Python :: Implementation :: CPython @@ -29,29 +31,20 @@ keywords = C++11 Python bindings +project_urls = + Documentation = https://pybind11.readthedocs.io/ + Bug Tracker = https://github.com/pybind/pybind11/issues + Discussions = https://github.com/pybind/pybind11/discussions + Changelog = https://pybind11.readthedocs.io/en/latest/changelog.html + Chat = https://gitter.im/pybind/Lobby + [options] -python_requires = >=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4 +python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.* zip_safe = False [bdist_wheel] universal=1 -[check-manifest] -ignore = - tests/** - docs/** - tools/** - include/** - .appveyor.yml - .cmake-format.yaml - .gitmodules - .pre-commit-config.yaml - .readthedocs.yml - .clang-tidy - pybind11/include/** - pybind11/share/** - CMakeLists.txt - [flake8] max-line-length = 99 @@ -64,3 +57,7 @@ ignore = N813 # Black conflict W503, E203 + + +[tool:pytest] +timeout = 300 diff --git a/wrap/pybind11/setup.py b/wrap/pybind11/setup.py index c9ba77d6d..0e7348982 100644 --- a/wrap/pybind11/setup.py +++ b/wrap/pybind11/setup.py @@ -4,6 +4,7 @@ # Setup script for PyPI; use CMakeFile.txt to build extension modules import contextlib +import io import os import re import shutil @@ -19,6 +20,36 @@ VERSION_REGEX = re.compile( r"^\s*#\s*define\s+PYBIND11_VERSION_([A-Z]+)\s+(.*)$", re.MULTILINE ) + +def build_expected_version_hex(matches): + patch_level_serial = matches["PATCH"] + serial = None + try: + major = int(matches["MAJOR"]) + minor = int(matches["MINOR"]) + flds = patch_level_serial.split(".") + if flds: + patch = int(flds[0]) + level = None + if len(flds) == 1: + level = "0" + serial = 0 + elif len(flds) == 2: + level_serial = flds[1] + for level in ("a", "b", "c", "dev"): + if level_serial.startswith(level): + serial = int(level_serial[len(level) :]) + break + except ValueError: + pass + if serial is None: + msg = 'Invalid PYBIND11_VERSION_PATCH: "{}"'.format(patch_level_serial) + raise RuntimeError(msg) + return "0x{:02x}{:02x}{:02x}{}{:x}".format( + major, minor, patch, level[:1].upper(), serial + ) + + # PYBIND11_GLOBAL_SDIST will build a different sdist, with the python-headers # files, and the sys.prefix files (CMake and headers). @@ -35,12 +66,12 @@ to_src = ( # Read the listed version with open("pybind11/_version.py") as f: code = compile(f.read(), "pybind11/_version.py", "exec") - loc = {} - exec(code, loc) - version = loc["__version__"] +loc = {} +exec(code, loc) +version = loc["__version__"] # Verify that the version matches the one in C++ -with open("include/pybind11/detail/common.h") as f: +with io.open("include/pybind11/detail/common.h", encoding="utf8") as f: matches = dict(VERSION_REGEX.findall(f.read())) cpp_version = "{MAJOR}.{MINOR}.{PATCH}".format(**matches) if version != cpp_version: @@ -49,6 +80,15 @@ if version != cpp_version: ) raise RuntimeError(msg) +version_hex = matches.get("HEX", "MISSING") +expected_version_hex = build_expected_version_hex(matches) +if version_hex != expected_version_hex: + msg = "PYBIND11_VERSION_HEX {} does not match expected value {}!".format( + version_hex, + expected_version_hex, + ) + raise RuntimeError(msg) + def get_and_replace(filename, binary=False, **opts): with open(filename, "rb" if binary else "r") as f: @@ -106,6 +146,13 @@ with remove_output("pybind11/include", "pybind11/share"): "-DBUILD_TESTING=OFF", "-DPYBIND11_NOPYTHON=ON", ] + if "CMAKE_ARGS" in os.environ: + fcommand = [ + c + for c in os.environ["CMAKE_ARGS"].split() + if "DCMAKE_INSTALL_PREFIX" not in c + ] + cmd += fcommand cmake_opts = dict(cwd=DIR, stdout=sys.stdout, stderr=sys.stderr) subprocess.check_call(cmd, **cmake_opts) subprocess.check_call(["cmake", "--install", tmpdir], **cmake_opts) diff --git a/wrap/pybind11/tests/CMakeLists.txt b/wrap/pybind11/tests/CMakeLists.txt index 45e094b08..9040cf8c0 100644 --- a/wrap/pybind11/tests/CMakeLists.txt +++ b/wrap/pybind11/tests/CMakeLists.txt @@ -10,27 +10,34 @@ cmake_minimum_required(VERSION 3.4) # The `cmake_minimum_required(VERSION 3.4...3.18)` syntax does not work with # some versions of VS that have a patched CMake 3.11. This forces us to emulate # the behavior using the following workaround: -if(${CMAKE_VERSION} VERSION_LESS 3.18) +if(${CMAKE_VERSION} VERSION_LESS 3.21) cmake_policy(VERSION ${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}) else() - cmake_policy(VERSION 3.18) + cmake_policy(VERSION 3.21) endif() # Only needed for CMake < 3.5 support include(CMakeParseArguments) -# Filter out items; print an optional message if any items filtered +# Filter out items; print an optional message if any items filtered. This ignores extensions. # # Usage: # pybind11_filter_tests(LISTNAME file1.cpp file2.cpp ... MESSAGE "") # -macro(PYBIND11_FILTER_TESTS LISTNAME) +macro(pybind11_filter_tests LISTNAME) cmake_parse_arguments(ARG "" "MESSAGE" "" ${ARGN}) set(PYBIND11_FILTER_TESTS_FOUND OFF) + # Make a list of the test without any extensions, for easier filtering. + set(_TMP_ACTUAL_LIST "${${LISTNAME}};") # enforce ';' at the end to allow matching last item. + string(REGEX REPLACE "\\.[^.;]*;" ";" LIST_WITHOUT_EXTENSIONS "${_TMP_ACTUAL_LIST}") foreach(filename IN LISTS ARG_UNPARSED_ARGUMENTS) - list(FIND ${LISTNAME} ${filename} _FILE_FOUND) + string(REGEX REPLACE "\\.[^.]*$" "" filename_no_ext ${filename}) + # Search in the list without extensions. + list(FIND LIST_WITHOUT_EXTENSIONS ${filename_no_ext} _FILE_FOUND) if(_FILE_FOUND GREATER -1) - list(REMOVE_AT ${LISTNAME} ${_FILE_FOUND}) + list(REMOVE_AT ${LISTNAME} ${_FILE_FOUND}) # And remove from the list with extensions. + list(REMOVE_AT LIST_WITHOUT_EXTENSIONS ${_FILE_FOUND} + )# And our search list, to ensure it is in sync. set(PYBIND11_FILTER_TESTS_FOUND ON) endif() endforeach() @@ -39,6 +46,26 @@ macro(PYBIND11_FILTER_TESTS LISTNAME) endif() endmacro() +macro(possibly_uninitialized) + foreach(VARNAME ${ARGN}) + if(NOT DEFINED "${VARNAME}") + set("${VARNAME}" "") + endif() + endforeach() +endmacro() + +# Function to add additional targets if any of the provided tests are found. +# Needles; Specifies the test names to look for. +# Additions; Specifies the additional test targets to add when any of the needles are found. +macro(tests_extra_targets needles additions) + # Add the index for this relation to the index extra targets map. + list(LENGTH PYBIND11_TEST_EXTRA_TARGETS PYBIND11_TEST_EXTRA_TARGETS_LEN) + list(APPEND PYBIND11_TEST_EXTRA_TARGETS ${PYBIND11_TEST_EXTRA_TARGETS_LEN}) + # Add the test names to look for, and the associated test target additions. + set(PYBIND11_TEST_EXTRA_TARGETS_NEEDLES_${PYBIND11_TEST_EXTRA_TARGETS_LEN} ${needles}) + set(PYBIND11_TEST_EXTRA_TARGETS_ADDITION_${PYBIND11_TEST_EXTRA_TARGETS_LEN} ${additions}) +endmacro() + # New Python support if(DEFINED Python_EXECUTABLE) set(PYTHON_EXECUTABLE "${Python_EXECUTABLE}") @@ -67,7 +94,7 @@ if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) find_package(pybind11 REQUIRED CONFIG) endif() -if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) +if(NOT CMAKE_BUILD_TYPE AND NOT DEFINED CMAKE_CONFIGURATION_TYPES) message(STATUS "Setting tests build type to MinSizeRel as none was specified") set(CMAKE_BUILD_TYPE MinSizeRel @@ -84,52 +111,67 @@ if(PYBIND11_CUDA_TESTS) set(CMAKE_CUDA_STANDARD_REQUIRED ON) endif() -# Full set of test files (you can override these; see below) +# Full set of test files (you can override these; see below, overrides ignore extension) +# Any test that has no extension is both .py and .cpp, so 'foo' will add 'foo.cpp' and 'foo.py'. +# Any test that has an extension is exclusively that and handled as such. set(PYBIND11_TEST_FILES - test_async.cpp - test_buffers.cpp - test_builtin_casters.cpp - test_call_policies.cpp - test_callbacks.cpp - test_chrono.cpp - test_class.cpp - test_constants_and_functions.cpp - test_copy_move.cpp - test_custom_type_casters.cpp - test_docstring_options.cpp - test_eigen.cpp - test_enum.cpp - test_eval.cpp - test_exceptions.cpp - test_factory_constructors.cpp - test_gil_scoped.cpp - test_iostream.cpp - test_kwargs_and_defaults.cpp - test_local_bindings.cpp - test_methods_and_attributes.cpp - test_modules.cpp - test_multiple_inheritance.cpp - test_numpy_array.cpp - test_numpy_dtypes.cpp - test_numpy_vectorize.cpp - test_opaque_types.cpp - test_operator_overloading.cpp - test_pickling.cpp - test_pytypes.cpp - test_sequences_and_iterators.cpp - test_smart_ptr.cpp - test_stl.cpp - test_stl_binders.cpp - test_tagbased_polymorphic.cpp - test_union.cpp - test_virtual_functions.cpp) + test_async + test_buffers + test_builtin_casters + test_call_policies + test_callbacks + test_chrono + test_class + test_const_name + test_constants_and_functions + test_copy_move + test_custom_type_casters + test_custom_type_setup + test_docstring_options + test_eigen + test_enum + test_eval + test_exceptions + test_factory_constructors + test_gil_scoped + test_iostream + test_kwargs_and_defaults + test_local_bindings + test_methods_and_attributes + test_modules + test_multiple_inheritance + test_numpy_array + test_numpy_dtypes + test_numpy_vectorize + test_opaque_types + test_operator_overloading + test_pickling + test_pytypes + test_sequences_and_iterators + test_smart_ptr + test_stl + test_stl_binders + test_tagbased_polymorphic + test_thread + test_union + test_virtual_functions) # Invoking cmake with something like: # cmake -DPYBIND11_TEST_OVERRIDE="test_callbacks.cpp;test_pickling.cpp" .. # lets you override the tests that get compiled and run. You can restore to all tests with: # cmake -DPYBIND11_TEST_OVERRIDE= .. if(PYBIND11_TEST_OVERRIDE) - set(PYBIND11_TEST_FILES ${PYBIND11_TEST_OVERRIDE}) + # Instead of doing a direct override here, we iterate over the overrides without extension and + # match them against entries from the PYBIND11_TEST_FILES, anything that not matches goes into the filter list. + string(REGEX REPLACE "\\.[^.;]*;" ";" TEST_OVERRIDE_NO_EXT "${PYBIND11_TEST_OVERRIDE};") + string(REGEX REPLACE "\\.[^.;]*;" ";" TEST_FILES_NO_EXT "${PYBIND11_TEST_FILES};") + # This allows the override to be done with extensions, preserving backwards compatibility. + foreach(test_name ${TEST_FILES_NO_EXT}) + if(NOT ${test_name} IN_LIST TEST_OVERRIDE_NO_EXT + )# If not in the whitelist, add to be filtered out. + list(APPEND PYBIND11_TEST_FILTER ${test_name}) + endif() + endforeach() endif() # You can also filter tests: @@ -151,15 +193,46 @@ if(PYBIND11_CUDA_TESTS) "Skipping test_constants_and_functions due to incompatible exception specifications") endif() -string(REPLACE ".cpp" ".py" PYBIND11_PYTEST_FILES "${PYBIND11_TEST_FILES}") +# Now that the test filtering is complete, we need to split the list into the test for PYTEST +# and the list for the cpp targets. +set(PYBIND11_CPPTEST_FILES "") +set(PYBIND11_PYTEST_FILES "") + +foreach(test_name ${PYBIND11_TEST_FILES}) + if(test_name MATCHES "\\.py$") # Ends in .py, purely python test. + list(APPEND PYBIND11_PYTEST_FILES ${test_name}) + elseif(test_name MATCHES "\\.cpp$") # Ends in .cpp, purely cpp test. + list(APPEND PYBIND11_CPPTEST_FILES ${test_name}) + elseif(NOT test_name MATCHES "\\.") # No extension specified, assume both, add extension. + list(APPEND PYBIND11_PYTEST_FILES ${test_name}.py) + list(APPEND PYBIND11_CPPTEST_FILES ${test_name}.cpp) + else() + message(WARNING "Unhanded test extension in test: ${test_name}") + endif() +endforeach() +set(PYBIND11_TEST_FILES ${PYBIND11_CPPTEST_FILES}) +list(SORT PYBIND11_PYTEST_FILES) # Contains the set of test files that require pybind11_cross_module_tests to be # built; if none of these are built (i.e. because TEST_OVERRIDE is used and # doesn't include them) the second module doesn't get built. -set(PYBIND11_CROSS_MODULE_TESTS test_exceptions.py test_local_bindings.py test_stl.py - test_stl_binders.py) +tests_extra_targets("test_exceptions.py;test_local_bindings.py;test_stl.py;test_stl_binders.py" + "pybind11_cross_module_tests") -set(PYBIND11_CROSS_MODULE_GIL_TESTS test_gil_scoped.py) +# And add additional targets for other tests. +tests_extra_targets("test_gil_scoped.py" "cross_module_gil_utils") + +set(PYBIND11_EIGEN_REPO + "https://gitlab.com/libeigen/eigen.git" + CACHE STRING "Eigen repository to use for tests") +# Always use a hash for reconfigure speed and security reasons +# Include the version number for pretty printing (keep in sync) +set(PYBIND11_EIGEN_VERSION_AND_HASH + "3.4.0;929bc0e191d0927b1735b9a1ddc0e8b77e3a25ec" + CACHE STRING "Eigen version to use for tests, format: VERSION;HASH") + +list(GET PYBIND11_EIGEN_VERSION_AND_HASH 0 PYBIND11_EIGEN_VERSION_STRING) +list(GET PYBIND11_EIGEN_VERSION_AND_HASH 1 PYBIND11_EIGEN_VERSION_HASH) # Check if Eigen is available; if not, remove from PYBIND11_TEST_FILES (but # keep it in PYBIND11_PYTEST_FILES, so that we get the "eigen is not installed" @@ -174,22 +247,26 @@ if(PYBIND11_TEST_FILES_EIGEN_I GREATER -1) message(FATAL_ERROR "CMake 3.11+ required when using DOWNLOAD_EIGEN") endif() - set(EIGEN3_VERSION_STRING "3.3.7") - include(FetchContent) FetchContent_Declare( eigen - GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git - GIT_TAG ${EIGEN3_VERSION_STRING}) + GIT_REPOSITORY "${PYBIND11_EIGEN_REPO}" + GIT_TAG "${PYBIND11_EIGEN_VERSION_HASH}") FetchContent_GetProperties(eigen) if(NOT eigen_POPULATED) - message(STATUS "Downloading Eigen") + message( + STATUS + "Downloading Eigen ${PYBIND11_EIGEN_VERSION_STRING} (${PYBIND11_EIGEN_VERSION_HASH}) from ${PYBIND11_EIGEN_REPO}" + ) FetchContent_Populate(eigen) endif() set(EIGEN3_INCLUDE_DIR ${eigen_SOURCE_DIR}) set(EIGEN3_FOUND TRUE) + # When getting locally, the version is not visible from a superprojet, + # so just force it. + set(EIGEN3_VERSION "${PYBIND11_EIGEN_VERSION_STRING}") else() find_package(Eigen3 3.2.7 QUIET CONFIG) @@ -217,7 +294,8 @@ if(PYBIND11_TEST_FILES_EIGEN_I GREATER -1) message(STATUS "Building tests with Eigen v${EIGEN3_VERSION}") else() list(REMOVE_AT PYBIND11_TEST_FILES ${PYBIND11_TEST_FILES_EIGEN_I}) - message(STATUS "Building tests WITHOUT Eigen, use -DDOWNLOAD_EIGEN on CMake 3.11+ to download") + message( + STATUS "Building tests WITHOUT Eigen, use -DDOWNLOAD_EIGEN=ON on CMake 3.11+ to download") endif() endif() @@ -226,25 +304,69 @@ find_package(Boost 1.56) if(Boost_FOUND) if(NOT TARGET Boost::headers) + add_library(Boost::headers IMPORTED INTERFACE) if(TARGET Boost::boost) # Classic FindBoost - add_library(Boost::headers ALIAS Boost::boost) + set_property(TARGET Boost::boost PROPERTY INTERFACE_LINK_LIBRARIES Boost::boost) else() # Very old FindBoost, or newer Boost than CMake in older CMakes - add_library(Boost::headers IMPORTED INTERFACE) set_property(TARGET Boost::headers PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${Boost_INCLUDE_DIRS}) endif() endif() endif() +# Check if we need to add -lstdc++fs or -lc++fs or nothing +if(DEFINED CMAKE_CXX_STANDARD AND CMAKE_CXX_STANDARD LESS 17) + set(STD_FS_NO_LIB_NEEDED TRUE) +elseif(MSVC) + set(STD_FS_NO_LIB_NEEDED TRUE) +else() + file( + WRITE ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + "#include \nint main(int argc, char ** argv) {\n std::filesystem::path p(argv[0]);\n return p.string().length();\n}" + ) + try_compile( + STD_FS_NO_LIB_NEEDED ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17) + try_compile( + STD_FS_NEEDS_STDCXXFS ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17 + LINK_LIBRARIES stdc++fs) + try_compile( + STD_FS_NEEDS_CXXFS ${CMAKE_CURRENT_BINARY_DIR} + SOURCES ${CMAKE_CURRENT_BINARY_DIR}/main.cpp + COMPILE_DEFINITIONS -std=c++17 + LINK_LIBRARIES c++fs) +endif() + +if(${STD_FS_NEEDS_STDCXXFS}) + set(STD_FS_LIB stdc++fs) +elseif(${STD_FS_NEEDS_CXXFS}) + set(STD_FS_LIB c++fs) +elseif(${STD_FS_NO_LIB_NEEDED}) + set(STD_FS_LIB "") +else() + message(WARNING "Unknown C++17 compiler - not passing -lstdc++fs") + set(STD_FS_LIB "") +endif() + # Compile with compiler warnings turned on function(pybind11_enable_warnings target_name) if(MSVC) target_compile_options(${target_name} PRIVATE /W4) elseif(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Intel|Clang)" AND NOT PYBIND11_CUDA_TESTS) - target_compile_options(${target_name} PRIVATE -Wall -Wextra -Wconversion -Wcast-qual - -Wdeprecated -Wundef) + target_compile_options( + ${target_name} + PRIVATE -Wall + -Wextra + -Wconversion + -Wcast-qual + -Wdeprecated + -Wundef + -Wnon-virtual-dtor) endif() if(PYBIND11_WERROR) @@ -252,12 +374,22 @@ function(pybind11_enable_warnings target_name) target_compile_options(${target_name} PRIVATE /WX) elseif(PYBIND11_CUDA_TESTS) target_compile_options(${target_name} PRIVATE "SHELL:-Werror all-warnings") - elseif(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Intel|Clang)") + elseif(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Clang|IntelLLVM)") target_compile_options(${target_name} PRIVATE -Werror) + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + if(CMAKE_CXX_STANDARD EQUAL 17) # See PR #3570 + target_compile_options(${target_name} PRIVATE -Wno-conversion) + endif() + target_compile_options( + ${target_name} + PRIVATE + -Werror-all + # "Inlining inhibited by limit max-size", "Inlining inhibited by limit max-total-size" + -diag-disable 11074,11076) endif() endif() - # Needs to be readded since the ordering requires these to be after the ones above + # Needs to be re-added since the ordering requires these to be after the ones above if(CMAKE_CXX_STANDARD AND CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND PYTHON_VERSION VERSION_LESS 3.0) @@ -271,21 +403,17 @@ endfunction() set(test_targets pybind11_tests) -# Build pybind11_cross_module_tests if any test_whatever.py are being built that require it -foreach(t ${PYBIND11_CROSS_MODULE_TESTS}) - list(FIND PYBIND11_PYTEST_FILES ${t} i) - if(i GREATER -1) - list(APPEND test_targets pybind11_cross_module_tests) - break() - endif() -endforeach() - -foreach(t ${PYBIND11_CROSS_MODULE_GIL_TESTS}) - list(FIND PYBIND11_PYTEST_FILES ${t} i) - if(i GREATER -1) - list(APPEND test_targets cross_module_gil_utils) - break() - endif() +# Check if any tests need extra targets by iterating through the mappings registered. +foreach(i ${PYBIND11_TEST_EXTRA_TARGETS}) + foreach(needle ${PYBIND11_TEST_EXTRA_TARGETS_NEEDLES_${i}}) + if(needle IN_LIST PYBIND11_PYTEST_FILES) + # Add all the additional targets to the test list. List join in newer cmake. + foreach(extra_target ${PYBIND11_TEST_EXTRA_TARGETS_ADDITION_${i}}) + list(APPEND test_targets ${extra_target}) + endforeach() + break() # Breaks out of the needle search, continues with the next mapping. + endif() + endforeach() endforeach() # Support CUDA testing by forcing the target file to compile with NVCC @@ -334,38 +462,34 @@ foreach(target ${test_targets}) target_compile_definitions(${target} PRIVATE -DPYBIND11_TEST_BOOST) endif() + target_link_libraries(${target} PRIVATE ${STD_FS_LIB}) + # Always write the output file directly into the 'tests' directory (even on MSVC) if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY) set_target_properties(${target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") - foreach(config ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${config} config) - set_target_properties(${target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} - "${CMAKE_CURRENT_BINARY_DIR}") - endforeach() + + if(DEFINED CMAKE_CONFIGURATION_TYPES) + foreach(config ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${config} config) + set_target_properties(${target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} + "${CMAKE_CURRENT_BINARY_DIR}") + endforeach() + endif() endif() endforeach() -# Make sure pytest is found or produce a warning -if(NOT PYBIND11_PYTEST_FOUND) - execute_process( - COMMAND ${PYTHON_EXECUTABLE} -c "import pytest; print(pytest.__version__)" - RESULT_VARIABLE pytest_not_found - OUTPUT_VARIABLE pytest_version - ERROR_QUIET) - if(pytest_not_found) - message(WARNING "Running the tests requires pytest. Please install it manually" - " (try: ${PYTHON_EXECUTABLE} -m pip install pytest)") - elseif(pytest_version VERSION_LESS 3.1) - message(WARNING "Running the tests requires pytest >= 3.1. Found: ${pytest_version}" - "Please update it (try: ${PYTHON_EXECUTABLE} -m pip install -U pytest)") - else() - set(PYBIND11_PYTEST_FOUND - TRUE - CACHE INTERNAL "") - endif() +# Provide nice organisation in IDEs +if(NOT CMAKE_VERSION VERSION_LESS 3.8) + source_group( + TREE "${CMAKE_CURRENT_SOURCE_DIR}/../include" + PREFIX "Header Files" + FILES ${PYBIND11_HEADERS}) endif() +# Make sure pytest is found or produce a warning +pybind11_find_import(pytest VERSION 3.1) + if(NOT CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR) # This is not used later in the build, so it's okay to regenerate each time. configure_file("${CMAKE_CURRENT_SOURCE_DIR}/pytest.ini" "${CMAKE_CURRENT_BINARY_DIR}/pytest.ini" @@ -377,15 +501,20 @@ endif() # cmake 3.12 added list(transform prepend # but we can't use it yet -string(REPLACE "test_" "${CMAKE_CURRENT_BINARY_DIR}/test_" PYBIND11_BINARY_TEST_FILES +string(REPLACE "test_" "${CMAKE_CURRENT_SOURCE_DIR}/test_" PYBIND11_ABS_PYTEST_FILES "${PYBIND11_PYTEST_FILES}") +set(PYBIND11_TEST_PREFIX_COMMAND + "" + CACHE STRING "Put this before pytest, use for checkers and such") + # A single command to compile and run the tests add_custom_target( pytest - COMMAND ${PYTHON_EXECUTABLE} -m pytest ${PYBIND11_BINARY_PYTEST_FILES} + COMMAND ${PYBIND11_TEST_PREFIX_COMMAND} ${PYTHON_EXECUTABLE} -m pytest + ${PYBIND11_ABS_PYTEST_FILES} DEPENDS ${test_targets} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" USES_TERMINAL) if(PYBIND11_TEST_OVERRIDE) @@ -396,6 +525,27 @@ if(PYBIND11_TEST_OVERRIDE) "Note: not all tests run: -DPYBIND11_TEST_OVERRIDE is in effect") endif() +# cmake-format: off +add_custom_target( + memcheck + COMMAND + PYTHONMALLOC=malloc + valgrind + --leak-check=full + --show-leak-kinds=definite,indirect + --errors-for-leak-kinds=definite,indirect + --error-exitcode=1 + --read-var-info=yes + --track-origins=yes + --suppressions="${CMAKE_CURRENT_SOURCE_DIR}/valgrind-python.supp" + --suppressions="${CMAKE_CURRENT_SOURCE_DIR}/valgrind-numpy-scipy.supp" + --gen-suppressions=all + ${PYTHON_EXECUTABLE} -m pytest ${PYBIND11_ABS_PYTEST_FILES} + DEPENDS ${test_targets} + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + USES_TERMINAL) +# cmake-format: on + # Add a check target to run all the tests, starting with pytest (we add dependencies to this below) add_custom_target(check DEPENDS pytest) diff --git a/wrap/pybind11/tests/conftest.py b/wrap/pybind11/tests/conftest.py index a2350d041..362eb8069 100644 --- a/wrap/pybind11/tests/conftest.py +++ b/wrap/pybind11/tests/conftest.py @@ -18,9 +18,9 @@ import env # Early diagnostic for failed imports import pybind11_tests # noqa: F401 -_unicode_marker = re.compile(r'u(\'[^\']*\')') -_long_marker = re.compile(r'([0-9])L') -_hexadecimal = re.compile(r'0x[0-9a-fA-F]+') +_unicode_marker = re.compile(r"u(\'[^\']*\')") +_long_marker = re.compile(r"([0-9])L") +_hexadecimal = re.compile(r"0x[0-9a-fA-F]+") # Avoid collecting Python3 only files collect_ignore = [] @@ -30,7 +30,7 @@ if env.PY2: def _strip_and_dedent(s): """For triple-quote strings""" - return textwrap.dedent(s.lstrip('\n').rstrip()) + return textwrap.dedent(s.lstrip("\n").rstrip()) def _split_and_sort(s): @@ -40,11 +40,14 @@ def _split_and_sort(s): def _make_explanation(a, b): """Explanation for a failed assert -- the a and b arguments are List[str]""" - return ["--- actual / +++ expected"] + [line.strip('\n') for line in difflib.ndiff(a, b)] + return ["--- actual / +++ expected"] + [ + line.strip("\n") for line in difflib.ndiff(a, b) + ] class Output(object): """Basic output post-processing and comparison""" + def __init__(self, string): self.string = string self.explanation = [] @@ -54,7 +57,11 @@ class Output(object): def __eq__(self, other): # Ignore constructor/destructor output which is prefixed with "###" - a = [line for line in self.string.strip().splitlines() if not line.startswith("###")] + a = [ + line + for line in self.string.strip().splitlines() + if not line.startswith("###") + ] b = _strip_and_dedent(other).splitlines() if a == b: return True @@ -65,6 +72,7 @@ class Output(object): class Unordered(Output): """Custom comparison for output without strict line ordering""" + def __eq__(self, other): a = _split_and_sort(self.string) b = _split_and_sort(other) @@ -175,7 +183,7 @@ def msg(): # noinspection PyUnusedLocal def pytest_assertrepr_compare(op, left, right): """Hook to insert custom failure explanation""" - if hasattr(left, 'explanation'): + if hasattr(left, "explanation"): return left.explanation @@ -189,8 +197,8 @@ def suppress(exception): def gc_collect(): - ''' Run the garbage collector twice (needed when running - reference counting tests with PyPy) ''' + """Run the garbage collector twice (needed when running + reference counting tests with PyPy)""" gc.collect() gc.collect() diff --git a/wrap/pybind11/tests/constructor_stats.h b/wrap/pybind11/tests/constructor_stats.h index abfaf9161..805968a09 100644 --- a/wrap/pybind11/tests/constructor_stats.h +++ b/wrap/pybind11/tests/constructor_stats.h @@ -120,7 +120,7 @@ public: throw py::error_already_set(); Py_DECREF(result); #else - py::module::import("gc").attr("collect")(); + py::module_::import("gc").attr("collect")(); #endif } diff --git a/wrap/pybind11/tests/env.py b/wrap/pybind11/tests/env.py index 5cded4412..6172b451b 100644 --- a/wrap/pybind11/tests/env.py +++ b/wrap/pybind11/tests/env.py @@ -2,6 +2,8 @@ import platform import sys +import pytest + LINUX = sys.platform.startswith("linux") MACOS = sys.platform.startswith("darwin") WIN = sys.platform.startswith("win32") or sys.platform.startswith("cygwin") @@ -12,3 +14,20 @@ PYPY = platform.python_implementation() == "PyPy" PY2 = sys.version_info.major == 2 PY = sys.version_info + + +def deprecated_call(): + """ + pytest.deprecated_call() seems broken in pytest<3.9.x; concretely, it + doesn't work on CPython 3.8.0 with pytest==3.3.2 on Ubuntu 18.04 (#2922). + + This is a narrowed reimplementation of the following PR :( + https://github.com/pytest-dev/pytest/pull/4104 + """ + # TODO: Remove this when testing requires pytest>=3.9. + pieces = pytest.__version__.split(".") + pytest_major_minor = (int(pieces[0]), int(pieces[1])) + if pytest_major_minor < (3, 9): + return pytest.warns((DeprecationWarning, PendingDeprecationWarning)) + else: + return pytest.deprecated_call() diff --git a/wrap/pybind11/tests/extra_python_package/test_files.py b/wrap/pybind11/tests/extra_python_package/test_files.py index ac8ca1f97..337a72dfe 100644 --- a/wrap/pybind11/tests/extra_python_package/test_files.py +++ b/wrap/pybind11/tests/extra_python_package/test_files.py @@ -25,6 +25,7 @@ main_headers = { "include/pybind11/embed.h", "include/pybind11/eval.h", "include/pybind11/functional.h", + "include/pybind11/gil.h", "include/pybind11/iostream.h", "include/pybind11/numpy.h", "include/pybind11/operators.h", @@ -41,9 +42,14 @@ detail_headers = { "include/pybind11/detail/descr.h", "include/pybind11/detail/init.h", "include/pybind11/detail/internals.h", + "include/pybind11/detail/type_caster_base.h", "include/pybind11/detail/typeid.h", } +stl_headers = { + "include/pybind11/stl/filesystem.h", +} + cmake_files = { "share/cmake/pybind11/FindPythonLibsNew.cmake", "share/cmake/pybind11/pybind11Common.cmake", @@ -58,11 +64,14 @@ py_files = { "__init__.py", "__main__.py", "_version.py", + "_version.pyi", "commands.py", + "py.typed", "setup_helpers.py", + "setup_helpers.pyi", } -headers = main_headers | detail_headers +headers = main_headers | detail_headers | stl_headers src_files = headers | cmake_files all_files = src_files | py_files @@ -72,6 +81,7 @@ sdist_files = { "pybind11/include", "pybind11/include/pybind11", "pybind11/include/pybind11/detail", + "pybind11/include/pybind11/stl", "pybind11/share", "pybind11/share/cmake", "pybind11/share/cmake/pybind11", @@ -80,7 +90,7 @@ sdist_files = { "setup.py", "LICENSE", "MANIFEST.in", - "README.md", + "README.rst", "PKG-INFO", } @@ -116,7 +126,7 @@ def test_build_sdist(monkeypatch, tmpdir): with tarfile.open(str(sdist)) as tar: start = tar.getnames()[0] + "/" version = start[9:-1] - simpler = set(n.split("/", 1)[-1] for n in tar.getnames()[1:]) + simpler = {n.split("/", 1)[-1] for n in tar.getnames()[1:]} with contextlib.closing( tar.extractfile(tar.getmember(start + "setup.py")) @@ -128,9 +138,19 @@ def test_build_sdist(monkeypatch, tmpdir): ) as f: pyproject_toml = f.read() - files = set("pybind11/{}".format(n) for n in all_files) + with contextlib.closing( + tar.extractfile( + tar.getmember( + start + "pybind11/share/cmake/pybind11/pybind11Config.cmake" + ) + ) + ) as f: + contents = f.read().decode("utf8") + assert 'set(pybind11_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include")' in contents + + files = {"pybind11/{}".format(n) for n in all_files} files |= sdist_files - files |= set("pybind11{}".format(n) for n in local_sdist_files) + files |= {"pybind11{}".format(n) for n in local_sdist_files} files.add("pybind11.egg-info/entry_points.txt") files.add("pybind11.egg-info/requires.txt") assert simpler == files @@ -141,11 +161,11 @@ def test_build_sdist(monkeypatch, tmpdir): .substitute(version=version, extra_cmd="") .encode() ) - assert setup_py == contents + assert setup_py == contents with open(os.path.join(MAIN_DIR, "tools", "pyproject.toml"), "rb") as f: contents = f.read() - assert pyproject_toml == contents + assert pyproject_toml == contents def test_build_global_dist(monkeypatch, tmpdir): @@ -171,7 +191,7 @@ def test_build_global_dist(monkeypatch, tmpdir): with tarfile.open(str(sdist)) as tar: start = tar.getnames()[0] + "/" version = start[16:-1] - simpler = set(n.split("/", 1)[-1] for n in tar.getnames()[1:]) + simpler = {n.split("/", 1)[-1] for n in tar.getnames()[1:]} with contextlib.closing( tar.extractfile(tar.getmember(start + "setup.py")) @@ -183,9 +203,9 @@ def test_build_global_dist(monkeypatch, tmpdir): ) as f: pyproject_toml = f.read() - files = set("pybind11/{}".format(n) for n in all_files) + files = {"pybind11/{}".format(n) for n in all_files} files |= sdist_files - files |= set("pybind11_global{}".format(n) for n in local_sdist_files) + files |= {"pybind11_global{}".format(n) for n in local_sdist_files} assert simpler == files with open(os.path.join(MAIN_DIR, "tools", "setup_global.py.in"), "rb") as f: @@ -210,7 +230,7 @@ def tests_build_wheel(monkeypatch, tmpdir): (wheel,) = tmpdir.visit("*.whl") - files = set("pybind11/{}".format(n) for n in all_files) + files = {"pybind11/{}".format(n) for n in all_files} files |= { "dist-info/LICENSE", "dist-info/METADATA", @@ -223,10 +243,10 @@ def tests_build_wheel(monkeypatch, tmpdir): with zipfile.ZipFile(str(wheel)) as z: names = z.namelist() - trimmed = set(n for n in names if "dist-info" not in n) - trimmed |= set( + trimmed = {n for n in names if "dist-info" not in n} + trimmed |= { "dist-info/{}".format(n.split("/", 1)[-1]) for n in names if "dist-info" in n - ) + } assert files == trimmed @@ -240,8 +260,8 @@ def tests_build_global_wheel(monkeypatch, tmpdir): (wheel,) = tmpdir.visit("*.whl") - files = set("data/data/{}".format(n) for n in src_files) - files |= set("data/headers/{}".format(n[8:]) for n in headers) + files = {"data/data/{}".format(n) for n in src_files} + files |= {"data/headers/{}".format(n[8:]) for n in headers} files |= { "dist-info/LICENSE", "dist-info/METADATA", @@ -254,6 +274,6 @@ def tests_build_global_wheel(monkeypatch, tmpdir): names = z.namelist() beginning = names[0].split("/", 1)[0].rsplit(".", 1)[0] - trimmed = set(n[len(beginning) + 1 :] for n in names) + trimmed = {n[len(beginning) + 1 :] for n in names} assert files == trimmed diff --git a/wrap/pybind11/tests/extra_setuptools/test_setuphelper.py b/wrap/pybind11/tests/extra_setuptools/test_setuphelper.py index de0b516a9..788f368b1 100644 --- a/wrap/pybind11/tests/extra_setuptools/test_setuphelper.py +++ b/wrap/pybind11/tests/extra_setuptools/test_setuphelper.py @@ -1,17 +1,19 @@ # -*- coding: utf-8 -*- import os -import sys import subprocess +import sys from textwrap import dedent import pytest DIR = os.path.abspath(os.path.dirname(__file__)) MAIN_DIR = os.path.dirname(os.path.dirname(DIR)) +WIN = sys.platform.startswith("win32") or sys.platform.startswith("cygwin") +@pytest.mark.parametrize("parallel", [False, True]) @pytest.mark.parametrize("std", [11, 0]) -def test_simple_setup_py(monkeypatch, tmpdir, std): +def test_simple_setup_py(monkeypatch, tmpdir, parallel, std): monkeypatch.chdir(tmpdir) monkeypatch.syspath_prepend(MAIN_DIR) @@ -39,13 +41,18 @@ def test_simple_setup_py(monkeypatch, tmpdir, std): cmdclass["build_ext"] = build_ext + parallel = {parallel} + if parallel: + from pybind11.setup_helpers import ParallelCompile + ParallelCompile().install() + setup( name="simple_setup_package", cmdclass=cmdclass, ext_modules=ext_modules, ) """ - ).format(MAIN_DIR=MAIN_DIR, std=std), + ).format(MAIN_DIR=MAIN_DIR, std=std, parallel=parallel), encoding="ascii", ) @@ -65,13 +72,20 @@ def test_simple_setup_py(monkeypatch, tmpdir, std): encoding="ascii", ) - subprocess.check_call( + out = subprocess.check_output( [sys.executable, "setup.py", "build_ext", "--inplace"], - stdout=sys.stdout, - stderr=sys.stderr, ) + if not WIN: + assert b"-g0" in out + out = subprocess.check_output( + [sys.executable, "setup.py", "build_ext", "--inplace", "--force"], + env=dict(os.environ, CFLAGS="-g"), + ) + if not WIN: + assert b"-g0" not in out # Debug helper printout, normally hidden + print(out) for item in tmpdir.listdir(): print(item.basename) @@ -93,3 +107,45 @@ def test_simple_setup_py(monkeypatch, tmpdir, std): subprocess.check_call( [sys.executable, "test.py"], stdout=sys.stdout, stderr=sys.stderr ) + + +def test_intree_extensions(monkeypatch, tmpdir): + monkeypatch.syspath_prepend(MAIN_DIR) + + from pybind11.setup_helpers import intree_extensions + + monkeypatch.chdir(tmpdir) + root = tmpdir + root.ensure_dir() + subdir = root / "dir" + subdir.ensure_dir() + src = subdir / "ext.cpp" + src.ensure() + (ext,) = intree_extensions([src.relto(tmpdir)]) + assert ext.name == "ext" + subdir.ensure("__init__.py") + (ext,) = intree_extensions([src.relto(tmpdir)]) + assert ext.name == "dir.ext" + + +def test_intree_extensions_package_dir(monkeypatch, tmpdir): + monkeypatch.syspath_prepend(MAIN_DIR) + + from pybind11.setup_helpers import intree_extensions + + monkeypatch.chdir(tmpdir) + root = tmpdir / "src" + root.ensure_dir() + subdir = root / "dir" + subdir.ensure_dir() + src = subdir / "ext.cpp" + src.ensure() + (ext,) = intree_extensions([src.relto(tmpdir)], package_dir={"": "src"}) + assert ext.name == "dir.ext" + (ext,) = intree_extensions([src.relto(tmpdir)], package_dir={"foo": "src"}) + assert ext.name == "foo.dir.ext" + subdir.ensure("__init__.py") + (ext,) = intree_extensions([src.relto(tmpdir)], package_dir={"": "src"}) + assert ext.name == "dir.ext" + (ext,) = intree_extensions([src.relto(tmpdir)], package_dir={"foo": "src"}) + assert ext.name == "foo.dir.ext" diff --git a/wrap/pybind11/tests/local_bindings.h b/wrap/pybind11/tests/local_bindings.h index 22537b13a..4c936c19a 100644 --- a/wrap/pybind11/tests/local_bindings.h +++ b/wrap/pybind11/tests/local_bindings.h @@ -1,10 +1,12 @@ #pragma once +#include + #include "pybind11_tests.h" /// Simple class used to test py::local: template class LocalBase { public: - LocalBase(int i) : i(i) { } + explicit LocalBase(int i) : i(i) { } int i = -1; }; @@ -33,6 +35,25 @@ using NonLocalVec2 = std::vector; using NonLocalMap = std::unordered_map; using NonLocalMap2 = std::unordered_map; + +// Exception that will be caught via the module local translator. +class LocalException : public std::exception { +public: + explicit LocalException(const char * m) : message{m} {} + const char * what() const noexcept override {return message.c_str();} +private: + std::string message = ""; +}; + +// Exception that will be registered with register_local_exception_translator +class LocalSimpleException : public std::exception { +public: + explicit LocalSimpleException(const char * m) : message{m} {} + const char * what() const noexcept override {return message.c_str();} +private: + std::string message = ""; +}; + PYBIND11_MAKE_OPAQUE(LocalVec); PYBIND11_MAKE_OPAQUE(LocalVec2); PYBIND11_MAKE_OPAQUE(LocalMap); @@ -54,11 +75,11 @@ py::class_ bind_local(Args && ...args) { namespace pets { class Pet { public: - Pet(std::string name) : name_(name) {} + explicit Pet(std::string name) : name_(std::move(name)) {} std::string name_; - const std::string &name() { return name_; } + const std::string &name() const { return name_; } }; } // namespace pets -struct MixGL { int i; MixGL(int i) : i{i} {} }; -struct MixGL2 { int i; MixGL2(int i) : i{i} {} }; +struct MixGL { int i; explicit MixGL(int i) : i{i} {} }; +struct MixGL2 { int i; explicit MixGL2(int i) : i{i} {} }; diff --git a/wrap/pybind11/tests/object.h b/wrap/pybind11/tests/object.h index 9235f19c2..df34c2bad 100644 --- a/wrap/pybind11/tests/object.h +++ b/wrap/pybind11/tests/object.h @@ -1,5 +1,4 @@ -#if !defined(__OBJECT_H) -#define __OBJECT_H +#pragma once #include #include "constructor_stats.h" @@ -65,7 +64,7 @@ public: ref() : m_ptr(nullptr) { print_default_created(this); track_default_created((ref_tag*) this); } /// Construct a reference from a pointer - ref(T *ptr) : m_ptr(ptr) { + explicit ref(T *ptr) : m_ptr(ptr) { if (m_ptr) ((Object *) m_ptr)->incRef(); print_created(this, "from pointer", m_ptr); track_created((ref_tag*) this, "from pointer"); @@ -81,7 +80,7 @@ public: } /// Move constructor - ref(ref &&r) : m_ptr(r.m_ptr) { + ref(ref &&r) noexcept : m_ptr(r.m_ptr) { r.m_ptr = nullptr; print_move_created(this, "with pointer", m_ptr); track_move_created((ref_tag*) this); @@ -96,7 +95,7 @@ public: } /// Move another reference into the current one - ref& operator=(ref&& r) { + ref &operator=(ref &&r) noexcept { print_move_assigned(this, "pointer", r.m_ptr); track_move_assigned((ref_tag*) this); if (*this == r) @@ -110,7 +109,11 @@ public: /// Overwrite this reference with another reference ref& operator=(const ref& r) { - print_copy_assigned(this, "pointer", r.m_ptr); track_copy_assigned((ref_tag*) this); + if (this == &r) { + return *this; + } + print_copy_assigned(this, "pointer", r.m_ptr); + track_copy_assigned((ref_tag *) this); if (m_ptr == r.m_ptr) return *this; @@ -161,7 +164,7 @@ public: const T& operator*() const { return *m_ptr; } /// Return a pointer to the referenced object - operator T* () { return m_ptr; } + explicit operator T* () { return m_ptr; } /// Return a const pointer to the referenced object T* get_ptr() { return m_ptr; } @@ -171,5 +174,3 @@ public: private: T *m_ptr; }; - -#endif /* __OBJECT_H */ diff --git a/wrap/pybind11/tests/pybind11_cross_module_tests.cpp b/wrap/pybind11/tests/pybind11_cross_module_tests.cpp index f705e3106..5838cb274 100644 --- a/wrap/pybind11/tests/pybind11_cross_module_tests.cpp +++ b/wrap/pybind11/tests/pybind11_cross_module_tests.cpp @@ -9,8 +9,12 @@ #include "pybind11_tests.h" #include "local_bindings.h" +#include "test_exceptions.h" + #include + #include +#include PYBIND11_MODULE(pybind11_cross_module_tests, m) { m.doc() = "pybind11 cross-module test module"; @@ -25,11 +29,32 @@ PYBIND11_MODULE(pybind11_cross_module_tests, m) { bind_local(m, "ExternalType2", py::module_local()); // test_exceptions.py + py::register_local_exception(m, "LocalSimpleException"); m.def("raise_runtime_error", []() { PyErr_SetString(PyExc_RuntimeError, "My runtime error"); throw py::error_already_set(); }); m.def("raise_value_error", []() { PyErr_SetString(PyExc_ValueError, "My value error"); throw py::error_already_set(); }); m.def("throw_pybind_value_error", []() { throw py::value_error("pybind11 value error"); }); m.def("throw_pybind_type_error", []() { throw py::type_error("pybind11 type error"); }); m.def("throw_stop_iteration", []() { throw py::stop_iteration(); }); + m.def("throw_local_error", []() { throw LocalException("just local"); }); + m.def("throw_local_simple_error", []() { throw LocalSimpleException("external mod"); }); + py::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const shared_exception &e) { + PyErr_SetString(PyExc_KeyError, e.what()); + } + }); + + // translate the local exception into a key error but only in this module + py::register_local_exception_translator([](std::exception_ptr p) { + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const LocalException &e) { + PyErr_SetString(PyExc_KeyError, e.what()); + } + }); // test_local_bindings.py // Local to both: @@ -83,7 +108,7 @@ PYBIND11_MODULE(pybind11_cross_module_tests, m) { m.def("get_mixed_lg", [](int i) { return MixedLocalGlobal(i); }); // test_internal_locals_differ - m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::registered_local_types_cpp(); }); + m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::get_local_internals().registered_types_cpp; }); // test_stl_caster_vs_stl_bind py::bind_vector>(m, "VectorInt"); @@ -96,7 +121,10 @@ PYBIND11_MODULE(pybind11_cross_module_tests, m) { m.def("return_self", [](LocalVec *v) { return v; }); m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); }); - class Dog : public pets::Pet { public: Dog(std::string name) : Pet(name) {}; }; + class Dog : public pets::Pet { + public: + explicit Dog(std::string name) : Pet(std::move(name)) {} + }; py::class_(m, "Pet", py::module_local()) .def("name", &pets::Pet::name); // Binding for local extending class: @@ -118,6 +146,6 @@ PYBIND11_MODULE(pybind11_cross_module_tests, m) { // test_missing_header_message // The main module already includes stl.h, but we need to test the error message // which appears when this header is missing. - m.def("missing_header_arg", [](std::vector) { }); + m.def("missing_header_arg", [](const std::vector &) {}); m.def("missing_header_return", []() { return std::vector(); }); } diff --git a/wrap/pybind11/tests/pybind11_tests.cpp b/wrap/pybind11/tests/pybind11_tests.cpp index 24b65df6f..439cd4012 100644 --- a/wrap/pybind11/tests/pybind11_tests.cpp +++ b/wrap/pybind11/tests/pybind11_tests.cpp @@ -26,8 +26,8 @@ productively. Instead, see the "How can I reduce the build time?" question in the "Frequently asked questions" section of the documentation for good practice on splitting binding code over multiple files. */ -std::list> &initializers() { - static std::list> inits; +std::list> &initializers() { + static std::list> inits; return inits; } @@ -36,13 +36,13 @@ test_initializer::test_initializer(Initializer init) { } test_initializer::test_initializer(const char *submodule_name, Initializer init) { - initializers().emplace_back([=](py::module &parent) { + initializers().emplace_back([=](py::module_ &parent) { auto m = parent.def_submodule(submodule_name); init(m); }); } -void bind_ConstructorStats(py::module &m) { +void bind_ConstructorStats(py::module_ &m) { py::class_(m, "ConstructorStats") .def("alive", &ConstructorStats::alive) .def("values", &ConstructorStats::values) diff --git a/wrap/pybind11/tests/pybind11_tests.h b/wrap/pybind11/tests/pybind11_tests.h index 1e4741627..9b9992323 100644 --- a/wrap/pybind11/tests/pybind11_tests.h +++ b/wrap/pybind11/tests/pybind11_tests.h @@ -1,27 +1,29 @@ #pragma once + #include +#include #if defined(_MSC_VER) && _MSC_VER < 1910 // We get some really long type names here which causes MSVC 2015 to emit warnings -# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated +# pragma warning( \ + disable : 4503) // warning C4503: decorated name length exceeded, name was truncated #endif namespace py = pybind11; using namespace pybind11::literals; class test_initializer { - using Initializer = void (*)(py::module &); + using Initializer = void (*)(py::module_ &); public: - test_initializer(Initializer init); + explicit test_initializer(Initializer init); test_initializer(const char *submodule_name, Initializer init); }; -#define TEST_SUBMODULE(name, variable) \ - void test_submodule_##name(py::module &); \ - test_initializer name(#name, test_submodule_##name); \ - void test_submodule_##name(py::module &variable) - +#define TEST_SUBMODULE(name, variable) \ + void test_submodule_##name(py::module_ &); \ + test_initializer name(#name, test_submodule_##name); \ + void test_submodule_##name(py::module_ &(variable)) /// Dummy type which is not exported anywhere -- something to trigger a conversion error struct UnregisteredType { }; @@ -30,7 +32,7 @@ struct UnregisteredType { }; class UserType { public: UserType() = default; - UserType(int i) : i(i) { } + explicit UserType(int i) : i(i) { } int value() const { return i; } void set(int set) { i = set; } @@ -50,6 +52,12 @@ public: IncType &operator=(IncType &&) = delete; }; +/// A simple union for basic testing +union IntFloat { + int i; + float f; +}; + /// Custom cast-only type that casts to a string "rvalue" or "lvalue" depending on the cast context. /// Used to test recursive casters (e.g. std::tuple, stl containers). struct RValueCaster {}; @@ -57,9 +65,21 @@ PYBIND11_NAMESPACE_BEGIN(pybind11) PYBIND11_NAMESPACE_BEGIN(detail) template<> class type_caster { public: - PYBIND11_TYPE_CASTER(RValueCaster, _("RValueCaster")); + PYBIND11_TYPE_CASTER(RValueCaster, const_name("RValueCaster")); static handle cast(RValueCaster &&, return_value_policy, handle) { return py::str("rvalue").release(); } static handle cast(const RValueCaster &, return_value_policy, handle) { return py::str("lvalue").release(); } }; PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(pybind11) + +template +void ignoreOldStyleInitWarnings(F &&body) { + py::exec(R"( + message = "pybind11-bound class '.+' is using an old-style placement-new '(?:__init__|__setstate__)' which has been deprecated" + + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=message, category=FutureWarning) + body() + )", py::dict(py::arg("body") = py::cpp_function(body))); +} diff --git a/wrap/pybind11/tests/pytest.ini b/wrap/pybind11/tests/pytest.ini index c47cbe9c1..a3871d6c3 100644 --- a/wrap/pybind11/tests/pytest.ini +++ b/wrap/pybind11/tests/pytest.ini @@ -7,11 +7,11 @@ addopts = -rs # capture only Python print and C++ py::print, but not C output (low-level Python errors) --capture=sys - # enable all warnings - -Wa filterwarnings = # make warnings into errors but ignore certain third-party extension issues error + # somehow, some DeprecationWarnings do not get turned into errors + always::DeprecationWarning # importing scipy submodules on some version of Python ignore::ImportWarning # bogus numpy ABI warning (see numpy/#432) diff --git a/wrap/pybind11/tests/requirements.txt b/wrap/pybind11/tests/requirements.txt index 39bd57a1c..98ca46d28 100644 --- a/wrap/pybind11/tests/requirements.txt +++ b/wrap/pybind11/tests/requirements.txt @@ -1,8 +1,12 @@ ---extra-index-url https://antocuni.github.io/pypy-wheels/manylinux2010/ -numpy==1.16.6; python_version<"3.6" -numpy==1.18.0; platform_python_implementation=="PyPy" and sys_platform=="darwin" and python_version>="3.6" -numpy==1.19.1; (platform_python_implementation!="PyPy" or sys_platform!="darwin") and python_version>="3.6" and python_version<"3.9" +numpy==1.16.6; python_version<"3.6" and sys_platform!="win32" and platform_python_implementation!="PyPy" +numpy==1.19.0; platform_python_implementation=="PyPy" and sys_platform=="linux" and python_version=="3.6" +numpy==1.20.0; platform_python_implementation=="PyPy" and sys_platform=="linux" and python_version=="3.7" +numpy==1.19.3; platform_python_implementation!="PyPy" and python_version=="3.6" +numpy==1.21.3; platform_python_implementation!="PyPy" and python_version>="3.7" and python_version<"3.11" +py @ git+https://github.com/pytest-dev/py; python_version>="3.11" pytest==4.6.9; python_version<"3.5" -pytest==5.4.3; python_version>="3.5" -scipy==1.2.3; (platform_python_implementation!="PyPy" or sys_platform!="darwin") and python_version<"3.6" -scipy==1.5.2; (platform_python_implementation!="PyPy" or sys_platform!="darwin") and python_version>="3.6" and python_version<"3.9" +pytest==6.1.2; python_version=="3.5" +pytest==6.2.4; python_version>="3.6" +pytest-timeout +scipy==1.2.3; platform_python_implementation!="PyPy" and python_version<"3.6" +scipy==1.5.4; platform_python_implementation!="PyPy" and python_version>="3.6" and python_version<"3.10" diff --git a/wrap/pybind11/tests/test_async.cpp b/wrap/pybind11/tests/test_async.cpp index f0ad0d535..e6e01d72c 100644 --- a/wrap/pybind11/tests/test_async.cpp +++ b/wrap/pybind11/tests/test_async.cpp @@ -18,7 +18,7 @@ TEST_SUBMODULE(async_module, m) { .def(py::init<>()) .def("__await__", [](const SupportsAsync& self) -> py::object { static_cast(self); - py::object loop = py::module::import("asyncio.events").attr("get_event_loop")(); + py::object loop = py::module_::import("asyncio.events").attr("get_event_loop")(); py::object f = loop.attr("create_future")(); f.attr("set_result")(5); return f.attr("__await__")(); diff --git a/wrap/pybind11/tests/test_buffers.cpp b/wrap/pybind11/tests/test_buffers.cpp index 1bc67ff7b..3a8e3e7b7 100644 --- a/wrap/pybind11/tests/test_buffers.cpp +++ b/wrap/pybind11/tests/test_buffers.cpp @@ -9,12 +9,13 @@ #include "pybind11_tests.h" #include "constructor_stats.h" +#include TEST_SUBMODULE(buffers, m) { // test_from_python / test_to_python: class Matrix { public: - Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) { + Matrix(py::ssize_t rows, py::ssize_t cols) : m_rows(rows), m_cols(cols) { print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); m_data = new float[(size_t) (rows*cols)]; memset(m_data, 0, sizeof(float) * (size_t) (rows * cols)); @@ -26,7 +27,7 @@ TEST_SUBMODULE(buffers, m) { memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); } - Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) { + Matrix(Matrix &&s) noexcept : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) { print_move_created(this); s.m_rows = 0; s.m_cols = 0; @@ -39,7 +40,11 @@ TEST_SUBMODULE(buffers, m) { } Matrix &operator=(const Matrix &s) { - print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); + if (this == &s) { + return *this; + } + print_copy_assigned(this, + std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); delete[] m_data; m_rows = s.m_rows; m_cols = s.m_cols; @@ -48,7 +53,7 @@ TEST_SUBMODULE(buffers, m) { return *this; } - Matrix &operator=(Matrix &&s) { + Matrix &operator=(Matrix &&s) noexcept { print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); if (&s != this) { delete[] m_data; @@ -58,27 +63,27 @@ TEST_SUBMODULE(buffers, m) { return *this; } - float operator()(ssize_t i, ssize_t j) const { + float operator()(py::ssize_t i, py::ssize_t j) const { return m_data[(size_t) (i*m_cols + j)]; } - float &operator()(ssize_t i, ssize_t j) { + float &operator()(py::ssize_t i, py::ssize_t j) { return m_data[(size_t) (i*m_cols + j)]; } float *data() { return m_data; } - ssize_t rows() const { return m_rows; } - ssize_t cols() const { return m_cols; } + py::ssize_t rows() const { return m_rows; } + py::ssize_t cols() const { return m_cols; } private: - ssize_t m_rows; - ssize_t m_cols; + py::ssize_t m_rows; + py::ssize_t m_cols; float *m_data; }; py::class_(m, "Matrix", py::buffer_protocol()) - .def(py::init()) + .def(py::init()) /// Construct from a buffer - .def(py::init([](py::buffer const b) { + .def(py::init([](const py::buffer &b) { py::buffer_info info = b.request(); if (info.format != py::format_descriptor::format() || info.ndim != 2) throw std::runtime_error("Incompatible buffer format!"); @@ -88,40 +93,40 @@ TEST_SUBMODULE(buffers, m) { return v; })) - .def("rows", &Matrix::rows) - .def("cols", &Matrix::cols) + .def("rows", &Matrix::rows) + .def("cols", &Matrix::cols) /// Bare bones interface - .def("__getitem__", [](const Matrix &m, std::pair i) { - if (i.first >= m.rows() || i.second >= m.cols()) - throw py::index_error(); - return m(i.first, i.second); - }) - .def("__setitem__", [](Matrix &m, std::pair i, float v) { - if (i.first >= m.rows() || i.second >= m.cols()) - throw py::index_error(); - m(i.first, i.second) = v; - }) - /// Provide buffer access - .def_buffer([](Matrix &m) -> py::buffer_info { + .def("__getitem__", + [](const Matrix &m, std::pair i) { + if (i.first >= m.rows() || i.second >= m.cols()) + throw py::index_error(); + return m(i.first, i.second); + }) + .def("__setitem__", + [](Matrix &m, std::pair i, float v) { + if (i.first >= m.rows() || i.second >= m.cols()) + throw py::index_error(); + m(i.first, i.second) = v; + }) + /// Provide buffer access + .def_buffer([](Matrix &m) -> py::buffer_info { return py::buffer_info( m.data(), /* Pointer to buffer */ { m.rows(), m.cols() }, /* Buffer dimensions */ { sizeof(float) * size_t(m.cols()), /* Strides (in bytes) for each index */ sizeof(float) } ); - }) - ; - + }); // test_inherited_protocol class SquareMatrix : public Matrix { public: - SquareMatrix(ssize_t n) : Matrix(n, n) { } + explicit SquareMatrix(py::ssize_t n) : Matrix(n, n) {} }; // Derived classes inherit the buffer protocol and the buffer access function py::class_(m, "SquareMatrix") - .def(py::init()); + .def(py::init()); // test_pointer_to_member_fn @@ -153,7 +158,7 @@ TEST_SUBMODULE(buffers, m) { py::format_descriptor::format(), 1); } - ConstBuffer() : value(new int32_t{0}) { }; + ConstBuffer() : value(new int32_t{0}) {} }; py::class_(m, "ConstBuffer", py::buffer_protocol()) .def(py::init<>()) @@ -168,7 +173,7 @@ TEST_SUBMODULE(buffers, m) { struct BufferReadOnly { const uint8_t value = 0; - BufferReadOnly(uint8_t value): value(value) {} + explicit BufferReadOnly(uint8_t value) : value(value) {} py::buffer_info get_buffer_info() { return py::buffer_info(&value, 1); @@ -192,4 +197,20 @@ TEST_SUBMODULE(buffers, m) { .def_readwrite("readonly", &BufferReadOnlySelect::readonly) .def_buffer(&BufferReadOnlySelect::get_buffer_info); + // Expose buffer_info for testing. + py::class_(m, "buffer_info") + .def(py::init<>()) + .def_readonly("itemsize", &py::buffer_info::itemsize) + .def_readonly("size", &py::buffer_info::size) + .def_readonly("format", &py::buffer_info::format) + .def_readonly("ndim", &py::buffer_info::ndim) + .def_readonly("shape", &py::buffer_info::shape) + .def_readonly("strides", &py::buffer_info::strides) + .def_readonly("readonly", &py::buffer_info::readonly) + .def("__repr__", [](py::handle self) { + return py::str("itemsize={0.itemsize!r}, size={0.size!r}, format={0.format!r}, ndim={0.ndim!r}, shape={0.shape!r}, strides={0.strides!r}, readonly={0.readonly!r}").format(self); + }) + ; + + m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); }); } diff --git a/wrap/pybind11/tests/test_buffers.py b/wrap/pybind11/tests/test_buffers.py index d6adaf1f5..0d5bf16c3 100644 --- a/wrap/pybind11/tests/test_buffers.py +++ b/wrap/pybind11/tests/test_buffers.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- +import ctypes import io import struct import pytest -import env # noqa: F401 - -from pybind11_tests import buffers as m +import env from pybind11_tests import ConstructorStats +from pybind11_tests import buffers as m np = pytest.importorskip("numpy") @@ -36,6 +36,10 @@ def test_from_python(): # https://foss.heptapod.net/pypy/pypy/-/issues/2444 +# TODO: fix on recent PyPy +@pytest.mark.xfail( + env.PYPY, reason="PyPy 7.3.7 doesn't clear this anymore", strict=False +) def test_to_python(): mat = m.Matrix(5, 4) assert memoryview(mat).shape == (5, 4) @@ -45,8 +49,8 @@ def test_to_python(): mat[3, 2] = 7.0 assert mat[2, 3] == 4 assert mat[3, 2] == 7 - assert struct.unpack_from('f', mat, (3 * 4 + 2) * 4) == (7, ) - assert struct.unpack_from('f', mat, (2 * 4 + 3) * 4) == (4, ) + assert struct.unpack_from("f", mat, (3 * 4 + 2) * 4) == (7,) + assert struct.unpack_from("f", mat, (2 * 4 + 3) * 4) == (4,) mat2 = np.array(mat, copy=False) assert mat2.shape == (5, 4) @@ -82,28 +86,82 @@ def test_pointer_to_member_fn(): for cls in [m.Buffer, m.ConstBuffer, m.DerivedBuffer]: buf = cls() buf.value = 0x12345678 - value = struct.unpack('i', bytearray(buf))[0] + value = struct.unpack("i", bytearray(buf))[0] assert value == 0x12345678 def test_readonly_buffer(): buf = m.BufferReadOnly(0x64) view = memoryview(buf) - assert view[0] == b'd' if env.PY2 else 0x64 + assert view[0] == b"d" if env.PY2 else 0x64 assert view.readonly + with pytest.raises(TypeError): + view[0] = b"\0" if env.PY2 else 0 def test_selective_readonly_buffer(): buf = m.BufferReadOnlySelect() - memoryview(buf)[0] = b'd' if env.PY2 else 0x64 + memoryview(buf)[0] = b"d" if env.PY2 else 0x64 assert buf.value == 0x64 - io.BytesIO(b'A').readinto(buf) - assert buf.value == ord(b'A') + io.BytesIO(b"A").readinto(buf) + assert buf.value == ord(b"A") buf.readonly = True with pytest.raises(TypeError): - memoryview(buf)[0] = b'\0' if env.PY2 else 0 + memoryview(buf)[0] = b"\0" if env.PY2 else 0 with pytest.raises(TypeError): - io.BytesIO(b'1').readinto(buf) + io.BytesIO(b"1").readinto(buf) + + +def test_ctypes_array_1d(): + char1d = (ctypes.c_char * 10)() + int1d = (ctypes.c_int * 15)() + long1d = (ctypes.c_long * 7)() + + for carray in (char1d, int1d, long1d): + info = m.get_buffer_info(carray) + assert info.itemsize == ctypes.sizeof(carray._type_) + assert info.size == len(carray) + assert info.ndim == 1 + assert info.shape == [info.size] + assert info.strides == [info.itemsize] + assert not info.readonly + + +def test_ctypes_array_2d(): + char2d = ((ctypes.c_char * 10) * 4)() + int2d = ((ctypes.c_int * 15) * 3)() + long2d = ((ctypes.c_long * 7) * 2)() + + for carray in (char2d, int2d, long2d): + info = m.get_buffer_info(carray) + assert info.itemsize == ctypes.sizeof(carray[0]._type_) + assert info.size == len(carray) * len(carray[0]) + assert info.ndim == 2 + assert info.shape == [len(carray), len(carray[0])] + assert info.strides == [info.itemsize * len(carray[0]), info.itemsize] + assert not info.readonly + + +@pytest.mark.skipif( + "env.PYPY and env.PY2", reason="PyPy2 bytes buffer not reported as readonly" +) +def test_ctypes_from_buffer(): + test_pystr = b"0123456789" + for pyarray in (test_pystr, bytearray(test_pystr)): + pyinfo = m.get_buffer_info(pyarray) + + if pyinfo.readonly: + cbytes = (ctypes.c_char * len(pyarray)).from_buffer_copy(pyarray) + cinfo = m.get_buffer_info(cbytes) + else: + cbytes = (ctypes.c_char * len(pyarray)).from_buffer(pyarray) + cinfo = m.get_buffer_info(cbytes) + + assert cinfo.size == pyinfo.size + assert cinfo.ndim == pyinfo.ndim + assert cinfo.shape == pyinfo.shape + assert cinfo.strides == pyinfo.strides + assert not cinfo.readonly diff --git a/wrap/pybind11/tests/test_builtin_casters.cpp b/wrap/pybind11/tests/test_builtin_casters.cpp index acc9f8fb3..4a9f33837 100644 --- a/wrap/pybind11/tests/test_builtin_casters.cpp +++ b/wrap/pybind11/tests/test_builtin_casters.cpp @@ -10,10 +10,64 @@ #include "pybind11_tests.h" #include -#if defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif +struct ConstRefCasted { + int tag; +}; + +PYBIND11_NAMESPACE_BEGIN(pybind11) +PYBIND11_NAMESPACE_BEGIN(detail) +template <> +class type_caster { + public: + static constexpr auto name = const_name(); + + // Input is unimportant, a new value will always be constructed based on the + // cast operator. + bool load(handle, bool) { return true; } + + explicit operator ConstRefCasted &&() { + value = {1}; + // NOLINTNEXTLINE(performance-move-const-arg) + return std::move(value); + } + explicit operator ConstRefCasted &() { + value = {2}; + return value; + } + explicit operator ConstRefCasted *() { + value = {3}; + return &value; + } + + explicit operator const ConstRefCasted &() { + value = {4}; + return value; + } + explicit operator const ConstRefCasted *() { + value = {5}; + return &value; + } + + // custom cast_op to explicitly propagate types to the conversion operators. + template + using cast_op_type = + /// const + conditional_t< + std::is_same, const ConstRefCasted*>::value, const ConstRefCasted*, + conditional_t< + std::is_same::value, const ConstRefCasted&, + /// non-const + conditional_t< + std::is_same, ConstRefCasted*>::value, ConstRefCasted*, + conditional_t< + std::is_same::value, ConstRefCasted&, + /* else */ConstRefCasted&&>>>>; + + private: + ConstRefCasted value = {0}; +}; +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(pybind11) TEST_SUBMODULE(builtin_casters, m) { // test_simple_string @@ -26,7 +80,7 @@ TEST_SUBMODULE(builtin_casters, m) { std::wstring wstr; wstr.push_back(0x61); // a wstr.push_back(0x2e18); // ⸘ - if (sizeof(wchar_t) == 2) { wstr.push_back(mathbfA16_1); wstr.push_back(mathbfA16_2); } // 𝐀, utf16 + if (PYBIND11_SILENCE_MSVC_C4127(sizeof(wchar_t) == 2)) { wstr.push_back(mathbfA16_1); wstr.push_back(mathbfA16_2); } // 𝐀, utf16 else { wstr.push_back((wchar_t) mathbfA32); } // 𝐀, utf32 wstr.push_back(0x7a); // z @@ -36,11 +90,12 @@ TEST_SUBMODULE(builtin_casters, m) { m.def("good_wchar_string", [=]() { return wstr; }); // a‽𝐀z m.def("bad_utf8_string", []() { return std::string("abc\xd0" "def"); }); m.def("bad_utf16_string", [=]() { return std::u16string({ b16, char16_t(0xd800), z16 }); }); +#if PY_MAJOR_VERSION >= 3 // Under Python 2.7, invalid unicode UTF-32 characters don't appear to trigger UnicodeDecodeError - if (PY_MAJOR_VERSION >= 3) - m.def("bad_utf32_string", [=]() { return std::u32string({ a32, char32_t(0xd800), z32 }); }); - if (PY_MAJOR_VERSION >= 3 || sizeof(wchar_t) == 2) + m.def("bad_utf32_string", [=]() { return std::u32string({ a32, char32_t(0xd800), z32 }); }); + if (PYBIND11_SILENCE_MSVC_C4127(sizeof(wchar_t) == 2)) m.def("bad_wchar_string", [=]() { return std::wstring({ wchar_t(0x61), wchar_t(0xd800) }); }); +#endif m.def("u8_Z", []() -> char { return 'Z'; }); m.def("u8_eacute", []() -> char { return '\xe9'; }); m.def("u16_ibang", [=]() -> char16_t { return ib16; }); @@ -58,7 +113,7 @@ TEST_SUBMODULE(builtin_casters, m) { // test_bytes_to_string m.def("strlen", [](char *s) { return strlen(s); }); - m.def("string_length", [](std::string s) { return s.length(); }); + m.def("string_length", [](const std::string &s) { return s.length(); }); #ifdef PYBIND11_HAS_U8STRING m.attr("has_u8string") = true; @@ -85,11 +140,35 @@ TEST_SUBMODULE(builtin_casters, m) { m.def("string_view16_return", []() { return std::u16string_view(u"utf16 secret \U0001f382"); }); m.def("string_view32_return", []() { return std::u32string_view(U"utf32 secret \U0001f382"); }); + // The inner lambdas here are to also test implicit conversion + using namespace std::literals; + m.def("string_view_bytes", []() { return [](py::bytes b) { return b; }("abc \x80\x80 def"sv); }); + m.def("string_view_str", []() { return [](py::str s) { return s; }("abc \342\200\275 def"sv); }); + m.def("string_view_from_bytes", [](const py::bytes &b) { return [](std::string_view s) { return s; }(b); }); +#if PY_MAJOR_VERSION >= 3 + m.def("string_view_memoryview", []() { + static constexpr auto val = "Have some \360\237\216\202"sv; + return py::memoryview::from_memory(val); + }); +#endif + # ifdef PYBIND11_HAS_U8STRING m.def("string_view8_print", [](std::u8string_view s) { py::print(s, s.size()); }); m.def("string_view8_chars", [](std::u8string_view s) { py::list l; for (auto c : s) l.append((std::uint8_t) c); return l; }); m.def("string_view8_return", []() { return std::u8string_view(u8"utf8 secret \U0001f382"); }); + m.def("string_view8_str", []() { return py::str{std::u8string_view{u8"abc ‽ def"}}; }); # endif + + struct TypeWithBothOperatorStringAndStringView { + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::string() const { return "success"; } + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::string_view() const { return "failure"; } + }; + m.def("bytes_from_type_with_both_operator_string_and_string_view", + []() { return py::bytes(TypeWithBothOperatorStringAndStringView()); }); + m.def("str_from_type_with_both_operator_string_and_string_view", + []() { return py::str(TypeWithBothOperatorStringAndStringView()); }); #endif // test_integer_casting @@ -98,10 +177,17 @@ TEST_SUBMODULE(builtin_casters, m) { m.def("i64_str", [](std::int64_t v) { return std::to_string(v); }); m.def("u64_str", [](std::uint64_t v) { return std::to_string(v); }); + // test_int_convert + m.def("int_passthrough", [](int arg) { return arg; }); + m.def("int_passthrough_noconvert", [](int arg) { return arg; }, py::arg{}.noconvert()); + // test_tuple - m.def("pair_passthrough", [](std::pair input) { - return std::make_pair(input.second, input.first); - }, "Return a pair in reversed order"); + m.def( + "pair_passthrough", + [](const std::pair &input) { + return std::make_pair(input.second, input.first); + }, + "Return a pair in reversed order"); m.def("tuple_passthrough", [](std::tuple input) { return std::make_tuple(std::get<2>(input), std::get<1>(input), std::get<0>(input)); }, "Return a triple in reversed order"); @@ -130,23 +216,45 @@ TEST_SUBMODULE(builtin_casters, m) { // test_none_deferred m.def("defer_none_cstring", [](char *) { return false; }); - m.def("defer_none_cstring", [](py::none) { return true; }); + m.def("defer_none_cstring", [](const py::none &) { return true; }); m.def("defer_none_custom", [](UserType *) { return false; }); - m.def("defer_none_custom", [](py::none) { return true; }); + m.def("defer_none_custom", [](const py::none &) { return true; }); m.def("nodefer_none_void", [](void *) { return true; }); - m.def("nodefer_none_void", [](py::none) { return false; }); + m.def("nodefer_none_void", [](const py::none &) { return false; }); // test_void_caster m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile m.def("cast_nullptr_t", []() { return std::nullptr_t{}; }); + // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works. + // test_bool_caster m.def("bool_passthrough", [](bool arg) { return arg; }); - m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert()); + m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg{}.noconvert()); + + // TODO: This should be disabled and fixed in future Intel compilers +#if !defined(__INTEL_COMPILER) + // Test "bool_passthrough_noconvert" again, but using () instead of {} to construct py::arg + // When compiled with the Intel compiler, this results in segmentation faults when importing + // the module. Tested with icc (ICC) 2021.1 Beta 20200827, this should be tested again when + // a newer version of icc is available. + m.def("bool_passthrough_noconvert2", [](bool arg) { return arg; }, py::arg().noconvert()); +#endif // test_reference_wrapper m.def("refwrap_builtin", [](std::reference_wrapper p) { return 10 * p.get(); }); m.def("refwrap_usertype", [](std::reference_wrapper p) { return p.get().value(); }); + m.def("refwrap_usertype_const", [](std::reference_wrapper p) { return p.get().value(); }); + + m.def("refwrap_lvalue", []() -> std::reference_wrapper { + static UserType x(1); + return std::ref(x); + }); + m.def("refwrap_lvalue_const", []() -> std::reference_wrapper { + static UserType x(1); + return std::cref(x); + }); + // Not currently supported (std::pair caster has return-by-value cast operator); // triggers static_assert failure. //m.def("refwrap_pair", [](std::reference_wrapper>) { }); @@ -162,7 +270,7 @@ TEST_SUBMODULE(builtin_casters, m) { }, "copy"_a); m.def("refwrap_iiw", [](const IncType &w) { return w.value(); }); - m.def("refwrap_call_iiw", [](IncType &w, py::function f) { + m.def("refwrap_call_iiw", [](IncType &w, const py::function &f) { py::list l; l.append(f(std::ref(w))); l.append(f(std::cref(w))); @@ -189,4 +297,14 @@ TEST_SUBMODULE(builtin_casters, m) { py::object o = py::cast(v); return py::cast(o) == v; }); + + // Tests const/non-const propagation in cast_op. + m.def("takes", [](ConstRefCasted x) { return x.tag; }); + m.def("takes_move", [](ConstRefCasted&& x) { return x.tag; }); + m.def("takes_ptr", [](ConstRefCasted* x) { return x->tag; }); + m.def("takes_ref", [](ConstRefCasted& x) { return x.tag; }); + m.def("takes_ref_wrap", [](std::reference_wrapper x) { return x.get().tag; }); + m.def("takes_const_ptr", [](const ConstRefCasted* x) { return x->tag; }); + m.def("takes_const_ref", [](const ConstRefCasted& x) { return x.tag; }); + m.def("takes_const_ref_wrap", [](std::reference_wrapper x) { return x.get().tag; }); } diff --git a/wrap/pybind11/tests/test_builtin_casters.py b/wrap/pybind11/tests/test_builtin_casters.py index 08d38bc15..b1f1e395a 100644 --- a/wrap/pybind11/tests/test_builtin_casters.py +++ b/wrap/pybind11/tests/test_builtin_casters.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- import pytest -import env # noqa: F401 - +import env +from pybind11_tests import IncType, UserType from pybind11_tests import builtin_casters as m -from pybind11_tests import UserType, IncType def test_simple_string(): @@ -37,79 +36,85 @@ def test_unicode_conversion(): with pytest.raises(UnicodeDecodeError): m.bad_utf8_u8string() - assert m.u8_Z() == 'Z' - assert m.u8_eacute() == u'é' - assert m.u16_ibang() == u'‽' - assert m.u32_mathbfA() == u'𝐀' - assert m.wchar_heart() == u'♥' + assert m.u8_Z() == "Z" + assert m.u8_eacute() == u"é" + assert m.u16_ibang() == u"‽" + assert m.u32_mathbfA() == u"𝐀" + assert m.wchar_heart() == u"♥" if hasattr(m, "has_u8string"): - assert m.u8_char8_Z() == 'Z' + assert m.u8_char8_Z() == "Z" def test_single_char_arguments(): """Tests failures for passing invalid inputs to char-accepting functions""" + def toobig_message(r): - return "Character code point not in range({0:#x})".format(r) + return "Character code point not in range({:#x})".format(r) + toolong_message = "Expected a character, but multi-character string found" - assert m.ord_char(u'a') == 0x61 # simple ASCII - assert m.ord_char_lv(u'b') == 0x62 - assert m.ord_char(u'é') == 0xE9 # requires 2 bytes in utf-8, but can be stuffed in a char + assert m.ord_char(u"a") == 0x61 # simple ASCII + assert m.ord_char_lv(u"b") == 0x62 + assert ( + m.ord_char(u"é") == 0xE9 + ) # requires 2 bytes in utf-8, but can be stuffed in a char with pytest.raises(ValueError) as excinfo: - assert m.ord_char(u'Ā') == 0x100 # requires 2 bytes, doesn't fit in a char + assert m.ord_char(u"Ā") == 0x100 # requires 2 bytes, doesn't fit in a char assert str(excinfo.value) == toobig_message(0x100) with pytest.raises(ValueError) as excinfo: - assert m.ord_char(u'ab') + assert m.ord_char(u"ab") assert str(excinfo.value) == toolong_message - assert m.ord_char16(u'a') == 0x61 - assert m.ord_char16(u'é') == 0xE9 - assert m.ord_char16_lv(u'ê') == 0xEA - assert m.ord_char16(u'Ā') == 0x100 - assert m.ord_char16(u'‽') == 0x203d - assert m.ord_char16(u'♥') == 0x2665 - assert m.ord_char16_lv(u'♡') == 0x2661 + assert m.ord_char16(u"a") == 0x61 + assert m.ord_char16(u"é") == 0xE9 + assert m.ord_char16_lv(u"ê") == 0xEA + assert m.ord_char16(u"Ā") == 0x100 + assert m.ord_char16(u"‽") == 0x203D + assert m.ord_char16(u"♥") == 0x2665 + assert m.ord_char16_lv(u"♡") == 0x2661 with pytest.raises(ValueError) as excinfo: - assert m.ord_char16(u'🎂') == 0x1F382 # requires surrogate pair + assert m.ord_char16(u"🎂") == 0x1F382 # requires surrogate pair assert str(excinfo.value) == toobig_message(0x10000) with pytest.raises(ValueError) as excinfo: - assert m.ord_char16(u'aa') + assert m.ord_char16(u"aa") assert str(excinfo.value) == toolong_message - assert m.ord_char32(u'a') == 0x61 - assert m.ord_char32(u'é') == 0xE9 - assert m.ord_char32(u'Ā') == 0x100 - assert m.ord_char32(u'‽') == 0x203d - assert m.ord_char32(u'♥') == 0x2665 - assert m.ord_char32(u'🎂') == 0x1F382 + assert m.ord_char32(u"a") == 0x61 + assert m.ord_char32(u"é") == 0xE9 + assert m.ord_char32(u"Ā") == 0x100 + assert m.ord_char32(u"‽") == 0x203D + assert m.ord_char32(u"♥") == 0x2665 + assert m.ord_char32(u"🎂") == 0x1F382 with pytest.raises(ValueError) as excinfo: - assert m.ord_char32(u'aa') + assert m.ord_char32(u"aa") assert str(excinfo.value) == toolong_message - assert m.ord_wchar(u'a') == 0x61 - assert m.ord_wchar(u'é') == 0xE9 - assert m.ord_wchar(u'Ā') == 0x100 - assert m.ord_wchar(u'‽') == 0x203d - assert m.ord_wchar(u'♥') == 0x2665 + assert m.ord_wchar(u"a") == 0x61 + assert m.ord_wchar(u"é") == 0xE9 + assert m.ord_wchar(u"Ā") == 0x100 + assert m.ord_wchar(u"‽") == 0x203D + assert m.ord_wchar(u"♥") == 0x2665 if m.wchar_size == 2: with pytest.raises(ValueError) as excinfo: - assert m.ord_wchar(u'🎂') == 0x1F382 # requires surrogate pair + assert m.ord_wchar(u"🎂") == 0x1F382 # requires surrogate pair assert str(excinfo.value) == toobig_message(0x10000) else: - assert m.ord_wchar(u'🎂') == 0x1F382 + assert m.ord_wchar(u"🎂") == 0x1F382 with pytest.raises(ValueError) as excinfo: - assert m.ord_wchar(u'aa') + assert m.ord_wchar(u"aa") assert str(excinfo.value) == toolong_message if hasattr(m, "has_u8string"): - assert m.ord_char8(u'a') == 0x61 # simple ASCII - assert m.ord_char8_lv(u'b') == 0x62 - assert m.ord_char8(u'é') == 0xE9 # requires 2 bytes in utf-8, but can be stuffed in a char + assert m.ord_char8(u"a") == 0x61 # simple ASCII + assert m.ord_char8_lv(u"b") == 0x62 + assert ( + m.ord_char8(u"é") == 0xE9 + ) # requires 2 bytes in utf-8, but can be stuffed in a char with pytest.raises(ValueError) as excinfo: - assert m.ord_char8(u'Ā') == 0x100 # requires 2 bytes, doesn't fit in a char + assert m.ord_char8(u"Ā") == 0x100 # requires 2 bytes, doesn't fit in a char assert str(excinfo.value) == toobig_message(0x100) with pytest.raises(ValueError) as excinfo: - assert m.ord_char8(u'ab') + assert m.ord_char8(u"ab") assert str(excinfo.value) == toolong_message @@ -129,19 +134,19 @@ def test_bytes_to_string(): assert m.strlen(to_bytes("a\x00b")) == 1 # C-string limitation # passing in a utf8 encoded string should work - assert m.string_length(u'💩'.encode("utf8")) == 4 + assert m.string_length(u"💩".encode("utf8")) == 4 @pytest.mark.skipif(not hasattr(m, "has_string_view"), reason="no ") def test_string_view(capture): """Tests support for C++17 string_view arguments and return values""" assert m.string_view_chars("Hi") == [72, 105] - assert m.string_view_chars("Hi 🎂") == [72, 105, 32, 0xf0, 0x9f, 0x8e, 0x82] - assert m.string_view16_chars(u"Hi 🎂") == [72, 105, 32, 0xd83c, 0xdf82] + assert m.string_view_chars("Hi 🎂") == [72, 105, 32, 0xF0, 0x9F, 0x8E, 0x82] + assert m.string_view16_chars(u"Hi 🎂") == [72, 105, 32, 0xD83C, 0xDF82] assert m.string_view32_chars(u"Hi 🎂") == [72, 105, 32, 127874] if hasattr(m, "has_u8string"): assert m.string_view8_chars("Hi") == [72, 105] - assert m.string_view8_chars(u"Hi 🎂") == [72, 105, 32, 0xf0, 0x9f, 0x8e, 0x82] + assert m.string_view8_chars(u"Hi 🎂") == [72, 105, 32, 0xF0, 0x9F, 0x8E, 0x82] assert m.string_view_return() == u"utf8 secret 🎂" assert m.string_view16_return() == u"utf16 secret 🎂" @@ -154,40 +159,63 @@ def test_string_view(capture): m.string_view_print("utf8 🎂") m.string_view16_print(u"utf16 🎂") m.string_view32_print(u"utf32 🎂") - assert capture == u""" + assert ( + capture + == u""" Hi 2 utf8 🎂 9 utf16 🎂 8 utf32 🎂 7 """ + ) if hasattr(m, "has_u8string"): with capture: m.string_view8_print("Hi") m.string_view8_print(u"utf8 🎂") - assert capture == u""" + assert ( + capture + == u""" Hi 2 utf8 🎂 9 """ + ) with capture: m.string_view_print("Hi, ascii") m.string_view_print("Hi, utf8 🎂") m.string_view16_print(u"Hi, utf16 🎂") m.string_view32_print(u"Hi, utf32 🎂") - assert capture == u""" + assert ( + capture + == u""" Hi, ascii 9 Hi, utf8 🎂 13 Hi, utf16 🎂 12 Hi, utf32 🎂 11 """ + ) if hasattr(m, "has_u8string"): with capture: m.string_view8_print("Hi, ascii") m.string_view8_print(u"Hi, utf8 🎂") - assert capture == u""" + assert ( + capture + == u""" Hi, ascii 9 Hi, utf8 🎂 13 """ + ) + + assert m.string_view_bytes() == b"abc \x80\x80 def" + assert m.string_view_str() == u"abc ‽ def" + assert m.string_view_from_bytes(u"abc ‽ def".encode("utf-8")) == u"abc ‽ def" + if hasattr(m, "has_u8string"): + assert m.string_view8_str() == u"abc ‽ def" + if not env.PY2: + assert m.string_view_memoryview() == "Have some 🎂".encode() + + assert m.bytes_from_type_with_both_operator_string_and_string_view() == b"success" + assert m.str_from_type_with_both_operator_string_and_string_view() == "success" def test_integer_casting(): @@ -199,8 +227,14 @@ def test_integer_casting(): if env.PY2: assert m.i32_str(long(-1)) == "-1" # noqa: F821 undefined name 'long' assert m.i64_str(long(-1)) == "-1" # noqa: F821 undefined name 'long' - assert m.i64_str(long(-999999999999)) == "-999999999999" # noqa: F821 undefined name - assert m.u64_str(long(999999999999)) == "999999999999" # noqa: F821 undefined name 'long' + assert ( + m.i64_str(long(-999999999999)) # noqa: F821 undefined name 'long' + == "-999999999999" + ) + assert ( + m.u64_str(long(999999999999)) # noqa: F821 undefined name 'long' + == "999999999999" + ) else: assert m.i64_str(-999999999999) == "-999999999999" assert m.u64_str(999999999999) == "999999999999" @@ -227,6 +261,101 @@ def test_integer_casting(): assert "incompatible function arguments" in str(excinfo.value) +def test_int_convert(): + class Int(object): + def __int__(self): + return 42 + + class NotInt(object): + pass + + class Float(object): + def __float__(self): + return 41.99999 + + class Index(object): + def __index__(self): + return 42 + + class IntAndIndex(object): + def __int__(self): + return 42 + + def __index__(self): + return 0 + + class RaisingTypeErrorOnIndex(object): + def __index__(self): + raise TypeError + + def __int__(self): + return 42 + + class RaisingValueErrorOnIndex(object): + def __index__(self): + raise ValueError + + def __int__(self): + return 42 + + convert, noconvert = m.int_passthrough, m.int_passthrough_noconvert + + def requires_conversion(v): + pytest.raises(TypeError, noconvert, v) + + def cant_convert(v): + pytest.raises(TypeError, convert, v) + + assert convert(7) == 7 + assert noconvert(7) == 7 + cant_convert(3.14159) + # TODO: Avoid DeprecationWarning in `PyLong_AsLong` (and similar) + # TODO: PyPy 3.8 does not behave like CPython 3.8 here yet (7.3.7) + if (3, 8) <= env.PY < (3, 10) and env.CPYTHON: + with env.deprecated_call(): + assert convert(Int()) == 42 + else: + assert convert(Int()) == 42 + requires_conversion(Int()) + cant_convert(NotInt()) + cant_convert(Float()) + + # Before Python 3.8, `PyLong_AsLong` does not pick up on `obj.__index__`, + # but pybind11 "backports" this behavior. + assert convert(Index()) == 42 + assert noconvert(Index()) == 42 + assert convert(IntAndIndex()) == 0 # Fishy; `int(DoubleThought)` == 42 + assert noconvert(IntAndIndex()) == 0 + assert convert(RaisingTypeErrorOnIndex()) == 42 + requires_conversion(RaisingTypeErrorOnIndex()) + assert convert(RaisingValueErrorOnIndex()) == 42 + requires_conversion(RaisingValueErrorOnIndex()) + + +def test_numpy_int_convert(): + np = pytest.importorskip("numpy") + + convert, noconvert = m.int_passthrough, m.int_passthrough_noconvert + + def require_implicit(v): + pytest.raises(TypeError, noconvert, v) + + # `np.intc` is an alias that corresponds to a C++ `int` + assert convert(np.intc(42)) == 42 + assert noconvert(np.intc(42)) == 42 + + # The implicit conversion from np.float32 is undesirable but currently accepted. + # TODO: Avoid DeprecationWarning in `PyLong_AsLong` (and similar) + # TODO: PyPy 3.8 does not behave like CPython 3.8 here yet (7.3.7) + # https://github.com/pybind/pybind11/issues/3408 + if (3, 8) <= env.PY < (3, 10) and env.CPYTHON: + with env.deprecated_call(): + assert convert(np.float32(3.14159)) == 3 + else: + assert convert(np.float32(3.14159)) == 3 + require_implicit(np.float32(3.14159)) + + def test_tuple(doc): """std::pair <-> tuple & std::tuple <-> tuple""" assert m.pair_passthrough((True, "test")) == ("test", True) @@ -236,16 +365,22 @@ def test_tuple(doc): assert m.tuple_passthrough([True, "test", 5]) == (5, "test", True) assert m.empty_tuple() == () - assert doc(m.pair_passthrough) == """ + assert ( + doc(m.pair_passthrough) + == """ pair_passthrough(arg0: Tuple[bool, str]) -> Tuple[str, bool] Return a pair in reversed order """ - assert doc(m.tuple_passthrough) == """ + ) + assert ( + doc(m.tuple_passthrough) + == """ tuple_passthrough(arg0: Tuple[bool, str, int]) -> Tuple[int, str, bool] Return a triple in reversed order """ + ) assert m.rvalue_pair() == ("rvalue", "rvalue") assert m.lvalue_pair() == ("lvalue", "lvalue") @@ -285,6 +420,7 @@ def test_reference_wrapper(): """std::reference_wrapper for builtin and user types""" assert m.refwrap_builtin(42) == 420 assert m.refwrap_usertype(UserType(42)) == 42 + assert m.refwrap_usertype_const(UserType(42)) == 42 with pytest.raises(TypeError) as excinfo: m.refwrap_builtin(None) @@ -294,6 +430,9 @@ def test_reference_wrapper(): m.refwrap_usertype(None) assert "incompatible function arguments" in str(excinfo.value) + assert m.refwrap_lvalue().value == 1 + assert m.refwrap_lvalue_const().value == 1 + a1 = m.refwrap_list(copy=True) a2 = m.refwrap_list(copy=True) assert [x.value for x in a1] == [2, 3] @@ -372,7 +511,7 @@ def test_numpy_bool(): assert convert(np.bool_(False)) is False assert noconvert(np.bool_(True)) is True assert noconvert(np.bool_(False)) is False - cant_convert(np.zeros(2, dtype='int')) + cant_convert(np.zeros(2, dtype="int")) def test_int_long(): @@ -382,7 +521,8 @@ def test_int_long(): long.""" import sys - must_be_long = type(getattr(sys, 'maxint', 1) + 1) + + must_be_long = type(getattr(sys, "maxint", 1) + 1) assert isinstance(m.int_cast(), int) assert isinstance(m.long_cast(), int) assert isinstance(m.longlong_cast(), must_be_long) @@ -390,3 +530,21 @@ def test_int_long(): def test_void_caster_2(): assert m.test_void_caster() + + +def test_const_ref_caster(): + """Verifies that const-ref is propagated through type_caster cast_op. + The returned ConstRefCasted type is a minimal type that is constructed to + reference the casting mode used. + """ + x = False + assert m.takes(x) == 1 + assert m.takes_move(x) == 1 + + assert m.takes_ptr(x) == 3 + assert m.takes_ref(x) == 2 + assert m.takes_ref_wrap(x) == 2 + + assert m.takes_const_ptr(x) == 5 + assert m.takes_const_ref(x) == 4 + assert m.takes_const_ref_wrap(x) == 4 diff --git a/wrap/pybind11/tests/test_call_policies.cpp b/wrap/pybind11/tests/test_call_policies.cpp index 26c83f81b..7cb98d0d8 100644 --- a/wrap/pybind11/tests/test_call_policies.cpp +++ b/wrap/pybind11/tests/test_call_policies.cpp @@ -51,6 +51,7 @@ TEST_SUBMODULE(call_policies, m) { void addChild(Child *) { } Child *returnChild() { return new Child(); } Child *returnNullChild() { return nullptr; } + static Child *staticFunction(Parent*) { return new Child(); } }; py::class_(m, "Parent") .def(py::init<>()) @@ -60,7 +61,12 @@ TEST_SUBMODULE(call_policies, m) { .def("returnChild", &Parent::returnChild) .def("returnChildKeepAlive", &Parent::returnChild, py::keep_alive<1, 0>()) .def("returnNullChildKeepAliveChild", &Parent::returnNullChild, py::keep_alive<1, 0>()) - .def("returnNullChildKeepAliveParent", &Parent::returnNullChild, py::keep_alive<0, 1>()); + .def("returnNullChildKeepAliveParent", &Parent::returnNullChild, py::keep_alive<0, 1>()) + .def_static( + "staticFunction", &Parent::staticFunction, py::keep_alive<1, 0>()); + + m.def("free_function", [](Parent*, Child*) {}, py::keep_alive<1, 2>()); + m.def("invalid_arg_index", []{}, py::keep_alive<0, 1>()); #if !defined(PYPY_VERSION) // test_alive_gc diff --git a/wrap/pybind11/tests/test_call_policies.py b/wrap/pybind11/tests/test_call_policies.py index ec005c132..3599cf81a 100644 --- a/wrap/pybind11/tests/test_call_policies.py +++ b/wrap/pybind11/tests/test_call_policies.py @@ -2,9 +2,8 @@ import pytest import env # noqa: F401 - -from pybind11_tests import call_policies as m from pybind11_tests import ConstructorStats +from pybind11_tests import call_policies as m @pytest.mark.xfail("env.PYPY", reason="sometimes comes out 1 off on PyPy", strict=False) @@ -16,10 +15,13 @@ def test_keep_alive_argument(capture): with capture: p.addChild(m.Child()) assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == """ + assert ( + capture + == """ Allocating child. Releasing child. """ + ) with capture: del p assert ConstructorStats.detail_reg_inst() == n_inst @@ -35,10 +37,26 @@ def test_keep_alive_argument(capture): with capture: del p assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. """ + ) + + p = m.Parent() + c = m.Child() + assert ConstructorStats.detail_reg_inst() == n_inst + 2 + m.free_function(p, c) + del c + assert ConstructorStats.detail_reg_inst() == n_inst + 2 + del p + assert ConstructorStats.detail_reg_inst() == n_inst + + with pytest.raises(RuntimeError) as excinfo: + m.invalid_arg_index() + assert str(excinfo.value) == "Could not activate keep_alive!" def test_keep_alive_return_value(capture): @@ -49,10 +67,13 @@ def test_keep_alive_return_value(capture): with capture: p.returnChild() assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == """ + assert ( + capture + == """ Allocating child. Releasing child. """ + ) with capture: del p assert ConstructorStats.detail_reg_inst() == n_inst @@ -68,10 +89,30 @@ def test_keep_alive_return_value(capture): with capture: del p assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. """ + ) + + p = m.Parent() + assert ConstructorStats.detail_reg_inst() == n_inst + 1 + with capture: + m.Parent.staticFunction(p) + assert ConstructorStats.detail_reg_inst() == n_inst + 2 + assert capture == "Allocating child." + with capture: + del p + assert ConstructorStats.detail_reg_inst() == n_inst + assert ( + capture + == """ + Releasing parent. + Releasing child. + """ + ) # https://foss.heptapod.net/pypy/pypy/-/issues/2447 @@ -82,14 +123,17 @@ def test_alive_gc(capture): p.addChildKeepAlive(m.Child()) assert ConstructorStats.detail_reg_inst() == n_inst + 2 lst = [p] - lst.append(lst) # creates a circular reference + lst.append(lst) # creates a circular reference with capture: del p, lst assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. """ + ) def test_alive_gc_derived(capture): @@ -101,14 +145,17 @@ def test_alive_gc_derived(capture): p.addChildKeepAlive(m.Child()) assert ConstructorStats.detail_reg_inst() == n_inst + 2 lst = [p] - lst.append(lst) # creates a circular reference + lst.append(lst) # creates a circular reference with capture: del p, lst assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. """ + ) def test_alive_gc_multi_derived(capture): @@ -123,15 +170,18 @@ def test_alive_gc_multi_derived(capture): # +3 rather than +2 because Derived corresponds to two registered instances assert ConstructorStats.detail_reg_inst() == n_inst + 3 lst = [p] - lst.append(lst) # creates a circular reference + lst.append(lst) # creates a circular reference with capture: del p, lst assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. Releasing child. """ + ) def test_return_none(capture): @@ -167,17 +217,23 @@ def test_keep_alive_constructor(capture): with capture: p = m.Parent(m.Child()) assert ConstructorStats.detail_reg_inst() == n_inst + 2 - assert capture == """ + assert ( + capture + == """ Allocating child. Allocating parent. """ + ) with capture: del p assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ + assert ( + capture + == """ Releasing parent. Releasing child. """ + ) def test_call_guard(): diff --git a/wrap/pybind11/tests/test_callbacks.cpp b/wrap/pybind11/tests/test_callbacks.cpp index 71b88c44c..58688b6e8 100644 --- a/wrap/pybind11/tests/test_callbacks.cpp +++ b/wrap/pybind11/tests/test_callbacks.cpp @@ -17,8 +17,8 @@ int dummy_function(int i) { return i + 1; } TEST_SUBMODULE(callbacks, m) { // test_callbacks, test_function_signatures - m.def("test_callback1", [](py::object func) { return func(); }); - m.def("test_callback2", [](py::object func) { return func("Hello", 'x', true, 5); }); + m.def("test_callback1", [](const py::object &func) { return func(); }); + m.def("test_callback2", [](const py::object &func) { return func("Hello", 'x', true, 5); }); m.def("test_callback3", [](const std::function &func) { return "func(43) = " + std::to_string(func(43)); }); m.def("test_callback4", []() -> std::function { return [](int i) { return i+1; }; }); @@ -27,51 +27,48 @@ TEST_SUBMODULE(callbacks, m) { }); // test_keyword_args_and_generalized_unpacking - m.def("test_tuple_unpacking", [](py::function f) { + m.def("test_tuple_unpacking", [](const py::function &f) { auto t1 = py::make_tuple(2, 3); auto t2 = py::make_tuple(5, 6); return f("positional", 1, *t1, 4, *t2); }); - m.def("test_dict_unpacking", [](py::function f) { + m.def("test_dict_unpacking", [](const py::function &f) { auto d1 = py::dict("key"_a="value", "a"_a=1); auto d2 = py::dict(); auto d3 = py::dict("b"_a=2); return f("positional", 1, **d1, **d2, **d3); }); - m.def("test_keyword_args", [](py::function f) { - return f("x"_a=10, "y"_a=20); - }); + m.def("test_keyword_args", [](const py::function &f) { return f("x"_a = 10, "y"_a = 20); }); - m.def("test_unpacking_and_keywords1", [](py::function f) { + m.def("test_unpacking_and_keywords1", [](const py::function &f) { auto args = py::make_tuple(2); auto kwargs = py::dict("d"_a=4); return f(1, *args, "c"_a=3, **kwargs); }); - m.def("test_unpacking_and_keywords2", [](py::function f) { + m.def("test_unpacking_and_keywords2", [](const py::function &f) { auto kwargs1 = py::dict("a"_a=1); auto kwargs2 = py::dict("c"_a=3, "d"_a=4); return f("positional", *py::make_tuple(1), 2, *py::make_tuple(3, 4), 5, "key"_a="value", **kwargs1, "b"_a=2, **kwargs2, "e"_a=5); }); - m.def("test_unpacking_error1", [](py::function f) { + m.def("test_unpacking_error1", [](const py::function &f) { auto kwargs = py::dict("x"_a=3); return f("x"_a=1, "y"_a=2, **kwargs); // duplicate ** after keyword }); - m.def("test_unpacking_error2", [](py::function f) { + m.def("test_unpacking_error2", [](const py::function &f) { auto kwargs = py::dict("x"_a=3); return f(**kwargs, "x"_a=1); // duplicate keyword after ** }); - m.def("test_arg_conversion_error1", [](py::function f) { - f(234, UnregisteredType(), "kw"_a=567); - }); + m.def("test_arg_conversion_error1", + [](const py::function &f) { f(234, UnregisteredType(), "kw"_a = 567); }); - m.def("test_arg_conversion_error2", [](py::function f) { + m.def("test_arg_conversion_error2", [](const py::function &f) { f(234, "expected_name"_a=UnregisteredType(), "kw"_a=567); }); @@ -80,23 +77,64 @@ TEST_SUBMODULE(callbacks, m) { Payload() { print_default_created(this); } ~Payload() { print_destroyed(this); } Payload(const Payload &) { print_copy_created(this); } - Payload(Payload &&) { print_move_created(this); } + Payload(Payload &&) noexcept { print_move_created(this); } }; // Export the payload constructor statistics for testing purposes: m.def("payload_cstats", &ConstructorStats::get); - /* Test cleanup of lambda closure */ - m.def("test_cleanup", []() -> std::function { + m.def("test_lambda_closure_cleanup", []() -> std::function { Payload p; + // In this situation, `Func` in the implementation of + // `cpp_function::initialize` is NOT trivially destructible. return [p]() { /* p should be cleaned up when the returned function is garbage collected */ (void) p; }; }); + class CppCallable { + public: + CppCallable() { track_default_created(this); } + ~CppCallable() { track_destroyed(this); } + CppCallable(const CppCallable &) { track_copy_created(this); } + CppCallable(CppCallable &&) noexcept { track_move_created(this); } + void operator()() {} + }; + + m.def("test_cpp_callable_cleanup", []() { + // Related issue: https://github.com/pybind/pybind11/issues/3228 + // Related PR: https://github.com/pybind/pybind11/pull/3229 + py::list alive_counts; + ConstructorStats &stat = ConstructorStats::get(); + alive_counts.append(stat.alive()); + { + CppCallable cpp_callable; + alive_counts.append(stat.alive()); + { + // In this situation, `Func` in the implementation of + // `cpp_function::initialize` IS trivially destructible, + // only `capture` is not. + py::cpp_function py_func(cpp_callable); + py::detail::silence_unused_warnings(py_func); + alive_counts.append(stat.alive()); + } + alive_counts.append(stat.alive()); + { + py::cpp_function py_func(std::move(cpp_callable)); + py::detail::silence_unused_warnings(py_func); + alive_counts.append(stat.alive()); + } + alive_counts.append(stat.alive()); + } + alive_counts.append(stat.alive()); + return alive_counts; + }); + // test_cpp_function_roundtrip /* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */ m.def("dummy_function", &dummy_function); + m.def("dummy_function_overloaded", [](int i, int j) { return i + j; }); + m.def("dummy_function_overloaded", &dummy_function); m.def("dummy_function2", [](int i, int j) { return i + j; }); m.def("roundtrip", [](std::function f, bool expect_none = false) { if (expect_none && f) @@ -109,16 +147,25 @@ TEST_SUBMODULE(callbacks, m) { if (!result) { auto r = f(1); return "can't convert to function pointer: eval(1) = " + std::to_string(r); - } else if (*result == dummy_function) { + } + if (*result == dummy_function) { auto r = (*result)(1); return "matches dummy_function: eval(1) = " + std::to_string(r); - } else { - return "argument does NOT match dummy_function. This should never happen!"; } + return "argument does NOT match dummy_function. This should never happen!"; + }); - class AbstractBase { public: virtual unsigned int func() = 0; }; - m.def("func_accepting_func_accepting_base", [](std::function) { }); + class AbstractBase { + public: + // [workaround(intel)] = default does not work here + // Defaulting this destructor results in linking errors with the Intel compiler + // (in Debug builds only, tested with icpc (ICC) 2021.1 Beta 20200827) + virtual ~AbstractBase() {} // NOLINT(modernize-use-equals-default) + virtual unsigned int func() = 0; + }; + m.def("func_accepting_func_accepting_base", + [](const std::function &) {}); struct MovableObject { bool valid = true; @@ -126,8 +173,8 @@ TEST_SUBMODULE(callbacks, m) { MovableObject() = default; MovableObject(const MovableObject &) = default; MovableObject &operator=(const MovableObject &) = default; - MovableObject(MovableObject &&o) : valid(o.valid) { o.valid = false; } - MovableObject &operator=(MovableObject &&o) { + MovableObject(MovableObject &&o) noexcept : valid(o.valid) { o.valid = false; } + MovableObject &operator=(MovableObject &&o) noexcept { valid = o.valid; o.valid = false; return *this; @@ -136,7 +183,7 @@ TEST_SUBMODULE(callbacks, m) { py::class_(m, "MovableObject"); // test_movable_object - m.def("callback_with_movable", [](std::function f) { + m.def("callback_with_movable", [](const std::function &f) { auto x = MovableObject(); f(x); // lvalue reference shouldn't move out object return x.valid; // must still return `true` @@ -148,9 +195,15 @@ TEST_SUBMODULE(callbacks, m) { .def(py::init<>()) .def("triple", [](CppBoundMethodTest &, int val) { return 3 * val; }); + // This checks that builtin functions can be passed as callbacks + // rather than throwing RuntimeError due to trying to extract as capsule + m.def("test_sum_builtin", [](const std::function &sum_builtin, const py::iterable &i) { + return sum_builtin(i); + }); + // test async Python callbacks using callback_f = std::function; - m.def("test_async_callback", [](callback_f f, py::list work) { + m.def("test_async_callback", [](const callback_f &f, const py::list &work) { // make detached thread that calls `f` with piece of work after a little delay auto start_f = [f](int j) { auto invoke_f = [f, j] { @@ -165,4 +218,10 @@ TEST_SUBMODULE(callbacks, m) { for (auto i : work) start_f(py::cast(i)); }); + + m.def("callback_num_times", [](const py::function &f, std::size_t num) { + for (std::size_t i = 0; i < num; i++) { + f(); + } + }); } diff --git a/wrap/pybind11/tests/test_callbacks.py b/wrap/pybind11/tests/test_callbacks.py index d5d0e045d..f41ad86e7 100644 --- a/wrap/pybind11/tests/test_callbacks.py +++ b/wrap/pybind11/tests/test_callbacks.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- -import pytest -from pybind11_tests import callbacks as m +import time from threading import Thread +import pytest + +import env # noqa: F401 +from pybind11_tests import callbacks as m + def test_callbacks(): from functools import partial @@ -42,17 +46,19 @@ def test_bound_method_callback(): def test_keyword_args_and_generalized_unpacking(): - def f(*args, **kwargs): return args, kwargs assert m.test_tuple_unpacking(f) == (("positional", 1, 2, 3, 4, 5, 6), {}) - assert m.test_dict_unpacking(f) == (("positional", 1), {"key": "value", "a": 1, "b": 2}) + assert m.test_dict_unpacking(f) == ( + ("positional", 1), + {"key": "value", "a": 1, "b": 2}, + ) assert m.test_keyword_args(f) == ((), {"x": 10, "y": 20}) assert m.test_unpacking_and_keywords1(f) == ((1, 2), {"c": 3, "d": 4}) assert m.test_unpacking_and_keywords2(f) == ( ("positional", 1, 2, 3, 4, 5), - {"key": "value", "a": 1, "b": 2, "c": 3, "d": 4, "e": 5} + {"key": "value", "a": 1, "b": 2, "c": 3, "d": 4, "e": 5}, ) with pytest.raises(TypeError) as excinfo: @@ -73,22 +79,37 @@ def test_keyword_args_and_generalized_unpacking(): def test_lambda_closure_cleanup(): - m.test_cleanup() + m.test_lambda_closure_cleanup() cstats = m.payload_cstats() assert cstats.alive() == 0 assert cstats.copy_constructions == 1 assert cstats.move_constructions >= 1 +def test_cpp_callable_cleanup(): + alive_counts = m.test_cpp_callable_cleanup() + assert alive_counts == [0, 1, 2, 1, 2, 1, 0] + + def test_cpp_function_roundtrip(): """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer""" - assert m.test_dummy_function(m.dummy_function) == "matches dummy_function: eval(1) = 2" - assert (m.test_dummy_function(m.roundtrip(m.dummy_function)) == - "matches dummy_function: eval(1) = 2") + assert ( + m.test_dummy_function(m.dummy_function) == "matches dummy_function: eval(1) = 2" + ) + assert ( + m.test_dummy_function(m.roundtrip(m.dummy_function)) + == "matches dummy_function: eval(1) = 2" + ) + assert ( + m.test_dummy_function(m.dummy_function_overloaded) + == "matches dummy_function: eval(1) = 2" + ) assert m.roundtrip(None, expect_none=True) is None - assert (m.test_dummy_function(lambda x: x + 2) == - "can't convert to function pointer: eval(1) = 3") + assert ( + m.test_dummy_function(lambda x: x + 2) + == "can't convert to function pointer: eval(1) = 3" + ) with pytest.raises(TypeError) as excinfo: m.test_dummy_function(m.dummy_function2) @@ -96,8 +117,10 @@ def test_cpp_function_roundtrip(): with pytest.raises(TypeError) as excinfo: m.test_dummy_function(lambda x, y: x + y) - assert any(s in str(excinfo.value) for s in ("missing 1 required positional argument", - "takes exactly 2 arguments")) + assert any( + s in str(excinfo.value) + for s in ("missing 1 required positional argument", "takes exactly 2 arguments") + ) def test_function_signatures(doc): @@ -109,6 +132,16 @@ def test_movable_object(): assert m.callback_with_movable(lambda _: None) is True +@pytest.mark.skipif( + "env.PYPY", + reason="PyPy segfaults on here. See discussion on #1413.", +) +def test_python_builtins(): + """Test if python builtins like sum() can be used as callbacks""" + assert m.test_sum_builtin(sum, [1, 2, 3]) == 6 + assert m.test_sum_builtin(sum, []) == 0 + + def test_async_callbacks(): # serves as state for async callback class Item: @@ -127,11 +160,43 @@ def test_async_callbacks(): m.test_async_callback(gen_f(), work) # wait until work is done from time import sleep + sleep(0.5) - assert sum(res) == sum([x + 3 for x in work]) + assert sum(res) == sum(x + 3 for x in work) def test_async_async_callbacks(): t = Thread(target=test_async_callbacks) t.start() t.join() + + +def test_callback_num_times(): + # Super-simple micro-benchmarking related to PR #2919. + # Example runtimes (Intel Xeon 2.2GHz, fully optimized): + # num_millions 1, repeats 2: 0.1 secs + # num_millions 20, repeats 10: 11.5 secs + one_million = 1000000 + num_millions = 1 # Try 20 for actual micro-benchmarking. + repeats = 2 # Try 10. + rates = [] + for rep in range(repeats): + t0 = time.time() + m.callback_num_times(lambda: None, num_millions * one_million) + td = time.time() - t0 + rate = num_millions / td if td else 0 + rates.append(rate) + if not rep: + print() + print( + "callback_num_times: {:d} million / {:.3f} seconds = {:.3f} million / second".format( + num_millions, td, rate + ) + ) + if len(rates) > 1: + print("Min Mean Max") + print( + "{:6.3f} {:6.3f} {:6.3f}".format( + min(rates), sum(rates) / len(rates), max(rates) + ) + ) diff --git a/wrap/pybind11/tests/test_chrono.py b/wrap/pybind11/tests/test_chrono.py index ae24b7dda..fdd73d690 100644 --- a/wrap/pybind11/tests/test_chrono.py +++ b/wrap/pybind11/tests/test_chrono.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- -from pybind11_tests import chrono as m import datetime + import pytest import env # noqa: F401 +from pybind11_tests import chrono as m def test_chrono_system_clock(): @@ -39,9 +40,7 @@ def test_chrono_system_clock_roundtrip(): # They should be identical (no information lost on roundtrip) diff = abs(date1 - date2) - assert diff.days == 0 - assert diff.seconds == 0 - assert diff.microseconds == 0 + assert diff == datetime.timedelta(0) def test_chrono_system_clock_roundtrip_date(): @@ -64,9 +63,7 @@ def test_chrono_system_clock_roundtrip_date(): assert diff.microseconds == 0 # Year, Month & Day should be the same after the round trip - assert date1.year == date2.year - assert date1.month == date2.month - assert date1.day == date2.day + assert date1 == date2 # There should be no time information assert time2.hour == 0 @@ -80,22 +77,28 @@ SKIP_TZ_ENV_ON_WIN = pytest.mark.skipif( ) -@pytest.mark.parametrize("time1", [ - datetime.datetime.today().time(), - datetime.time(0, 0, 0), - datetime.time(0, 0, 0, 1), - datetime.time(0, 28, 45, 109827), - datetime.time(0, 59, 59, 999999), - datetime.time(1, 0, 0), - datetime.time(5, 59, 59, 0), - datetime.time(5, 59, 59, 1), -]) -@pytest.mark.parametrize("tz", [ - None, - pytest.param("Europe/Brussels", marks=SKIP_TZ_ENV_ON_WIN), - pytest.param("Asia/Pyongyang", marks=SKIP_TZ_ENV_ON_WIN), - pytest.param("America/New_York", marks=SKIP_TZ_ENV_ON_WIN), -]) +@pytest.mark.parametrize( + "time1", + [ + datetime.datetime.today().time(), + datetime.time(0, 0, 0), + datetime.time(0, 0, 0, 1), + datetime.time(0, 28, 45, 109827), + datetime.time(0, 59, 59, 999999), + datetime.time(1, 0, 0), + datetime.time(5, 59, 59, 0), + datetime.time(5, 59, 59, 1), + ], +) +@pytest.mark.parametrize( + "tz", + [ + None, + pytest.param("Europe/Brussels", marks=SKIP_TZ_ENV_ON_WIN), + pytest.param("Asia/Pyongyang", marks=SKIP_TZ_ENV_ON_WIN), + pytest.param("America/New_York", marks=SKIP_TZ_ENV_ON_WIN), + ], +) def test_chrono_system_clock_roundtrip_time(time1, tz, monkeypatch): if tz is not None: monkeypatch.setenv("TZ", "/usr/share/zoneinfo/{}".format(tz)) @@ -111,10 +114,7 @@ def test_chrono_system_clock_roundtrip_time(time1, tz, monkeypatch): assert isinstance(time2, datetime.time) # Hour, Minute, Second & Microsecond should be the same after the round trip - assert time1.hour == time2.hour - assert time1.minute == time2.minute - assert time1.second == time2.second - assert time1.microsecond == time2.microsecond + assert time1 == time2 # There should be no date information (i.e. date = python base date) assert date2.year == 1970 @@ -134,9 +134,13 @@ def test_chrono_duration_roundtrip(): cpp_diff = m.test_chrono3(diff) - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds + assert cpp_diff == diff + + # Negative timedelta roundtrip + diff = datetime.timedelta(microseconds=-1) + cpp_diff = m.test_chrono3(diff) + + assert cpp_diff == diff def test_chrono_duration_subtraction_equivalence(): @@ -147,9 +151,7 @@ def test_chrono_duration_subtraction_equivalence(): diff = date2 - date1 cpp_diff = m.test_chrono4(date2, date1) - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds + assert cpp_diff == diff def test_chrono_duration_subtraction_equivalence_date(): @@ -160,9 +162,7 @@ def test_chrono_duration_subtraction_equivalence_date(): diff = date2 - date1 cpp_diff = m.test_chrono4(date2, date1) - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds + assert cpp_diff == diff def test_chrono_steady_clock(): @@ -177,9 +177,7 @@ def test_chrono_steady_clock_roundtrip(): assert isinstance(time2, datetime.timedelta) # They should be identical (no information lost on roundtrip) - assert time1.days == time2.days - assert time1.seconds == time2.seconds - assert time1.microseconds == time2.microseconds + assert time1 == time2 def test_floating_point_duration(): @@ -199,7 +197,7 @@ def test_floating_point_duration(): def test_nano_timepoint(): time = datetime.datetime.now() time1 = m.test_nano_timepoint(time, datetime.timedelta(seconds=60)) - assert(time1 == time + datetime.timedelta(seconds=60)) + assert time1 == time + datetime.timedelta(seconds=60) def test_chrono_different_resolutions(): diff --git a/wrap/pybind11/tests/test_class.cpp b/wrap/pybind11/tests/test_class.cpp index b0e3d3a4b..52a41a3bc 100644 --- a/wrap/pybind11/tests/test_class.cpp +++ b/wrap/pybind11/tests/test_class.cpp @@ -7,18 +7,27 @@ BSD-style license that can be found in the LICENSE file. */ +#if defined(__INTEL_COMPILER) && __cplusplus >= 201703L +// Intel compiler requires a separate header file to support aligned new operators +// and does not set the __cpp_aligned_new feature macro. +// This header needs to be included before pybind11. +#include +#endif + #include "pybind11_tests.h" #include "constructor_stats.h" #include "local_bindings.h" #include +#include + #if defined(_MSC_VER) # pragma warning(disable: 4324) // warning C4324: structure was padded due to alignment specifier #endif // test_brace_initialization struct NoBraceInitialization { - NoBraceInitialization(std::vector v) : vec{std::move(v)} {} + explicit NoBraceInitialization(std::vector v) : vec{std::move(v)} {} template NoBraceInitialization(std::initializer_list l) : vec(l) {} @@ -38,10 +47,26 @@ TEST_SUBMODULE(class_, m) { } ~NoConstructor() { print_destroyed(this); } }; + struct NoConstructorNew { + NoConstructorNew() = default; + NoConstructorNew(const NoConstructorNew &) = default; + NoConstructorNew(NoConstructorNew &&) = default; + static NoConstructorNew *new_instance() { + auto *ptr = new NoConstructorNew(); + print_created(ptr, "via new_instance"); + return ptr; + } + ~NoConstructorNew() { print_destroyed(this); } + }; py::class_(m, "NoConstructor") .def_static("new_instance", &NoConstructor::new_instance, "Return an instance"); + py::class_(m, "NoConstructorNew") + .def(py::init([](const NoConstructorNew &self) { return self; })) // Need a NOOP __init__ + .def_static("__new__", + [](const py::object &) { return NoConstructorNew::new_instance(); }); + // test_inheritance class Pet { public: @@ -56,18 +81,18 @@ TEST_SUBMODULE(class_, m) { class Dog : public Pet { public: - Dog(const std::string &name) : Pet(name, "dog") {} + explicit Dog(const std::string &name) : Pet(name, "dog") {} std::string bark() const { return "Woof!"; } }; class Rabbit : public Pet { public: - Rabbit(const std::string &name) : Pet(name, "parrot") {} + explicit Rabbit(const std::string &name) : Pet(name, "parrot") {} }; class Hamster : public Pet { public: - Hamster(const std::string &name) : Pet(name, "rodent") {} + explicit Hamster(const std::string &name) : Pet(name, "rodent") {} }; class Chimera : public Pet { @@ -122,7 +147,7 @@ TEST_SUBMODULE(class_, m) { m.def("return_none", []() -> BaseClass* { return nullptr; }); // test_isinstance - m.def("check_instances", [](py::list l) { + m.def("check_instances", [](const py::list &l) { return py::make_tuple( py::isinstance(l[0]), py::isinstance(l[1]), @@ -144,21 +169,16 @@ TEST_SUBMODULE(class_, m) { // return py::type::of(); if (category == 1) return py::type::of(); - else - return py::type::of(); + return py::type::of(); }); - m.def("get_type_of", [](py::object ob) { - return py::type::of(ob); + m.def("get_type_of", [](py::object ob) { return py::type::of(std::move(ob)); }); + + m.def("get_type_classic", [](py::handle h) { + return h.get_type(); }); - m.def("as_type", [](py::object ob) { - auto tp = py::type(ob); - if (py::isinstance(ob)) - return tp; - else - throw std::runtime_error("Invalid type"); - }); + m.def("as_type", [](const py::object &ob) { return py::type(ob); }); // test_mismatched_holder struct MismatchBase1 { }; @@ -168,12 +188,12 @@ TEST_SUBMODULE(class_, m) { struct MismatchDerived2 : MismatchBase2 { }; m.def("mismatched_holder_1", []() { - auto mod = py::module::import("__main__"); + auto mod = py::module_::import("__main__"); py::class_>(mod, "MismatchBase1"); py::class_(mod, "MismatchDerived1"); }); m.def("mismatched_holder_2", []() { - auto mod = py::module::import("__main__"); + auto mod = py::module_::import("__main__"); py::class_(mod, "MismatchBase2"); py::class_, MismatchBase2>(mod, "MismatchDerived2"); @@ -204,7 +224,7 @@ TEST_SUBMODULE(class_, m) { struct ConvertibleFromUserType { int i; - ConvertibleFromUserType(UserType u) : i(u.value()) { } + explicit ConvertibleFromUserType(UserType u) : i(u.value()) {} }; py::class_(m, "AcceptsUserType") @@ -212,7 +232,7 @@ TEST_SUBMODULE(class_, m) { py::implicitly_convertible(); m.def("implicitly_convert_argument", [](const ConvertibleFromUserType &r) { return r.i; }); - m.def("implicitly_convert_variable", [](py::object o) { + m.def("implicitly_convert_variable", [](const py::object &o) { // `o` is `UserType` and `r` is a reference to a temporary created by implicit // conversion. This is valid when called inside a bound function because the temp // object is attached to the same life support system as the arguments. @@ -231,7 +251,8 @@ TEST_SUBMODULE(class_, m) { }; auto def = new PyMethodDef{"f", f, METH_VARARGS, nullptr}; - return py::reinterpret_steal(PyCFunction_NewEx(def, nullptr, m.ptr())); + py::capsule def_capsule(def, [](void *ptr) { delete reinterpret_cast(ptr); }); + return py::reinterpret_steal(PyCFunction_NewEx(def, def_capsule.ptr(), m.ptr())); }()); // test_operator_new_delete @@ -258,7 +279,7 @@ TEST_SUBMODULE(class_, m) { }; struct PyAliasedHasOpNewDelSize : AliasedHasOpNewDelSize { PyAliasedHasOpNewDelSize() = default; - PyAliasedHasOpNewDelSize(int) { } + explicit PyAliasedHasOpNewDelSize(int) {} std::uint64_t j; }; struct HasOpNewDelBoth { @@ -322,6 +343,10 @@ TEST_SUBMODULE(class_, m) { class PublicistB : public ProtectedB { public: + // [workaround(intel)] = default does not work here + // Removing or defaulting this destructor results in linking errors with the Intel compiler + // (in Debug builds only, tested with icpc (ICC) 2021.1 Beta 20200827) + ~PublicistB() override {}; // NOLINT(modernize-use-equals-default) using ProtectedB::foo; }; @@ -385,7 +410,7 @@ TEST_SUBMODULE(class_, m) { struct StringWrapper { std::string str; }; m.def("test_error_after_conversions", [](int) {}); m.def("test_error_after_conversions", - [](StringWrapper) -> NotRegistered { return {}; }); + [](const StringWrapper &) -> NotRegistered { return {}; }); py::class_(m, "StringWrapper").def(py::init()); py::implicitly_convertible(); @@ -406,6 +431,7 @@ TEST_SUBMODULE(class_, m) { struct IsNonFinalFinal {}; py::class_(m, "IsNonFinalFinal", py::is_final()); + // test_exception_rvalue_abort struct PyPrintDestructor { PyPrintDestructor() = default; ~PyPrintDestructor() { @@ -416,6 +442,55 @@ TEST_SUBMODULE(class_, m) { py::class_(m, "PyPrintDestructor") .def(py::init<>()) .def("throw_something", &PyPrintDestructor::throw_something); + + // test_multiple_instances_with_same_pointer + struct SamePointer {}; + static SamePointer samePointer; + py::class_>(m, "SamePointer") + .def(py::init([]() { return &samePointer; })); + + struct Empty {}; + py::class_(m, "Empty") + .def(py::init<>()); + + // test_base_and_derived_nested_scope + struct BaseWithNested { + struct Nested {}; + }; + + struct DerivedWithNested : BaseWithNested { + struct Nested {}; + }; + + py::class_ baseWithNested_class(m, "BaseWithNested"); + py::class_ derivedWithNested_class(m, "DerivedWithNested"); + py::class_(baseWithNested_class, "Nested") + .def_static("get_name", []() { return "BaseWithNested::Nested"; }); + py::class_(derivedWithNested_class, "Nested") + .def_static("get_name", []() { return "DerivedWithNested::Nested"; }); + + // test_register_duplicate_class + struct Duplicate {}; + struct OtherDuplicate {}; + struct DuplicateNested {}; + struct OtherDuplicateNested {}; + + m.def("register_duplicate_class_name", [](const py::module_ &m) { + py::class_(m, "Duplicate"); + py::class_(m, "Duplicate"); + }); + m.def("register_duplicate_class_type", [](const py::module_ &m) { + py::class_(m, "OtherDuplicate"); + py::class_(m, "YetAnotherDuplicate"); + }); + m.def("register_duplicate_nested_class_name", [](const py::object >) { + py::class_(gt, "DuplicateNested"); + py::class_(gt, "DuplicateNested"); + }); + m.def("register_duplicate_nested_class_type", [](const py::object >) { + py::class_(gt, "OtherDuplicateNested"); + py::class_(gt, "YetAnotherDuplicateNested"); + }); } template class BreaksBase { public: @@ -433,15 +508,15 @@ using DoesntBreak5 = py::class_>; using DoesntBreak6 = py::class_, std::shared_ptr>, BreaksTramp<6>>; using DoesntBreak7 = py::class_, BreaksTramp<7>, std::shared_ptr>>; using DoesntBreak8 = py::class_, std::shared_ptr>>; -#define CHECK_BASE(N) static_assert(std::is_same>::value, \ +#define CHECK_BASE(N) static_assert(std::is_same>::value, \ "DoesntBreak" #N " has wrong type!") CHECK_BASE(1); CHECK_BASE(2); CHECK_BASE(3); CHECK_BASE(4); CHECK_BASE(5); CHECK_BASE(6); CHECK_BASE(7); CHECK_BASE(8); -#define CHECK_ALIAS(N) static_assert(DoesntBreak##N::has_alias && std::is_same>::value, \ +#define CHECK_ALIAS(N) static_assert(DoesntBreak##N::has_alias && std::is_same>::value, \ "DoesntBreak" #N " has wrong type_alias!") #define CHECK_NOALIAS(N) static_assert(!DoesntBreak##N::has_alias && std::is_void::value, \ "DoesntBreak" #N " has type alias, but shouldn't!") CHECK_ALIAS(1); CHECK_ALIAS(2); CHECK_NOALIAS(3); CHECK_ALIAS(4); CHECK_NOALIAS(5); CHECK_ALIAS(6); CHECK_ALIAS(7); CHECK_NOALIAS(8); -#define CHECK_HOLDER(N, TYPE) static_assert(std::is_same>>::value, \ +#define CHECK_HOLDER(N, TYPE) static_assert(std::is_same>>::value, \ "DoesntBreak" #N " has wrong holder_type!") CHECK_HOLDER(1, unique); CHECK_HOLDER(2, unique); CHECK_HOLDER(3, unique); CHECK_HOLDER(4, unique); CHECK_HOLDER(5, unique); CHECK_HOLDER(6, shared); CHECK_HOLDER(7, shared); CHECK_HOLDER(8, shared); @@ -451,7 +526,7 @@ CHECK_HOLDER(6, shared); CHECK_HOLDER(7, shared); CHECK_HOLDER(8, shared); // failures occurs). // We have to actually look into the type: the typedef alone isn't enough to instantiate the type: -#define CHECK_BROKEN(N) static_assert(std::is_same>::value, \ +#define CHECK_BROKEN(N) static_assert(std::is_same>::value, \ "Breaks1 has wrong type!"); //// Two holder classes: diff --git a/wrap/pybind11/tests/test_class.py b/wrap/pybind11/tests/test_class.py index be21f3709..caafe2068 100644 --- a/wrap/pybind11/tests/test_class.py +++ b/wrap/pybind11/tests/test_class.py @@ -2,9 +2,8 @@ import pytest import env # noqa: F401 - +from pybind11_tests import ConstructorStats, UserType from pybind11_tests import class_ as m -from pybind11_tests import UserType, ConstructorStats def test_repr(): @@ -26,13 +25,23 @@ def test_instance(msg): assert cstats.alive() == 0 +def test_instance_new(msg): + instance = m.NoConstructorNew() # .__new__(m.NoConstructor.__class__) + cstats = ConstructorStats.get(m.NoConstructorNew) + assert cstats.alive() == 1 + del instance + assert cstats.alive() == 0 + + def test_type(): assert m.check_type(1) == m.DerivedClass1 with pytest.raises(RuntimeError) as execinfo: m.check_type(0) - assert 'pybind11::detail::get_type_info: unable to find type info' in str(execinfo.value) - assert 'Invalid' in str(execinfo.value) + assert "pybind11::detail::get_type_info: unable to find type info" in str( + execinfo.value + ) + assert "Invalid" in str(execinfo.value) # Currently not supported # See https://github.com/pybind/pybind11/issues/2486 @@ -45,6 +54,12 @@ def test_type_of_py(): assert m.get_type_of(int) == type +def test_type_of_classic(): + assert m.get_type_classic(1) == int + assert m.get_type_classic(m.DerivedClass1()) == m.DerivedClass1 + assert m.get_type_classic(int) == type + + def test_type_of_py_nodelete(): # If the above test deleted the class, this will segfault assert m.get_type_of(m.DerivedClass1()) == m.DerivedClass1 @@ -53,10 +68,10 @@ def test_type_of_py_nodelete(): def test_as_type_py(): assert m.as_type(int) == int - with pytest.raises(RuntimeError): + with pytest.raises(TypeError): assert m.as_type(1) == int - with pytest.raises(RuntimeError): + with pytest.raises(TypeError): assert m.as_type(m.DerivedClass1()) == m.DerivedClass1 @@ -67,18 +82,24 @@ def test_docstrings(doc): assert UserType.get_value.__name__ == "get_value" assert UserType.get_value.__module__ == "pybind11_tests" - assert doc(UserType.get_value) == """ + assert ( + doc(UserType.get_value) + == """ get_value(self: m.UserType) -> int Get value using a method """ + ) assert doc(UserType.value) == "Get/set value using a property" - assert doc(m.NoConstructor.new_instance) == """ + assert ( + doc(m.NoConstructor.new_instance) + == """ new_instance() -> m.class_.NoConstructor Return an instance """ + ) def test_qualname(doc): @@ -87,51 +108,69 @@ def test_qualname(doc): assert m.NestBase.__qualname__ == "NestBase" assert m.NestBase.Nested.__qualname__ == "NestBase.Nested" - assert doc(m.NestBase.__init__) == """ + assert ( + doc(m.NestBase.__init__) + == """ __init__(self: m.class_.NestBase) -> None """ - assert doc(m.NestBase.g) == """ + ) + assert ( + doc(m.NestBase.g) + == """ g(self: m.class_.NestBase, arg0: m.class_.NestBase.Nested) -> None """ - assert doc(m.NestBase.Nested.__init__) == """ + ) + assert ( + doc(m.NestBase.Nested.__init__) + == """ __init__(self: m.class_.NestBase.Nested) -> None """ - assert doc(m.NestBase.Nested.fn) == """ + ) + assert ( + doc(m.NestBase.Nested.fn) + == """ fn(self: m.class_.NestBase.Nested, arg0: int, arg1: m.class_.NestBase, arg2: m.class_.NestBase.Nested) -> None """ # noqa: E501 line too long - assert doc(m.NestBase.Nested.fa) == """ + ) + assert ( + doc(m.NestBase.Nested.fa) + == """ fa(self: m.class_.NestBase.Nested, a: int, b: m.class_.NestBase, c: m.class_.NestBase.Nested) -> None """ # noqa: E501 line too long + ) assert m.NestBase.__module__ == "pybind11_tests.class_" assert m.NestBase.Nested.__module__ == "pybind11_tests.class_" def test_inheritance(msg): - roger = m.Rabbit('Rabbit') + roger = m.Rabbit("Rabbit") assert roger.name() + " is a " + roger.species() == "Rabbit is a parrot" assert m.pet_name_species(roger) == "Rabbit is a parrot" - polly = m.Pet('Polly', 'parrot') + polly = m.Pet("Polly", "parrot") assert polly.name() + " is a " + polly.species() == "Polly is a parrot" assert m.pet_name_species(polly) == "Polly is a parrot" - molly = m.Dog('Molly') + molly = m.Dog("Molly") assert molly.name() + " is a " + molly.species() == "Molly is a dog" assert m.pet_name_species(molly) == "Molly is a dog" - fred = m.Hamster('Fred') + fred = m.Hamster("Fred") assert fred.name() + " is a " + fred.species() == "Fred is a rodent" assert m.dog_bark(molly) == "Woof!" with pytest.raises(TypeError) as excinfo: m.dog_bark(polly) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ dog_bark(): incompatible function arguments. The following argument types are supported: 1. (arg0: m.class_.Dog) -> str Invoked with: """ + ) with pytest.raises(TypeError) as excinfo: m.Chimera("lion", "goat") @@ -144,12 +183,11 @@ def test_inheritance_init(msg): class Python(m.Pet): def __init__(self): pass + with pytest.raises(TypeError) as exc_info: Python() - expected = ["m.class_.Pet.__init__() must be called when overriding __init__", - "Pet.__init__() must be called when overriding __init__"] # PyPy? - # TODO: fix PyPy error message wrt. tp_name/__qualname__? - assert msg(exc_info.value) in expected + expected = "m.class_.Pet.__init__() must be called when overriding __init__" + assert msg(exc_info.value) == expected # Multiple bases class RabbitHamster(m.Rabbit, m.Hamster): @@ -158,9 +196,8 @@ def test_inheritance_init(msg): with pytest.raises(TypeError) as exc_info: RabbitHamster() - expected = ["m.class_.Hamster.__init__() must be called when overriding __init__", - "Hamster.__init__() must be called when overriding __init__"] # PyPy - assert msg(exc_info.value) in expected + expected = "m.class_.Hamster.__init__() must be called when overriding __init__" + assert msg(exc_info.value) == expected def test_automatic_upcasting(): @@ -188,13 +225,19 @@ def test_mismatched_holder(): with pytest.raises(RuntimeError) as excinfo: m.mismatched_holder_1() - assert re.match('generic_type: type ".*MismatchDerived1" does not have a non-default ' - 'holder type while its base ".*MismatchBase1" does', str(excinfo.value)) + assert re.match( + 'generic_type: type ".*MismatchDerived1" does not have a non-default ' + 'holder type while its base ".*MismatchBase1" does', + str(excinfo.value), + ) with pytest.raises(RuntimeError) as excinfo: m.mismatched_holder_2() - assert re.match('generic_type: type ".*MismatchDerived2" has a non-default holder type ' - 'while its base ".*MismatchBase2" does not', str(excinfo.value)) + assert re.match( + 'generic_type: type ".*MismatchDerived2" has a non-default holder type ' + 'while its base ".*MismatchBase2" does not', + str(excinfo.value), + ) def test_override_static(): @@ -226,20 +269,20 @@ def test_operator_new_delete(capture): a = m.HasOpNewDel() b = m.HasOpNewDelSize() d = m.HasOpNewDelBoth() - assert capture == """ + assert ( + capture + == """ A new 8 B new 4 D new 32 """ + ) sz_alias = str(m.AliasedHasOpNewDelSize.size_alias) sz_noalias = str(m.AliasedHasOpNewDelSize.size_noalias) with capture: c = m.AliasedHasOpNewDelSize() c2 = SubAliased() - assert capture == ( - "C new " + sz_noalias + "\n" + - "C new " + sz_alias + "\n" - ) + assert capture == ("C new " + sz_noalias + "\n" + "C new " + sz_alias + "\n") with capture: del a @@ -248,21 +291,21 @@ def test_operator_new_delete(capture): pytest.gc_collect() del d pytest.gc_collect() - assert capture == """ + assert ( + capture + == """ A delete B delete 4 D delete """ + ) with capture: del c pytest.gc_collect() del c2 pytest.gc_collect() - assert capture == ( - "C delete " + sz_noalias + "\n" + - "C delete " + sz_alias + "\n" - ) + assert capture == ("C delete " + sz_noalias + "\n" + "C delete " + sz_alias + "\n") def test_bind_protected_functions(): @@ -285,7 +328,7 @@ def test_bind_protected_functions(): def test_brace_initialization(): - """ Tests that simple POD classes can be constructed using C++11 brace initialization """ + """Tests that simple POD classes can be constructed using C++11 brace initialization""" a = m.BraceInitialization(123, "test") assert a.field1 == 123 assert a.field2 == "test" @@ -322,19 +365,23 @@ def test_reentrant_implicit_conversion_failure(msg): # ensure that there is no runaway reentrant implicit conversion (#1035) with pytest.raises(TypeError) as excinfo: m.BogusImplicitConversion(0) - assert msg(excinfo.value) == ''' + assert ( + msg(excinfo.value) + == """ __init__(): incompatible constructor arguments. The following argument types are supported: 1. m.class_.BogusImplicitConversion(arg0: m.class_.BogusImplicitConversion) Invoked with: 0 - ''' + """ + ) def test_error_after_conversions(): with pytest.raises(TypeError) as exc_info: m.test_error_after_conversions("hello") assert str(exc_info.value).startswith( - "Unable to convert function return value to a Python type!") + "Unable to convert function return value to a Python type!" + ) def test_aligned(): @@ -347,8 +394,10 @@ def test_aligned(): @pytest.mark.xfail("env.PYPY") def test_final(): with pytest.raises(TypeError) as exc_info: + class PyFinalChild(m.IsFinal): pass + assert str(exc_info.value).endswith("is not an acceptable base type") @@ -356,8 +405,10 @@ def test_final(): @pytest.mark.xfail("env.PYPY") def test_non_final_final(): with pytest.raises(TypeError) as exc_info: + class PyNonFinalFinalChild(m.IsNonFinalFinal): pass + assert str(exc_info.value).endswith("is not an acceptable base type") @@ -365,3 +416,58 @@ def test_non_final_final(): def test_exception_rvalue_abort(): with pytest.raises(RuntimeError): m.PyPrintDestructor().throw_something() + + +# https://github.com/pybind/pybind11/issues/1568 +def test_multiple_instances_with_same_pointer(capture): + n = 100 + instances = [m.SamePointer() for _ in range(n)] + for i in range(n): + # We need to reuse the same allocated memory for with a different type, + # to ensure the bug in `deregister_instance_impl` is detected. Otherwise + # `Py_TYPE(self) == Py_TYPE(it->second)` will still succeed, even though + # the `instance` is already deleted. + instances[i] = m.Empty() + # No assert: if this does not trigger the error + # pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + # and just completes without crashing, we're good. + + +# https://github.com/pybind/pybind11/issues/1624 +def test_base_and_derived_nested_scope(): + assert issubclass(m.DerivedWithNested, m.BaseWithNested) + assert m.BaseWithNested.Nested != m.DerivedWithNested.Nested + assert m.BaseWithNested.Nested.get_name() == "BaseWithNested::Nested" + assert m.DerivedWithNested.Nested.get_name() == "DerivedWithNested::Nested" + + +def test_register_duplicate_class(): + import types + + module_scope = types.ModuleType("module_scope") + with pytest.raises(RuntimeError) as exc_info: + m.register_duplicate_class_name(module_scope) + expected = ( + 'generic_type: cannot initialize type "Duplicate": ' + "an object with that name is already defined" + ) + assert str(exc_info.value) == expected + with pytest.raises(RuntimeError) as exc_info: + m.register_duplicate_class_type(module_scope) + expected = 'generic_type: type "YetAnotherDuplicate" is already registered!' + assert str(exc_info.value) == expected + + class ClassScope: + pass + + with pytest.raises(RuntimeError) as exc_info: + m.register_duplicate_nested_class_name(ClassScope) + expected = ( + 'generic_type: cannot initialize type "DuplicateNested": ' + "an object with that name is already defined" + ) + assert str(exc_info.value) == expected + with pytest.raises(RuntimeError) as exc_info: + m.register_duplicate_nested_class_type(ClassScope) + expected = 'generic_type: type "YetAnotherDuplicateNested" is already registered!' + assert str(exc_info.value) == expected diff --git a/wrap/pybind11/tests/test_cmake_build/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/CMakeLists.txt index 0c0578ad3..8bfaa386a 100644 --- a/wrap/pybind11/tests/test_cmake_build/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/CMakeLists.txt @@ -25,7 +25,7 @@ function(pybind11_add_build_test name) endif() if(NOT ARG_INSTALL) - list(APPEND build_options "-DPYBIND11_PROJECT_DIR=${pybind11_SOURCE_DIR}") + list(APPEND build_options "-Dpybind11_SOURCE_DIR=${pybind11_SOURCE_DIR}") else() list(APPEND build_options "-DCMAKE_PREFIX_PATH=${pybind11_BINARY_DIR}/mock_install") endif() @@ -55,6 +55,8 @@ function(pybind11_add_build_test name) add_dependencies(test_cmake_build test_build_${name}) endfunction() +possibly_uninitialized(PYTHON_MODULE_EXTENSION Python_INTERPRETER_ID) + pybind11_add_build_test(subdirectory_function) pybind11_add_build_test(subdirectory_target) if("${PYTHON_MODULE_EXTENSION}" MATCHES "pypy" OR "${Python_INTERPRETER_ID}" STREQUAL "PyPy") @@ -77,3 +79,6 @@ if(PYBIND11_INSTALL) endif() add_dependencies(check test_cmake_build) + +add_subdirectory(subdirectory_target EXCLUDE_FROM_ALL) +add_subdirectory(subdirectory_embed EXCLUDE_FROM_ALL) diff --git a/wrap/pybind11/tests/test_cmake_build/embed.cpp b/wrap/pybind11/tests/test_cmake_build/embed.cpp index b9581d2fd..a3abc8a84 100644 --- a/wrap/pybind11/tests/test_cmake_build/embed.cpp +++ b/wrap/pybind11/tests/test_cmake_build/embed.cpp @@ -12,10 +12,10 @@ int main(int argc, char *argv[]) { py::scoped_interpreter guard{}; - auto m = py::module::import("test_cmake_build"); + auto m = py::module_::import("test_cmake_build"); if (m.attr("add")(1, 2).cast() != 3) throw std::runtime_error("embed.cpp failed"); - py::module::import("sys").attr("argv") = py::make_tuple("test.py", "embed.cpp"); + py::module_::import("sys").attr("argv") = py::make_tuple("test.py", "embed.cpp"); py::eval_file(test_py_file, py::globals()); } diff --git a/wrap/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt index 64ae5c4bf..f7d693998 100644 --- a/wrap/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt @@ -22,5 +22,7 @@ set_target_properties(test_installed_embed PROPERTIES OUTPUT_NAME test_cmake_bui # This may be needed to resolve header conflicts, e.g. between Python release and debug headers. set_target_properties(test_installed_embed PROPERTIES NO_SYSTEM_FROM_IMPORTED ON) -add_custom_target(check_installed_embed $ - ${PROJECT_SOURCE_DIR}/../test.py) +add_custom_target( + check_installed_embed + $ ${PROJECT_SOURCE_DIR}/../test.py + DEPENDS test_installed_embed) diff --git a/wrap/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt index 1a502863c..d7ca4db55 100644 --- a/wrap/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt @@ -35,4 +35,5 @@ add_custom_target( PYTHONPATH=$ ${_Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py - ${PROJECT_NAME}) + ${PROJECT_NAME} + DEPENDS test_installed_function) diff --git a/wrap/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt index b38eb7747..bc5e101f1 100644 --- a/wrap/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt @@ -42,4 +42,5 @@ add_custom_target( PYTHONPATH=$ ${_Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py - ${PROJECT_NAME}) + ${PROJECT_NAME} + DEPENDS test_installed_target) diff --git a/wrap/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt index c7df0cf77..58cdd7cfd 100644 --- a/wrap/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt @@ -16,15 +16,17 @@ set(PYBIND11_INSTALL CACHE BOOL "") set(PYBIND11_EXPORT_NAME test_export) -add_subdirectory(${PYBIND11_PROJECT_DIR} pybind11) +add_subdirectory("${pybind11_SOURCE_DIR}" pybind11) # Test basic target functionality add_executable(test_subdirectory_embed ../embed.cpp) target_link_libraries(test_subdirectory_embed PRIVATE pybind11::embed) set_target_properties(test_subdirectory_embed PROPERTIES OUTPUT_NAME test_cmake_build) -add_custom_target(check_subdirectory_embed $ - ${PROJECT_SOURCE_DIR}/../test.py) +add_custom_target( + check_subdirectory_embed + $ "${PROJECT_SOURCE_DIR}/../test.py" + DEPENDS test_subdirectory_embed) # Test custom export group -- PYBIND11_EXPORT_NAME add_library(test_embed_lib ../embed.cpp) diff --git a/wrap/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt index 624c600f8..01557c439 100644 --- a/wrap/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt @@ -11,7 +11,7 @@ endif() project(test_subdirectory_function CXX) -add_subdirectory("${PYBIND11_PROJECT_DIR}" pybind11) +add_subdirectory("${pybind11_SOURCE_DIR}" pybind11) pybind11_add_module(test_subdirectory_function ../main.cpp) set_target_properties(test_subdirectory_function PROPERTIES OUTPUT_NAME test_cmake_build) @@ -31,4 +31,5 @@ add_custom_target( PYTHONPATH=$ ${_Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py - ${PROJECT_NAME}) + ${PROJECT_NAME} + DEPENDS test_subdirectory_function) diff --git a/wrap/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt b/wrap/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt index 2471941fb..ba82fdee2 100644 --- a/wrap/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt +++ b/wrap/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt @@ -11,7 +11,7 @@ endif() project(test_subdirectory_target CXX) -add_subdirectory(${PYBIND11_PROJECT_DIR} pybind11) +add_subdirectory("${pybind11_SOURCE_DIR}" pybind11) add_library(test_subdirectory_target MODULE ../main.cpp) set_target_properties(test_subdirectory_target PROPERTIES OUTPUT_NAME test_cmake_build) @@ -37,4 +37,5 @@ add_custom_target( PYTHONPATH=$ ${_Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py - ${PROJECT_NAME}) + ${PROJECT_NAME} + DEPENDS test_subdirectory_target) diff --git a/wrap/pybind11/tests/test_cmake_build/test.py b/wrap/pybind11/tests/test_cmake_build/test.py index 87ed5135f..972a27bea 100644 --- a/wrap/pybind11/tests/test_cmake_build/test.py +++ b/wrap/pybind11/tests/test_cmake_build/test.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- import sys + import test_cmake_build +if str is not bytes: # If not Python2 + assert isinstance(__file__, str) # Test this is properly set + assert test_cmake_build.add(1, 2) == 3 print("{} imports, runs, and adds: 1 + 2 = 3".format(sys.argv[1])) diff --git a/wrap/pybind11/tests/test_const_name.cpp b/wrap/pybind11/tests/test_const_name.cpp new file mode 100644 index 000000000..5cb3d16c1 --- /dev/null +++ b/wrap/pybind11/tests/test_const_name.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2021 The Pybind Development Team. +// All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +#include "pybind11_tests.h" + +#if defined(_MSC_VER) && _MSC_VER < 1910 + +// MSVC 2015 fails in bizarre ways. +# define PYBIND11_SKIP_TEST_CONST_NAME + +#else // Only test with MSVC 2017 or newer. + +// IUT = Implementation Under Test +# define CONST_NAME_TESTS(TEST_FUNC, IUT) \ + std::string TEST_FUNC(int selector) { \ + switch (selector) { \ + case 0: \ + return IUT("").text; \ + case 1: \ + return IUT("A").text; \ + case 2: \ + return IUT("Bd").text; \ + case 3: \ + return IUT("Cef").text; \ + case 4: \ + return IUT().text; /*NOLINT(bugprone-macro-parentheses)*/ \ + case 5: \ + return IUT().text; /*NOLINT(bugprone-macro-parentheses)*/ \ + case 6: \ + return IUT("T1", "T2").text; /*NOLINT(bugprone-macro-parentheses)*/ \ + case 7: \ + return IUT("U1", "U2").text; /*NOLINT(bugprone-macro-parentheses)*/ \ + case 8: \ + /*NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \ + return IUT(IUT("D1"), IUT("D2")).text; \ + case 9: \ + /*NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \ + return IUT(IUT("E1"), IUT("E2")).text; \ + case 10: \ + return IUT("KeepAtEnd").text; \ + default: \ + break; \ + } \ + throw std::runtime_error("Invalid selector value."); \ + } + +CONST_NAME_TESTS(const_name_tests, py::detail::const_name) + +# ifdef PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY +CONST_NAME_TESTS(underscore_tests, py::detail::_) +# endif + +#endif // MSVC >= 2017 + +TEST_SUBMODULE(const_name, m) { +#ifdef PYBIND11_SKIP_TEST_CONST_NAME + m.attr("const_name_tests") = "PYBIND11_SKIP_TEST_CONST_NAME"; +#else + m.def("const_name_tests", const_name_tests); +#endif + +#ifdef PYBIND11_SKIP_TEST_CONST_NAME + m.attr("underscore_tests") = "PYBIND11_SKIP_TEST_CONST_NAME"; +#elif defined(PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY) + m.def("underscore_tests", underscore_tests); +#else + m.attr("underscore_tests") = "PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY not defined."; +#endif +} diff --git a/wrap/pybind11/tests/test_const_name.py b/wrap/pybind11/tests/test_const_name.py new file mode 100644 index 000000000..d4e45e5e9 --- /dev/null +++ b/wrap/pybind11/tests/test_const_name.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +import pytest + +import env +from pybind11_tests import const_name as m + + +@pytest.mark.parametrize("func", (m.const_name_tests, m.underscore_tests)) +@pytest.mark.parametrize( + "selector, expected", + enumerate( + ( + "", + "A", + "Bd", + "Cef", + "%", + "%", + "T1", + "U2", + "D1", + "E2", + "KeepAtEnd", + ) + ), +) +def test_const_name(func, selector, expected): + if isinstance(func, type(u"") if env.PY2 else str): + pytest.skip(func) + text = func(selector) + assert text == expected diff --git a/wrap/pybind11/tests/test_constants_and_functions.cpp b/wrap/pybind11/tests/test_constants_and_functions.cpp index f60779559..c0554503f 100644 --- a/wrap/pybind11/tests/test_constants_and_functions.cpp +++ b/wrap/pybind11/tests/test_constants_and_functions.cpp @@ -1,5 +1,6 @@ /* - tests/test_constants_and_functions.cpp -- global constants and functions, enumerations, raw byte strings + tests/test_constants_and_functions.cpp -- global constants and functions, enumerations, raw + byte strings Copyright (c) 2016 Wenzel Jakob @@ -33,7 +34,7 @@ py::bytes return_bytes() { return std::string(data, 4); } -std::string print_bytes(py::bytes bytes) { +std::string print_bytes(const py::bytes &bytes) { std::string ret = "bytes["; const auto value = static_cast(bytes); for (size_t i = 0; i < value.length(); ++i) { @@ -46,15 +47,23 @@ std::string print_bytes(py::bytes bytes) { // Test that we properly handle C++17 exception specifiers (which are part of the function signature // in C++17). These should all still work before C++17, but don't affect the function signature. namespace test_exc_sp { +// [workaround(intel)] Unable to use noexcept instead of noexcept(true) +// Make the f1 test basically the same as the f2 test in C++17 mode for the Intel compiler as +// it fails to compile with a plain noexcept (tested with icc (ICC) 2021.1 Beta 20200827). +#if defined(__INTEL_COMPILER) && defined(PYBIND11_CPP17) +int f1(int x) noexcept(true) { return x+1; } +#else int f1(int x) noexcept { return x+1; } +#endif int f2(int x) noexcept(true) { return x+2; } int f3(int x) noexcept(false) { return x+3; } -#if defined(__GNUG__) +#if defined(__GNUG__) && !defined(__INTEL_COMPILER) # pragma GCC diagnostic push # pragma GCC diagnostic ignored "-Wdeprecated" #endif +// NOLINTNEXTLINE(modernize-use-noexcept) int f4(int x) throw() { return x+4; } // Deprecated equivalent to noexcept(true) -#if defined(__GNUG__) +#if defined(__GNUG__) && !defined(__INTEL_COMPILER) # pragma GCC diagnostic pop #endif struct C { @@ -64,13 +73,15 @@ struct C { int m4(int x) const noexcept(true) { return x-4; } int m5(int x) noexcept(false) { return x-5; } int m6(int x) const noexcept(false) { return x-6; } -#if defined(__GNUG__) +#if defined(__GNUG__) && !defined(__INTEL_COMPILER) # pragma GCC diagnostic push # pragma GCC diagnostic ignored "-Wdeprecated" #endif - int m7(int x) throw() { return x-7; } - int m8(int x) const throw() { return x-8; } -#if defined(__GNUG__) + // NOLINTNEXTLINE(modernize-use-noexcept) + int m7(int x) throw() { return x - 7; } + // NOLINTNEXTLINE(modernize-use-noexcept) + int m8(int x) const throw() { return x - 8; } +#if defined(__GNUG__) && !defined(__INTEL_COMPILER) # pragma GCC diagnostic pop #endif }; @@ -122,6 +133,33 @@ TEST_SUBMODULE(constants_and_functions, m) { ; m.def("f1", f1); m.def("f2", f2); +#if defined(__INTEL_COMPILER) +# pragma warning push +# pragma warning disable 878 // incompatible exception specifications +#endif m.def("f3", f3); +#if defined(__INTEL_COMPILER) +# pragma warning pop +#endif m.def("f4", f4); + + // test_function_record_leaks + struct LargeCapture { + // This should always be enough to trigger the alternative branch + // where `sizeof(capture) > sizeof(rec->data)` + uint64_t zeros[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + }; + m.def("register_large_capture_with_invalid_arguments", [](py::module_ m) { + LargeCapture capture; // VS 2015's MSVC is acting up if we create the array here + m.def("should_raise", [capture](int) { return capture.zeros[9] + 33; }, py::kw_only(), py::arg()); + }); + m.def("register_with_raising_repr", [](py::module_ m, const py::object &default_value) { + m.def( + "should_raise", + [](int, int, const py::object &) { return 42; }, + "some docstring", + py::arg_v("x", 42), + py::arg_v("y", 42, ""), + py::arg_v("z", default_value)); + }); } diff --git a/wrap/pybind11/tests/test_constants_and_functions.py b/wrap/pybind11/tests/test_constants_and_functions.py index b980ccf1c..ff13bd0f2 100644 --- a/wrap/pybind11/tests/test_constants_and_functions.py +++ b/wrap/pybind11/tests/test_constants_and_functions.py @@ -40,3 +40,14 @@ def test_exception_specifiers(): assert m.f2(53) == 55 assert m.f3(86) == 89 assert m.f4(140) == 144 + + +def test_function_record_leaks(): + class RaisingRepr: + def __repr__(self): + raise RuntimeError("Surprise!") + + with pytest.raises(RuntimeError): + m.register_large_capture_with_invalid_arguments(m) + with pytest.raises(RuntimeError): + m.register_with_raising_repr(m, RaisingRepr()) diff --git a/wrap/pybind11/tests/test_copy_move.cpp b/wrap/pybind11/tests/test_copy_move.cpp index 05d5c4767..4711a9482 100644 --- a/wrap/pybind11/tests/test_copy_move.cpp +++ b/wrap/pybind11/tests/test_copy_move.cpp @@ -37,9 +37,16 @@ template <> lacking_move_ctor empty::instance_ = {}; class MoveOnlyInt { public: MoveOnlyInt() { print_default_created(this); } - MoveOnlyInt(int v) : value{std::move(v)} { print_created(this, value); } - MoveOnlyInt(MoveOnlyInt &&m) { print_move_created(this, m.value); std::swap(value, m.value); } - MoveOnlyInt &operator=(MoveOnlyInt &&m) { print_move_assigned(this, m.value); std::swap(value, m.value); return *this; } + explicit MoveOnlyInt(int v) : value{v} { print_created(this, value); } + MoveOnlyInt(MoveOnlyInt &&m) noexcept { + print_move_created(this, m.value); + std::swap(value, m.value); + } + MoveOnlyInt &operator=(MoveOnlyInt &&m) noexcept { + print_move_assigned(this, m.value); + std::swap(value, m.value); + return *this; + } MoveOnlyInt(const MoveOnlyInt &) = delete; MoveOnlyInt &operator=(const MoveOnlyInt &) = delete; ~MoveOnlyInt() { print_destroyed(this); } @@ -49,9 +56,16 @@ public: class MoveOrCopyInt { public: MoveOrCopyInt() { print_default_created(this); } - MoveOrCopyInt(int v) : value{std::move(v)} { print_created(this, value); } - MoveOrCopyInt(MoveOrCopyInt &&m) { print_move_created(this, m.value); std::swap(value, m.value); } - MoveOrCopyInt &operator=(MoveOrCopyInt &&m) { print_move_assigned(this, m.value); std::swap(value, m.value); return *this; } + explicit MoveOrCopyInt(int v) : value{v} { print_created(this, value); } + MoveOrCopyInt(MoveOrCopyInt &&m) noexcept { + print_move_created(this, m.value); + std::swap(value, m.value); + } + MoveOrCopyInt &operator=(MoveOrCopyInt &&m) noexcept { + print_move_assigned(this, m.value); + std::swap(value, m.value); + return *this; + } MoveOrCopyInt(const MoveOrCopyInt &c) { print_copy_created(this, c.value); value = c.value; } MoveOrCopyInt &operator=(const MoveOrCopyInt &c) { print_copy_assigned(this, c.value); value = c.value; return *this; } ~MoveOrCopyInt() { print_destroyed(this); } @@ -61,7 +75,7 @@ public: class CopyOnlyInt { public: CopyOnlyInt() { print_default_created(this); } - CopyOnlyInt(int v) : value{std::move(v)} { print_created(this, value); } + explicit CopyOnlyInt(int v) : value{v} { print_created(this, value); } CopyOnlyInt(const CopyOnlyInt &c) { print_copy_created(this, c.value); value = c.value; } CopyOnlyInt &operator=(const CopyOnlyInt &c) { print_copy_assigned(this, c.value); value = c.value; return *this; } ~CopyOnlyInt() { print_destroyed(this); } @@ -71,13 +85,13 @@ public: PYBIND11_NAMESPACE_BEGIN(pybind11) PYBIND11_NAMESPACE_BEGIN(detail) template <> struct type_caster { - PYBIND11_TYPE_CASTER(MoveOnlyInt, _("MoveOnlyInt")); + PYBIND11_TYPE_CASTER(MoveOnlyInt, const_name("MoveOnlyInt")); bool load(handle src, bool) { value = MoveOnlyInt(src.cast()); return true; } static handle cast(const MoveOnlyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } }; template <> struct type_caster { - PYBIND11_TYPE_CASTER(MoveOrCopyInt, _("MoveOrCopyInt")); + PYBIND11_TYPE_CASTER(MoveOrCopyInt, const_name("MoveOrCopyInt")); bool load(handle src, bool) { value = MoveOrCopyInt(src.cast()); return true; } static handle cast(const MoveOrCopyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } }; @@ -86,15 +100,15 @@ template <> struct type_caster { protected: CopyOnlyInt value; public: - static constexpr auto name = _("CopyOnlyInt"); + static constexpr auto name = const_name("CopyOnlyInt"); bool load(handle src, bool) { value = CopyOnlyInt(src.cast()); return true; } static handle cast(const CopyOnlyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } static handle cast(const CopyOnlyInt *src, return_value_policy policy, handle parent) { if (!src) return none().release(); return cast(*src, policy, parent); } - operator CopyOnlyInt*() { return &value; } - operator CopyOnlyInt&() { return value; } + explicit operator CopyOnlyInt *() { return &value; } + explicit operator CopyOnlyInt &() { return value; } template using cast_op_type = pybind11::detail::cast_op_type; }; PYBIND11_NAMESPACE_END(detail) @@ -111,14 +125,15 @@ TEST_SUBMODULE(copy_move_policies, m) { py::return_value_policy::move); // test_move_and_copy_casts - m.def("move_and_copy_casts", [](py::object o) { + // NOLINTNEXTLINE(performance-unnecessary-value-param) + m.def("move_and_copy_casts", [](const py::object &o) { int r = 0; r += py::cast(o).value; /* moves */ r += py::cast(o).value; /* moves */ r += py::cast(o).value; /* copies */ - MoveOrCopyInt m1(py::cast(o)); /* moves */ - MoveOnlyInt m2(py::cast(o)); /* moves */ - CopyOnlyInt m3(py::cast(o)); /* copies */ + auto m1(py::cast(o)); /* moves */ + auto m2(py::cast(o)); /* moves */ + auto m3(py::cast(o)); /* copies */ r += m1.value + m2.value + m3.value; return r; @@ -126,7 +141,11 @@ TEST_SUBMODULE(copy_move_policies, m) { // test_move_and_copy_loads m.def("move_only", [](MoveOnlyInt m) { return m.value; }); + // Changing this breaks the existing test: needs careful review. + // NOLINTNEXTLINE(performance-unnecessary-value-param) m.def("move_or_copy", [](MoveOrCopyInt m) { return m.value; }); + // Changing this breaks the existing test: needs careful review. + // NOLINTNEXTLINE(performance-unnecessary-value-param) m.def("copy_only", [](CopyOnlyInt m) { return m.value; }); m.def("move_pair", [](std::pair p) { return p.first.value + p.second.value; @@ -186,8 +205,7 @@ TEST_SUBMODULE(copy_move_policies, m) { void *ptr = std::malloc(bytes); if (ptr) return ptr; - else - throw std::bad_alloc{}; + throw std::bad_alloc{}; } }; py::class_(m, "PrivateOpNew").def_readonly("value", &PrivateOpNew::value); @@ -201,7 +219,7 @@ TEST_SUBMODULE(copy_move_policies, m) { // #389: rvp::move should fall-through to copy on non-movable objects struct MoveIssue1 { int v; - MoveIssue1(int v) : v{v} {} + explicit MoveIssue1(int v) : v{v} {} MoveIssue1(const MoveIssue1 &c) = default; MoveIssue1(MoveIssue1 &&) = delete; }; @@ -209,11 +227,12 @@ TEST_SUBMODULE(copy_move_policies, m) { struct MoveIssue2 { int v; - MoveIssue2(int v) : v{v} {} + explicit MoveIssue2(int v) : v{v} {} MoveIssue2(MoveIssue2 &&) = default; }; py::class_(m, "MoveIssue2").def(py::init()).def_readwrite("value", &MoveIssue2::v); - m.def("get_moveissue1", [](int i) { return new MoveIssue1(i); }, py::return_value_policy::move); + // #2742: Don't expect ownership of raw pointer to `new`ed object to be transferred with `py::return_value_policy::move` + m.def("get_moveissue1", [](int i) { return std::unique_ptr(new MoveIssue1(i)); }, py::return_value_policy::move); m.def("get_moveissue2", [](int i) { return MoveIssue2(i); }, py::return_value_policy::move); } diff --git a/wrap/pybind11/tests/test_copy_move.py b/wrap/pybind11/tests/test_copy_move.py index 6b53993a9..eb1efddd5 100644 --- a/wrap/pybind11/tests/test_copy_move.py +++ b/wrap/pybind11/tests/test_copy_move.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pytest + from pybind11_tests import copy_move_policies as m @@ -19,7 +20,11 @@ def test_move_and_copy_casts(): """Cast some values in C++ via custom type casters and count the number of moves/copies.""" cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] + c_m, c_mc, c_c = ( + cstats["MoveOnlyInt"], + cstats["MoveOrCopyInt"], + cstats["CopyOnlyInt"], + ) # The type move constructions/assignments below each get incremented: the move assignment comes # from the type_caster load; the move construction happens when extracting that via a cast or @@ -43,7 +48,11 @@ def test_move_and_copy_loads(): moves/copies.""" cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] + c_m, c_mc, c_c = ( + cstats["MoveOnlyInt"], + cstats["MoveOrCopyInt"], + cstats["CopyOnlyInt"], + ) assert m.move_only(10) == 10 # 1 move, c_m assert m.move_or_copy(11) == 11 # 1 move, c_mc @@ -66,12 +75,16 @@ def test_move_and_copy_loads(): assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 -@pytest.mark.skipif(not m.has_optional, reason='no ') +@pytest.mark.skipif(not m.has_optional, reason="no ") def test_move_and_copy_load_optional(): """Tests move/copy loads of std::optional arguments""" cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] + c_m, c_mc, c_c = ( + cstats["MoveOnlyInt"], + cstats["MoveOrCopyInt"], + cstats["CopyOnlyInt"], + ) # The extra move/copy constructions below come from the std::optional move (which has to move # its arguments): @@ -107,7 +120,7 @@ def test_private_op_new(): def test_move_fallback(): """#389: rvp::move should fall-through to copy on non-movable objects""" - m2 = m.get_moveissue2(2) - assert m2.value == 2 m1 = m.get_moveissue1(1) assert m1.value == 1 + m2 = m.get_moveissue2(2) + assert m2.value == 2 diff --git a/wrap/pybind11/tests/test_custom_type_casters.cpp b/wrap/pybind11/tests/test_custom_type_casters.cpp index d565add26..48613ee5a 100644 --- a/wrap/pybind11/tests/test_custom_type_casters.cpp +++ b/wrap/pybind11/tests/test_custom_type_casters.cpp @@ -18,7 +18,12 @@ class ArgAlwaysConverts { }; namespace pybind11 { namespace detail { template <> struct type_caster { public: + // Classic +#ifdef PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY PYBIND11_TYPE_CASTER(ArgInspector1, _("ArgInspector1")); +#else + PYBIND11_TYPE_CASTER(ArgInspector1, const_name("ArgInspector1")); +#endif bool load(handle src, bool convert) { value.arg = "loading ArgInspector1 argument " + @@ -33,7 +38,7 @@ public: }; template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(ArgInspector2, _("ArgInspector2")); + PYBIND11_TYPE_CASTER(ArgInspector2, const_name("ArgInspector2")); bool load(handle src, bool convert) { value.arg = "loading ArgInspector2 argument " + @@ -48,7 +53,7 @@ public: }; template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(ArgAlwaysConverts, _("ArgAlwaysConverts")); + PYBIND11_TYPE_CASTER(ArgAlwaysConverts, const_name("ArgAlwaysConverts")); bool load(handle, bool convert) { return convert; @@ -67,13 +72,16 @@ public: DestructionTester() { print_default_created(this); } ~DestructionTester() { print_destroyed(this); } DestructionTester(const DestructionTester &) { print_copy_created(this); } - DestructionTester(DestructionTester &&) { print_move_created(this); } + DestructionTester(DestructionTester &&) noexcept { print_move_created(this); } DestructionTester &operator=(const DestructionTester &) { print_copy_assigned(this); return *this; } - DestructionTester &operator=(DestructionTester &&) { print_move_assigned(this); return *this; } + DestructionTester &operator=(DestructionTester &&) noexcept { + print_move_assigned(this); + return *this; + } }; namespace pybind11 { namespace detail { template <> struct type_caster { - PYBIND11_TYPE_CASTER(DestructionTester, _("DestructionTester")); + PYBIND11_TYPE_CASTER(DestructionTester, const_name("DestructionTester")); bool load(handle, bool) { return true; } static handle cast(const DestructionTester &, return_value_policy, handle) { @@ -94,24 +102,35 @@ TEST_SUBMODULE(custom_type_casters, m) { class ArgInspector { public: ArgInspector1 f(ArgInspector1 a, ArgAlwaysConverts) { return a; } - std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d, ArgAlwaysConverts) { + std::string g(const ArgInspector1 &a, + const ArgInspector1 &b, + int c, + ArgInspector2 *d, + ArgAlwaysConverts) { return a.arg + "\n" + b.arg + "\n" + std::to_string(c) + "\n" + d->arg; } static ArgInspector2 h(ArgInspector2 a, ArgAlwaysConverts) { return a; } }; + // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works. py::class_(m, "ArgInspector") .def(py::init<>()) .def("f", &ArgInspector::f, py::arg(), py::arg() = ArgAlwaysConverts()) .def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2(), py::arg() = ArgAlwaysConverts()) - .def_static("h", &ArgInspector::h, py::arg().noconvert(), py::arg() = ArgAlwaysConverts()) + .def_static("h", &ArgInspector::h, py::arg{}.noconvert(), py::arg() = ArgAlwaysConverts()) ; - m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b, ArgAlwaysConverts) { return a.arg + "\n" + b.arg; }, - py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true), py::arg() = ArgAlwaysConverts()); + m.def( + "arg_inspect_func", + [](const ArgInspector2 &a, const ArgInspector1 &b, ArgAlwaysConverts) { + return a.arg + "\n" + b.arg; + }, + py::arg{}.noconvert(false), + py::arg_v(nullptr, ArgInspector1()).noconvert(true), + py::arg() = ArgAlwaysConverts()); - m.def("floats_preferred", [](double f) { return 0.5 * f; }, py::arg("f")); - m.def("floats_only", [](double f) { return 0.5 * f; }, py::arg("f").noconvert()); - m.def("ints_preferred", [](int i) { return i / 2; }, py::arg("i")); - m.def("ints_only", [](int i) { return i / 2; }, py::arg("i").noconvert()); + m.def("floats_preferred", [](double f) { return 0.5 * f; }, "f"_a); + m.def("floats_only", [](double f) { return 0.5 * f; }, "f"_a.noconvert()); + m.def("ints_preferred", [](int i) { return i / 2; }, "i"_a); + m.def("ints_only", [](int i) { return i / 2; }, "i"_a.noconvert()); // test_custom_caster_destruction // Test that `take_ownership` works on types with a custom type caster when given a pointer diff --git a/wrap/pybind11/tests/test_custom_type_casters.py b/wrap/pybind11/tests/test_custom_type_casters.py index 9475c4516..a10646ff4 100644 --- a/wrap/pybind11/tests/test_custom_type_casters.py +++ b/wrap/pybind11/tests/test_custom_type_casters.py @@ -1,69 +1,96 @@ # -*- coding: utf-8 -*- import pytest + from pybind11_tests import custom_type_casters as m def test_noconvert_args(msg): a = m.ArgInspector() - assert msg(a.f("hi")) == """ + assert ( + msg(a.f("hi")) + == """ loading ArgInspector1 argument WITH conversion allowed. Argument value = hi """ - assert msg(a.g("this is a", "this is b")) == """ + ) + assert ( + msg(a.g("this is a", "this is b")) + == """ loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b 13 loading ArgInspector2 argument WITH conversion allowed. Argument value = (default arg inspector 2) """ # noqa: E501 line too long - assert msg(a.g("this is a", "this is b", 42)) == """ + ) + assert ( + msg(a.g("this is a", "this is b", 42)) + == """ loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b 42 loading ArgInspector2 argument WITH conversion allowed. Argument value = (default arg inspector 2) """ # noqa: E501 line too long - assert msg(a.g("this is a", "this is b", 42, "this is d")) == """ + ) + assert ( + msg(a.g("this is a", "this is b", 42, "this is d")) + == """ loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b 42 loading ArgInspector2 argument WITH conversion allowed. Argument value = this is d """ - assert (a.h("arg 1") == - "loading ArgInspector2 argument WITHOUT conversion allowed. Argument value = arg 1") - assert msg(m.arg_inspect_func("A1", "A2")) == """ + ) + assert ( + a.h("arg 1") + == "loading ArgInspector2 argument WITHOUT conversion allowed. Argument value = arg 1" + ) + assert ( + msg(m.arg_inspect_func("A1", "A2")) + == """ loading ArgInspector2 argument WITH conversion allowed. Argument value = A1 loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = A2 """ + ) assert m.floats_preferred(4) == 2.0 assert m.floats_only(4.0) == 2.0 with pytest.raises(TypeError) as excinfo: m.floats_only(4) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ floats_only(): incompatible function arguments. The following argument types are supported: 1. (f: float) -> float Invoked with: 4 """ + ) assert m.ints_preferred(4) == 2 assert m.ints_preferred(True) == 0 with pytest.raises(TypeError) as excinfo: m.ints_preferred(4.0) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ ints_preferred(): incompatible function arguments. The following argument types are supported: 1. (i: int) -> int Invoked with: 4.0 """ # noqa: E501 line too long + ) assert m.ints_only(4) == 2 with pytest.raises(TypeError) as excinfo: m.ints_only(4.0) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ ints_only(): incompatible function arguments. The following argument types are supported: 1. (i: int) -> int Invoked with: 4.0 """ + ) def test_custom_caster_destruction(): diff --git a/wrap/pybind11/tests/test_custom_type_setup.cpp b/wrap/pybind11/tests/test_custom_type_setup.cpp new file mode 100644 index 000000000..42fae05d5 --- /dev/null +++ b/wrap/pybind11/tests/test_custom_type_setup.cpp @@ -0,0 +1,41 @@ +/* + tests/test_custom_type_setup.cpp -- Tests `pybind11::custom_type_setup` + + Copyright (c) Google LLC + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include + +#include "pybind11_tests.h" + +namespace py = pybind11; + +namespace { + +struct OwnsPythonObjects { + py::object value = py::none(); +}; +} // namespace + +TEST_SUBMODULE(custom_type_setup, m) { + py::class_ cls( + m, "OwnsPythonObjects", py::custom_type_setup([](PyHeapTypeObject *heap_type) { + auto *type = &heap_type->ht_type; + type->tp_flags |= Py_TPFLAGS_HAVE_GC; + type->tp_traverse = [](PyObject *self_base, visitproc visit, void *arg) { + auto &self = py::cast(py::handle(self_base)); + Py_VISIT(self.value.ptr()); + return 0; + }; + type->tp_clear = [](PyObject *self_base) { + auto &self = py::cast(py::handle(self_base)); + self.value = py::none(); + return 0; + }; + })); + cls.def(py::init<>()); + cls.def_readwrite("value", &OwnsPythonObjects::value); +} diff --git a/wrap/pybind11/tests/test_custom_type_setup.py b/wrap/pybind11/tests/test_custom_type_setup.py new file mode 100644 index 000000000..ef96f0814 --- /dev/null +++ b/wrap/pybind11/tests/test_custom_type_setup.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- + +import gc +import weakref + +import pytest + +import env # noqa: F401 +from pybind11_tests import custom_type_setup as m + + +@pytest.fixture +def gc_tester(): + """Tests that an object is garbage collected. + + Assumes that any unreferenced objects are fully collected after calling + `gc.collect()`. That is true on CPython, but does not appear to reliably + hold on PyPy. + """ + + weak_refs = [] + + def add_ref(obj): + # PyPy does not support `gc.is_tracked`. + if hasattr(gc, "is_tracked"): + assert gc.is_tracked(obj) + weak_refs.append(weakref.ref(obj)) + + yield add_ref + + gc.collect() + for ref in weak_refs: + assert ref() is None + + +# PyPy does not seem to reliably garbage collect. +@pytest.mark.skipif("env.PYPY") +def test_self_cycle(gc_tester): + obj = m.OwnsPythonObjects() + obj.value = obj + gc_tester(obj) + + +# PyPy does not seem to reliably garbage collect. +@pytest.mark.skipif("env.PYPY") +def test_indirect_cycle(gc_tester): + obj = m.OwnsPythonObjects() + obj_list = [obj] + obj.value = obj_list + gc_tester(obj) diff --git a/wrap/pybind11/tests/test_docstring_options.cpp b/wrap/pybind11/tests/test_docstring_options.cpp index 8c8f79fd5..8a97af55f 100644 --- a/wrap/pybind11/tests/test_docstring_options.cpp +++ b/wrap/pybind11/tests/test_docstring_options.cpp @@ -45,6 +45,14 @@ TEST_SUBMODULE(docstring_options, m) { m.def("test_function7", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); + { + py::options options; + options.disable_user_defined_docstrings(); + options.disable_function_signatures(); + + m.def("test_function8", []() {}); + } + { py::options options; options.disable_user_defined_docstrings(); diff --git a/wrap/pybind11/tests/test_docstring_options.py b/wrap/pybind11/tests/test_docstring_options.py index 80ade0f15..8ee661388 100644 --- a/wrap/pybind11/tests/test_docstring_options.py +++ b/wrap/pybind11/tests/test_docstring_options.py @@ -18,10 +18,10 @@ def test_docstring_options(): assert m.test_overloaded3.__doc__ == "Overload docstr" # options.enable_function_signatures() - assert m.test_function3.__doc__ .startswith("test_function3(a: int, b: int) -> None") + assert m.test_function3.__doc__.startswith("test_function3(a: int, b: int) -> None") - assert m.test_function4.__doc__ .startswith("test_function4(a: int, b: int) -> None") - assert m.test_function4.__doc__ .endswith("A custom docstring\n") + assert m.test_function4.__doc__.startswith("test_function4(a: int, b: int) -> None") + assert m.test_function4.__doc__.endswith("A custom docstring\n") # options.disable_function_signatures() # options.disable_user_defined_docstrings() @@ -31,8 +31,11 @@ def test_docstring_options(): assert m.test_function6.__doc__ == "A custom docstring" # RAII destructor - assert m.test_function7.__doc__ .startswith("test_function7(a: int, b: int) -> None") - assert m.test_function7.__doc__ .endswith("A custom docstring\n") + assert m.test_function7.__doc__.startswith("test_function7(a: int, b: int) -> None") + assert m.test_function7.__doc__.endswith("A custom docstring\n") + + # when all options are disabled, no docstring (instead of an empty one) should be generated + assert m.test_function8.__doc__ is None # Suppression of user-defined docstrings for non-function objects assert not m.DocstringTestFoo.__doc__ diff --git a/wrap/pybind11/tests/test_eigen.cpp b/wrap/pybind11/tests/test_eigen.cpp index 56aa1a4a6..d22a94a1a 100644 --- a/wrap/pybind11/tests/test_eigen.cpp +++ b/wrap/pybind11/tests/test_eigen.cpp @@ -13,6 +13,9 @@ #include #if defined(_MSC_VER) +#if _MSC_VER < 1910 // VS 2015's MSVC +# pragma warning(disable: 4127) // C4127: conditional expression is constant +#endif # pragma warning(disable: 4996) // C4996: std::unary_negation is deprecated #endif @@ -54,15 +57,15 @@ void reset_refs() { } // Returns element 2,1 from a matrix (used to test copy/nocopy) -double get_elem(Eigen::Ref m) { return m(2, 1); }; - +double get_elem(const Eigen::Ref &m) { return m(2, 1); }; // Returns a matrix with 10*r + 100*c added to each matrix element (to help test that the matrix // reference is referencing rows/columns correctly). template Eigen::MatrixXd adjust_matrix(MatrixArgType m) { Eigen::MatrixXd ret(m); - for (int c = 0; c < m.cols(); c++) for (int r = 0; r < m.rows(); r++) - ret(r, c) += 10*r + 100*c; + for (int c = 0; c < m.cols(); c++) + for (int r = 0; r < m.rows(); r++) + ret(r, c) += 10*r + 100*c; // NOLINT(clang-analyzer-core.uninitialized.Assign) return ret; } @@ -93,15 +96,18 @@ TEST_SUBMODULE(eigen, m) { m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; }); m.def("double_threec", [](py::EigenDRef x) { x *= 2; }); m.def("double_threer", [](py::EigenDRef x) { x *= 2; }); - m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; }); - m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; }); + m.def("double_mat_cm", [](const Eigen::MatrixXf &x) -> Eigen::MatrixXf { return 2.0f * x; }); + m.def("double_mat_rm", [](const DenseMatrixR &x) -> DenseMatrixR { return 2.0f * x; }); // test_eigen_ref_to_python // Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended - m.def("cholesky1", [](Eigen::Ref x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); + m.def("cholesky1", + [](const Eigen::Ref &x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); m.def("cholesky2", [](const Eigen::Ref &x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); m.def("cholesky3", [](const Eigen::Ref &x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); - m.def("cholesky4", [](Eigen::Ref x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); + m.def("cholesky4", [](const Eigen::Ref &x) -> Eigen::MatrixXd { + return x.llt().matrixL(); + }); // test_eigen_ref_mutators // Mutators: these add some value to the given element using Eigen, but Eigen should be mapping into @@ -175,6 +181,7 @@ TEST_SUBMODULE(eigen, m) { ReturnTester() { print_created(this); } ~ReturnTester() { print_destroyed(this); } static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); } + // NOLINTNEXTLINE(readability-const-return-type) static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); } Eigen::MatrixXd &get() { return mat; } Eigen::MatrixXd *getPtr() { return &mat; } @@ -241,21 +248,27 @@ TEST_SUBMODULE(eigen, m) { // test_fixed, and various other tests m.def("fixed_r", [mat]() -> FixedMatrixR { return FixedMatrixR(mat); }); + // Our Eigen does a hack which respects constness through the numpy writeable flag. + // Therefore, the const return actually affects this type despite being an rvalue. + // NOLINTNEXTLINE(readability-const-return-type) m.def("fixed_r_const", [mat]() -> const FixedMatrixR { return FixedMatrixR(mat); }); m.def("fixed_c", [mat]() -> FixedMatrixC { return FixedMatrixC(mat); }); m.def("fixed_copy_r", [](const FixedMatrixR &m) -> FixedMatrixR { return m; }); m.def("fixed_copy_c", [](const FixedMatrixC &m) -> FixedMatrixC { return m; }); // test_mutator_descriptors - m.def("fixed_mutator_r", [](Eigen::Ref) {}); - m.def("fixed_mutator_c", [](Eigen::Ref) {}); - m.def("fixed_mutator_a", [](py::EigenDRef) {}); + m.def("fixed_mutator_r", [](const Eigen::Ref &) {}); + m.def("fixed_mutator_c", [](const Eigen::Ref &) {}); + m.def("fixed_mutator_a", [](const py::EigenDRef &) {}); // test_dense m.def("dense_r", [mat]() -> DenseMatrixR { return DenseMatrixR(mat); }); m.def("dense_c", [mat]() -> DenseMatrixC { return DenseMatrixC(mat); }); m.def("dense_copy_r", [](const DenseMatrixR &m) -> DenseMatrixR { return m; }); m.def("dense_copy_c", [](const DenseMatrixC &m) -> DenseMatrixC { return m; }); // test_sparse, test_sparse_signature - m.def("sparse_r", [mat]() -> SparseMatrixR { return Eigen::SparseView(mat); }); + m.def("sparse_r", [mat]() -> SparseMatrixR { + // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.UndefReturn) + return Eigen::SparseView(mat); + }); m.def("sparse_c", [mat]() -> SparseMatrixC { return Eigen::SparseView(mat); }); m.def("sparse_copy_r", [](const SparseMatrixR &m) -> SparseMatrixR { return m; }); m.def("sparse_copy_c", [](const SparseMatrixC &m) -> SparseMatrixC { return m; }); @@ -272,39 +285,47 @@ TEST_SUBMODULE(eigen, m) { m.def("cpp_ref_r", [](py::handle m) { return m.cast>()(1, 0); }); m.def("cpp_ref_any", [](py::handle m) { return m.cast>()(1, 0); }); + // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works. // test_nocopy_wrapper // Test that we can prevent copying into an argument that would normally copy: First a version // that would allow copying (if types or strides don't match) for comparison: m.def("get_elem", &get_elem); // Now this alternative that calls the tells pybind to fail rather than copy: - m.def("get_elem_nocopy", [](Eigen::Ref m) -> double { return get_elem(m); }, - py::arg().noconvert()); + m.def( + "get_elem_nocopy", + [](const Eigen::Ref &m) -> double { return get_elem(m); }, + py::arg{}.noconvert()); // Also test a row-major-only no-copy const ref: m.def("get_elem_rm_nocopy", [](Eigen::Ref> &m) -> long { return m(2, 1); }, - py::arg().noconvert()); + py::arg{}.noconvert()); // test_issue738 // Issue #738: 1xN or Nx1 2D matrices were neither accepted nor properly copied with an // incompatible stride value on the length-1 dimension--but that should be allowed (without // requiring a copy!) because the stride value can be safely ignored on a size-1 dimension. - m.def("iss738_f1", &adjust_matrix &>, py::arg().noconvert()); - m.def("iss738_f2", &adjust_matrix> &>, py::arg().noconvert()); + m.def("iss738_f1", &adjust_matrix &>, py::arg{}.noconvert()); + m.def("iss738_f2", &adjust_matrix> &>, py::arg{}.noconvert()); // test_issue1105 // Issue #1105: when converting from a numpy two-dimensional (Nx1) or (1xN) value into a dense - // eigen Vector or RowVector, the argument would fail to load because the numpy copy would fail: - // numpy won't broadcast a Nx1 into a 1-dimensional vector. - m.def("iss1105_col", [](Eigen::VectorXd) { return true; }); - m.def("iss1105_row", [](Eigen::RowVectorXd) { return true; }); + // eigen Vector or RowVector, the argument would fail to load because the numpy copy would + // fail: numpy won't broadcast a Nx1 into a 1-dimensional vector. + m.def("iss1105_col", [](const Eigen::VectorXd &) { return true; }); + m.def("iss1105_row", [](const Eigen::RowVectorXd &) { return true; }); // test_named_arguments // Make sure named arguments are working properly: - m.def("matrix_multiply", [](const py::EigenDRef A, const py::EigenDRef B) - -> Eigen::MatrixXd { - if (A.cols() != B.rows()) throw std::domain_error("Nonconformable matrices!"); - return A * B; - }, py::arg("A"), py::arg("B")); + m.def( + "matrix_multiply", + [](const py::EigenDRef &A, + const py::EigenDRef &B) -> Eigen::MatrixXd { + if (A.cols() != B.rows()) + throw std::domain_error("Nonconformable matrices!"); + return A * B; + }, + py::arg("A"), + py::arg("B")); // test_custom_operator_new py::class_(m, "CustomOperatorNew") @@ -316,12 +337,12 @@ TEST_SUBMODULE(eigen, m) { // In case of a failure (the caster's temp array does not live long enough), creating // a new array (np.ones(10)) increases the chances that the temp array will be garbage // collected and/or that its memory will be overridden with different values. - m.def("get_elem_direct", [](Eigen::Ref v) { - py::module::import("numpy").attr("ones")(10); + m.def("get_elem_direct", [](const Eigen::Ref &v) { + py::module_::import("numpy").attr("ones")(10); return v(5); }); m.def("get_elem_indirect", [](std::vector> v) { - py::module::import("numpy").attr("ones")(10); + py::module_::import("numpy").attr("ones")(10); return v[0](5); }); } diff --git a/wrap/pybind11/tests/test_eigen.py b/wrap/pybind11/tests/test_eigen.py index ac6847147..e53826cbb 100644 --- a/wrap/pybind11/tests/test_eigen.py +++ b/wrap/pybind11/tests/test_eigen.py @@ -1,16 +1,21 @@ # -*- coding: utf-8 -*- import pytest + from pybind11_tests import ConstructorStats np = pytest.importorskip("numpy") m = pytest.importorskip("pybind11_tests.eigen") -ref = np.array([[ 0., 3, 0, 0, 0, 11], - [22, 0, 0, 0, 17, 11], - [ 7, 5, 0, 1, 0, 11], - [ 0, 0, 0, 0, 0, 11], - [ 0, 0, 14, 0, 8, 11]]) +ref = np.array( + [ + [0.0, 3, 0, 0, 0, 11], + [22, 0, 0, 0, 17, 11], + [7, 5, 0, 1, 0, 11], + [0, 0, 0, 0, 0, 11], + [0, 0, 14, 0, 8, 11], + ] +) def assert_equal_ref(mat): @@ -40,28 +45,37 @@ def test_dense(): def test_partially_fixed(): - ref2 = np.array([[0., 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + ref2 = np.array([[0.0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2), ref2) np.testing.assert_array_equal(m.partial_copy_four_rm_c(ref2), ref2) np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2[:, 1]), ref2[:, [1]]) np.testing.assert_array_equal(m.partial_copy_four_rm_c(ref2[0, :]), ref2[[0], :]) - np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)]) np.testing.assert_array_equal( - m.partial_copy_four_rm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :]) + m.partial_copy_four_rm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)] + ) + np.testing.assert_array_equal( + m.partial_copy_four_rm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :] + ) np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2), ref2) np.testing.assert_array_equal(m.partial_copy_four_cm_c(ref2), ref2) np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2[:, 1]), ref2[:, [1]]) np.testing.assert_array_equal(m.partial_copy_four_cm_c(ref2[0, :]), ref2[[0], :]) - np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)]) np.testing.assert_array_equal( - m.partial_copy_four_cm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :]) + m.partial_copy_four_cm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)] + ) + np.testing.assert_array_equal( + m.partial_copy_four_cm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :] + ) # TypeError should be raise for a shape mismatch - functions = [m.partial_copy_four_rm_r, m.partial_copy_four_rm_c, - m.partial_copy_four_cm_r, m.partial_copy_four_cm_c] - matrix_with_wrong_shape = [[1, 2], - [3, 4]] + functions = [ + m.partial_copy_four_rm_r, + m.partial_copy_four_rm_c, + m.partial_copy_four_cm_r, + m.partial_copy_four_cm_c, + ] + matrix_with_wrong_shape = [[1, 2], [3, 4]] for f in functions: with pytest.raises(TypeError) as excinfo: f(matrix_with_wrong_shape) @@ -69,7 +83,7 @@ def test_partially_fixed(): def test_mutator_descriptors(): - zr = np.arange(30, dtype='float32').reshape(5, 6) # row-major + zr = np.arange(30, dtype="float32").reshape(5, 6) # row-major zc = zr.reshape(6, 5).transpose() # column-major m.fixed_mutator_r(zr) @@ -78,18 +92,21 @@ def test_mutator_descriptors(): m.fixed_mutator_a(zc) with pytest.raises(TypeError) as excinfo: m.fixed_mutator_r(zc) - assert ('(arg0: numpy.ndarray[numpy.float32[5, 6],' - ' flags.writeable, flags.c_contiguous]) -> None' - in str(excinfo.value)) + assert ( + "(arg0: numpy.ndarray[numpy.float32[5, 6]," + " flags.writeable, flags.c_contiguous]) -> None" in str(excinfo.value) + ) with pytest.raises(TypeError) as excinfo: m.fixed_mutator_c(zr) - assert ('(arg0: numpy.ndarray[numpy.float32[5, 6],' - ' flags.writeable, flags.f_contiguous]) -> None' - in str(excinfo.value)) + assert ( + "(arg0: numpy.ndarray[numpy.float32[5, 6]," + " flags.writeable, flags.f_contiguous]) -> None" in str(excinfo.value) + ) with pytest.raises(TypeError) as excinfo: - m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype='float32')) - assert ('(arg0: numpy.ndarray[numpy.float32[5, 6], flags.writeable]) -> None' - in str(excinfo.value)) + m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype="float32")) + assert "(arg0: numpy.ndarray[numpy.float32[5, 6], flags.writeable]) -> None" in str( + excinfo.value + ) zr.flags.writeable = False with pytest.raises(TypeError): m.fixed_mutator_r(zr) @@ -98,26 +115,26 @@ def test_mutator_descriptors(): def test_cpp_casting(): - assert m.cpp_copy(m.fixed_r()) == 22. - assert m.cpp_copy(m.fixed_c()) == 22. - z = np.array([[5., 6], [7, 8]]) - assert m.cpp_copy(z) == 7. - assert m.cpp_copy(m.get_cm_ref()) == 21. - assert m.cpp_copy(m.get_rm_ref()) == 21. - assert m.cpp_ref_c(m.get_cm_ref()) == 21. - assert m.cpp_ref_r(m.get_rm_ref()) == 21. + assert m.cpp_copy(m.fixed_r()) == 22.0 + assert m.cpp_copy(m.fixed_c()) == 22.0 + z = np.array([[5.0, 6], [7, 8]]) + assert m.cpp_copy(z) == 7.0 + assert m.cpp_copy(m.get_cm_ref()) == 21.0 + assert m.cpp_copy(m.get_rm_ref()) == 21.0 + assert m.cpp_ref_c(m.get_cm_ref()) == 21.0 + assert m.cpp_ref_r(m.get_rm_ref()) == 21.0 with pytest.raises(RuntimeError) as excinfo: # Can't reference m.fixed_c: it contains floats, m.cpp_ref_any wants doubles m.cpp_ref_any(m.fixed_c()) - assert 'Unable to cast Python instance' in str(excinfo.value) + assert "Unable to cast Python instance" in str(excinfo.value) with pytest.raises(RuntimeError) as excinfo: # Can't reference m.fixed_r: it contains floats, m.cpp_ref_any wants doubles m.cpp_ref_any(m.fixed_r()) - assert 'Unable to cast Python instance' in str(excinfo.value) - assert m.cpp_ref_any(m.ReturnTester.create()) == 1. + assert "Unable to cast Python instance" in str(excinfo.value) + assert m.cpp_ref_any(m.ReturnTester.create()) == 1.0 - assert m.cpp_ref_any(m.get_cm_ref()) == 21. - assert m.cpp_ref_any(m.get_cm_ref()) == 21. + assert m.cpp_ref_any(m.get_cm_ref()) == 21.0 + assert m.cpp_ref_any(m.get_cm_ref()) == 21.0 def test_pass_readonly_array(): @@ -149,7 +166,7 @@ def test_nonunit_stride_from_python(): # Mutator: m.double_threer(second_row) m.double_threec(second_col) - np.testing.assert_array_equal(counting_mat, [[0., 2, 2], [6, 16, 10], [6, 14, 8]]) + np.testing.assert_array_equal(counting_mat, [[0.0, 2, 2], [6, 16, 10], [6, 14, 8]]) def test_negative_stride_from_python(msg): @@ -178,26 +195,36 @@ def test_negative_stride_from_python(msg): # Mutator: with pytest.raises(TypeError) as excinfo: m.double_threer(second_row) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ double_threer(): incompatible function arguments. The following argument types are supported: 1. (arg0: numpy.ndarray[numpy.float32[1, 3], flags.writeable]) -> None - Invoked with: """ + repr(np.array([ 5., 4., 3.], dtype='float32')) # noqa: E501 line too long + Invoked with: """ # noqa: E501 line too long + + repr(np.array([5.0, 4.0, 3.0], dtype="float32")) + ) with pytest.raises(TypeError) as excinfo: m.double_threec(second_col) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ double_threec(): incompatible function arguments. The following argument types are supported: 1. (arg0: numpy.ndarray[numpy.float32[3, 1], flags.writeable]) -> None - Invoked with: """ + repr(np.array([ 7., 4., 1.], dtype='float32')) # noqa: E501 line too long + Invoked with: """ # noqa: E501 line too long + + repr(np.array([7.0, 4.0, 1.0], dtype="float32")) + ) def test_nonunit_stride_to_python(): assert np.all(m.diagonal(ref) == ref.diagonal()) assert np.all(m.diagonal_1(ref) == ref.diagonal(1)) for i in range(-5, 7): - assert np.all(m.diagonal_n(ref, i) == ref.diagonal(i)), "m.diagonal_n({})".format(i) + assert np.all( + m.diagonal_n(ref, i) == ref.diagonal(i) + ), "m.diagonal_n({})".format(i) assert np.all(m.block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]) assert np.all(m.block(ref, 1, 4, 4, 2) == ref[1:, 4:]) @@ -207,8 +234,10 @@ def test_nonunit_stride_to_python(): def test_eigen_ref_to_python(): chols = [m.cholesky1, m.cholesky2, m.cholesky3, m.cholesky4] for i, chol in enumerate(chols, start=1): - mymat = chol(np.array([[1., 2, 4], [2, 13, 23], [4, 23, 77]])) - assert np.all(mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])), "cholesky{}".format(i) + mymat = chol(np.array([[1.0, 2, 4], [2, 13, 23], [4, 23, 77]])) + assert np.all( + mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) + ), "cholesky{}".format(i) def assign_both(a1, a2, r, c, v): @@ -325,8 +354,12 @@ def test_eigen_return_references(): np.testing.assert_array_equal(a_block1, master[3:5, 3:5]) np.testing.assert_array_equal(a_block2, master[2:5, 2:4]) np.testing.assert_array_equal(a_block3, master[6:10, 7:10]) - np.testing.assert_array_equal(a_corn1, master[0::master.shape[0] - 1, 0::master.shape[1] - 1]) - np.testing.assert_array_equal(a_corn2, master[0::master.shape[0] - 1, 0::master.shape[1] - 1]) + np.testing.assert_array_equal( + a_corn1, master[0 :: master.shape[0] - 1, 0 :: master.shape[1] - 1] + ) + np.testing.assert_array_equal( + a_corn2, master[0 :: master.shape[0] - 1, 0 :: master.shape[1] - 1] + ) np.testing.assert_array_equal(a_copy1, c1want) np.testing.assert_array_equal(a_copy2, c2want) @@ -355,16 +388,28 @@ def test_eigen_keepalive(): cstats = ConstructorStats.get(m.ReturnTester) assert cstats.alive() == 1 unsafe = [a.ref(), a.ref_const(), a.block(1, 2, 3, 4)] - copies = [a.copy_get(), a.copy_view(), a.copy_ref(), a.copy_ref_const(), - a.copy_block(4, 3, 2, 1)] + copies = [ + a.copy_get(), + a.copy_view(), + a.copy_ref(), + a.copy_ref_const(), + a.copy_block(4, 3, 2, 1), + ] del a assert cstats.alive() == 0 del unsafe del copies - for meth in [m.ReturnTester.get, m.ReturnTester.get_ptr, m.ReturnTester.view, - m.ReturnTester.view_ptr, m.ReturnTester.ref_safe, m.ReturnTester.ref_const_safe, - m.ReturnTester.corners, m.ReturnTester.corners_const]: + for meth in [ + m.ReturnTester.get, + m.ReturnTester.get_ptr, + m.ReturnTester.view, + m.ReturnTester.view_ptr, + m.ReturnTester.ref_safe, + m.ReturnTester.ref_const_safe, + m.ReturnTester.corners, + m.ReturnTester.corners_const, + ]: assert_keeps_alive(m.ReturnTester, meth) for meth in [m.ReturnTester.block_safe, m.ReturnTester.block_const]: @@ -374,18 +419,18 @@ def test_eigen_keepalive(): def test_eigen_ref_mutators(): """Tests Eigen's ability to mutate numpy values""" - orig = np.array([[1., 2, 3], [4, 5, 6], [7, 8, 9]]) + orig = np.array([[1.0, 2, 3], [4, 5, 6], [7, 8, 9]]) zr = np.array(orig) - zc = np.array(orig, order='F') + zc = np.array(orig, order="F") m.add_rm(zr, 1, 0, 100) - assert np.all(zr == np.array([[1., 2, 3], [104, 5, 6], [7, 8, 9]])) + assert np.all(zr == np.array([[1.0, 2, 3], [104, 5, 6], [7, 8, 9]])) m.add_cm(zc, 1, 0, 200) - assert np.all(zc == np.array([[1., 2, 3], [204, 5, 6], [7, 8, 9]])) + assert np.all(zc == np.array([[1.0, 2, 3], [204, 5, 6], [7, 8, 9]])) m.add_any(zr, 1, 0, 20) - assert np.all(zr == np.array([[1., 2, 3], [124, 5, 6], [7, 8, 9]])) + assert np.all(zr == np.array([[1.0, 2, 3], [124, 5, 6], [7, 8, 9]])) m.add_any(zc, 1, 0, 10) - assert np.all(zc == np.array([[1., 2, 3], [214, 5, 6], [7, 8, 9]])) + assert np.all(zc == np.array([[1.0, 2, 3], [214, 5, 6], [7, 8, 9]])) # Can't reference a col-major array with a row-major Ref, and vice versa: with pytest.raises(TypeError): @@ -406,8 +451,8 @@ def test_eigen_ref_mutators(): cornersr = zr[0::2, 0::2] cornersc = zc[0::2, 0::2] - assert np.all(cornersr == np.array([[1., 3], [7, 9]])) - assert np.all(cornersc == np.array([[1., 3], [7, 9]])) + assert np.all(cornersr == np.array([[1.0, 3], [7, 9]])) + assert np.all(cornersc == np.array([[1.0, 3], [7, 9]])) with pytest.raises(TypeError): m.add_rm(cornersr, 0, 1, 25) @@ -419,8 +464,8 @@ def test_eigen_ref_mutators(): m.add_cm(cornersc, 0, 1, 25) m.add_any(cornersr, 0, 1, 25) m.add_any(cornersc, 0, 1, 44) - assert np.all(zr == np.array([[1., 2, 28], [4, 5, 6], [7, 8, 9]])) - assert np.all(zc == np.array([[1., 2, 47], [4, 5, 6], [7, 8, 9]])) + assert np.all(zr == np.array([[1.0, 2, 28], [4, 5, 6], [7, 8, 9]])) + assert np.all(zc == np.array([[1.0, 2, 47], [4, 5, 6], [7, 8, 9]])) # You shouldn't be allowed to pass a non-writeable array to a mutating Eigen method: zro = zr[0:4, 0:4] @@ -458,7 +503,7 @@ def test_numpy_ref_mutators(): assert not zrro.flags.owndata and not zrro.flags.writeable zc[1, 2] = 99 - expect = np.array([[11., 12, 13], [21, 22, 99], [31, 32, 33]]) + expect = np.array([[11.0, 12, 13], [21, 22, 99], [31, 32, 33]]) # We should have just changed zc, of course, but also zcro and the original eigen matrix assert np.all(zc == expect) assert np.all(zcro == expect) @@ -506,18 +551,20 @@ def test_both_ref_mutators(): assert np.all(z == z3) assert np.all(z == z4) assert np.all(z == z5) - expect = np.array([[0., 22, 20], [31, 37, 33], [41, 42, 38]]) + expect = np.array([[0.0, 22, 20], [31, 37, 33], [41, 42, 38]]) assert np.all(z == expect) - y = np.array(range(100), dtype='float64').reshape(10, 10) + y = np.array(range(100), dtype="float64").reshape(10, 10) y2 = m.incr_matrix_any(y, 10) # np -> eigen -> np - y3 = m.incr_matrix_any(y2[0::2, 0::2], -33) # np -> eigen -> np slice -> np -> eigen -> np + y3 = m.incr_matrix_any( + y2[0::2, 0::2], -33 + ) # np -> eigen -> np slice -> np -> eigen -> np y4 = m.even_rows(y3) # numpy -> eigen slice -> (... y3) y5 = m.even_cols(y4) # numpy -> eigen slice -> (... y4) y6 = m.incr_matrix_any(y5, 1000) # numpy -> eigen -> (... y5) # Apply same mutations using just numpy: - yexpect = np.array(range(100), dtype='float64').reshape(10, 10) + yexpect = np.array(range(100), dtype="float64").reshape(10, 10) yexpect += 10 yexpect[0::2, 0::2] -= 33 yexpect[0::4, 0::4] += 1000 @@ -532,10 +579,14 @@ def test_both_ref_mutators(): def test_nocopy_wrapper(): # get_elem requires a column-contiguous matrix reference, but should be # callable with other types of matrix (via copying): - int_matrix_colmajor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], order='F') - dbl_matrix_colmajor = np.array(int_matrix_colmajor, dtype='double', order='F', copy=True) - int_matrix_rowmajor = np.array(int_matrix_colmajor, order='C', copy=True) - dbl_matrix_rowmajor = np.array(int_matrix_rowmajor, dtype='double', order='C', copy=True) + int_matrix_colmajor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], order="F") + dbl_matrix_colmajor = np.array( + int_matrix_colmajor, dtype="double", order="F", copy=True + ) + int_matrix_rowmajor = np.array(int_matrix_colmajor, order="C", copy=True) + dbl_matrix_rowmajor = np.array( + int_matrix_rowmajor, dtype="double", order="C", copy=True + ) # All should be callable via get_elem: assert m.get_elem(int_matrix_colmajor) == 8 @@ -546,32 +597,38 @@ def test_nocopy_wrapper(): # All but the second should fail with m.get_elem_nocopy: with pytest.raises(TypeError) as excinfo: m.get_elem_nocopy(int_matrix_colmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) + assert "get_elem_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.f_contiguous" in str(excinfo.value) assert m.get_elem_nocopy(dbl_matrix_colmajor) == 8 with pytest.raises(TypeError) as excinfo: m.get_elem_nocopy(int_matrix_rowmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) + assert "get_elem_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.f_contiguous" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: m.get_elem_nocopy(dbl_matrix_rowmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) + assert "get_elem_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.f_contiguous" in str(excinfo.value) # For the row-major test, we take a long matrix in row-major, so only the third is allowed: with pytest.raises(TypeError) as excinfo: m.get_elem_rm_nocopy(int_matrix_colmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) + assert "get_elem_rm_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.c_contiguous" in str(excinfo.value) with pytest.raises(TypeError) as excinfo: m.get_elem_rm_nocopy(dbl_matrix_colmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) + assert "get_elem_rm_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.c_contiguous" in str(excinfo.value) assert m.get_elem_rm_nocopy(int_matrix_rowmajor) == 8 with pytest.raises(TypeError) as excinfo: m.get_elem_rm_nocopy(dbl_matrix_rowmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) + assert "get_elem_rm_nocopy(): incompatible function arguments." in str( + excinfo.value + ) and ", flags.c_contiguous" in str(excinfo.value) def test_eigen_ref_life_support(): @@ -589,12 +646,9 @@ def test_eigen_ref_life_support(): def test_special_matrix_objects(): - assert np.all(m.incr_diag(7) == np.diag([1., 2, 3, 4, 5, 6, 7])) + assert np.all(m.incr_diag(7) == np.diag([1.0, 2, 3, 4, 5, 6, 7])) - asymm = np.array([[ 1., 2, 3, 4], - [ 5, 6, 7, 8], - [ 9, 10, 11, 12], - [13, 14, 15, 16]]) + asymm = np.array([[1.0, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) symm_lower = np.array(asymm) symm_upper = np.array(asymm) for i in range(4): @@ -607,41 +661,51 @@ def test_special_matrix_objects(): def test_dense_signature(doc): - assert doc(m.double_col) == """ + assert ( + doc(m.double_col) + == """ double_col(arg0: numpy.ndarray[numpy.float32[m, 1]]) -> numpy.ndarray[numpy.float32[m, 1]] """ - assert doc(m.double_row) == """ + ) + assert ( + doc(m.double_row) + == """ double_row(arg0: numpy.ndarray[numpy.float32[1, n]]) -> numpy.ndarray[numpy.float32[1, n]] """ - assert doc(m.double_complex) == (""" + ) + assert doc(m.double_complex) == ( + """ double_complex(arg0: numpy.ndarray[numpy.complex64[m, 1]])""" - """ -> numpy.ndarray[numpy.complex64[m, 1]] - """) - assert doc(m.double_mat_rm) == (""" + """ -> numpy.ndarray[numpy.complex64[m, 1]] + """ + ) + assert doc(m.double_mat_rm) == ( + """ double_mat_rm(arg0: numpy.ndarray[numpy.float32[m, n]])""" - """ -> numpy.ndarray[numpy.float32[m, n]] - """) + """ -> numpy.ndarray[numpy.float32[m, n]] + """ + ) def test_named_arguments(): a = np.array([[1.0, 2], [3, 4], [5, 6]]) b = np.ones((2, 1)) - assert np.all(m.matrix_multiply(a, b) == np.array([[3.], [7], [11]])) - assert np.all(m.matrix_multiply(A=a, B=b) == np.array([[3.], [7], [11]])) - assert np.all(m.matrix_multiply(B=b, A=a) == np.array([[3.], [7], [11]])) + assert np.all(m.matrix_multiply(a, b) == np.array([[3.0], [7], [11]])) + assert np.all(m.matrix_multiply(A=a, B=b) == np.array([[3.0], [7], [11]])) + assert np.all(m.matrix_multiply(B=b, A=a) == np.array([[3.0], [7], [11]])) with pytest.raises(ValueError) as excinfo: m.matrix_multiply(b, a) - assert str(excinfo.value) == 'Nonconformable matrices!' + assert str(excinfo.value) == "Nonconformable matrices!" with pytest.raises(ValueError) as excinfo: m.matrix_multiply(A=b, B=a) - assert str(excinfo.value) == 'Nonconformable matrices!' + assert str(excinfo.value) == "Nonconformable matrices!" with pytest.raises(ValueError) as excinfo: m.matrix_multiply(B=a, A=b) - assert str(excinfo.value) == 'Nonconformable matrices!' + assert str(excinfo.value) == "Nonconformable matrices!" def test_sparse(): @@ -656,21 +720,31 @@ def test_sparse(): def test_sparse_signature(doc): pytest.importorskip("scipy") - assert doc(m.sparse_copy_r) == """ + assert ( + doc(m.sparse_copy_r) + == """ sparse_copy_r(arg0: scipy.sparse.csr_matrix[numpy.float32]) -> scipy.sparse.csr_matrix[numpy.float32] """ # noqa: E501 line too long - assert doc(m.sparse_copy_c) == """ + ) + assert ( + doc(m.sparse_copy_c) + == """ sparse_copy_c(arg0: scipy.sparse.csc_matrix[numpy.float32]) -> scipy.sparse.csc_matrix[numpy.float32] """ # noqa: E501 line too long + ) def test_issue738(): """Ignore strides on a length-1 dimension (even if they would be incompatible length > 1)""" - assert np.all(m.iss738_f1(np.array([[1., 2, 3]])) == np.array([[1., 102, 203]])) - assert np.all(m.iss738_f1(np.array([[1.], [2], [3]])) == np.array([[1.], [12], [23]])) + assert np.all(m.iss738_f1(np.array([[1.0, 2, 3]])) == np.array([[1.0, 102, 203]])) + assert np.all( + m.iss738_f1(np.array([[1.0], [2], [3]])) == np.array([[1.0], [12], [23]]) + ) - assert np.all(m.iss738_f2(np.array([[1., 2, 3]])) == np.array([[1., 102, 203]])) - assert np.all(m.iss738_f2(np.array([[1.], [2], [3]])) == np.array([[1.], [12], [23]])) + assert np.all(m.iss738_f2(np.array([[1.0, 2, 3]])) == np.array([[1.0, 102, 203]])) + assert np.all( + m.iss738_f2(np.array([[1.0], [2], [3]])) == np.array([[1.0], [12], [23]]) + ) def test_issue1105(): diff --git a/wrap/pybind11/tests/test_embed/CMakeLists.txt b/wrap/pybind11/tests/test_embed/CMakeLists.txt index 2e298fa7e..edb8961a7 100644 --- a/wrap/pybind11/tests/test_embed/CMakeLists.txt +++ b/wrap/pybind11/tests/test_embed/CMakeLists.txt @@ -1,10 +1,13 @@ +possibly_uninitialized(PYTHON_MODULE_EXTENSION Python_INTERPRETER_ID) + if("${PYTHON_MODULE_EXTENSION}" MATCHES "pypy" OR "${Python_INTERPRETER_ID}" STREQUAL "PyPy") + message(STATUS "Skipping embed test on PyPy") add_custom_target(cpptest) # Dummy target on PyPy. Embedding is not supported. set(_suppress_unused_variable_warning "${DOWNLOAD_CATCH}") return() endif() -find_package(Catch 2.13.0) +find_package(Catch 2.13.2) if(CATCH_FOUND) message(STATUS "Building interpreter tests using Catch v${CATCH_VERSION}") @@ -22,12 +25,13 @@ pybind11_enable_warnings(test_embed) target_link_libraries(test_embed PRIVATE pybind11::embed Catch2::Catch2 Threads::Threads) if(NOT CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR) - file(COPY test_interpreter.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") + file(COPY test_interpreter.py test_trampoline.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}") endif() add_custom_target( cpptest COMMAND "$" + DEPENDS test_embed WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") pybind11_add_module(external_module THIN_LTO external_module.cpp) diff --git a/wrap/pybind11/tests/test_embed/external_module.cpp b/wrap/pybind11/tests/test_embed/external_module.cpp index e9a6058b1..490952299 100644 --- a/wrap/pybind11/tests/test_embed/external_module.cpp +++ b/wrap/pybind11/tests/test_embed/external_module.cpp @@ -9,7 +9,7 @@ namespace py = pybind11; PYBIND11_MODULE(external_module, m) { class A { public: - A(int value) : v{value} {}; + explicit A(int value) : v{value} {}; int v; }; diff --git a/wrap/pybind11/tests/test_embed/test_interpreter.cpp b/wrap/pybind11/tests/test_embed/test_interpreter.cpp index 753ce54dc..508975eb3 100644 --- a/wrap/pybind11/tests/test_embed/test_interpreter.cpp +++ b/wrap/pybind11/tests/test_embed/test_interpreter.cpp @@ -8,20 +8,23 @@ #include -#include +#include #include #include +#include +#include namespace py = pybind11; using namespace py::literals; class Widget { public: - Widget(std::string message) : message(message) { } + explicit Widget(std::string message) : message(std::move(message)) {} virtual ~Widget() = default; std::string the_message() const { return message; } virtual int the_answer() const = 0; + virtual std::string argv0() const = 0; private: std::string message; @@ -31,6 +34,23 @@ class PyWidget final : public Widget { using Widget::Widget; int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); } + std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); } +}; + +class test_override_cache_helper { + +public: + virtual int func() { return 0; } + + test_override_cache_helper() = default; + virtual ~test_override_cache_helper() = default; + // Non-copyable + test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete; + test_override_cache_helper(test_override_cache_helper const &Copy) = delete; +}; + +class test_override_cache_helper_trampoline : public test_override_cache_helper { + int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); } }; PYBIND11_EMBEDDED_MODULE(widget_module, m) { @@ -41,6 +61,12 @@ PYBIND11_EMBEDDED_MODULE(widget_module, m) { m.def("add", [](int i, int j) { return i + j; }); } +PYBIND11_EMBEDDED_MODULE(trampoline_module, m) { + py::class_>(m, "test_override_cache_helper") + .def(py::init_alias<>()) + .def("func", &test_override_cache_helper::func); +} + PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); } @@ -51,17 +77,17 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { } TEST_CASE("Pass classes and data between modules defined in C++ and Python") { - auto module = py::module::import("test_interpreter"); - REQUIRE(py::hasattr(module, "DerivedWidget")); + auto module_ = py::module_::import("test_interpreter"); + REQUIRE(py::hasattr(module_, "DerivedWidget")); - auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__")); + auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module_.attr("__dict__")); py::exec(R"( widget = DerivedWidget("{} - {}".format(hello, x)) message = widget.the_message )", py::globals(), locals); REQUIRE(locals["message"].cast() == "Hello, World! - 5"); - auto py_widget = module.attr("DerivedWidget")("The question"); + auto py_widget = module_.attr("DerivedWidget")("The question"); auto message = py_widget.attr("the_message"); REQUIRE(message.cast() == "The question"); @@ -69,12 +95,55 @@ TEST_CASE("Pass classes and data between modules defined in C++ and Python") { REQUIRE(cpp_widget.the_answer() == 42); } +TEST_CASE("Override cache") { + auto module_ = py::module_::import("test_trampoline"); + REQUIRE(py::hasattr(module_, "func")); + REQUIRE(py::hasattr(module_, "func2")); + + auto locals = py::dict(**module_.attr("__dict__")); + + int i = 0; + for (; i < 1500; ++i) { + std::shared_ptr p_obj; + std::shared_ptr p_obj2; + + py::object loc_inst = locals["func"](); + p_obj = py::cast>(loc_inst); + + int ret = p_obj->func(); + + REQUIRE(ret == 42); + + loc_inst = locals["func2"](); + + p_obj2 = py::cast>(loc_inst); + + p_obj2->func(); + } +} + TEST_CASE("Import error handling") { - REQUIRE_NOTHROW(py::module::import("widget_module")); - REQUIRE_THROWS_WITH(py::module::import("throw_exception"), + REQUIRE_NOTHROW(py::module_::import("widget_module")); + REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error"); - REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"), +#if PY_VERSION_HEX >= 0x03030000 + REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"), + Catch::Contains("ImportError: initialization failed")); + + auto locals = py::dict("is_keyerror"_a=false, "message"_a="not set"); + py::exec(R"( + try: + import throw_error_already_set + except ImportError as e: + is_keyerror = type(e.__cause__) == KeyError + message = str(e.__cause__) + )", py::globals(), locals); + REQUIRE(locals["is_keyerror"].cast() == true); + REQUIRE(locals["message"].cast() == "'missing'"); +#else + REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"), Catch::Contains("ImportError: KeyError")); +#endif } TEST_CASE("There can be only one interpreter") { @@ -102,19 +171,19 @@ bool has_pybind11_internals_builtin() { bool has_pybind11_internals_static() { auto **&ipp = py::detail::get_internals_pp(); - return ipp && *ipp; + return (ipp != nullptr) && (*ipp != nullptr); } TEST_CASE("Restart the interpreter") { // Verify pre-restart state. - REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast() == 3); + REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast() == 3); REQUIRE(has_pybind11_internals_builtin()); REQUIRE(has_pybind11_internals_static()); - REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast() == 123); + REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast() == 123); // local and foreign module internals should point to the same internals: REQUIRE(reinterpret_cast(*py::detail::get_internals_pp()) == - py::module::import("external_module").attr("internals_at")().cast()); + py::module_::import("external_module").attr("internals_at")().cast()); // Restart the interpreter. py::finalize_interpreter(); @@ -130,14 +199,14 @@ TEST_CASE("Restart the interpreter") { REQUIRE(has_pybind11_internals_builtin()); REQUIRE(has_pybind11_internals_static()); REQUIRE(reinterpret_cast(*py::detail::get_internals_pp()) == - py::module::import("external_module").attr("internals_at")().cast()); + py::module_::import("external_module").attr("internals_at")().cast()); // Make sure that an interpreter with no get_internals() created until finalize still gets the // internals destroyed py::finalize_interpreter(); py::initialize_interpreter(); bool ran = false; - py::module::import("__main__").attr("internals_destroy_test") = + py::module_::import("__main__").attr("internals_destroy_test") = py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast(ran) = true; }); REQUIRE_FALSE(has_pybind11_internals_builtin()); REQUIRE_FALSE(has_pybind11_internals_static()); @@ -149,20 +218,20 @@ TEST_CASE("Restart the interpreter") { REQUIRE_FALSE(has_pybind11_internals_static()); // C++ modules can be reloaded. - auto cpp_module = py::module::import("widget_module"); + auto cpp_module = py::module_::import("widget_module"); REQUIRE(cpp_module.attr("add")(1, 2).cast() == 3); // C++ type information is reloaded and can be used in python modules. - auto py_module = py::module::import("test_interpreter"); + auto py_module = py::module_::import("test_interpreter"); auto py_widget = py_module.attr("DerivedWidget")("Hello after restart"); REQUIRE(py_widget.attr("the_message").cast() == "Hello after restart"); } TEST_CASE("Subinterpreter") { // Add tags to the modules in the main interpreter and test the basics. - py::module::import("__main__").attr("main_tag") = "main interpreter"; + py::module_::import("__main__").attr("main_tag") = "main interpreter"; { - auto m = py::module::import("widget_module"); + auto m = py::module_::import("widget_module"); m.attr("extension_module_tag") = "added to module in main interpreter"; REQUIRE(m.attr("add")(1, 2).cast() == 3); @@ -181,9 +250,9 @@ TEST_CASE("Subinterpreter") { REQUIRE(has_pybind11_internals_static()); // Modules tags should be gone. - REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag")); + REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag")); { - auto m = py::module::import("widget_module"); + auto m = py::module_::import("widget_module"); REQUIRE_FALSE(py::hasattr(m, "extension_module_tag")); // Function bindings should still work. @@ -194,8 +263,8 @@ TEST_CASE("Subinterpreter") { Py_EndInterpreter(sub_tstate); PyThreadState_Swap(main_tstate); - REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag")); - REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag")); + REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag")); + REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag")); } TEST_CASE("Execution frame") { @@ -245,7 +314,7 @@ TEST_CASE("Reload module from file") { // Disable generation of cached bytecode (.pyc files) for this test, otherwise // Python might pick up an old version from the cache instead of the new versions // of the .py files generated below - auto sys = py::module::import("sys"); + auto sys = py::module_::import("sys"); bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast(); sys.attr("dont_write_bytecode") = true; // Reset the value at scope exit @@ -267,8 +336,8 @@ TEST_CASE("Reload module from file") { }); // Import the module from file - auto module = py::module::import(module_name.c_str()); - int result = module.attr("test")().cast(); + auto module_ = py::module_::import(module_name.c_str()); + int result = module_.attr("test")().cast(); REQUIRE(result == 1); // Update the module .py file with a small change @@ -278,7 +347,29 @@ TEST_CASE("Reload module from file") { test_module.close(); // Reload the module - module.reload(); - result = module.attr("test")().cast(); + module_.reload(); + result = module_.attr("test")().cast(); REQUIRE(result == 2); } + +TEST_CASE("sys.argv gets initialized properly") { + py::finalize_interpreter(); + { + py::scoped_interpreter default_scope; + auto module = py::module::import("test_interpreter"); + auto py_widget = module.attr("DerivedWidget")("The question"); + const auto &cpp_widget = py_widget.cast(); + REQUIRE(cpp_widget.argv0().empty()); + } + + { + char *argv[] = {strdup("a.out")}; + py::scoped_interpreter argv_scope(true, 1, argv); + std::free(argv[0]); + auto module = py::module::import("test_interpreter"); + auto py_widget = module.attr("DerivedWidget")("The question"); + const auto &cpp_widget = py_widget.cast(); + REQUIRE(cpp_widget.argv0() == "a.out"); + } + py::initialize_interpreter(); +} diff --git a/wrap/pybind11/tests/test_embed/test_interpreter.py b/wrap/pybind11/tests/test_embed/test_interpreter.py index 6174ede44..5ab55a4b3 100644 --- a/wrap/pybind11/tests/test_embed/test_interpreter.py +++ b/wrap/pybind11/tests/test_embed/test_interpreter.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import sys + from widget_module import Widget @@ -8,3 +10,6 @@ class DerivedWidget(Widget): def the_answer(self): return 42 + + def argv0(self): + return sys.argv[0] diff --git a/wrap/pybind11/tests/test_embed/test_trampoline.py b/wrap/pybind11/tests/test_embed/test_trampoline.py new file mode 100644 index 000000000..87c8fa44c --- /dev/null +++ b/wrap/pybind11/tests/test_embed/test_trampoline.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + +import trampoline_module + + +def func(): + class Test(trampoline_module.test_override_cache_helper): + def func(self): + return 42 + + return Test() + + +def func2(): + class Test(trampoline_module.test_override_cache_helper): + pass + + return Test() diff --git a/wrap/pybind11/tests/test_enum.cpp b/wrap/pybind11/tests/test_enum.cpp index 315308920..40c48d412 100644 --- a/wrap/pybind11/tests/test_enum.cpp +++ b/wrap/pybind11/tests/test_enum.cpp @@ -84,4 +84,65 @@ TEST_SUBMODULE(enums, m) { .value("ONE", SimpleEnum::THREE) .export_values(); }); + + // test_enum_scalar + enum UnscopedUCharEnum : unsigned char {}; + enum class ScopedShortEnum : short {}; + enum class ScopedLongEnum : long {}; + enum UnscopedUInt64Enum : std::uint64_t {}; + static_assert(py::detail::all_of< + std::is_same::Scalar, unsigned char>, + std::is_same::Scalar, short>, + std::is_same::Scalar, long>, + std::is_same::Scalar, std::uint64_t> + >::value, "Error during the deduction of enum's scalar type with normal integer underlying"); + + // test_enum_scalar_with_char_underlying + enum class ScopedCharEnum : char { Zero, Positive }; + enum class ScopedWCharEnum : wchar_t { Zero, Positive }; + enum class ScopedChar32Enum : char32_t { Zero, Positive }; + enum class ScopedChar16Enum : char16_t { Zero, Positive }; + + // test the scalar of char type enums according to chapter 'Character types' + // from https://en.cppreference.com/w/cpp/language/types + static_assert(py::detail::any_of< + std::is_same::Scalar, signed char>, // e.g. gcc on x86 + std::is_same::Scalar, unsigned char> // e.g. arm linux + >::value, "char should be cast to either signed char or unsigned char"); + static_assert( + sizeof(py::enum_::Scalar) == 2 || + sizeof(py::enum_::Scalar) == 4 + , "wchar_t should be either 16 bits (Windows) or 32 (everywhere else)"); + static_assert(py::detail::all_of< + std::is_same::Scalar, std::uint_least32_t>, + std::is_same::Scalar, std::uint_least16_t> + >::value, "char32_t, char16_t (and char8_t)'s size, signedness, and alignment is determined"); +#if defined(PYBIND11_HAS_U8STRING) + enum class ScopedChar8Enum : char8_t { Zero, Positive }; + static_assert(std::is_same::Scalar, unsigned char>::value); +#endif + + // test_char_underlying_enum + py::enum_(m, "ScopedCharEnum") + .value("Zero", ScopedCharEnum::Zero) + .value("Positive", ScopedCharEnum::Positive); + py::enum_(m, "ScopedWCharEnum") + .value("Zero", ScopedWCharEnum::Zero) + .value("Positive", ScopedWCharEnum::Positive); + py::enum_(m, "ScopedChar32Enum") + .value("Zero", ScopedChar32Enum::Zero) + .value("Positive", ScopedChar32Enum::Positive); + py::enum_(m, "ScopedChar16Enum") + .value("Zero", ScopedChar16Enum::Zero) + .value("Positive", ScopedChar16Enum::Positive); + + // test_bool_underlying_enum + enum class ScopedBoolEnum : bool { FALSE, TRUE }; + + // bool is unsigned (std::is_signed returns false) and 1-byte long, so represented with u8 + static_assert(std::is_same::Scalar, std::uint8_t>::value, ""); + + py::enum_(m, "ScopedBoolEnum") + .value("FALSE", ScopedBoolEnum::FALSE) + .value("TRUE", ScopedBoolEnum::TRUE); } diff --git a/wrap/pybind11/tests/test_enum.py b/wrap/pybind11/tests/test_enum.py index bfaa193e9..14c754e72 100644 --- a/wrap/pybind11/tests/test_enum.py +++ b/wrap/pybind11/tests/test_enum.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import pytest + +import env from pybind11_tests import enums as m @@ -7,32 +9,50 @@ def test_unscoped_enum(): assert str(m.UnscopedEnum.EOne) == "UnscopedEnum.EOne" assert str(m.UnscopedEnum.ETwo) == "UnscopedEnum.ETwo" assert str(m.EOne) == "UnscopedEnum.EOne" + assert repr(m.UnscopedEnum.EOne) == "" + assert repr(m.UnscopedEnum.ETwo) == "" + assert repr(m.EOne) == "" # name property assert m.UnscopedEnum.EOne.name == "EOne" + assert m.UnscopedEnum.EOne.value == 1 assert m.UnscopedEnum.ETwo.name == "ETwo" - assert m.EOne.name == "EOne" - # name readonly + assert m.UnscopedEnum.ETwo.value == 2 + assert m.EOne is m.UnscopedEnum.EOne + # name, value readonly with pytest.raises(AttributeError): m.UnscopedEnum.EOne.name = "" - # name returns a copy - foo = m.UnscopedEnum.EOne.name - foo = "bar" + with pytest.raises(AttributeError): + m.UnscopedEnum.EOne.value = 10 + # name, value returns a copy + # TODO: Neither the name nor value tests actually check against aliasing. + # Use a mutable type that has reference semantics. + nonaliased_name = m.UnscopedEnum.EOne.name + nonaliased_name = "bar" # noqa: F841 assert m.UnscopedEnum.EOne.name == "EOne" + nonaliased_value = m.UnscopedEnum.EOne.value + nonaliased_value = 10 # noqa: F841 + assert m.UnscopedEnum.EOne.value == 1 # __members__ property - assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} + assert m.UnscopedEnum.__members__ == { + "EOne": m.UnscopedEnum.EOne, + "ETwo": m.UnscopedEnum.ETwo, + "EThree": m.UnscopedEnum.EThree, + } # __members__ readonly with pytest.raises(AttributeError): m.UnscopedEnum.__members__ = {} # __members__ returns a copy - foo = m.UnscopedEnum.__members__ - foo["bar"] = "baz" - assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} + nonaliased_members = m.UnscopedEnum.__members__ + nonaliased_members["bar"] = "baz" + assert m.UnscopedEnum.__members__ == { + "EOne": m.UnscopedEnum.EOne, + "ETwo": m.UnscopedEnum.ETwo, + "EThree": m.UnscopedEnum.EThree, + } - for docstring_line in '''An unscoped enumeration + for docstring_line in """An unscoped enumeration Members: @@ -40,7 +60,9 @@ Members: ETwo : Docstring for ETwo - EThree : Docstring for EThree'''.split('\n'): + EThree : Docstring for EThree""".split( + "\n" + ): assert docstring_line in m.UnscopedEnum.__doc__ # Unscoped enums will accept ==/!= int comparisons @@ -50,10 +72,10 @@ Members: assert y != 3 assert 3 != y # Compare with None - assert (y != None) # noqa: E711 + assert y != None # noqa: E711 assert not (y == None) # noqa: E711 # Compare with an object - assert (y != object()) + assert y != object() assert not (y == object()) # Compare with string assert y != "2" @@ -62,16 +84,16 @@ Members: assert not (y == "2") with pytest.raises(TypeError): - y < object() + y < object() # noqa: B015 with pytest.raises(TypeError): - y <= object() + y <= object() # noqa: B015 with pytest.raises(TypeError): - y > object() + y > object() # noqa: B015 with pytest.raises(TypeError): - y >= object() + y >= object() # noqa: B015 with pytest.raises(TypeError): y | object() @@ -116,20 +138,20 @@ def test_scoped_enum(): assert z != 3 assert 3 != z # Compare with None - assert (z != None) # noqa: E711 + assert z != None # noqa: E711 assert not (z == None) # noqa: E711 # Compare with an object - assert (z != object()) + assert z != object() assert not (z == object()) # Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions) with pytest.raises(TypeError): - z > 3 + z > 3 # noqa: B015 with pytest.raises(TypeError): - z < 3 + z < 3 # noqa: B015 with pytest.raises(TypeError): - z >= 3 + z >= 3 # noqa: B015 with pytest.raises(TypeError): - z <= 3 + z <= 3 # noqa: B015 # order assert m.ScopedEnum.Two < m.ScopedEnum.Three @@ -143,6 +165,8 @@ def test_scoped_enum(): def test_implicit_conversion(): assert str(m.ClassWithUnscopedEnum.EMode.EFirstMode) == "EMode.EFirstMode" assert str(m.ClassWithUnscopedEnum.EFirstMode) == "EMode.EFirstMode" + assert repr(m.ClassWithUnscopedEnum.EMode.EFirstMode) == "" + assert repr(m.ClassWithUnscopedEnum.EFirstMode) == "" f = m.ClassWithUnscopedEnum.test_function first = m.ClassWithUnscopedEnum.EFirstMode @@ -167,7 +191,7 @@ def test_implicit_conversion(): x[f(first)] = 3 x[f(second)] = 4 # Hashing test - assert str(x) == "{EMode.EFirstMode: 3, EMode.ESecondMode: 4}" + assert repr(x) == "{: 3, : 4}" def test_binary_operators(): @@ -195,13 +219,54 @@ def test_binary_operators(): def test_enum_to_int(): m.test_enum_to_int(m.Flags.Read) m.test_enum_to_int(m.ClassWithUnscopedEnum.EMode.EFirstMode) + m.test_enum_to_int(m.ScopedCharEnum.Positive) + m.test_enum_to_int(m.ScopedBoolEnum.TRUE) m.test_enum_to_uint(m.Flags.Read) m.test_enum_to_uint(m.ClassWithUnscopedEnum.EMode.EFirstMode) + m.test_enum_to_uint(m.ScopedCharEnum.Positive) + m.test_enum_to_uint(m.ScopedBoolEnum.TRUE) m.test_enum_to_long_long(m.Flags.Read) m.test_enum_to_long_long(m.ClassWithUnscopedEnum.EMode.EFirstMode) + m.test_enum_to_long_long(m.ScopedCharEnum.Positive) + m.test_enum_to_long_long(m.ScopedBoolEnum.TRUE) def test_duplicate_enum_name(): with pytest.raises(ValueError) as excinfo: m.register_bad_enum() assert str(excinfo.value) == 'SimpleEnum: element "ONE" already exists!' + + +def test_char_underlying_enum(): # Issue #1331/PR #1334: + assert type(m.ScopedCharEnum.Positive.__int__()) is int + assert int(m.ScopedChar16Enum.Zero) == 0 + assert hash(m.ScopedChar32Enum.Positive) == 1 + if env.PY2: + assert m.ScopedCharEnum.Positive.__getstate__() == 1 # long + else: + assert type(m.ScopedCharEnum.Positive.__getstate__()) is int + assert m.ScopedWCharEnum(1) == m.ScopedWCharEnum.Positive + with pytest.raises(TypeError): + # Even if the underlying type is char, only an int can be used to construct the enum: + m.ScopedCharEnum("0") + + +def test_bool_underlying_enum(): + assert type(m.ScopedBoolEnum.TRUE.__int__()) is int + assert int(m.ScopedBoolEnum.FALSE) == 0 + assert hash(m.ScopedBoolEnum.TRUE) == 1 + if env.PY2: + assert m.ScopedBoolEnum.TRUE.__getstate__() == 1 # long + else: + assert type(m.ScopedBoolEnum.TRUE.__getstate__()) is int + assert m.ScopedBoolEnum(1) == m.ScopedBoolEnum.TRUE + # Enum could construct with a bool + # (bool is a strict subclass of int, and False will be converted to 0) + assert m.ScopedBoolEnum(False) == m.ScopedBoolEnum.FALSE + + +def test_docstring_signatures(): + for enum_type in [m.ScopedEnum, m.UnscopedEnum]: + for attr in enum_type.__dict__.values(): + # Issue #2623/PR #2637: Add argument names to enum_ methods + assert "arg0" not in (attr.__doc__ or "") diff --git a/wrap/pybind11/tests/test_eval.cpp b/wrap/pybind11/tests/test_eval.cpp index e09482191..29366f679 100644 --- a/wrap/pybind11/tests/test_eval.cpp +++ b/wrap/pybind11/tests/test_eval.cpp @@ -9,12 +9,14 @@ #include + #include "pybind11_tests.h" +#include TEST_SUBMODULE(eval_, m) { // test_evals - auto global = py::dict(py::module::import("__main__").attr("__dict__")); + auto global = py::dict(py::module_::import("__main__").attr("__dict__")); m.def("test_eval_statements", [global]() { auto local = py::dict(); @@ -64,10 +66,10 @@ TEST_SUBMODULE(eval_, m) { auto local = py::dict(); local["y"] = py::int_(43); - int val_out; + int val_out = 0; local["call_test2"] = py::cpp_function([&](int value) { val_out = value; }); - auto result = py::eval_file(filename, global, local); + auto result = py::eval_file(std::move(filename), global, local); return val_out == 43 && result.is_none(); }); @@ -88,4 +90,30 @@ TEST_SUBMODULE(eval_, m) { } return false; }); + + // test_eval_empty_globals + m.def("eval_empty_globals", [](py::object global) { + if (global.is_none()) + global = py::dict(); + auto int_class = py::eval("isinstance(42, int)", global); + return global; + }); + + // test_eval_closure + m.def("test_eval_closure", []() { + py::dict global; + global["closure_value"] = 42; + py::dict local; + local["closure_value"] = 0; + py::exec(R"( + local_value = closure_value + + def func_global(): + return closure_value + + def func_local(): + return local_value + )", global, local); + return std::make_pair(global, local); + }); } diff --git a/wrap/pybind11/tests/test_eval.py b/wrap/pybind11/tests/test_eval.py index b6f9d1881..1bbd991bc 100644 --- a/wrap/pybind11/tests/test_eval.py +++ b/wrap/pybind11/tests/test_eval.py @@ -4,7 +4,6 @@ import os import pytest import env # noqa: F401 - from pybind11_tests import eval_ as m @@ -25,3 +24,28 @@ def test_eval_file(): assert m.test_eval_file(filename) assert m.test_eval_file_failure() + + +def test_eval_empty_globals(): + assert "__builtins__" in m.eval_empty_globals(None) + + g = {} + assert "__builtins__" in m.eval_empty_globals(g) + assert "__builtins__" in g + + +def test_eval_closure(): + global_, local = m.test_eval_closure() + + assert global_["closure_value"] == 42 + assert local["closure_value"] == 0 + + assert "local_value" not in global_ + assert local["local_value"] == 0 + + assert "func_global" not in global_ + assert local["func_global"]() == 42 + + assert "func_local" not in global_ + with pytest.raises(NameError): + local["func_local"]() diff --git a/wrap/pybind11/tests/test_eval_call.py b/wrap/pybind11/tests/test_eval_call.py index d42a0a6d3..373b67bac 100644 --- a/wrap/pybind11/tests/test_eval_call.py +++ b/wrap/pybind11/tests/test_eval_call.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- # This file is called from 'test_eval.py' -if 'call_test2' in locals(): +if "call_test2" in locals(): call_test2(y) # noqa: F821 undefined name diff --git a/wrap/pybind11/tests/test_exceptions.cpp b/wrap/pybind11/tests/test_exceptions.cpp index 6187f2efb..3aa967382 100644 --- a/wrap/pybind11/tests/test_exceptions.cpp +++ b/wrap/pybind11/tests/test_exceptions.cpp @@ -6,8 +6,14 @@ All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. */ +#include "test_exceptions.h" + +#include "local_bindings.h" #include "pybind11_tests.h" +#include +#include +#include // A type that should be raised as an exception in Python class MyException : public std::exception { @@ -32,6 +38,13 @@ class MyException3 { public: explicit MyException3(const char * m) : message{m} {} virtual const char * what() const noexcept {return message.c_str();} + // Rule of 5 BEGIN: to preempt compiler warnings. + MyException3(const MyException3&) = default; + MyException3(MyException3&&) = default; + MyException3& operator=(const MyException3&) = default; + MyException3& operator=(MyException3&&) = default; + virtual ~MyException3() = default; + // Rule of 5 END. private: std::string message = ""; }; @@ -58,8 +71,19 @@ class MyException5_1 : public MyException5 { using MyException5::MyException5; }; + +// Exception that will be caught via the module local translator. +class MyException6 : public std::exception { +public: + explicit MyException6(const char * m) : message{m} {} + const char * what() const noexcept override {return message.c_str();} +private: + std::string message = ""; +}; + + struct PythonCallInDestructor { - PythonCallInDestructor(const py::dict &d) : d(d) {} + explicit PythonCallInDestructor(const py::dict &d) : d(d) {} ~PythonCallInDestructor() { d["good"] = true; } py::dict d; @@ -68,7 +92,7 @@ struct PythonCallInDestructor { struct PythonAlreadySetInDestructor { - PythonAlreadySetInDestructor(const py::str &s) : s(s) {} + explicit PythonAlreadySetInDestructor(const py::str &s) : s(s) {} ~PythonAlreadySetInDestructor() { py::dict foo; try { @@ -83,7 +107,6 @@ struct PythonAlreadySetInDestructor { py::str s; }; - TEST_SUBMODULE(exceptions, m) { m.def("throw_std_exception", []() { throw std::runtime_error("This exception was intentionally thrown."); @@ -128,14 +151,29 @@ TEST_SUBMODULE(exceptions, m) { // A slightly more complicated one that declares MyException5_1 as a subclass of MyException5 py::register_exception(m, "MyException5_1", ex5.ptr()); + //py::register_local_exception(m, "LocalSimpleException") + + py::register_local_exception_translator([](std::exception_ptr p) { + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const MyException6 &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + }); + m.def("throws1", []() { throw MyException("this error should go to a custom type"); }); m.def("throws2", []() { throw MyException2("this error should go to a standard Python exception"); }); m.def("throws3", []() { throw MyException3("this error cannot be translated"); }); m.def("throws4", []() { throw MyException4("this error is rethrown"); }); m.def("throws5", []() { throw MyException5("this is a helper-defined translated exception"); }); m.def("throws5_1", []() { throw MyException5_1("MyException5 subclass"); }); + m.def("throws6", []() { throw MyException6("MyException6 only handled in this module"); }); m.def("throws_logic_error", []() { throw std::logic_error("this error should fall through to the standard handler"); }); - m.def("throws_overflow_error", []() {throw std::overflow_error(""); }); + m.def("throws_overflow_error", []() { throw std::overflow_error(""); }); + m.def("throws_local_error", []() { throw LocalException("never caught"); }); + m.def("throws_local_simple_error", []() { throw LocalSimpleException("this mod"); }); m.def("exception_matches", []() { py::dict foo; try { @@ -163,7 +201,7 @@ TEST_SUBMODULE(exceptions, m) { m.def("modulenotfound_exception_matches_base", []() { try { // On Python >= 3.6, this raises a ModuleNotFoundError, a subclass of ImportError - py::module::import("nonexistent"); + py::module_::import("nonexistent"); } catch (py::error_already_set &ex) { if (!ex.matches(PyExc_ImportError)) throw; @@ -191,34 +229,65 @@ TEST_SUBMODULE(exceptions, m) { throw py::error_already_set(); }); - m.def("python_call_in_destructor", [](py::dict d) { + m.def("python_call_in_destructor", [](const py::dict &d) { + bool retval = false; try { PythonCallInDestructor set_dict_in_destructor(d); PyErr_SetString(PyExc_ValueError, "foo"); throw py::error_already_set(); } catch (const py::error_already_set&) { - return true; + retval = true; } - return false; + return retval; }); - m.def("python_alreadyset_in_destructor", [](py::str s) { + m.def("python_alreadyset_in_destructor", [](const py::str &s) { PythonAlreadySetInDestructor alreadyset_in_destructor(s); return true; }); // test_nested_throws - m.def("try_catch", [m](py::object exc_type, py::function f, py::args args) { - try { f(*args); } - catch (py::error_already_set &ex) { - if (ex.matches(exc_type)) - py::print(ex.what()); - else - throw; - } - }); + m.def("try_catch", + [m](const py::object &exc_type, const py::function &f, const py::args &args) { + try { + f(*args); + } catch (py::error_already_set &ex) { + if (ex.matches(exc_type)) + py::print(ex.what()); + else + throw; + } + }); // Test repr that cannot be displayed m.def("simple_bool_passthrough", [](bool x) {return x;}); + m.def("throw_should_be_translated_to_key_error", []() { throw shared_exception(); }); + +#if PY_VERSION_HEX >= 0x03030000 + + m.def("raise_from", []() { + PyErr_SetString(PyExc_ValueError, "inner"); + py::raise_from(PyExc_ValueError, "outer"); + throw py::error_already_set(); + }); + + m.def("raise_from_already_set", []() { + try { + PyErr_SetString(PyExc_ValueError, "inner"); + throw py::error_already_set(); + } catch (py::error_already_set& e) { + py::raise_from(e, PyExc_ValueError, "outer"); + throw py::error_already_set(); + } + }); + + m.def("throw_nested_exception", []() { + try { + throw std::runtime_error("Inner Exception"); + } catch (const std::runtime_error &) { + std::throw_with_nested(std::runtime_error("Outer Exception")); + } + }); +#endif } diff --git a/wrap/pybind11/tests/test_exceptions.h b/wrap/pybind11/tests/test_exceptions.h new file mode 100644 index 000000000..9d428312e --- /dev/null +++ b/wrap/pybind11/tests/test_exceptions.h @@ -0,0 +1,12 @@ +#pragma once +#include "pybind11_tests.h" +#include + +// shared exceptions for cross_module_tests + +class PYBIND11_EXPORT_EXCEPTION shared_exception : public pybind11::builtin_exception { +public: + using builtin_exception::builtin_exception; + explicit shared_exception() : shared_exception("") {} + void set_error() const override { PyErr_SetString(PyExc_RuntimeError, what()); } +}; diff --git a/wrap/pybind11/tests/test_exceptions.py b/wrap/pybind11/tests/test_exceptions.py index 7d7088d00..d698b1312 100644 --- a/wrap/pybind11/tests/test_exceptions.py +++ b/wrap/pybind11/tests/test_exceptions.py @@ -3,8 +3,9 @@ import sys import pytest -from pybind11_tests import exceptions as m +import env import pybind11_cross_module_tests as cm +from pybind11_tests import exceptions as m def test_std_exception(msg): @@ -23,7 +24,23 @@ def test_error_already_set(msg): assert msg(excinfo.value) == "foo" -def test_cross_module_exceptions(): +@pytest.mark.skipif("env.PY2") +def test_raise_from(msg): + with pytest.raises(ValueError) as excinfo: + m.raise_from() + assert msg(excinfo.value) == "outer" + assert msg(excinfo.value.__cause__) == "inner" + + +@pytest.mark.skipif("env.PY2") +def test_raise_from_already_set(msg): + with pytest.raises(ValueError) as excinfo: + m.raise_from_already_set() + assert msg(excinfo.value) == "outer" + assert msg(excinfo.value.__cause__) == "inner" + + +def test_cross_module_exceptions(msg): with pytest.raises(RuntimeError) as excinfo: cm.raise_runtime_error() assert str(excinfo.value) == "My runtime error" @@ -43,6 +60,27 @@ def test_cross_module_exceptions(): with pytest.raises(StopIteration) as excinfo: cm.throw_stop_iteration() + with pytest.raises(cm.LocalSimpleException) as excinfo: + cm.throw_local_simple_error() + assert msg(excinfo.value) == "external mod" + + with pytest.raises(KeyError) as excinfo: + cm.throw_local_error() + # KeyError is a repr of the key, so it has an extra set of quotes + assert str(excinfo.value) == "'just local'" + + +# TODO: FIXME +@pytest.mark.xfail( + "env.PYPY and env.MACOS", + raises=RuntimeError, + reason="Expected failure with PyPy and libc++ (Issue #2847 & PR #2999)", +) +def test_cross_module_exception_translator(): + with pytest.raises(KeyError): + # translator registered in cross_module_tests + m.throw_should_be_translated_to_key_error() + def test_python_call_in_catch(): d = {} @@ -50,31 +88,44 @@ def test_python_call_in_catch(): assert d["good"] is True +def ignore_pytest_unraisable_warning(f): + unraisable = "PytestUnraisableExceptionWarning" + if hasattr(pytest, unraisable): # Python >= 3.8 and pytest >= 6 + dec = pytest.mark.filterwarnings("ignore::pytest.{}".format(unraisable)) + return dec(f) + else: + return f + + +# TODO: find out why this fails on PyPy, https://foss.heptapod.net/pypy/pypy/-/issues/3583 +@pytest.mark.xfail(env.PYPY, reason="Failure on PyPy 3.8 (7.3.7)", strict=False) +@ignore_pytest_unraisable_warning def test_python_alreadyset_in_destructor(monkeypatch, capsys): hooked = False triggered = [False] # mutable, so Python 2.7 closure can modify it - if hasattr(sys, 'unraisablehook'): # Python 3.8+ + if hasattr(sys, "unraisablehook"): # Python 3.8+ hooked = True - default_hook = sys.unraisablehook + # Don't take `sys.unraisablehook`, as that's overwritten by pytest + default_hook = sys.__unraisablehook__ def hook(unraisable_hook_args): exc_type, exc_value, exc_tb, err_msg, obj = unraisable_hook_args - if obj == 'already_set demo': + if obj == "already_set demo": triggered[0] = True default_hook(unraisable_hook_args) return # Use monkeypatch so pytest can apply and remove the patch as appropriate - monkeypatch.setattr(sys, 'unraisablehook', hook) + monkeypatch.setattr(sys, "unraisablehook", hook) - assert m.python_alreadyset_in_destructor('already_set demo') is True + assert m.python_alreadyset_in_destructor("already_set demo") is True if hooked: assert triggered[0] is True _, captured_stderr = capsys.readouterr() # Error message is different in Python 2 and 3, check for words that appear in both - assert 'ignored' in captured_stderr and 'already_set demo' in captured_stderr + assert "ignored" in captured_stderr and "already_set demo" in captured_stderr def test_exception_matches(): @@ -107,7 +158,9 @@ def test_custom(msg): # Can we fall-through to the default handler? with pytest.raises(RuntimeError) as excinfo: m.throws_logic_error() - assert msg(excinfo.value) == "this error should fall through to the standard handler" + assert ( + msg(excinfo.value) == "this error should fall through to the standard handler" + ) # OverFlow error translation. with pytest.raises(OverflowError) as excinfo: @@ -166,7 +219,13 @@ def test_nested_throws(capture): # C++ -> Python -> C++ -> Python with capture: m.try_catch( - m.MyException5, pycatch, m.MyException, m.try_catch, m.MyException, throw_myex5) + m.MyException5, + pycatch, + m.MyException, + m.try_catch, + m.MyException, + throw_myex5, + ) assert str(capture).startswith("MyException5: nested error 5") # C++ -> Python -> C++ @@ -180,12 +239,37 @@ def test_nested_throws(capture): assert str(excinfo.value) == "this is a helper-defined translated exception" +@pytest.mark.skipif("env.PY2") +def test_throw_nested_exception(): + with pytest.raises(RuntimeError) as excinfo: + m.throw_nested_exception() + assert str(excinfo.value) == "Outer Exception" + assert str(excinfo.value.__cause__) == "Inner Exception" + + # This can often happen if you wrap a pybind11 class in a Python wrapper def test_invalid_repr(): - class MyRepr(object): def __repr__(self): raise AttributeError("Example error") with pytest.raises(TypeError): m.simple_bool_passthrough(MyRepr()) + + +def test_local_translator(msg): + """Tests that a local translator works and that the local translator from + the cross module is not applied""" + with pytest.raises(RuntimeError) as excinfo: + m.throws6() + assert msg(excinfo.value) == "MyException6 only handled in this module" + + with pytest.raises(RuntimeError) as excinfo: + m.throws_local_error() + assert not isinstance(excinfo.value, KeyError) + assert msg(excinfo.value) == "never caught" + + with pytest.raises(Exception) as excinfo: + m.throws_local_simple_error() + assert not isinstance(excinfo.value, cm.LocalSimpleException) + assert msg(excinfo.value) == "this mod" diff --git a/wrap/pybind11/tests/test_factory_constructors.cpp b/wrap/pybind11/tests/test_factory_constructors.cpp index 2368dabb8..660e2896a 100644 --- a/wrap/pybind11/tests/test_factory_constructors.cpp +++ b/wrap/pybind11/tests/test_factory_constructors.cpp @@ -8,35 +8,45 @@ BSD-style license that can be found in the LICENSE file. */ -#include "pybind11_tests.h" #include "constructor_stats.h" +#include "pybind11_tests.h" #include #include +#include // Classes for testing python construction via C++ factory function: // Not publicly constructible, copyable, or movable: class TestFactory1 { friend class TestFactoryHelper; TestFactory1() : value("(empty)") { print_default_created(this); } - TestFactory1(int v) : value(std::to_string(v)) { print_created(this, value); } - TestFactory1(std::string v) : value(std::move(v)) { print_created(this, value); } + explicit TestFactory1(int v) : value(std::to_string(v)) { print_created(this, value); } + explicit TestFactory1(std::string v) : value(std::move(v)) { print_created(this, value); } + +public: + std::string value; TestFactory1(TestFactory1 &&) = delete; TestFactory1(const TestFactory1 &) = delete; TestFactory1 &operator=(TestFactory1 &&) = delete; TestFactory1 &operator=(const TestFactory1 &) = delete; -public: - std::string value; ~TestFactory1() { print_destroyed(this); } }; // Non-public construction, but moveable: class TestFactory2 { friend class TestFactoryHelper; TestFactory2() : value("(empty2)") { print_default_created(this); } - TestFactory2(int v) : value(std::to_string(v)) { print_created(this, value); } - TestFactory2(std::string v) : value(std::move(v)) { print_created(this, value); } + explicit TestFactory2(int v) : value(std::to_string(v)) { print_created(this, value); } + explicit TestFactory2(std::string v) : value(std::move(v)) { print_created(this, value); } + public: - TestFactory2(TestFactory2 &&m) { value = std::move(m.value); print_move_created(this); } - TestFactory2 &operator=(TestFactory2 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; } + TestFactory2(TestFactory2 &&m) noexcept { + value = std::move(m.value); + print_move_created(this); + } + TestFactory2 &operator=(TestFactory2 &&m) noexcept { + value = std::move(m.value); + print_move_assigned(this); + return *this; + } std::string value; ~TestFactory2() { print_destroyed(this); } }; @@ -45,11 +55,19 @@ class TestFactory3 { protected: friend class TestFactoryHelper; TestFactory3() : value("(empty3)") { print_default_created(this); } - TestFactory3(int v) : value(std::to_string(v)) { print_created(this, value); } + explicit TestFactory3(int v) : value(std::to_string(v)) { print_created(this, value); } + public: - TestFactory3(std::string v) : value(std::move(v)) { print_created(this, value); } - TestFactory3(TestFactory3 &&m) { value = std::move(m.value); print_move_created(this); } - TestFactory3 &operator=(TestFactory3 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; } + explicit TestFactory3(std::string v) : value(std::move(v)) { print_created(this, value); } + TestFactory3(TestFactory3 &&m) noexcept { + value = std::move(m.value); + print_move_created(this); + } + TestFactory3 &operator=(TestFactory3 &&m) noexcept { + value = std::move(m.value); + print_move_assigned(this); + return *this; + } std::string value; virtual ~TestFactory3() { print_destroyed(this); } }; @@ -57,13 +75,13 @@ public: class TestFactory4 : public TestFactory3 { public: TestFactory4() : TestFactory3() { print_default_created(this); } - TestFactory4(int v) : TestFactory3(v) { print_created(this, v); } + explicit TestFactory4(int v) : TestFactory3(v) { print_created(this, v); } ~TestFactory4() override { print_destroyed(this); } }; // Another class for an invalid downcast test class TestFactory5 : public TestFactory3 { public: - TestFactory5(int i) : TestFactory3(i) { print_created(this, i); } + explicit TestFactory5(int i) : TestFactory3(i) { print_created(this, i); } ~TestFactory5() override { print_destroyed(this); } }; @@ -72,22 +90,35 @@ protected: int value; bool alias = false; public: - TestFactory6(int i) : value{i} { print_created(this, i); } - TestFactory6(TestFactory6 &&f) { print_move_created(this); value = f.value; alias = f.alias; } + explicit TestFactory6(int i) : value{i} { print_created(this, i); } + TestFactory6(TestFactory6 &&f) noexcept { + print_move_created(this); + value = f.value; + alias = f.alias; + } TestFactory6(const TestFactory6 &f) { print_copy_created(this); value = f.value; alias = f.alias; } virtual ~TestFactory6() { print_destroyed(this); } virtual int get() { return value; } - bool has_alias() { return alias; } + bool has_alias() const { return alias; } }; class PyTF6 : public TestFactory6 { public: // Special constructor that allows the factory to construct a PyTF6 from a TestFactory6 only // when an alias is needed: - PyTF6(TestFactory6 &&base) : TestFactory6(std::move(base)) { alias = true; print_created(this, "move", value); } - PyTF6(int i) : TestFactory6(i) { alias = true; print_created(this, i); } - PyTF6(PyTF6 &&f) : TestFactory6(std::move(f)) { print_move_created(this); } + explicit PyTF6(TestFactory6 &&base) : TestFactory6(std::move(base)) { + alias = true; + print_created(this, "move", value); + } + explicit PyTF6(int i) : TestFactory6(i) { + alias = true; + print_created(this, i); + } + PyTF6(PyTF6 &&f) noexcept : TestFactory6(std::move(f)) { print_move_created(this); } PyTF6(const PyTF6 &f) : TestFactory6(f) { print_copy_created(this); } - PyTF6(std::string s) : TestFactory6((int) s.size()) { alias = true; print_created(this, s); } + explicit PyTF6(std::string s) : TestFactory6((int) s.size()) { + alias = true; + print_created(this, s); + } ~PyTF6() override { print_destroyed(this); } int get() override { PYBIND11_OVERRIDE(int, TestFactory6, get, /*no args*/); } }; @@ -97,17 +128,24 @@ protected: int value; bool alias = false; public: - TestFactory7(int i) : value{i} { print_created(this, i); } - TestFactory7(TestFactory7 &&f) { print_move_created(this); value = f.value; alias = f.alias; } + explicit TestFactory7(int i) : value{i} { print_created(this, i); } + TestFactory7(TestFactory7 &&f) noexcept { + print_move_created(this); + value = f.value; + alias = f.alias; + } TestFactory7(const TestFactory7 &f) { print_copy_created(this); value = f.value; alias = f.alias; } virtual ~TestFactory7() { print_destroyed(this); } virtual int get() { return value; } - bool has_alias() { return alias; } + bool has_alias() const { return alias; } }; class PyTF7 : public TestFactory7 { public: - PyTF7(int i) : TestFactory7(i) { alias = true; print_created(this, i); } - PyTF7(PyTF7 &&f) : TestFactory7(std::move(f)) { print_move_created(this); } + explicit PyTF7(int i) : TestFactory7(i) { + alias = true; + print_created(this, i); + } + PyTF7(PyTF7 &&f) noexcept : TestFactory7(std::move(f)) { print_move_created(this); } PyTF7(const PyTF7 &f) : TestFactory7(f) { print_copy_created(this); } ~PyTF7() override { print_destroyed(this); } int get() override { PYBIND11_OVERRIDE(int, TestFactory7, get, /*no args*/); } @@ -122,7 +160,9 @@ public: // Holder: static std::unique_ptr construct1(int a) { return std::unique_ptr(new TestFactory1(a)); } // pointer again - static TestFactory1 *construct1_string(std::string a) { return new TestFactory1(a); } + static TestFactory1 *construct1_string(std::string a) { + return new TestFactory1(std::move(a)); + } // Moveable type: // pointer: @@ -130,7 +170,7 @@ public: // holder: static std::unique_ptr construct2(int a) { return std::unique_ptr(new TestFactory2(a)); } // by value moving: - static TestFactory2 construct2(std::string a) { return TestFactory2(a); } + static TestFactory2 construct2(std::string a) { return TestFactory2(std::move(a)); } // shared_ptr holder type: // pointer: @@ -142,7 +182,7 @@ public: TEST_SUBMODULE(factory_constructors, m) { // Define various trivial types to allow simpler overload resolution: - py::module m_tag = m.def_submodule("tag"); + py::module_ m_tag = m.def_submodule("tag"); #define MAKE_TAG_TYPE(Name) \ struct Name##_tag {}; \ py::class_(m_tag, #Name "_tag").def(py::init<>()); \ @@ -173,21 +213,27 @@ TEST_SUBMODULE(factory_constructors, m) { ; py::class_(m, "TestFactory2") .def(py::init([](pointer_tag, int v) { return TestFactoryHelper::construct2(v); })) - .def(py::init([](unique_ptr_tag, std::string v) { return TestFactoryHelper::construct2(v); })) + .def(py::init([](unique_ptr_tag, std::string v) { + return TestFactoryHelper::construct2(std::move(v)); + })) .def(py::init([](move_tag) { return TestFactoryHelper::construct2(); })) - .def_readwrite("value", &TestFactory2::value) - ; + .def_readwrite("value", &TestFactory2::value); // Stateful & reused: int c = 1; auto c4a = [c](pointer_tag, TF4_tag, int a) { (void) c; return new TestFactory4(a);}; // test_init_factory_basic, test_init_factory_casting - py::class_>(m, "TestFactory3") + py::class_> pyTestFactory3(m, "TestFactory3"); + pyTestFactory3 .def(py::init([](pointer_tag, int v) { return TestFactoryHelper::construct3(v); })) - .def(py::init([](shared_ptr_tag) { return TestFactoryHelper::construct3(); })) - .def("__init__", [](TestFactory3 &self, std::string v) { new (&self) TestFactory3(v); }) // placement-new ctor - + .def(py::init([](shared_ptr_tag) { return TestFactoryHelper::construct3(); })); + ignoreOldStyleInitWarnings([&pyTestFactory3]() { + pyTestFactory3.def("__init__", [](TestFactory3 &self, std::string v) { + new (&self) TestFactory3(std::move(v)); + }); // placement-new ctor + }); + pyTestFactory3 // factories returning a derived type: .def(py::init(c4a)) // derived ptr .def(py::init([](pointer_tag, TF5_tag, int a) { return new TestFactory5(a); })) @@ -216,58 +262,60 @@ TEST_SUBMODULE(factory_constructors, m) { py::class_(m, "TestFactory6") .def(py::init([](base_tag, int i) { return TestFactory6(i); })) .def(py::init([](alias_tag, int i) { return PyTF6(i); })) - .def(py::init([](alias_tag, std::string s) { return PyTF6(s); })) + .def(py::init([](alias_tag, std::string s) { return PyTF6(std::move(s)); })) .def(py::init([](alias_tag, pointer_tag, int i) { return new PyTF6(i); })) .def(py::init([](base_tag, pointer_tag, int i) { return new TestFactory6(i); })) - .def(py::init([](base_tag, alias_tag, pointer_tag, int i) { return (TestFactory6 *) new PyTF6(i); })) + .def(py::init( + [](base_tag, alias_tag, pointer_tag, int i) { return (TestFactory6 *) new PyTF6(i); })) .def("get", &TestFactory6::get) .def("has_alias", &TestFactory6::has_alias) - .def_static("get_cstats", &ConstructorStats::get, py::return_value_policy::reference) - .def_static("get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference) - ; + .def_static( + "get_cstats", &ConstructorStats::get, py::return_value_policy::reference) + .def_static( + "get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference); // test_init_factory_dual // Separate alias constructor testing py::class_>(m, "TestFactory7") - .def(py::init( - [](int i) { return TestFactory7(i); }, - [](int i) { return PyTF7(i); })) - .def(py::init( - [](pointer_tag, int i) { return new TestFactory7(i); }, - [](pointer_tag, int i) { return new PyTF7(i); })) - .def(py::init( - [](mixed_tag, int i) { return new TestFactory7(i); }, - [](mixed_tag, int i) { return PyTF7(i); })) - .def(py::init( - [](mixed_tag, std::string s) { return TestFactory7((int) s.size()); }, - [](mixed_tag, std::string s) { return new PyTF7((int) s.size()); })) - .def(py::init( - [](base_tag, pointer_tag, int i) { return new TestFactory7(i); }, - [](base_tag, pointer_tag, int i) { return (TestFactory7 *) new PyTF7(i); })) - .def(py::init( - [](alias_tag, pointer_tag, int i) { return new PyTF7(i); }, - [](alias_tag, pointer_tag, int i) { return new PyTF7(10*i); })) + .def(py::init([](int i) { return TestFactory7(i); }, [](int i) { return PyTF7(i); })) + .def(py::init([](pointer_tag, int i) { return new TestFactory7(i); }, + [](pointer_tag, int i) { return new PyTF7(i); })) + .def(py::init([](mixed_tag, int i) { return new TestFactory7(i); }, + [](mixed_tag, int i) { return PyTF7(i); })) + .def(py::init([](mixed_tag, const std::string &s) { return TestFactory7((int) s.size()); }, + [](mixed_tag, const std::string &s) { return new PyTF7((int) s.size()); })) + .def(py::init([](base_tag, pointer_tag, int i) { return new TestFactory7(i); }, + [](base_tag, pointer_tag, int i) { return (TestFactory7 *) new PyTF7(i); })) + .def(py::init([](alias_tag, pointer_tag, int i) { return new PyTF7(i); }, + [](alias_tag, pointer_tag, int i) { return new PyTF7(10 * i); })) .def(py::init( [](shared_ptr_tag, base_tag, int i) { return std::make_shared(i); }, - [](shared_ptr_tag, base_tag, int i) { auto *p = new PyTF7(i); return std::shared_ptr(p); })) - .def(py::init( - [](shared_ptr_tag, invalid_base_tag, int i) { return std::make_shared(i); }, - [](shared_ptr_tag, invalid_base_tag, int i) { return std::make_shared(i); })) // <-- invalid alias factory + [](shared_ptr_tag, base_tag, int i) { + auto *p = new PyTF7(i); + return std::shared_ptr(p); + })) + .def(py::init([](shared_ptr_tag, + invalid_base_tag, + int i) { return std::make_shared(i); }, + [](shared_ptr_tag, invalid_base_tag, int i) { + return std::make_shared(i); + })) // <-- invalid alias factory .def("get", &TestFactory7::get) .def("has_alias", &TestFactory7::has_alias) - .def_static("get_cstats", &ConstructorStats::get, py::return_value_policy::reference) - .def_static("get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference) - ; + .def_static( + "get_cstats", &ConstructorStats::get, py::return_value_policy::reference) + .def_static( + "get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference); // test_placement_new_alternative // Class with a custom new operator but *without* a placement new operator (issue #948) class NoPlacementNew { public: - NoPlacementNew(int i) : i(i) { } + explicit NoPlacementNew(int i) : i(i) {} static void *operator new(std::size_t s) { auto *p = ::operator new(s); py::print("operator new called, returning", reinterpret_cast(p)); @@ -291,8 +339,8 @@ TEST_SUBMODULE(factory_constructors, m) { // Class that has verbose operator_new/operator_delete calls struct NoisyAlloc { NoisyAlloc(const NoisyAlloc &) = default; - NoisyAlloc(int i) { py::print(py::str("NoisyAlloc(int {})").format(i)); } - NoisyAlloc(double d) { py::print(py::str("NoisyAlloc(double {})").format(d)); } + explicit NoisyAlloc(int i) { py::print(py::str("NoisyAlloc(int {})").format(i)); } + explicit NoisyAlloc(double d) { py::print(py::str("NoisyAlloc(double {})").format(d)); } ~NoisyAlloc() { py::print("~NoisyAlloc()"); } static void *operator new(size_t s) { py::print("noisy new"); return ::operator new(s); } @@ -304,27 +352,33 @@ TEST_SUBMODULE(factory_constructors, m) { static void operator delete(void *p) { py::print("noisy delete"); ::operator delete(p); } #endif }; - py::class_(m, "NoisyAlloc") + + + py::class_ pyNoisyAlloc(m, "NoisyAlloc"); // Since these overloads have the same number of arguments, the dispatcher will try each of // them until the arguments convert. Thus we can get a pre-allocation here when passing a // single non-integer: - .def("__init__", [](NoisyAlloc *a, int i) { new (a) NoisyAlloc(i); }) // Regular constructor, runs first, requires preallocation - .def(py::init([](double d) { return new NoisyAlloc(d); })) - - // The two-argument version: first the factory pointer overload. - .def(py::init([](int i, int) { return new NoisyAlloc(i); })) - // Return-by-value: - .def(py::init([](double d, int) { return NoisyAlloc(d); })) - // Old-style placement new init; requires preallocation - .def("__init__", [](NoisyAlloc &a, double d, double) { new (&a) NoisyAlloc(d); }) - // Requires deallocation of previous overload preallocated value: - .def(py::init([](int i, double) { return new NoisyAlloc(i); })) - // Regular again: requires yet another preallocation - .def("__init__", [](NoisyAlloc &a, int i, std::string) { new (&a) NoisyAlloc(i); }) - ; - + ignoreOldStyleInitWarnings([&pyNoisyAlloc]() { + pyNoisyAlloc.def("__init__", [](NoisyAlloc *a, int i) { new (a) NoisyAlloc(i); }); // Regular constructor, runs first, requires preallocation + }); + pyNoisyAlloc.def(py::init([](double d) { return new NoisyAlloc(d); })); + // The two-argument version: first the factory pointer overload. + pyNoisyAlloc.def(py::init([](int i, int) { return new NoisyAlloc(i); })); + // Return-by-value: + pyNoisyAlloc.def(py::init([](double d, int) { return NoisyAlloc(d); })); + // Old-style placement new init; requires preallocation + ignoreOldStyleInitWarnings([&pyNoisyAlloc]() { + pyNoisyAlloc.def("__init__", [](NoisyAlloc &a, double d, double) { new (&a) NoisyAlloc(d); }); + }); + // Requires deallocation of previous overload preallocated value: + pyNoisyAlloc.def(py::init([](int i, double) { return new NoisyAlloc(i); })); + // Regular again: requires yet another preallocation + ignoreOldStyleInitWarnings([&pyNoisyAlloc]() { + pyNoisyAlloc.def( + "__init__", [](NoisyAlloc &a, int i, const std::string &) { new (&a) NoisyAlloc(i); }); + }); // static_assert testing (the following def's should all fail with appropriate compilation errors): #if 0 diff --git a/wrap/pybind11/tests/test_factory_constructors.py b/wrap/pybind11/tests/test_factory_constructors.py index b141c13de..8bc026985 100644 --- a/wrap/pybind11/tests/test_factory_constructors.py +++ b/wrap/pybind11/tests/test_factory_constructors.py @@ -1,18 +1,21 @@ # -*- coding: utf-8 -*- -import pytest import re -import env # noqa: F401 +import pytest +import env # noqa: F401 +from pybind11_tests import ConstructorStats from pybind11_tests import factory_constructors as m from pybind11_tests.factory_constructors import tag -from pybind11_tests import ConstructorStats def test_init_factory_basic(): """Tests py::init_factory() wrapper around various ways of returning the object""" - cstats = [ConstructorStats.get(c) for c in [m.TestFactory1, m.TestFactory2, m.TestFactory3]] + cstats = [ + ConstructorStats.get(c) + for c in [m.TestFactory1, m.TestFactory2, m.TestFactory3] + ] cstats[0].alive() # force gc n_inst = ConstructorStats.detail_reg_inst() @@ -41,12 +44,12 @@ def test_init_factory_basic(): z3 = m.TestFactory3("bye") assert z3.value == "bye" - for null_ptr_kind in [tag.null_ptr, - tag.null_unique_ptr, - tag.null_shared_ptr]: + for null_ptr_kind in [tag.null_ptr, tag.null_unique_ptr, tag.null_shared_ptr]: with pytest.raises(TypeError) as excinfo: m.TestFactory3(null_ptr_kind) - assert str(excinfo.value) == "pybind11::init(): factory function returned nullptr" + assert ( + str(excinfo.value) == "pybind11::init(): factory function returned nullptr" + ) assert [i.alive() for i in cstats] == [3, 3, 3] assert ConstructorStats.detail_reg_inst() == n_inst + 9 @@ -61,7 +64,7 @@ def test_init_factory_basic(): assert [i.values() for i in cstats] == [ ["3", "hi!"], ["7", "hi again"], - ["42", "bye"] + ["42", "bye"], ] assert [i.default_constructions for i in cstats] == [1, 1, 1] @@ -69,7 +72,9 @@ def test_init_factory_basic(): def test_init_factory_signature(msg): with pytest.raises(TypeError) as excinfo: m.TestFactory1("invalid", "constructor", "arguments") - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ __init__(): incompatible constructor arguments. The following argument types are supported: 1. m.factory_constructors.TestFactory1(arg0: m.factory_constructors.tag.unique_ptr_tag, arg1: int) 2. m.factory_constructors.TestFactory1(arg0: str) @@ -78,8 +83,11 @@ def test_init_factory_signature(msg): Invoked with: 'invalid', 'constructor', 'arguments' """ # noqa: E501 line too long + ) - assert msg(m.TestFactory1.__init__.__doc__) == """ + assert ( + msg(m.TestFactory1.__init__.__doc__) + == """ __init__(*args, **kwargs) Overloaded function. @@ -91,12 +99,16 @@ def test_init_factory_signature(msg): 4. __init__(self: m.factory_constructors.TestFactory1, arg0: handle, arg1: int, arg2: handle) -> None """ # noqa: E501 line too long + ) def test_init_factory_casting(): """Tests py::init_factory() wrapper with various upcasting and downcasting returns""" - cstats = [ConstructorStats.get(c) for c in [m.TestFactory3, m.TestFactory4, m.TestFactory5]] + cstats = [ + ConstructorStats.get(c) + for c in [m.TestFactory3, m.TestFactory4, m.TestFactory5] + ] cstats[0].alive() # force gc n_inst = ConstructorStats.detail_reg_inst() @@ -134,7 +146,7 @@ def test_init_factory_casting(): assert [i.values() for i in cstats] == [ ["4", "5", "6", "7", "8"], ["4", "5", "8"], - ["6", "7"] + ["6", "7"], ] @@ -204,7 +216,7 @@ def test_init_factory_alias(): assert [i.values() for i in cstats] == [ ["1", "8", "3", "4", "5", "6", "123", "10", "47"], - ["hi there", "3", "4", "6", "move", "123", "why hello!", "move", "47"] + ["hi there", "3", "4", "6", "move", "123", "why hello!", "move", "47"], ] @@ -268,9 +280,11 @@ def test_init_factory_dual(): assert not g1.has_alias() with pytest.raises(TypeError) as excinfo: PythFactory7(tag.shared_ptr, tag.invalid_base, 14) - assert (str(excinfo.value) == - "pybind11::init(): construction failed: returned holder-wrapped instance is not an " - "alias instance") + assert ( + str(excinfo.value) + == "pybind11::init(): construction failed: returned holder-wrapped instance is not an " + "alias instance" + ) assert [i.alive() for i in cstats] == [13, 7] assert ConstructorStats.detail_reg_inst() == n_inst + 13 @@ -284,7 +298,7 @@ def test_init_factory_dual(): assert [i.values() for i in cstats] == [ ["1", "2", "3", "4", "5", "6", "7", "8", "9", "100", "11", "12", "13", "14"], - ["2", "4", "6", "8", "9", "100", "12"] + ["2", "4", "6", "8", "9", "100", "12"], ] @@ -294,7 +308,7 @@ def test_no_placement_new(capture): with capture: a = m.NoPlacementNew(123) - found = re.search(r'^operator new called, returning (\d+)\n$', str(capture)) + found = re.search(r"^operator new called, returning (\d+)\n$", str(capture)) assert found assert a.i == 123 with capture: @@ -305,7 +319,7 @@ def test_no_placement_new(capture): with capture: b = m.NoPlacementNew() - found = re.search(r'^operator new called, returning (\d+)\n$', str(capture)) + found = re.search(r"^operator new called, returning (\d+)\n$", str(capture)) assert found assert b.i == 100 with capture: @@ -333,7 +347,7 @@ def create_and_destroy(*args): def strip_comments(s): - return re.sub(r'\s+#.*', '', s) + return re.sub(r"\s+#.*", "", s) def test_reallocation_a(capture, msg): @@ -345,7 +359,9 @@ def test_reallocation_a(capture, msg): with capture: create_and_destroy(1) - assert msg(capture) == """ + assert ( + msg(capture) + == """ noisy new noisy placement new NoisyAlloc(int 1) @@ -353,12 +369,14 @@ def test_reallocation_a(capture, msg): ~NoisyAlloc() noisy delete """ + ) def test_reallocation_b(capture, msg): with capture: create_and_destroy(1.5) - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ noisy new # allocation required to attempt first overload noisy delete # have to dealloc before considering factory init overload noisy new # pointer factory calling "new", part 1: allocation @@ -366,51 +384,59 @@ def test_reallocation_b(capture, msg): --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) def test_reallocation_c(capture, msg): with capture: create_and_destroy(2, 3) - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ noisy new # pointer factory calling "new", allocation NoisyAlloc(int 2) # constructor --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) def test_reallocation_d(capture, msg): with capture: create_and_destroy(2.5, 3) - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ NoisyAlloc(double 2.5) # construction (local func variable: operator_new not called) noisy new # return-by-value "new" part 1: allocation ~NoisyAlloc() # moved-away local func variable destruction --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) def test_reallocation_e(capture, msg): with capture: create_and_destroy(3.5, 4.5) - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ noisy new # preallocation needed before invoking placement-new overload noisy placement new # Placement new NoisyAlloc(double 3.5) # construction --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) def test_reallocation_f(capture, msg): with capture: create_and_destroy(4, 0.5) - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ noisy new # preallocation needed before invoking placement-new overload noisy delete # deallocation of preallocated storage noisy new # Factory pointer allocation @@ -418,13 +444,15 @@ def test_reallocation_f(capture, msg): --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) def test_reallocation_g(capture, msg): with capture: create_and_destroy(5, "hi") - assert msg(capture) == strip_comments(""" + assert msg(capture) == strip_comments( + """ noisy new # preallocation needed before invoking first placement new noisy delete # delete before considering new-style constructor noisy new # preallocation for second placement new @@ -433,13 +461,15 @@ def test_reallocation_g(capture, msg): --- ~NoisyAlloc() # Destructor noisy delete # operator delete - """) + """ + ) @pytest.mark.skipif("env.PY2") def test_invalid_self(): """Tests invocation of the pybind-registered base class with an invalid `self` argument. You can only actually do this on Python 3: Python 2 raises an exception itself if you try.""" + class NotPybindDerived(object): pass @@ -456,23 +486,35 @@ def test_invalid_self(): # Same as above, but for a class with an alias: class BrokenTF6(m.TestFactory6): def __init__(self, bad): - if bad == 1: + if bad == 0: + m.TestFactory6.__init__() + elif bad == 1: a = m.TestFactory2(tag.pointer, 1) m.TestFactory6.__init__(a, tag.base, 1) elif bad == 2: a = m.TestFactory2(tag.pointer, 1) m.TestFactory6.__init__(a, tag.alias, 1) elif bad == 3: - m.TestFactory6.__init__(NotPybindDerived.__new__(NotPybindDerived), tag.base, 1) + m.TestFactory6.__init__( + NotPybindDerived.__new__(NotPybindDerived), tag.base, 1 + ) elif bad == 4: - m.TestFactory6.__init__(NotPybindDerived.__new__(NotPybindDerived), tag.alias, 1) + m.TestFactory6.__init__( + NotPybindDerived.__new__(NotPybindDerived), tag.alias, 1 + ) for arg in (1, 2): with pytest.raises(TypeError) as excinfo: BrokenTF1(arg) - assert str(excinfo.value) == "__init__(self, ...) called with invalid `self` argument" + assert ( + str(excinfo.value) + == "__init__(self, ...) called with invalid or missing `self` argument" + ) - for arg in (1, 2, 3, 4): + for arg in (0, 1, 2, 3, 4): with pytest.raises(TypeError) as excinfo: BrokenTF6(arg) - assert str(excinfo.value) == "__init__(self, ...) called with invalid `self` argument" + assert ( + str(excinfo.value) + == "__init__(self, ...) called with invalid or missing `self` argument" + ) diff --git a/wrap/pybind11/tests/test_gil_scoped.cpp b/wrap/pybind11/tests/test_gil_scoped.cpp index eb6308956..b261085c8 100644 --- a/wrap/pybind11/tests/test_gil_scoped.cpp +++ b/wrap/pybind11/tests/test_gil_scoped.cpp @@ -35,20 +35,15 @@ TEST_SUBMODULE(gil_scoped, m) { .def("virtual_func", &VirtClass::virtual_func) .def("pure_virtual_func", &VirtClass::pure_virtual_func); - m.def("test_callback_py_obj", - [](py::object func) { func(); }); - m.def("test_callback_std_func", - [](const std::function &func) { func(); }); - m.def("test_callback_virtual_func", - [](VirtClass &virt) { virt.virtual_func(); }); - m.def("test_callback_pure_virtual_func", - [](VirtClass &virt) { virt.pure_virtual_func(); }); - m.def("test_cross_module_gil", - []() { - auto cm = py::module::import("cross_module_gil_utils"); - auto gil_acquire = reinterpret_cast( - PyLong_AsVoidPtr(cm.attr("gil_acquire_funcaddr").ptr())); - py::gil_scoped_release gil_release; - gil_acquire(); - }); + m.def("test_callback_py_obj", [](py::object &func) { func(); }); + m.def("test_callback_std_func", [](const std::function &func) { func(); }); + m.def("test_callback_virtual_func", [](VirtClass &virt) { virt.virtual_func(); }); + m.def("test_callback_pure_virtual_func", [](VirtClass &virt) { virt.pure_virtual_func(); }); + m.def("test_cross_module_gil", []() { + auto cm = py::module_::import("cross_module_gil_utils"); + auto gil_acquire + = reinterpret_cast(PyLong_AsVoidPtr(cm.attr("gil_acquire_funcaddr").ptr())); + py::gil_scoped_release gil_release; + gil_acquire(); + }); } diff --git a/wrap/pybind11/tests/test_gil_scoped.py b/wrap/pybind11/tests/test_gil_scoped.py index 27122cca2..0a1d62747 100644 --- a/wrap/pybind11/tests/test_gil_scoped.py +++ b/wrap/pybind11/tests/test_gil_scoped.py @@ -2,10 +2,6 @@ import multiprocessing import threading -import pytest - -import env # noqa: F401 - from pybind11_tests import gil_scoped as m @@ -25,6 +21,7 @@ def _run_in_process(target, *args, **kwargs): def _python_to_cpp_to_python(): """Calls different C++ functions that come back to Python.""" + class ExtendedVirtClass(m.VirtClass): def virtual_func(self): pass @@ -54,8 +51,7 @@ def _python_to_cpp_to_python_from_threads(num_threads, parallel=False): thread.join() -# TODO: FIXME, sometimes returns -11 instead of 0 -@pytest.mark.xfail("env.PY > (3,8) and env.MACOS", strict=False) +# TODO: FIXME, sometimes returns -11 (segfault) instead of 0 on macOS Python 3.9 def test_python_to_cpp_to_python_from_thread(): """Makes sure there is no GIL deadlock when running in a thread. @@ -64,8 +60,7 @@ def test_python_to_cpp_to_python_from_thread(): assert _run_in_process(_python_to_cpp_to_python_from_threads, 1) == 0 -# TODO: FIXME -@pytest.mark.xfail("env.PY > (3,8) and env.MACOS", strict=False) +# TODO: FIXME on macOS Python 3.9 def test_python_to_cpp_to_python_from_thread_multiple_parallel(): """Makes sure there is no GIL deadlock when running in a thread multiple times in parallel. @@ -74,18 +69,18 @@ def test_python_to_cpp_to_python_from_thread_multiple_parallel(): assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=True) == 0 -# TODO: FIXME -@pytest.mark.xfail("env.PY > (3,8) and env.MACOS", strict=False) +# TODO: FIXME on macOS Python 3.9 def test_python_to_cpp_to_python_from_thread_multiple_sequential(): """Makes sure there is no GIL deadlock when running in a thread multiple times sequentially. It runs in a separate process to be able to stop and assert if it deadlocks. """ - assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=False) == 0 + assert ( + _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=False) == 0 + ) -# TODO: FIXME -@pytest.mark.xfail("env.PY > (3,8) and env.MACOS", strict=False) +# TODO: FIXME on macOS Python 3.9 def test_python_to_cpp_to_python_from_process(): """Makes sure there is no GIL deadlock when using processes. diff --git a/wrap/pybind11/tests/test_iostream.cpp b/wrap/pybind11/tests/test_iostream.cpp index e67f88af5..c620b5949 100644 --- a/wrap/pybind11/tests/test_iostream.cpp +++ b/wrap/pybind11/tests/test_iostream.cpp @@ -7,37 +7,87 @@ BSD-style license that can be found in the LICENSE file. */ +#if defined(_MSC_VER) && _MSC_VER < 1910 // VS 2015's MSVC +# pragma warning(disable: 4702) // unreachable code in system header (xatomic.h(382)) +#endif #include #include "pybind11_tests.h" +#include #include +#include +#include +#include - -void noisy_function(std::string msg, bool flush) { +void noisy_function(const std::string &msg, bool flush) { std::cout << msg; if (flush) std::cout << std::flush; } -void noisy_funct_dual(std::string msg, std::string emsg) { +void noisy_funct_dual(const std::string &msg, const std::string &emsg) { std::cout << msg; std::cerr << emsg; } +// object to manage C++ thread +// simply repeatedly write to std::cerr until stopped +// redirect is called at some point to test the safety of scoped_estream_redirect +struct TestThread { + TestThread() : stop_{false} { + auto thread_f = [this] { + static std::mutex cout_mutex; + while (!stop_) { + { + // #HelpAppreciated: Work on iostream.h thread safety. + // Without this lock, the clang ThreadSanitizer (tsan) reliably reports a + // data race, and this test is predictably flakey on Windows. + // For more background see the discussion under + // https://github.com/pybind/pybind11/pull/2982 and + // https://github.com/pybind/pybind11/pull/2995. + const std::lock_guard lock(cout_mutex); + std::cout << "x" << std::flush; + } + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } }; + t_ = new std::thread(std::move(thread_f)); + } + + ~TestThread() { + delete t_; + } + + void stop() { stop_ = true; } + + void join() const { + py::gil_scoped_release gil_lock; + t_->join(); + } + + void sleep() { + py::gil_scoped_release gil_lock; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + std::thread *t_{nullptr}; + std::atomic stop_; +}; + + TEST_SUBMODULE(iostream, m) { add_ostream_redirect(m); // test_evals - m.def("captured_output_default", [](std::string msg) { + m.def("captured_output_default", [](const std::string &msg) { py::scoped_ostream_redirect redir; std::cout << msg << std::flush; }); - m.def("captured_output", [](std::string msg) { - py::scoped_ostream_redirect redir(std::cout, py::module::import("sys").attr("stdout")); + m.def("captured_output", [](const std::string &msg) { + py::scoped_ostream_redirect redir(std::cout, py::module_::import("sys").attr("stdout")); std::cout << msg << std::flush; }); @@ -45,8 +95,8 @@ TEST_SUBMODULE(iostream, m) { py::call_guard(), py::arg("msg"), py::arg("flush")=true); - m.def("captured_err", [](std::string msg) { - py::scoped_ostream_redirect redir(std::cerr, py::module::import("sys").attr("stderr")); + m.def("captured_err", [](const std::string &msg) { + py::scoped_ostream_redirect redir(std::cerr, py::module_::import("sys").attr("stderr")); std::cerr << msg << std::flush; }); @@ -56,18 +106,20 @@ TEST_SUBMODULE(iostream, m) { py::call_guard(), py::arg("msg"), py::arg("emsg")); - m.def("raw_output", [](std::string msg) { - std::cout << msg << std::flush; - }); + m.def("raw_output", [](const std::string &msg) { std::cout << msg << std::flush; }); - m.def("raw_err", [](std::string msg) { - std::cerr << msg << std::flush; - }); + m.def("raw_err", [](const std::string &msg) { std::cerr << msg << std::flush; }); - m.def("captured_dual", [](std::string msg, std::string emsg) { - py::scoped_ostream_redirect redirout(std::cout, py::module::import("sys").attr("stdout")); - py::scoped_ostream_redirect redirerr(std::cerr, py::module::import("sys").attr("stderr")); + m.def("captured_dual", [](const std::string &msg, const std::string &emsg) { + py::scoped_ostream_redirect redirout(std::cout, py::module_::import("sys").attr("stdout")); + py::scoped_ostream_redirect redirerr(std::cerr, py::module_::import("sys").attr("stderr")); std::cout << msg << std::flush; std::cerr << emsg << std::flush; }); + + py::class_(m, "TestThread") + .def(py::init<>()) + .def("stop", &TestThread::stop) + .def("join", &TestThread::join) + .def("sleep", &TestThread::sleep); } diff --git a/wrap/pybind11/tests/test_iostream.py b/wrap/pybind11/tests/test_iostream.py index 7ac4fcece..7f18ca65c 100644 --- a/wrap/pybind11/tests/test_iostream.py +++ b/wrap/pybind11/tests/test_iostream.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -from pybind11_tests import iostream as m import sys - from contextlib import contextmanager +from pybind11_tests import iostream as m + try: # Python 3 from io import StringIO @@ -18,6 +18,7 @@ try: # Python 3.4 from contextlib import redirect_stdout except ImportError: + @contextmanager def redirect_stdout(target): original = sys.stdout @@ -25,10 +26,12 @@ except ImportError: yield sys.stdout = original + try: # Python 3.5 from contextlib import redirect_stderr except ImportError: + @contextmanager def redirect_stderr(target): original = sys.stderr @@ -42,16 +45,16 @@ def test_captured(capsys): m.captured_output(msg) stdout, stderr = capsys.readouterr() assert stdout == msg - assert stderr == '' + assert stderr == "" m.captured_output_default(msg) stdout, stderr = capsys.readouterr() assert stdout == msg - assert stderr == '' + assert stderr == "" m.captured_err(msg) stdout, stderr = capsys.readouterr() - assert stdout == '' + assert stdout == "" assert stderr == msg @@ -63,7 +66,97 @@ def test_captured_large_string(capsys): m.captured_output_default(msg) stdout, stderr = capsys.readouterr() assert stdout == msg - assert stderr == '' + assert stderr == "" + + +def test_captured_utf8_2byte_offset0(capsys): + msg = "\u07FF" + msg = "" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_2byte_offset1(capsys): + msg = "\u07FF" + msg = "1" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_3byte_offset0(capsys): + msg = "\uFFFF" + msg = "" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_3byte_offset1(capsys): + msg = "\uFFFF" + msg = "1" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_3byte_offset2(capsys): + msg = "\uFFFF" + msg = "12" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_4byte_offset0(capsys): + msg = "\U0010FFFF" + msg = "" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_4byte_offset1(capsys): + msg = "\U0010FFFF" + msg = "1" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_4byte_offset2(capsys): + msg = "\U0010FFFF" + msg = "12" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" + + +def test_captured_utf8_4byte_offset3(capsys): + msg = "\U0010FFFF" + msg = "123" + msg * (1024 // len(msg) + 1) + + m.captured_output_default(msg) + stdout, stderr = capsys.readouterr() + assert stdout == msg + assert stderr == "" def test_guard_capture(capsys): @@ -71,7 +164,7 @@ def test_guard_capture(capsys): m.guard_output(msg) stdout, stderr = capsys.readouterr() assert stdout == msg - assert stderr == '' + assert stderr == "" def test_series_captured(capture): @@ -88,7 +181,7 @@ def test_flush(capfd): with m.ostream_redirect(): m.noisy_function(msg, flush=False) stdout, stderr = capfd.readouterr() - assert stdout == '' + assert stdout == "" m.noisy_function(msg2, flush=True) stdout, stderr = capfd.readouterr() @@ -107,15 +200,15 @@ def test_not_captured(capfd): m.raw_output(msg) stdout, stderr = capfd.readouterr() assert stdout == msg - assert stderr == '' - assert stream.getvalue() == '' + assert stderr == "" + assert stream.getvalue() == "" stream = StringIO() with redirect_stdout(stream): m.captured_output(msg) stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' + assert stdout == "" + assert stderr == "" assert stream.getvalue() == msg @@ -125,16 +218,16 @@ def test_err(capfd): with redirect_stderr(stream): m.raw_err(msg) stdout, stderr = capfd.readouterr() - assert stdout == '' + assert stdout == "" assert stderr == msg - assert stream.getvalue() == '' + assert stream.getvalue() == "" stream = StringIO() with redirect_stderr(stream): m.captured_err(msg) stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' + assert stdout == "" + assert stderr == "" assert stream.getvalue() == msg @@ -146,8 +239,8 @@ def test_multi_captured(capfd): m.captured_output("c") m.raw_output("d") stdout, stderr = capfd.readouterr() - assert stdout == 'bd' - assert stream.getvalue() == 'ac' + assert stdout == "bd" + assert stream.getvalue() == "ac" def test_dual(capsys): @@ -164,14 +257,14 @@ def test_redirect(capfd): m.raw_output(msg) stdout, stderr = capfd.readouterr() assert stdout == msg - assert stream.getvalue() == '' + assert stream.getvalue() == "" stream = StringIO() with redirect_stdout(stream): with m.ostream_redirect(): m.raw_output(msg) stdout, stderr = capfd.readouterr() - assert stdout == '' + assert stdout == "" assert stream.getvalue() == msg stream = StringIO() @@ -179,7 +272,7 @@ def test_redirect(capfd): m.raw_output(msg) stdout, stderr = capfd.readouterr() assert stdout == msg - assert stream.getvalue() == '' + assert stream.getvalue() == "" def test_redirect_err(capfd): @@ -193,7 +286,7 @@ def test_redirect_err(capfd): m.raw_err(msg2) stdout, stderr = capfd.readouterr() assert stdout == msg - assert stderr == '' + assert stderr == "" assert stream.getvalue() == msg2 @@ -209,7 +302,30 @@ def test_redirect_both(capfd): m.raw_output(msg) m.raw_err(msg2) stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' + assert stdout == "" + assert stderr == "" assert stream.getvalue() == msg assert stream2.getvalue() == msg2 + + +def test_threading(): + with m.ostream_redirect(stdout=True, stderr=False): + # start some threads + threads = [] + + # start some threads + for _j in range(20): + threads.append(m.TestThread()) + + # give the threads some time to fail + threads[0].sleep() + + # stop all the threads + for t in threads: + t.stop() + + for t in threads: + t.join() + + # if a thread segfaults, we don't get here + assert True diff --git a/wrap/pybind11/tests/test_kwargs_and_defaults.cpp b/wrap/pybind11/tests/test_kwargs_and_defaults.cpp index 641ec88c4..34ad2a864 100644 --- a/wrap/pybind11/tests/test_kwargs_and_defaults.cpp +++ b/wrap/pybind11/tests/test_kwargs_and_defaults.cpp @@ -11,6 +11,8 @@ #include "constructor_stats.h" #include +#include + TEST_SUBMODULE(kwargs_and_defaults, m) { auto kw_func = [](int x, int y) { return "x=" + std::to_string(x) + ", y=" + std::to_string(y); }; @@ -37,18 +39,16 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { m.def("args_function", [](py::args args) -> py::tuple { return std::move(args); }); - m.def("args_kwargs_function", [](py::args args, py::kwargs kwargs) { + m.def("args_kwargs_function", [](const py::args &args, const py::kwargs &kwargs) { return py::make_tuple(args, kwargs); }); // test_mixed_args_and_kwargs - m.def("mixed_plus_args", [](int i, double j, py::args args) { - return py::make_tuple(i, j, args); - }); - m.def("mixed_plus_kwargs", [](int i, double j, py::kwargs kwargs) { - return py::make_tuple(i, j, kwargs); - }); - auto mixed_plus_both = [](int i, double j, py::args args, py::kwargs kwargs) { + m.def("mixed_plus_args", + [](int i, double j, const py::args &args) { return py::make_tuple(i, j, args); }); + m.def("mixed_plus_kwargs", + [](int i, double j, const py::kwargs &kwargs) { return py::make_tuple(i, j, kwargs); }); + auto mixed_plus_both = [](int i, double j, const py::args &args, const py::kwargs &kwargs) { return py::make_tuple(i, j, args, kwargs); }; m.def("mixed_plus_args_kwargs", mixed_plus_both); @@ -56,6 +56,23 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { m.def("mixed_plus_args_kwargs_defaults", mixed_plus_both, py::arg("i") = 1, py::arg("j") = 3.14159); + m.def("args_kwonly", + [](int i, double j, const py::args &args, int z) { return py::make_tuple(i, j, args, z); }, + "i"_a, "j"_a, "z"_a); + m.def("args_kwonly_kwargs", + [](int i, double j, const py::args &args, int z, const py::kwargs &kwargs) { + return py::make_tuple(i, j, args, z, kwargs); }, + "i"_a, "j"_a, py::kw_only{}, "z"_a); + m.def("args_kwonly_kwargs_defaults", + [](int i, double j, const py::args &args, int z, const py::kwargs &kwargs) { + return py::make_tuple(i, j, args, z, kwargs); }, + "i"_a = 1, "j"_a = 3.14159, "z"_a = 42); + m.def("args_kwonly_full_monty", + [](int h, int i, double j, const py::args &args, int z, const py::kwargs &kwargs) { + return py::make_tuple(h, i, j, args, z, kwargs); }, + py::arg() = 1, py::arg() = 2, py::pos_only{}, "j"_a = 3.14159, "z"_a = 42); + + // test_args_refcount // PyPy needs a garbage collection to get the reference count values to match CPython's behaviour #ifdef PYPY_VERSION @@ -65,22 +82,25 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { #endif m.def("arg_refcount_h", [](py::handle h) { GC_IF_NEEDED; return h.ref_count(); }); m.def("arg_refcount_h", [](py::handle h, py::handle, py::handle) { GC_IF_NEEDED; return h.ref_count(); }); - m.def("arg_refcount_o", [](py::object o) { GC_IF_NEEDED; return o.ref_count(); }); + m.def("arg_refcount_o", [](const py::object &o) { + GC_IF_NEEDED; + return o.ref_count(); + }); m.def("args_refcount", [](py::args a) { GC_IF_NEEDED; py::tuple t(a.size()); for (size_t i = 0; i < a.size(); i++) // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: - t[i] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); + t[i] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); return t; }); - m.def("mixed_args_refcount", [](py::object o, py::args a) { + m.def("mixed_args_refcount", [](const py::object &o, py::args a) { GC_IF_NEEDED; py::tuple t(a.size() + 1); t[0] = o.ref_count(); for (size_t i = 0; i < a.size(); i++) // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: - t[i + 1] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); + t[i + 1] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); return t; }); @@ -103,11 +123,17 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { py::arg() = 3, "j"_a = 4, py::kw_only(), "k"_a = 5, "z"_a); m.def("kw_only_mixed", [](int i, int j) { return py::make_tuple(i, j); }, "i"_a, py::kw_only(), "j"_a); - m.def("kw_only_plus_more", [](int i, int j, int k, py::kwargs kwargs) { - return py::make_tuple(i, j, k, kwargs); }, - py::arg() /* positional */, py::arg("j") = -1 /* both */, py::kw_only(), py::arg("k") /* kw-only */); + m.def( + "kw_only_plus_more", + [](int i, int j, int k, const py::kwargs &kwargs) { + return py::make_tuple(i, j, k, kwargs); + }, + py::arg() /* positional */, + py::arg("j") = -1 /* both */, + py::kw_only(), + py::arg("k") /* kw-only */); - m.def("register_invalid_kw_only", [](py::module m) { + m.def("register_invalid_kw_only", [](py::module_ m) { m.def("bad_kw_only", [](int i, int j) { return py::make_tuple(i, j); }, py::kw_only(), py::arg() /* invalid unnamed argument */, "j"_a); }); @@ -137,6 +163,25 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { // Make sure a class (not an instance) can be used as a default argument. // The return value doesn't matter, only that the module is importable. - m.def("class_default_argument", [](py::object a) { return py::repr(a); }, - "a"_a = py::module::import("decimal").attr("Decimal")); + m.def( + "class_default_argument", + [](py::object a) { return py::repr(std::move(a)); }, + "a"_a = py::module_::import("decimal").attr("Decimal")); + + // Initial implementation of kw_only was broken when used on a method/constructor before any + // other arguments + // https://github.com/pybind/pybind11/pull/3402#issuecomment-963341987 + + struct first_arg_kw_only {}; + py::class_(m, "first_arg_kw_only") + .def(py::init([](int) { return first_arg_kw_only(); }), + py::kw_only(), // This being before any args was broken + py::arg("i") = 0) + .def("method", [](first_arg_kw_only&, int, int) {}, + py::kw_only(), // and likewise here + py::arg("i") = 1, py::arg("j") = 2) + // Closely related: pos_only marker didn't show up properly when it was before any other + // arguments (although that is fairly useless in practice). + .def("pos_only", [](first_arg_kw_only&, int, int) {}, + py::pos_only{}, py::arg("i"), py::arg("j")); } diff --git a/wrap/pybind11/tests/test_kwargs_and_defaults.py b/wrap/pybind11/tests/test_kwargs_and_defaults.py index 2a81dbdc5..d61cf2aa5 100644 --- a/wrap/pybind11/tests/test_kwargs_and_defaults.py +++ b/wrap/pybind11/tests/test_kwargs_and_defaults.py @@ -2,7 +2,6 @@ import pytest import env # noqa: F401 - from pybind11_tests import kwargs_and_defaults as m @@ -15,11 +14,17 @@ def test_function_signatures(doc): assert doc(m.kw_func_udl) == "kw_func_udl(x: int, y: int = 300) -> str" assert doc(m.kw_func_udl_z) == "kw_func_udl_z(x: int, y: int = 0) -> str" assert doc(m.args_function) == "args_function(*args) -> tuple" - assert doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple" - assert doc(m.KWClass.foo0) == \ - "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" - assert doc(m.KWClass.foo1) == \ - "foo1(self: m.kwargs_and_defaults.KWClass, x: int, y: float) -> None" + assert ( + doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple" + ) + assert ( + doc(m.KWClass.foo0) + == "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" + ) + assert ( + doc(m.KWClass.foo1) + == "foo1(self: m.kwargs_and_defaults.KWClass, x: int, y: float) -> None" + ) def test_named_arguments(msg): @@ -40,7 +45,9 @@ def test_named_arguments(msg): # noinspection PyArgumentList m.kw_func2(x=5, y=10, z=12) assert excinfo.match( - r'(?s)^kw_func2\(\): incompatible.*Invoked with: kwargs: ((x=5|y=10|z=12)(, |$))' + '{3}$') + r"(?s)^kw_func2\(\): incompatible.*Invoked with: kwargs: ((x=5|y=10|z=12)(, |$))" + + "{3}$" + ) assert m.kw_func4() == "{13 17}" assert m.kw_func4(myList=[1, 2, 3]) == "{1 2 3}" @@ -50,11 +57,11 @@ def test_named_arguments(msg): def test_arg_and_kwargs(): - args = 'arg1_value', 'arg2_value', 3 + args = "arg1_value", "arg2_value", 3 assert m.args_function(*args) == args - args = 'a1', 'a2' - kwargs = dict(arg3='a3', arg4=4) + args = "a1", "a2" + kwargs = dict(arg3="a3", arg4=4) assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs) @@ -68,47 +75,118 @@ def test_mixed_args_and_kwargs(msg): assert mpa(1, 2.5) == (1, 2.5, ()) with pytest.raises(TypeError) as excinfo: assert mpa(1) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ mixed_plus_args(): incompatible function arguments. The following argument types are supported: 1. (arg0: int, arg1: float, *args) -> tuple Invoked with: 1 """ # noqa: E501 line too long + ) with pytest.raises(TypeError) as excinfo: assert mpa() - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ mixed_plus_args(): incompatible function arguments. The following argument types are supported: 1. (arg0: int, arg1: float, *args) -> tuple Invoked with: """ # noqa: E501 line too long + ) - assert mpk(-2, 3.5, pi=3.14159, e=2.71828) == (-2, 3.5, {'e': 2.71828, 'pi': 3.14159}) + assert mpk(-2, 3.5, pi=3.14159, e=2.71828) == ( + -2, + 3.5, + {"e": 2.71828, "pi": 3.14159}, + ) assert mpak(7, 7.7, 7.77, 7.777, 7.7777, minusseven=-7) == ( - 7, 7.7, (7.77, 7.777, 7.7777), {'minusseven': -7}) + 7, + 7.7, + (7.77, 7.777, 7.7777), + {"minusseven": -7}, + ) assert mpakd() == (1, 3.14159, (), {}) assert mpakd(3) == (3, 3.14159, (), {}) assert mpakd(j=2.71828) == (1, 2.71828, (), {}) - assert mpakd(k=42) == (1, 3.14159, (), {'k': 42}) + assert mpakd(k=42) == (1, 3.14159, (), {"k": 42}) assert mpakd(1, 1, 2, 3, 5, 8, then=13, followedby=21) == ( - 1, 1, (2, 3, 5, 8), {'then': 13, 'followedby': 21}) + 1, + 1, + (2, 3, 5, 8), + {"then": 13, "followedby": 21}, + ) # Arguments specified both positionally and via kwargs should fail: with pytest.raises(TypeError) as excinfo: assert mpakd(1, i=1) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported: 1. (i: int = 1, j: float = 3.14159, *args, **kwargs) -> tuple Invoked with: 1; kwargs: i=1 """ # noqa: E501 line too long + ) with pytest.raises(TypeError) as excinfo: assert mpakd(1, 2, j=1) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported: 1. (i: int = 1, j: float = 3.14159, *args, **kwargs) -> tuple Invoked with: 1, 2; kwargs: j=1 """ # noqa: E501 line too long + ) + + # Arguments after a py::args are automatically keyword-only (pybind 2.9+) + assert m.args_kwonly(2, 2.5, z=22) == (2, 2.5, (), 22) + assert m.args_kwonly(2, 2.5, "a", "b", "c", z=22) == (2, 2.5, ("a", "b", "c"), 22) + assert m.args_kwonly(z=22, i=4, j=16) == (4, 16, (), 22) + + with pytest.raises(TypeError) as excinfo: + assert m.args_kwonly(2, 2.5, 22) # missing z= keyword + assert ( + msg(excinfo.value) + == """ + args_kwonly(): incompatible function arguments. The following argument types are supported: + 1. (i: int, j: float, *args, z: int) -> tuple + + Invoked with: 2, 2.5, 22 + """ + ) + + assert m.args_kwonly_kwargs(i=1, k=4, j=10, z=-1, y=9) == ( + 1, + 10, + (), + -1, + {"k": 4, "y": 9}, + ) + assert m.args_kwonly_kwargs(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, z=11, y=12) == ( + 1, + 2, + (3, 4, 5, 6, 7, 8, 9, 10), + 11, + {"y": 12}, + ) + assert ( + m.args_kwonly_kwargs.__doc__ + == "args_kwonly_kwargs(i: int, j: float, *args, z: int, **kwargs) -> tuple\n" + ) + + assert ( + m.args_kwonly_kwargs_defaults.__doc__ + == "args_kwonly_kwargs_defaults(i: int = 1, j: float = 3.14159, *args, z: int = 42, **kwargs) -> tuple\n" # noqa: E501 line too long + ) + assert m.args_kwonly_kwargs_defaults() == (1, 3.14159, (), 42, {}) + assert m.args_kwonly_kwargs_defaults(2) == (2, 3.14159, (), 42, {}) + assert m.args_kwonly_kwargs_defaults(z=-99) == (1, 3.14159, (), -99, {}) + assert m.args_kwonly_kwargs_defaults(5, 6, 7, 8) == (5, 6, (7, 8), 42, {}) + assert m.args_kwonly_kwargs_defaults(5, 6, 7, m=8) == (5, 6, (7,), 42, {"m": 8}) + assert m.args_kwonly_kwargs_defaults(5, 6, 7, m=8, z=9) == (5, 6, (7,), 9, {"m": 8}) def test_keyword_only_args(msg): @@ -134,9 +212,9 @@ def test_keyword_only_args(msg): assert m.kw_only_mixed(j=2, i=3) == (3, 2) assert m.kw_only_mixed(i=2, j=3) == (2, 3) - assert m.kw_only_plus_more(4, 5, k=6, extra=7) == (4, 5, 6, {'extra': 7}) - assert m.kw_only_plus_more(3, k=5, j=4, extra=6) == (3, 4, 5, {'extra': 6}) - assert m.kw_only_plus_more(2, k=3, extra=4) == (2, -1, 3, {'extra': 4}) + assert m.kw_only_plus_more(4, 5, k=6, extra=7) == (4, 5, 6, {"extra": 7}) + assert m.kw_only_plus_more(3, k=5, j=4, extra=6) == (3, 4, 5, {"extra": 6}) + assert m.kw_only_plus_more(2, k=3, extra=4) == (2, -1, 3, {"extra": 4}) with pytest.raises(TypeError) as excinfo: assert m.kw_only_mixed(i=1) == (1,) @@ -144,9 +222,25 @@ def test_keyword_only_args(msg): with pytest.raises(RuntimeError) as excinfo: m.register_invalid_kw_only(m) - assert msg(excinfo.value) == """ - arg(): cannot specify an unnamed argument after an kw_only() annotation + assert ( + msg(excinfo.value) + == """ + arg(): cannot specify an unnamed argument after a kw_only() annotation or args() argument """ + ) + + # https://github.com/pybind/pybind11/pull/3402#issuecomment-963341987 + x = m.first_arg_kw_only(i=1) + x.method() + x.method(i=1, j=2) + assert ( + m.first_arg_kw_only.__init__.__doc__ + == "__init__(self: pybind11_tests.kwargs_and_defaults.first_arg_kw_only, *, i: int = 0) -> None\n" # noqa: E501 line too long + ) + assert ( + m.first_arg_kw_only.method.__doc__ + == "method(self: pybind11_tests.kwargs_and_defaults.first_arg_kw_only, *, i: int = 1, j: int = 2) -> None\n" # noqa: E501 line too long + ) def test_positional_only_args(msg): @@ -188,13 +282,65 @@ def test_positional_only_args(msg): m.pos_only_def_mix(1, j=4) assert "incompatible function arguments" in str(excinfo.value) + # Mix it with args and kwargs: + assert ( + m.args_kwonly_full_monty.__doc__ + == "args_kwonly_full_monty(arg0: int = 1, arg1: int = 2, /, j: float = 3.14159, *args, z: int = 42, **kwargs) -> tuple\n" # noqa: E501 line too long + ) + assert m.args_kwonly_full_monty() == (1, 2, 3.14159, (), 42, {}) + assert m.args_kwonly_full_monty(8) == (8, 2, 3.14159, (), 42, {}) + assert m.args_kwonly_full_monty(8, 9) == (8, 9, 3.14159, (), 42, {}) + assert m.args_kwonly_full_monty(8, 9, 10) == (8, 9, 10.0, (), 42, {}) + assert m.args_kwonly_full_monty(3, 4, 5, 6, 7, m=8, z=9) == ( + 3, + 4, + 5.0, + ( + 6, + 7, + ), + 9, + {"m": 8}, + ) + assert m.args_kwonly_full_monty(3, 4, 5, 6, 7, m=8, z=9) == ( + 3, + 4, + 5.0, + ( + 6, + 7, + ), + 9, + {"m": 8}, + ) + assert m.args_kwonly_full_monty(5, j=7, m=8, z=9) == (5, 2, 7.0, (), 9, {"m": 8}) + assert m.args_kwonly_full_monty(i=5, j=7, m=8, z=9) == ( + 1, + 2, + 7.0, + (), + 9, + {"i": 5, "m": 8}, + ) + + # pos_only at the beginning of the argument list was "broken" in how it was displayed (though + # this is fairly useless in practice). Related to: + # https://github.com/pybind/pybind11/pull/3402#issuecomment-963341987 + assert ( + m.first_arg_kw_only.pos_only.__doc__ + == "pos_only(self: pybind11_tests.kwargs_and_defaults.first_arg_kw_only, /, i: int, j: int) -> None\n" # noqa: E501 line too long + ) + def test_signatures(): assert "kw_only_all(*, i: int, j: int) -> tuple\n" == m.kw_only_all.__doc__ assert "kw_only_mixed(i: int, *, j: int) -> tuple\n" == m.kw_only_mixed.__doc__ assert "pos_only_all(i: int, j: int, /) -> tuple\n" == m.pos_only_all.__doc__ assert "pos_only_mix(i: int, /, j: int) -> tuple\n" == m.pos_only_mix.__doc__ - assert "pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple\n" == m.pos_kw_only_mix.__doc__ + assert ( + "pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple\n" + == m.pos_kw_only_mix.__doc__ + ) @pytest.mark.xfail("env.PYPY and env.PY2", reason="PyPy2 doesn't double count") @@ -219,11 +365,18 @@ def test_args_refcount(): assert m.args_function(-1, myval) == (-1, myval) assert refcount(myval) == expected - assert m.mixed_plus_args_kwargs(5, 6.0, myval, a=myval) == (5, 6.0, (myval,), {"a": myval}) + assert m.mixed_plus_args_kwargs(5, 6.0, myval, a=myval) == ( + 5, + 6.0, + (myval,), + {"a": myval}, + ) assert refcount(myval) == expected - assert m.args_kwargs_function(7, 8, myval, a=1, b=myval) == \ - ((7, 8, myval), {"a": 1, "b": myval}) + assert m.args_kwargs_function(7, 8, myval, a=1, b=myval) == ( + (7, 8, myval), + {"a": 1, "b": myval}, + ) assert refcount(myval) == expected exp3 = refcount(myval, myval, myval) diff --git a/wrap/pybind11/tests/test_local_bindings.cpp b/wrap/pybind11/tests/test_local_bindings.cpp index 97c02dbeb..a5808e2f2 100644 --- a/wrap/pybind11/tests/test_local_bindings.cpp +++ b/wrap/pybind11/tests/test_local_bindings.cpp @@ -10,9 +10,12 @@ #include "pybind11_tests.h" #include "local_bindings.h" + #include #include + #include +#include TEST_SUBMODULE(local_bindings, m) { // test_load_external @@ -41,7 +44,7 @@ TEST_SUBMODULE(local_bindings, m) { // should raise a runtime error from the duplicate definition attempt. If test_class isn't // available it *also* throws a runtime error (with "test_class not enabled" as value). m.def("register_local_external", [m]() { - auto main = py::module::import("pybind11_tests"); + auto main = py::module_::import("pybind11_tests"); if (py::hasattr(main, "class_")) { bind_local(m, "LocalExternal", py::module_local()); } @@ -75,7 +78,7 @@ TEST_SUBMODULE(local_bindings, m) { m.def("get_mixed_lg", [](int i) { return MixedLocalGlobal(i); }); // test_internal_locals_differ - m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::registered_local_types_cpp(); }); + m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::get_local_internals().registered_types_cpp; }); // test_stl_caster_vs_stl_bind m.def("load_vector_via_caster", [](std::vector v) { @@ -86,7 +89,10 @@ TEST_SUBMODULE(local_bindings, m) { m.def("return_self", [](LocalVec *v) { return v; }); m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); }); - class Cat : public pets::Pet { public: Cat(std::string name) : Pet(name) {}; }; + class Cat : public pets::Pet { + public: + explicit Cat(std::string name) : Pet(std::move(name)) {} + }; py::class_(m, "Pet", py::module_local()) .def("get_name", &pets::Pet::name); // Binding for local extending class: diff --git a/wrap/pybind11/tests/test_local_bindings.py b/wrap/pybind11/tests/test_local_bindings.py index 5460727e1..52b1b6335 100644 --- a/wrap/pybind11/tests/test_local_bindings.py +++ b/wrap/pybind11/tests/test_local_bindings.py @@ -2,7 +2,6 @@ import pytest import env # noqa: F401 - from pybind11_tests import local_bindings as m @@ -36,8 +35,8 @@ def test_local_bindings(): assert i2.get() == 11 assert i2.get2() == 12 - assert not hasattr(i1, 'get2') - assert not hasattr(i2, 'get3') + assert not hasattr(i1, "get2") + assert not hasattr(i2, "get3") # Loading within the local module assert m.local_value(i1) == 5 @@ -55,7 +54,9 @@ def test_nonlocal_failure(): with pytest.raises(RuntimeError) as excinfo: cm.register_nonlocal() - assert str(excinfo.value) == 'generic_type: type "NonLocalType" is already registered!' + assert ( + str(excinfo.value) == 'generic_type: type "NonLocalType" is already registered!' + ) def test_duplicate_local(): @@ -63,9 +64,12 @@ def test_duplicate_local(): with pytest.raises(RuntimeError) as excinfo: m.register_local_external() import pybind11_tests + assert str(excinfo.value) == ( 'generic_type: type "LocalExternal" is already registered!' - if hasattr(pybind11_tests, 'class_') else 'test_class not enabled') + if hasattr(pybind11_tests, "class_") + else "test_class not enabled" + ) def test_stl_bind_local(): @@ -98,8 +102,8 @@ def test_stl_bind_local(): d1["b"] = v1[1] d2["c"] = v2[0] d2["d"] = v2[1] - assert {i: d1[i].get() for i in d1} == {'a': 0, 'b': 1} - assert {i: d2[i].get() for i in d2} == {'c': 2, 'd': 3} + assert {i: d1[i].get() for i in d1} == {"a": 0, "b": 1} + assert {i: d2[i].get() for i in d2} == {"c": 2, "d": 3} def test_stl_bind_global(): @@ -107,15 +111,21 @@ def test_stl_bind_global(): with pytest.raises(RuntimeError) as excinfo: cm.register_nonlocal_map() - assert str(excinfo.value) == 'generic_type: type "NonLocalMap" is already registered!' + assert ( + str(excinfo.value) == 'generic_type: type "NonLocalMap" is already registered!' + ) with pytest.raises(RuntimeError) as excinfo: cm.register_nonlocal_vec() - assert str(excinfo.value) == 'generic_type: type "NonLocalVec" is already registered!' + assert ( + str(excinfo.value) == 'generic_type: type "NonLocalVec" is already registered!' + ) with pytest.raises(RuntimeError) as excinfo: cm.register_nonlocal_map2() - assert str(excinfo.value) == 'generic_type: type "NonLocalMap2" is already registered!' + assert ( + str(excinfo.value) == 'generic_type: type "NonLocalMap2" is already registered!' + ) def test_mixed_local_global(): @@ -123,6 +133,7 @@ def test_mixed_local_global(): type can be registered even if the type is already registered globally. With the module, casting will go to the local type; outside the module casting goes to the global type.""" import pybind11_cross_module_tests as cm + m.register_mixed_global() m.register_mixed_local() @@ -145,17 +156,30 @@ def test_mixed_local_global(): a.append(cm.get_mixed_gl(11)) a.append(cm.get_mixed_lg(12)) - assert [x.get() for x in a] == \ - [101, 1002, 103, 1004, 105, 1006, 207, 2008, 109, 1010, 211, 2012] + assert [x.get() for x in a] == [ + 101, + 1002, + 103, + 1004, + 105, + 1006, + 207, + 2008, + 109, + 1010, + 211, + 2012, + ] def test_internal_locals_differ(): """Makes sure the internal local type map differs across the two modules""" import pybind11_cross_module_tests as cm + assert m.local_cpp_types_addr() != cm.local_cpp_types_addr() -@pytest.mark.xfail("env.PYPY") +@pytest.mark.xfail("env.PYPY and sys.pypy_version_info < (7, 3, 2)") def test_stl_caster_vs_stl_bind(msg): """One module uses a generic vector caster from `` while the other exports `std::vector` via `py:bind_vector` and `py::module_local`""" @@ -168,13 +192,16 @@ def test_stl_caster_vs_stl_bind(msg): v2 = [1, 2, 3] assert m.load_vector_via_caster(v2) == 6 with pytest.raises(TypeError) as excinfo: - cm.load_vector_via_binding(v2) == 6 - assert msg(excinfo.value) == """ + cm.load_vector_via_binding(v2) + assert ( + msg(excinfo.value) + == """ load_vector_via_binding(): incompatible function arguments. The following argument types are supported: 1. (arg0: pybind11_cross_module_tests.VectorInt) -> int Invoked with: [1, 2, 3] """ # noqa: E501 line too long + ) def test_cross_module_calls(): diff --git a/wrap/pybind11/tests/test_methods_and_attributes.cpp b/wrap/pybind11/tests/test_methods_and_attributes.cpp index 11d4e7b35..9e55452de 100644 --- a/wrap/pybind11/tests/test_methods_and_attributes.cpp +++ b/wrap/pybind11/tests/test_methods_and_attributes.cpp @@ -19,19 +19,21 @@ using overload_cast_ = pybind11::detail::overload_cast_impl; class ExampleMandA { public: ExampleMandA() { print_default_created(this); } - ExampleMandA(int value) : value(value) { print_created(this, value); } + explicit ExampleMandA(int value) : value(value) { print_created(this, value); } ExampleMandA(const ExampleMandA &e) : value(e.value) { print_copy_created(this); } - ExampleMandA(std::string&&) {} - ExampleMandA(ExampleMandA &&e) : value(e.value) { print_move_created(this); } + explicit ExampleMandA(std::string &&) {} + ExampleMandA(ExampleMandA &&e) noexcept : value(e.value) { print_move_created(this); } ~ExampleMandA() { print_destroyed(this); } - std::string toString() { - return "ExampleMandA[value=" + std::to_string(value) + "]"; - } + std::string toString() const { return "ExampleMandA[value=" + std::to_string(value) + "]"; } void operator=(const ExampleMandA &e) { print_copy_assigned(this); value = e.value; } - void operator=(ExampleMandA &&e) { print_move_assigned(this); value = e.value; } + void operator=(ExampleMandA &&e) noexcept { + print_move_assigned(this); + value = e.value; + } + // NOLINTNEXTLINE(performance-unnecessary-value-param) void add1(ExampleMandA other) { value += other.value; } // passing by value void add2(ExampleMandA &other) { value += other.value; } // passing by reference void add3(const ExampleMandA &other) { value += other.value; } // passing by const reference @@ -41,6 +43,7 @@ public: void add6(int other) { value += other; } // passing by value void add7(int &other) { value += other; } // passing by reference void add8(const int &other) { value += other; } // passing by const reference + // NOLINTNEXTLINE(readability-non-const-parameter) Deliberately non-const for testing void add9(int *other) { value += *other; } // passing by pointer void add10(const int *other) { value += *other; } // passing by const pointer @@ -48,13 +51,13 @@ public: ExampleMandA self1() { return *this; } // return by value ExampleMandA &self2() { return *this; } // return by reference - const ExampleMandA &self3() { return *this; } // return by const reference + const ExampleMandA &self3() const { return *this; } // return by const reference ExampleMandA *self4() { return this; } // return by pointer - const ExampleMandA *self5() { return this; } // return by const pointer + const ExampleMandA *self5() const { return this; } // return by const pointer - int internal1() { return value; } // return by value + int internal1() const { return value; } // return by value int &internal2() { return value; } // return by reference - const int &internal3() { return value; } // return by const reference + const int &internal3() const { return value; } // return by const reference int *internal4() { return &value; } // return by pointer const int *internal5() { return &value; } // return by const pointer @@ -114,13 +117,21 @@ int none1(const NoneTester &obj) { return obj.answer; } int none2(NoneTester *obj) { return obj ? obj->answer : -1; } int none3(std::shared_ptr &obj) { return obj ? obj->answer : -1; } int none4(std::shared_ptr *obj) { return obj && *obj ? (*obj)->answer : -1; } -int none5(std::shared_ptr obj) { return obj ? obj->answer : -1; } +int none5(const std::shared_ptr &obj) { return obj ? obj->answer : -1; } + +// Issue #2778: implicit casting from None to object (not pointer) +class NoneCastTester { +public: + int answer = -1; + NoneCastTester() = default; + explicit NoneCastTester(int v) : answer(v) {} +}; struct StrIssue { int val = -1; StrIssue() = default; - StrIssue(int i) : val{i} {} + explicit StrIssue(int i) : val{i} {} }; // Issues #854, #910: incompatible function args when member function/pointer is in unregistered base class @@ -148,6 +159,14 @@ struct RefQualified { int constRefQualified(int other) const & { return value + other; } }; +// Test rvalue ref param +struct RValueRefParam { + std::size_t func1(std::string&& s) { return s.size(); } + std::size_t func2(std::string&& s) const { return s.size(); } + std::size_t func3(std::string&& s) & { return s.size(); } + std::size_t func4(std::string&& s) const & { return s.size(); } +}; + TEST_SUBMODULE(methods_and_attributes, m) { // test_methods_and_attributes py::class_ emna(m, "ExampleMandA"); @@ -207,12 +226,12 @@ TEST_SUBMODULE(methods_and_attributes, m) { // test_no_mixed_overloads // Raise error if trying to mix static/non-static overloads on the same name: .def_static("add_mixed_overloads1", []() { - auto emna = py::reinterpret_borrow>(py::module::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); + auto emna = py::reinterpret_borrow>(py::module_::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); emna.def ("overload_mixed1", static_cast(&ExampleMandA::overloaded)) .def_static("overload_mixed1", static_cast(&ExampleMandA::overloaded)); }) .def_static("add_mixed_overloads2", []() { - auto emna = py::reinterpret_borrow>(py::module::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); + auto emna = py::reinterpret_borrow>(py::module_::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); emna.def_static("overload_mixed2", static_cast(&ExampleMandA::overloaded)) .def ("overload_mixed2", static_cast(&ExampleMandA::overloaded)); }) @@ -228,36 +247,41 @@ TEST_SUBMODULE(methods_and_attributes, m) { .def(py::init<>()) .def_readonly("def_readonly", &TestProperties::value) .def_readwrite("def_readwrite", &TestProperties::value) - .def_property("def_writeonly", nullptr, - [](TestProperties& s,int v) { s.value = v; } ) + .def_property("def_writeonly", nullptr, [](TestProperties &s, int v) { s.value = v; }) .def_property("def_property_writeonly", nullptr, &TestProperties::set) .def_property_readonly("def_property_readonly", &TestProperties::get) .def_property("def_property", &TestProperties::get, &TestProperties::set) .def_property("def_property_impossible", nullptr, nullptr) .def_readonly_static("def_readonly_static", &TestProperties::static_value) .def_readwrite_static("def_readwrite_static", &TestProperties::static_value) - .def_property_static("def_writeonly_static", nullptr, - [](py::object, int v) { TestProperties::static_value = v; }) - .def_property_readonly_static("def_property_readonly_static", - [](py::object) { return TestProperties::static_get(); }) - .def_property_static("def_property_writeonly_static", nullptr, - [](py::object, int v) { return TestProperties::static_set(v); }) - .def_property_static("def_property_static", - [](py::object) { return TestProperties::static_get(); }, - [](py::object, int v) { TestProperties::static_set(v); }) - .def_property_static("static_cls", - [](py::object cls) { return cls; }, - [](py::object cls, py::function f) { f(cls); }); + .def_property_static("def_writeonly_static", + nullptr, + [](const py::object &, int v) { TestProperties::static_value = v; }) + .def_property_readonly_static( + "def_property_readonly_static", + [](const py::object &) { return TestProperties::static_get(); }) + .def_property_static( + "def_property_writeonly_static", + nullptr, + [](const py::object &, int v) { return TestProperties::static_set(v); }) + .def_property_static( + "def_property_static", + [](const py::object &) { return TestProperties::static_get(); }, + [](const py::object &, int v) { TestProperties::static_set(v); }) + .def_property_static( + "static_cls", + [](py::object cls) { return cls; }, + [](const py::object &cls, const py::function &f) { f(cls); }); py::class_(m, "TestPropertiesOverride") .def(py::init<>()) .def_readonly("def_readonly", &TestPropertiesOverride::value) .def_readonly_static("def_readonly_static", &TestPropertiesOverride::static_value); - auto static_get1 = [](py::object) -> const UserType & { return TestPropRVP::sv1; }; - auto static_get2 = [](py::object) -> const UserType & { return TestPropRVP::sv2; }; - auto static_set1 = [](py::object, int v) { TestPropRVP::sv1.set(v); }; - auto static_set2 = [](py::object, int v) { TestPropRVP::sv2.set(v); }; + auto static_get1 = [](const py::object &) -> const UserType & { return TestPropRVP::sv1; }; + auto static_get2 = [](const py::object &) -> const UserType & { return TestPropRVP::sv2; }; + auto static_set1 = [](const py::object &, int v) { TestPropRVP::sv1.set(v); }; + auto static_set2 = [](const py::object &, int v) { TestPropRVP::sv2.set(v); }; auto rvp_copy = py::return_value_policy::copy; // test_property_return_value_policies @@ -268,21 +292,30 @@ TEST_SUBMODULE(methods_and_attributes, m) { .def_property_readonly("ro_func", py::cpp_function(&TestPropRVP::get2, rvp_copy)) .def_property("rw_ref", &TestPropRVP::get1, &TestPropRVP::set1) .def_property("rw_copy", &TestPropRVP::get2, &TestPropRVP::set2, rvp_copy) - .def_property("rw_func", py::cpp_function(&TestPropRVP::get2, rvp_copy), &TestPropRVP::set2) + .def_property( + "rw_func", py::cpp_function(&TestPropRVP::get2, rvp_copy), &TestPropRVP::set2) .def_property_readonly_static("static_ro_ref", static_get1) .def_property_readonly_static("static_ro_copy", static_get2, rvp_copy) .def_property_readonly_static("static_ro_func", py::cpp_function(static_get2, rvp_copy)) .def_property_static("static_rw_ref", static_get1, static_set1) .def_property_static("static_rw_copy", static_get2, static_set2, rvp_copy) - .def_property_static("static_rw_func", py::cpp_function(static_get2, rvp_copy), static_set2) + .def_property_static( + "static_rw_func", py::cpp_function(static_get2, rvp_copy), static_set2) // test_property_rvalue_policy .def_property_readonly("rvalue", &TestPropRVP::get_rvalue) - .def_property_readonly_static("static_rvalue", [](py::object) { return UserType(1); }); + .def_property_readonly_static("static_rvalue", + [](const py::object &) { return UserType(1); }); // test_metaclass_override struct MetaclassOverride { }; py::class_(m, "MetaclassOverride", py::metaclass((PyObject *) &PyType_Type)) - .def_property_readonly_static("readonly", [](py::object) { return 1; }); + .def_property_readonly_static("readonly", [](const py::object &) { return 1; }); + + // test_overload_ordering + m.def("overload_order", [](const std::string &) { return 1; }); + m.def("overload_order", [](const std::string &) { return 2; }); + m.def("overload_order", [](int) { return 3; }); + m.def("overload_order", [](int) { return 4; }, py::prepend{}); #if !defined(PYPY_VERSION) // test_dynamic_attributes @@ -308,28 +341,43 @@ TEST_SUBMODULE(methods_and_attributes, m) { m.attr("debug_enabled") = false; #endif m.def("bad_arg_def_named", []{ - auto m = py::module::import("pybind11_tests"); + auto m = py::module_::import("pybind11_tests"); m.def("should_fail", [](int, UnregisteredType) {}, py::arg(), py::arg("a") = UnregisteredType()); }); m.def("bad_arg_def_unnamed", []{ - auto m = py::module::import("pybind11_tests"); + auto m = py::module_::import("pybind11_tests"); m.def("should_fail", [](int, UnregisteredType) {}, py::arg(), py::arg() = UnregisteredType()); }); + // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works. + // test_accepts_none py::class_>(m, "NoneTester") .def(py::init<>()); - m.def("no_none1", &none1, py::arg().none(false)); - m.def("no_none2", &none2, py::arg().none(false)); - m.def("no_none3", &none3, py::arg().none(false)); - m.def("no_none4", &none4, py::arg().none(false)); - m.def("no_none5", &none5, py::arg().none(false)); + m.def("no_none1", &none1, py::arg{}.none(false)); + m.def("no_none2", &none2, py::arg{}.none(false)); + m.def("no_none3", &none3, py::arg{}.none(false)); + m.def("no_none4", &none4, py::arg{}.none(false)); + m.def("no_none5", &none5, py::arg{}.none(false)); m.def("ok_none1", &none1); - m.def("ok_none2", &none2, py::arg().none(true)); + m.def("ok_none2", &none2, py::arg{}.none(true)); m.def("ok_none3", &none3); - m.def("ok_none4", &none4, py::arg().none(true)); + m.def("ok_none4", &none4, py::arg{}.none(true)); m.def("ok_none5", &none5); + m.def("no_none_kwarg", &none2, "a"_a.none(false)); + m.def("no_none_kwarg_kw_only", &none2, py::kw_only(), "a"_a.none(false)); + + // test_casts_none + // Issue #2778: implicit casting from None to object (not pointer) + py::class_(m, "NoneCastTester") + .def(py::init<>()) + .def(py::init()) + .def(py::init([](py::none const&) { return NoneCastTester{}; })); + py::implicitly_convertible(); + m.def("ok_obj_or_none", [](NoneCastTester const& foo) { return foo.answer; }); + + // test_str_issue // Issue #283: __str__ called on uninitialized instance when constructor arguments invalid py::class_(m, "StrIssue") @@ -351,14 +399,14 @@ TEST_SUBMODULE(methods_and_attributes, m) { .def("increase_value", &RegisteredDerived::increase_value) .def_readwrite("rw_value", &RegisteredDerived::rw_value) .def_readonly("ro_value", &RegisteredDerived::ro_value) - // These should trigger a static_assert if uncommented - //.def_readwrite("fails", &UserType::value) // should trigger a static_assert if uncommented - //.def_readonly("fails", &UserType::value) // should trigger a static_assert if uncommented + // Uncommenting the next line should trigger a static_assert: + // .def_readwrite("fails", &UserType::value) + // Uncommenting the next line should trigger a static_assert: + // .def_readonly("fails", &UserType::value) .def_property("rw_value_prop", &RegisteredDerived::get_int, &RegisteredDerived::set_int) .def_property_readonly("ro_value_prop", &RegisteredDerived::get_double) // This one is in the registered class: - .def("sum", &RegisteredDerived::sum) - ; + .def("sum", &RegisteredDerived::sum); using Adapted = decltype(py::method_adaptor(&RegisteredDerived::do_nothing)); static_assert(std::is_same::value, ""); @@ -369,4 +417,11 @@ TEST_SUBMODULE(methods_and_attributes, m) { .def_readonly("value", &RefQualified::value) .def("refQualified", &RefQualified::refQualified) .def("constRefQualified", &RefQualified::constRefQualified); + + py::class_(m, "RValueRefParam") + .def(py::init<>()) + .def("func1", &RValueRefParam::func1) + .def("func2", &RValueRefParam::func2) + .def("func3", &RValueRefParam::func3) + .def("func4", &RValueRefParam::func4); } diff --git a/wrap/pybind11/tests/test_methods_and_attributes.py b/wrap/pybind11/tests/test_methods_and_attributes.py index c296b6868..fa026f9ed 100644 --- a/wrap/pybind11/tests/test_methods_and_attributes.py +++ b/wrap/pybind11/tests/test_methods_and_attributes.py @@ -2,9 +2,8 @@ import pytest import env # noqa: F401 - -from pybind11_tests import methods_and_attributes as m from pybind11_tests import ConstructorStats +from pybind11_tests import methods_and_attributes as m def test_methods_and_attributes(): @@ -40,17 +39,17 @@ def test_methods_and_attributes(): assert instance1.overloaded(0) == "(int)" assert instance1.overloaded(1, 1.0) == "(int, float)" assert instance1.overloaded(2.0, 2) == "(float, int)" - assert instance1.overloaded(3, 3) == "(int, int)" - assert instance1.overloaded(4., 4.) == "(float, float)" + assert instance1.overloaded(3, 3) == "(int, int)" + assert instance1.overloaded(4.0, 4.0) == "(float, float)" assert instance1.overloaded_const(-3) == "(int) const" assert instance1.overloaded_const(5, 5.0) == "(int, float) const" assert instance1.overloaded_const(6.0, 6) == "(float, int) const" - assert instance1.overloaded_const(7, 7) == "(int, int) const" - assert instance1.overloaded_const(8., 8.) == "(float, float) const" + assert instance1.overloaded_const(7, 7) == "(int, int) const" + assert instance1.overloaded_const(8.0, 8.0) == "(float, float) const" assert instance1.overloaded_float(1, 1) == "(float, float)" - assert instance1.overloaded_float(1, 1.) == "(float, float)" - assert instance1.overloaded_float(1., 1) == "(float, float)" - assert instance1.overloaded_float(1., 1.) == "(float, float)" + assert instance1.overloaded_float(1, 1.0) == "(float, float)" + assert instance1.overloaded_float(1.0, 1) == "(float, float)" + assert instance1.overloaded_float(1.0, 1.0) == "(float, float)" assert instance1.value == 320 instance1.value = 100 @@ -103,7 +102,7 @@ def test_properties(): assert instance.def_property == 3 with pytest.raises(AttributeError) as excinfo: - dummy = instance.def_property_writeonly # noqa: F841 unused var + dummy = instance.def_property_writeonly # unused var assert "unreadable attribute" in str(excinfo.value) instance.def_property_writeonly = 4 @@ -128,7 +127,7 @@ def test_static_properties(): assert m.TestProperties.def_readwrite_static == 2 with pytest.raises(AttributeError) as excinfo: - dummy = m.TestProperties.def_writeonly_static # noqa: F841 unused var + dummy = m.TestProperties.def_writeonly_static # unused var assert "unreadable attribute" in str(excinfo.value) m.TestProperties.def_writeonly_static = 3 @@ -171,6 +170,19 @@ def test_static_properties(): assert m.TestPropertiesOverride().def_readonly == 99 assert m.TestPropertiesOverride.def_readonly_static == 99 + # Only static attributes can be deleted + del m.TestPropertiesOverride.def_readonly_static + assert ( + hasattr(m.TestPropertiesOverride, "def_readonly_static") + and m.TestPropertiesOverride.def_readonly_static + is m.TestProperties.def_readonly_static + ) + assert "def_readonly_static" not in m.TestPropertiesOverride.__dict__ + properties_override = m.TestPropertiesOverride() + with pytest.raises(AttributeError) as excinfo: + del properties_override.def_readonly + assert "can't delete attribute" in str(excinfo.value) + def test_static_cls(): """Static property getter and setters expect the type object as the their only argument""" @@ -193,7 +205,10 @@ def test_metaclass_override(): assert type(m.MetaclassOverride).__name__ == "type" assert m.MetaclassOverride.readonly == 1 - assert type(m.MetaclassOverride.__dict__["readonly"]).__name__ == "pybind11_static_property" + assert ( + type(m.MetaclassOverride.__dict__["readonly"]).__name__ + == "pybind11_static_property" + ) # Regular `type` replaces the property instead of calling `__set__()` m.MetaclassOverride.readonly = 2 @@ -206,22 +221,26 @@ def test_no_mixed_overloads(): with pytest.raises(RuntimeError) as excinfo: m.ExampleMandA.add_mixed_overloads1() - assert (str(excinfo.value) == - "overloading a method with both static and instance methods is not supported; " + - ("compile in debug mode for more details" if not debug_enabled else - "error while attempting to bind static method ExampleMandA.overload_mixed1" - "(arg0: float) -> str") - ) + assert str( + excinfo.value + ) == "overloading a method with both static and instance methods is not supported; " + ( + "compile in debug mode for more details" + if not debug_enabled + else "error while attempting to bind static method ExampleMandA.overload_mixed1" + "(arg0: float) -> str" + ) with pytest.raises(RuntimeError) as excinfo: m.ExampleMandA.add_mixed_overloads2() - assert (str(excinfo.value) == - "overloading a method with both static and instance methods is not supported; " + - ("compile in debug mode for more details" if not debug_enabled else - "error while attempting to bind instance method ExampleMandA.overload_mixed2" - "(self: pybind11_tests.methods_and_attributes.ExampleMandA, arg0: int, arg1: int)" - " -> str") - ) + assert str( + excinfo.value + ) == "overloading a method with both static and instance methods is not supported; " + ( + "compile in debug mode for more details" + if not debug_enabled + else "error while attempting to bind instance method ExampleMandA.overload_mixed2" + "(self: pybind11_tests.methods_and_attributes.ExampleMandA, arg0: int, arg1: int)" + " -> str" + ) @pytest.mark.parametrize("access", ["ro", "rw", "static_ro", "static_rw"]) @@ -333,8 +352,8 @@ def test_bad_arg_default(msg): assert msg(excinfo.value) == ( "arg(): could not convert default argument 'a: UnregisteredType' in function " "'should_fail' into a Python object (type not registered yet?)" - if debug_enabled else - "arg(): could not convert default argument into a Python object (type not registered " + if debug_enabled + else "arg(): could not convert default argument into a Python object (type not registered " "yet?). Compile in debug mode for more information." ) @@ -343,8 +362,8 @@ def test_bad_arg_default(msg): assert msg(excinfo.value) == ( "arg(): could not convert default argument 'UnregisteredType' in function " "'should_fail' into a Python object (type not registered yet?)" - if debug_enabled else - "arg(): could not convert default argument into a Python object (type not registered " + if debug_enabled + else "arg(): could not convert default argument into a Python object (type not registered " "yet?). Compile in debug mode for more information." ) @@ -381,12 +400,15 @@ def test_accepts_none(msg): # The first one still raises because you can't pass None as a lvalue reference arg: with pytest.raises(TypeError) as excinfo: assert m.ok_none1(None) == -1 - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ ok_none1(): incompatible function arguments. The following argument types are supported: 1. (arg0: m.methods_and_attributes.NoneTester) -> int Invoked with: None """ + ) # The rest take the argument as pointer or holder, and accept None: assert m.ok_none2(None) == -1 @@ -394,6 +416,30 @@ def test_accepts_none(msg): assert m.ok_none4(None) == -1 assert m.ok_none5(None) == -1 + with pytest.raises(TypeError) as excinfo: + m.no_none_kwarg(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + m.no_none_kwarg(a=None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + m.no_none_kwarg_kw_only(None) + assert "incompatible function arguments" in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + m.no_none_kwarg_kw_only(a=None) + assert "incompatible function arguments" in str(excinfo.value) + + +def test_casts_none(): + """#2778: implicit casting from None to object (not pointer)""" + a = m.NoneCastTester() + assert m.ok_obj_or_none(a) == -1 + a = m.NoneCastTester(4) + assert m.ok_obj_or_none(a) == 4 + a = m.NoneCastTester(None) + assert m.ok_obj_or_none(a) == -1 + assert m.ok_obj_or_none(None) == -1 + def test_str_issue(msg): """#283: __str__ called on uninitialized instance when constructor arguments invalid""" @@ -402,13 +448,16 @@ def test_str_issue(msg): with pytest.raises(TypeError) as excinfo: str(m.StrIssue("no", "such", "constructor")) - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ __init__(): incompatible constructor arguments. The following argument types are supported: 1. m.methods_and_attributes.StrIssue(arg0: int) 2. m.methods_and_attributes.StrIssue() Invoked with: 'no', 'such', 'constructor' """ + ) def test_unregistered_base_implementations(): @@ -438,3 +487,39 @@ def test_ref_qualified(): r.refQualified(17) assert r.value == 17 assert r.constRefQualified(23) == 40 + + +def test_overload_ordering(): + "Check to see if the normal overload order (first defined) and prepend overload order works" + assert m.overload_order("string") == 1 + assert m.overload_order(0) == 4 + + # Different for Python 2 vs. 3 + uni_name = type(u"").__name__ + + assert "1. overload_order(arg0: int) -> int" in m.overload_order.__doc__ + assert ( + "2. overload_order(arg0: {}) -> int".format(uni_name) + in m.overload_order.__doc__ + ) + assert ( + "3. overload_order(arg0: {}) -> int".format(uni_name) + in m.overload_order.__doc__ + ) + assert "4. overload_order(arg0: int) -> int" in m.overload_order.__doc__ + + with pytest.raises(TypeError) as err: + m.overload_order(1.1) + + assert "1. (arg0: int) -> int" in str(err.value) + assert "2. (arg0: {}) -> int".format(uni_name) in str(err.value) + assert "3. (arg0: {}) -> int".format(uni_name) in str(err.value) + assert "4. (arg0: int) -> int" in str(err.value) + + +def test_rvalue_ref_param(): + r = m.RValueRefParam() + assert r.func1("123") == 3 + assert r.func2("1234") == 4 + assert r.func3("12345") == 5 + assert r.func4("123456") == 6 diff --git a/wrap/pybind11/tests/test_modules.cpp b/wrap/pybind11/tests/test_modules.cpp index c1475fa62..ce61c1a25 100644 --- a/wrap/pybind11/tests/test_modules.cpp +++ b/wrap/pybind11/tests/test_modules.cpp @@ -13,17 +13,19 @@ TEST_SUBMODULE(modules, m) { // test_nested_modules + // This is intentionally "py::module" to verify it still can be used in place of "py::module_" py::module m_sub = m.def_submodule("subsubmodule"); m_sub.def("submodule_func", []() { return "submodule_func()"; }); // test_reference_internal class A { public: - A(int v) : v(v) { print_created(this, v); } + explicit A(int v) : v(v) { print_created(this, v); } ~A() { print_destroyed(this); } A(const A&) { print_copy_created(this); } A& operator=(const A ©) { print_copy_assigned(this); v = copy.v; return *this; } - std::string toString() { return "A[" + std::to_string(v) + "]"; } + std::string toString() const { return "A[" + std::to_string(v) + "]"; } + private: int v; }; @@ -50,6 +52,7 @@ TEST_SUBMODULE(modules, m) { .def_readwrite("a1", &B::a1) // def_readonly uses an internal reference return policy by default .def_readwrite("a2", &B::a2); + // This is intentionally "py::module" to verify it still can be used in place of "py::module_" m.attr("OD") = py::module::import("collections").attr("OrderedDict"); // test_duplicate_registration @@ -60,7 +63,8 @@ TEST_SUBMODULE(modules, m) { class Dupe3 { }; class DupeException { }; - auto dm = py::module("dummy"); + // Go ahead and leak, until we have a non-leaking py::module_ constructor + auto dm = py::module_::create_extension_module("dummy", nullptr, new py::module_::module_def); auto failures = py::list(); py::class_(dm, "Dupe1"); diff --git a/wrap/pybind11/tests/test_modules.py b/wrap/pybind11/tests/test_modules.py index 7e2100524..49e1ea5e3 100644 --- a/wrap/pybind11/tests/test_modules.py +++ b/wrap/pybind11/tests/test_modules.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- +from pybind11_tests import ConstructorStats from pybind11_tests import modules as m from pybind11_tests.modules import subsubmodule as ms -from pybind11_tests import ConstructorStats def test_nested_modules(): import pybind11_tests + assert pybind11_tests.__name__ == "pybind11_tests" assert pybind11_tests.modules.__name__ == "pybind11_tests.modules" - assert pybind11_tests.modules.subsubmodule.__name__ == "pybind11_tests.modules.subsubmodule" + assert ( + pybind11_tests.modules.subsubmodule.__name__ + == "pybind11_tests.modules.subsubmodule" + ) assert m.__name__ == "pybind11_tests.modules" assert ms.__name__ == "pybind11_tests.modules.subsubmodule" @@ -35,7 +39,7 @@ def test_reference_internal(): del b assert astats.alive() == 0 assert bstats.alive() == 0 - assert astats.values() == ['1', '2', '42', '43'] + assert astats.values() == ["1", "2", "42", "43"] assert bstats.values() == [] assert astats.default_constructions == 0 assert bstats.default_constructions == 1 @@ -50,18 +54,20 @@ def test_reference_internal(): def test_importing(): - from pybind11_tests.modules import OD from collections import OrderedDict + from pybind11_tests.modules import OD + assert OD is OrderedDict - assert str(OD([(1, 'a'), (2, 'b')])) == "OrderedDict([(1, 'a'), (2, 'b')])" + assert str(OD([(1, "a"), (2, "b")])) == "OrderedDict([(1, 'a'), (2, 'b')])" def test_pydoc(): """Pydoc needs to be able to provide help() for everything inside a pybind11 module""" - import pybind11_tests import pydoc + import pybind11_tests + assert pybind11_tests.__name__ == "pybind11_tests" assert pybind11_tests.__doc__ == "pybind11 test module" assert pydoc.text.docmodule(pybind11_tests) @@ -71,3 +77,16 @@ def test_duplicate_registration(): """Registering two things with the same name""" assert m.duplicate_registration() == [] + + +def test_builtin_key_type(): + """Test that all the keys in the builtin modules have type str. + + Previous versions of pybind11 would add a unicode key in python 2. + """ + if hasattr(__builtins__, "keys"): + keys = __builtins__.keys() + else: # this is to make pypy happy since builtins is different there. + keys = __builtins__.__dict__.keys() + + assert {type(k) for k in keys} == {str} diff --git a/wrap/pybind11/tests/test_multiple_inheritance.cpp b/wrap/pybind11/tests/test_multiple_inheritance.cpp index 70e341785..4689df4e4 100644 --- a/wrap/pybind11/tests/test_multiple_inheritance.cpp +++ b/wrap/pybind11/tests/test_multiple_inheritance.cpp @@ -11,10 +11,12 @@ #include "pybind11_tests.h" #include "constructor_stats.h" +namespace { + // Many bases for testing that multiple inheritance from many classes (i.e. requiring extra // space for holder constructed flags) works. template struct BaseN { - BaseN(int i) : i(i) { } + explicit BaseN(int i) : i(i) {} int i; }; @@ -43,13 +45,40 @@ int WithStatic2::static_value2 = 2; int VanillaStaticMix1::static_value = 12; int VanillaStaticMix2::static_value = 12; +// test_multiple_inheritance_virtbase +struct Base1a { + explicit Base1a(int i) : i(i) {} + int foo() const { return i; } + int i; +}; +struct Base2a { + explicit Base2a(int i) : i(i) {} + int bar() const { return i; } + int i; +}; +struct Base12a : Base1a, Base2a { + Base12a(int i, int j) : Base1a(i), Base2a(j) { } +}; + +// test_mi_unaligned_base +// test_mi_base_return +struct I801B1 { int a = 1; I801B1() = default; I801B1(const I801B1 &) = default; virtual ~I801B1() = default; }; +struct I801B2 { int b = 2; I801B2() = default; I801B2(const I801B2 &) = default; virtual ~I801B2() = default; }; +struct I801C : I801B1, I801B2 {}; +struct I801D : I801C {}; // Indirect MI + +} // namespace + TEST_SUBMODULE(multiple_inheritance, m) { + // Please do not interleave `struct` and `class` definitions with bindings code, + // but implement `struct`s and `class`es in the anonymous namespace above. + // This helps keeping the smart_holder branch in sync with master. // test_multiple_inheritance_mix1 // test_multiple_inheritance_mix2 struct Base1 { - Base1(int i) : i(i) { } - int foo() { return i; } + explicit Base1(int i) : i(i) {} + int foo() const { return i; } int i; }; py::class_ b1(m, "Base1"); @@ -57,8 +86,8 @@ TEST_SUBMODULE(multiple_inheritance, m) { .def("foo", &Base1::foo); struct Base2 { - Base2(int i) : i(i) { } - int bar() { return i; } + explicit Base2(int i) : i(i) {} + int bar() const { return i; } int i; }; py::class_ b2(m, "Base2"); @@ -79,7 +108,10 @@ TEST_SUBMODULE(multiple_inheritance, m) { // test_multiple_inheritance_python_many_bases - #define PYBIND11_BASEN(N) py::class_>(m, "BaseN" #N).def(py::init()).def("f" #N, [](BaseN &b) { return b.i + N; }) +#define PYBIND11_BASEN(N) \ + py::class_>(m, "BaseN" #N).def(py::init()).def("f" #N, [](BaseN &b) { \ + return b.i + (N); \ + }) PYBIND11_BASEN( 1); PYBIND11_BASEN( 2); PYBIND11_BASEN( 3); PYBIND11_BASEN( 4); PYBIND11_BASEN( 5); PYBIND11_BASEN( 6); PYBIND11_BASEN( 7); PYBIND11_BASEN( 8); PYBIND11_BASEN( 9); PYBIND11_BASEN(10); PYBIND11_BASEN(11); PYBIND11_BASEN(12); @@ -99,41 +131,24 @@ TEST_SUBMODULE(multiple_inheritance, m) { // test_multiple_inheritance_virtbase // Test the case where not all base classes are specified, and where pybind11 requires the // py::multiple_inheritance flag to perform proper casting between types. - struct Base1a { - Base1a(int i) : i(i) { } - int foo() { return i; } - int i; - }; py::class_>(m, "Base1a") .def(py::init()) .def("foo", &Base1a::foo); - struct Base2a { - Base2a(int i) : i(i) { } - int bar() { return i; } - int i; - }; py::class_>(m, "Base2a") .def(py::init()) .def("bar", &Base2a::bar); - struct Base12a : Base1a, Base2a { - Base12a(int i, int j) : Base1a(i), Base2a(j) { } - }; py::class_>(m, "Base12a", py::multiple_inheritance()) .def(py::init()); m.def("bar_base2a", [](Base2a *b) { return b->bar(); }); - m.def("bar_base2a_sharedptr", [](std::shared_ptr b) { return b->bar(); }); + m.def("bar_base2a_sharedptr", [](const std::shared_ptr &b) { return b->bar(); }); // test_mi_unaligned_base // test_mi_base_return // Issue #801: invalid casting to derived type with MI bases - struct I801B1 { int a = 1; I801B1() = default; I801B1(const I801B1 &) = default; virtual ~I801B1() = default; }; - struct I801B2 { int b = 2; I801B2() = default; I801B2(const I801B2 &) = default; virtual ~I801B2() = default; }; - struct I801C : I801B1, I801B2 {}; - struct I801D : I801C {}; // Indirect MI // Unregistered classes: struct I801B3 { int c = 3; virtual ~I801B3() = default; }; struct I801E : I801B3, I801D {}; @@ -193,14 +208,12 @@ TEST_SUBMODULE(multiple_inheritance, m) { .def_readwrite_static("static_value", &VanillaStaticMix2::static_value); -#if !(defined(PYPY_VERSION) && (PYPY_VERSION_NUM < 0x06000000)) struct WithDict { }; struct VanillaDictMix1 : Vanilla, WithDict { }; struct VanillaDictMix2 : WithDict, Vanilla { }; py::class_(m, "WithDict", py::dynamic_attr()).def(py::init<>()); py::class_(m, "VanillaDictMix1").def(py::init<>()); py::class_(m, "VanillaDictMix2").def(py::init<>()); -#endif // test_diamond_inheritance // Issue #959: segfault when constructing diamond inheritance instance @@ -217,4 +230,87 @@ TEST_SUBMODULE(multiple_inheritance, m) { .def("c1", [](C1 *self) { return self; }); py::class_(m, "D") .def(py::init<>()); + + // test_pr3635_diamond_* + // - functions are get_{base}_{var}, return {var} + struct MVB { + MVB() = default; + MVB(const MVB &) = default; + virtual ~MVB() = default; + + int b = 1; + int get_b_b() const { return b; } + }; + struct MVC : virtual MVB { + int c = 2; + int get_c_b() const { return b; } + int get_c_c() const { return c; } + }; + struct MVD0 : virtual MVC { + int d0 = 3; + int get_d0_b() const { return b; } + int get_d0_c() const { return c; } + int get_d0_d0() const { return d0; } + }; + struct MVD1 : virtual MVC { + int d1 = 4; + int get_d1_b() const { return b; } + int get_d1_c() const { return c; } + int get_d1_d1() const { return d1; } + }; + struct MVE : virtual MVD0, virtual MVD1 { + int e = 5; + int get_e_b() const { return b; } + int get_e_c() const { return c; } + int get_e_d0() const { return d0; } + int get_e_d1() const { return d1; } + int get_e_e() const { return e; } + }; + struct MVF : virtual MVE { + int f = 6; + int get_f_b() const { return b; } + int get_f_c() const { return c; } + int get_f_d0() const { return d0; } + int get_f_d1() const { return d1; } + int get_f_e() const { return e; } + int get_f_f() const { return f; } + }; + py::class_(m, "MVB") + .def(py::init<>()) + .def("get_b_b", &MVB::get_b_b) + .def_readwrite("b", &MVB::b); + py::class_(m, "MVC") + .def(py::init<>()) + .def("get_c_b", &MVC::get_c_b) + .def("get_c_c", &MVC::get_c_c) + .def_readwrite("c", &MVC::c); + py::class_(m, "MVD0") + .def(py::init<>()) + .def("get_d0_b", &MVD0::get_d0_b) + .def("get_d0_c", &MVD0::get_d0_c) + .def("get_d0_d0", &MVD0::get_d0_d0) + .def_readwrite("d0", &MVD0::d0); + py::class_(m, "MVD1") + .def(py::init<>()) + .def("get_d1_b", &MVD1::get_d1_b) + .def("get_d1_c", &MVD1::get_d1_c) + .def("get_d1_d1", &MVD1::get_d1_d1) + .def_readwrite("d1", &MVD1::d1); + py::class_(m, "MVE") + .def(py::init<>()) + .def("get_e_b", &MVE::get_e_b) + .def("get_e_c", &MVE::get_e_c) + .def("get_e_d0", &MVE::get_e_d0) + .def("get_e_d1", &MVE::get_e_d1) + .def("get_e_e", &MVE::get_e_e) + .def_readwrite("e", &MVE::e); + py::class_(m, "MVF") + .def(py::init<>()) + .def("get_f_b", &MVF::get_f_b) + .def("get_f_c", &MVF::get_f_c) + .def("get_f_d0", &MVF::get_f_d0) + .def("get_f_d1", &MVF::get_f_d1) + .def("get_f_e", &MVF::get_f_e) + .def("get_f_f", &MVF::get_f_f) + .def_readwrite("f", &MVF::f); } diff --git a/wrap/pybind11/tests/test_multiple_inheritance.py b/wrap/pybind11/tests/test_multiple_inheritance.py index 7a0259d21..abdf25d60 100644 --- a/wrap/pybind11/tests/test_multiple_inheritance.py +++ b/wrap/pybind11/tests/test_multiple_inheritance.py @@ -2,7 +2,6 @@ import pytest import env # noqa: F401 - from pybind11_tests import ConstructorStats from pybind11_tests import multiple_inheritance as m @@ -57,7 +56,6 @@ def test_multiple_inheritance_mix2(): @pytest.mark.skipif("env.PYPY and env.PY2") @pytest.mark.xfail("env.PYPY and not env.PY2") def test_multiple_inheritance_python(): - class MI1(m.Base1, m.Base2): def __init__(self, i, j): m.Base1.__init__(self, i) @@ -163,7 +161,6 @@ def test_multiple_inheritance_python(): def test_multiple_inheritance_python_many_bases(): - class MIMany14(m.BaseN1, m.BaseN2, m.BaseN3, m.BaseN4): def __init__(self): m.BaseN1.__init__(self, 1) @@ -178,8 +175,16 @@ def test_multiple_inheritance_python_many_bases(): m.BaseN7.__init__(self, 7) m.BaseN8.__init__(self, 8) - class MIMany916(m.BaseN9, m.BaseN10, m.BaseN11, m.BaseN12, m.BaseN13, m.BaseN14, m.BaseN15, - m.BaseN16): + class MIMany916( + m.BaseN9, + m.BaseN10, + m.BaseN11, + m.BaseN12, + m.BaseN13, + m.BaseN14, + m.BaseN15, + m.BaseN16, + ): def __init__(self): m.BaseN9.__init__(self, 9) m.BaseN10.__init__(self, 10) @@ -225,7 +230,6 @@ def test_multiple_inheritance_python_many_bases(): def test_multiple_inheritance_virtbase(): - class MITypePy(m.Base12a): def __init__(self, i, j): m.Base12a.__init__(self, i, j) @@ -238,7 +242,7 @@ def test_multiple_inheritance_virtbase(): def test_mi_static_properties(): """Mixing bases with and without static properties should be possible - and the result should be independent of base definition order""" + and the result should be independent of base definition order""" for d in (m.VanillaStaticMix1(), m.VanillaStaticMix2()): assert d.vanilla() == "Vanilla" @@ -354,3 +358,139 @@ def test_diamond_inheritance(): assert d is d.c0().b() assert d is d.c1().b() assert d is d.c0().c1().b().c0().b() + + +def test_pr3635_diamond_b(): + o = m.MVB() + assert o.b == 1 + + assert o.get_b_b() == 1 + + +def test_pr3635_diamond_c(): + o = m.MVC() + assert o.b == 1 + assert o.c == 2 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + + assert o.get_c_c() == 2 + + +def test_pr3635_diamond_d0(): + o = m.MVD0() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + + assert o.get_d0_d0() == 3 + + +def test_pr3635_diamond_d1(): + o = m.MVD1() + assert o.b == 1 + assert o.c == 2 + assert o.d1 == 4 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d1_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d1_c() == 2 + + assert o.get_d1_d1() == 4 + + +def test_pr3635_diamond_e(): + o = m.MVE() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + assert o.d1 == 4 + assert o.e == 5 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + assert o.get_d1_b() == 1 + assert o.get_e_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + assert o.get_d1_c() == 2 + assert o.get_e_c() == 2 + + assert o.get_d0_d0() == 3 + assert o.get_e_d0() == 3 + + assert o.get_d1_d1() == 4 + assert o.get_e_d1() == 4 + + assert o.get_e_e() == 5 + + +def test_pr3635_diamond_f(): + o = m.MVF() + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + assert o.d1 == 4 + assert o.e == 5 + assert o.f == 6 + + assert o.get_b_b() == 1 + assert o.get_c_b() == 1 + assert o.get_d0_b() == 1 + assert o.get_d1_b() == 1 + assert o.get_e_b() == 1 + assert o.get_f_b() == 1 + + assert o.get_c_c() == 2 + assert o.get_d0_c() == 2 + assert o.get_d1_c() == 2 + assert o.get_e_c() == 2 + assert o.get_f_c() == 2 + + assert o.get_d0_d0() == 3 + assert o.get_e_d0() == 3 + assert o.get_f_d0() == 3 + + assert o.get_d1_d1() == 4 + assert o.get_e_d1() == 4 + assert o.get_f_d1() == 4 + + assert o.get_e_e() == 5 + assert o.get_f_e() == 5 + + assert o.get_f_f() == 6 + + +def test_python_inherit_from_mi(): + """Tests extending a Python class from a single inheritor of a MI class""" + + class PyMVF(m.MVF): + g = 7 + + def get_g_g(self): + return self.g + + o = PyMVF() + + assert o.b == 1 + assert o.c == 2 + assert o.d0 == 3 + assert o.d1 == 4 + assert o.e == 5 + assert o.f == 6 + assert o.g == 7 + + assert o.get_g_g() == 7 diff --git a/wrap/pybind11/tests/test_numpy_array.cpp b/wrap/pybind11/tests/test_numpy_array.cpp index 33f1d7857..30a71acc9 100644 --- a/wrap/pybind11/tests/test_numpy_array.cpp +++ b/wrap/pybind11/tests/test_numpy_array.cpp @@ -13,6 +13,7 @@ #include #include +#include // Size / dtype checks. struct DtypeCheck { @@ -22,7 +23,7 @@ struct DtypeCheck { template DtypeCheck get_dtype_check(const char* name) { - py::module np = py::module::import("numpy"); + py::module_ np = py::module_::import("numpy"); DtypeCheck check{}; check.numpy = np.attr("dtype")(np.attr(name)); check.pybind11 = py::dtype::of(); @@ -89,23 +90,23 @@ template arr data_t(const arr_t& a, Ix... index) { template arr& mutate_data(arr& a, Ix... index) { auto ptr = (uint8_t *) a.mutable_data(index...); - for (ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) + for (py::ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) ptr[i] = (uint8_t) (ptr[i] * 2); return a; } template arr_t& mutate_data_t(arr_t& a, Ix... index) { auto ptr = a.mutable_data(index...); - for (ssize_t i = 0; i < a.size() - a.index_at(index...); i++) + for (py::ssize_t i = 0; i < a.size() - a.index_at(index...); i++) ptr[i]++; return a; } -template ssize_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); } -template ssize_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); } -template ssize_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); } -template ssize_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); } -template ssize_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); } +template py::ssize_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); } +template py::ssize_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); } +template py::ssize_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); } +template py::ssize_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); } +template py::ssize_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); } template arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; } #define def_index_fn(name, type) \ @@ -133,7 +134,7 @@ template py::handle auxiliaries(T &&r, T2 &&r2) { static int data_i = 42; TEST_SUBMODULE(numpy_array, sm) { - try { py::module::import("numpy"); } + try { py::module_::import("numpy"); } catch (...) { return; } // test_dtypes @@ -159,9 +160,9 @@ TEST_SUBMODULE(numpy_array, sm) { // test_array_attributes sm.def("ndim", [](const arr& a) { return a.ndim(); }); sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); }); - sm.def("shape", [](const arr& a, ssize_t dim) { return a.shape(dim); }); + sm.def("shape", [](const arr& a, py::ssize_t dim) { return a.shape(dim); }); sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); }); - sm.def("strides", [](const arr& a, ssize_t dim) { return a.strides(dim); }); + sm.def("strides", [](const arr& a, py::ssize_t dim) { return a.strides(dim); }); sm.def("writeable", [](const arr& a) { return a.writeable(); }); sm.def("size", [](const arr& a) { return a.size(); }); sm.def("itemsize", [](const arr& a) { return a.itemsize(); }); @@ -192,7 +193,7 @@ TEST_SUBMODULE(numpy_array, sm) { sm.def("scalar_int", []() { return py::array(py::dtype("i"), {}, {}, &data_i); }); // test_wrap - sm.def("wrap", [](py::array a) { + sm.def("wrap", [](const py::array &a) { return py::array( a.dtype(), {a.shape(), a.shape() + a.ndim()}, @@ -222,9 +223,10 @@ TEST_SUBMODULE(numpy_array, sm) { // test_isinstance sm.def("isinstance_untyped", [](py::object yes, py::object no) { - return py::isinstance(yes) && !py::isinstance(no); + return py::isinstance(std::move(yes)) + && !py::isinstance(std::move(no)); }); - sm.def("isinstance_typed", [](py::object o) { + sm.def("isinstance_typed", [](const py::object &o) { return py::isinstance>(o) && !py::isinstance>(o); }); @@ -236,7 +238,7 @@ TEST_SUBMODULE(numpy_array, sm) { "array_t"_a=py::array_t() ); }); - sm.def("converting_constructors", [](py::object o) { + sm.def("converting_constructors", [](const py::object &o) { return py::dict( "array"_a=py::array(o), "array_t"_a=py::array_t(o), @@ -245,69 +247,78 @@ TEST_SUBMODULE(numpy_array, sm) { }); // test_overload_resolution - sm.def("overloaded", [](py::array_t) { return "double"; }); - sm.def("overloaded", [](py::array_t) { return "float"; }); - sm.def("overloaded", [](py::array_t) { return "int"; }); - sm.def("overloaded", [](py::array_t) { return "unsigned short"; }); - sm.def("overloaded", [](py::array_t) { return "long long"; }); - sm.def("overloaded", [](py::array_t>) { return "double complex"; }); - sm.def("overloaded", [](py::array_t>) { return "float complex"; }); + sm.def("overloaded", [](const py::array_t &) { return "double"; }); + sm.def("overloaded", [](const py::array_t &) { return "float"; }); + sm.def("overloaded", [](const py::array_t &) { return "int"; }); + sm.def("overloaded", [](const py::array_t &) { return "unsigned short"; }); + sm.def("overloaded", [](const py::array_t &) { return "long long"; }); + sm.def("overloaded", + [](const py::array_t> &) { return "double complex"; }); + sm.def("overloaded", [](const py::array_t> &) { return "float complex"; }); - sm.def("overloaded2", [](py::array_t>) { return "double complex"; }); - sm.def("overloaded2", [](py::array_t) { return "double"; }); - sm.def("overloaded2", [](py::array_t>) { return "float complex"; }); - sm.def("overloaded2", [](py::array_t) { return "float"; }); + sm.def("overloaded2", + [](const py::array_t> &) { return "double complex"; }); + sm.def("overloaded2", [](const py::array_t &) { return "double"; }); + sm.def("overloaded2", + [](const py::array_t> &) { return "float complex"; }); + sm.def("overloaded2", [](const py::array_t &) { return "float"; }); + + // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works. // Only accept the exact types: - sm.def("overloaded3", [](py::array_t) { return "int"; }, py::arg().noconvert()); - sm.def("overloaded3", [](py::array_t) { return "double"; }, py::arg().noconvert()); + sm.def( + "overloaded3", [](const py::array_t &) { return "int"; }, py::arg{}.noconvert()); + sm.def( + "overloaded3", + [](const py::array_t &) { return "double"; }, + py::arg{}.noconvert()); // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but // rather that float gets converted via the safe (conversion to double) overload: - sm.def("overloaded4", [](py::array_t) { return "long long"; }); - sm.def("overloaded4", [](py::array_t) { return "double"; }); + sm.def("overloaded4", [](const py::array_t &) { return "long long"; }); + sm.def("overloaded4", [](const py::array_t &) { return "double"; }); // But we do allow conversion to int if forcecast is enabled (but only if no overload matches // without conversion) - sm.def("overloaded5", [](py::array_t) { return "unsigned int"; }); - sm.def("overloaded5", [](py::array_t) { return "double"; }); + sm.def("overloaded5", [](const py::array_t &) { return "unsigned int"; }); + sm.def("overloaded5", [](const py::array_t &) { return "double"; }); // test_greedy_string_overload // Issue 685: ndarray shouldn't go to std::string overload - sm.def("issue685", [](std::string) { return "string"; }); - sm.def("issue685", [](py::array) { return "array"; }); - sm.def("issue685", [](py::object) { return "other"; }); + sm.def("issue685", [](const std::string &) { return "string"; }); + sm.def("issue685", [](const py::array &) { return "array"; }); + sm.def("issue685", [](const py::object &) { return "other"; }); // test_array_unchecked_fixed_dims sm.def("proxy_add2", [](py::array_t a, double v) { auto r = a.mutable_unchecked<2>(); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) r(i, j) += v; - }, py::arg().noconvert(), py::arg()); + }, py::arg{}.noconvert(), py::arg()); sm.def("proxy_init3", [](double start) { py::array_t a({ 3, 3, 3 }); auto r = a.mutable_unchecked<3>(); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t k = 0; k < r.shape(2); k++) r(i, j, k) = start++; return a; }); sm.def("proxy_init3F", [](double start) { py::array_t a({ 3, 3, 3 }); auto r = a.mutable_unchecked<3>(); - for (ssize_t k = 0; k < r.shape(2); k++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t k = 0; k < r.shape(2); k++) + for (py::ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t i = 0; i < r.shape(0); i++) r(i, j, k) = start++; return a; }); - sm.def("proxy_squared_L2_norm", [](py::array_t a) { + sm.def("proxy_squared_L2_norm", [](const py::array_t &a) { auto r = a.unchecked<1>(); double sumsq = 0; - for (ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t i = 0; i < r.shape(0); i++) sumsq += r[i] * r(i); // Either notation works for a 1D array return sumsq; }); @@ -318,22 +329,34 @@ TEST_SUBMODULE(numpy_array, sm) { return auxiliaries(r, r2); }); + sm.def("proxy_auxiliaries1_const_ref", [](py::array_t a) { + const auto &r = a.unchecked<1>(); + const auto &r2 = a.mutable_unchecked<1>(); + return r(0) == r2(0) && r[0] == r2[0]; + }); + + sm.def("proxy_auxiliaries2_const_ref", [](py::array_t a) { + const auto &r = a.unchecked<2>(); + const auto &r2 = a.mutable_unchecked<2>(); + return r(0, 0) == r2(0, 0); + }); + // test_array_unchecked_dyn_dims // Same as the above, but without a compile-time dimensions specification: sm.def("proxy_add2_dyn", [](py::array_t a, double v) { auto r = a.mutable_unchecked(); if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) r(i, j) += v; - }, py::arg().noconvert(), py::arg()); + }, py::arg{}.noconvert(), py::arg()); sm.def("proxy_init3_dyn", [](double start) { py::array_t a({ 3, 3, 3 }); auto r = a.mutable_unchecked(); if (r.ndim() != 3) throw std::domain_error("error: ndim != 3"); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) + for (py::ssize_t i = 0; i < r.shape(0); i++) + for (py::ssize_t j = 0; j < r.shape(1); j++) + for (py::ssize_t k = 0; k < r.shape(2); k++) r(i, j, k) = start++; return a; }); @@ -362,7 +385,7 @@ TEST_SUBMODULE(numpy_array, sm) { // test_array_resize // reshape array to 2D without changing size sm.def("array_reshape2", [](py::array_t a) { - const auto dim_sz = (ssize_t)std::sqrt(a.size()); + const auto dim_sz = (py::ssize_t)std::sqrt(a.size()); if (dim_sz * dim_sz != a.size()) throw std::domain_error("array_reshape2: input array total size is not a squared integer"); a.resize({dim_sz, dim_sz}); @@ -382,45 +405,68 @@ TEST_SUBMODULE(numpy_array, sm) { return a; }); - sm.def("index_using_ellipsis", [](py::array a) { - return a[py::make_tuple(0, py::ellipsis(), 0)]; + sm.def("array_view", + [](py::array_t a, const std::string &dtype) { return a.view(dtype); }); + + sm.def("reshape_initializer_list", [](py::array_t a, size_t N, size_t M, size_t O) { + return a.reshape({N, M, O}); + }); + sm.def("reshape_tuple", [](py::array_t a, const std::vector &new_shape) { + return a.reshape(new_shape); }); + sm.def("index_using_ellipsis", + [](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; }); + // test_argument_conversions - sm.def("accept_double", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_forcecast", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_c_style", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_c_style_forcecast", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_f_style", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_f_style_forcecast", - [](py::array_t) {}, - py::arg("a")); - sm.def("accept_double_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); - sm.def("accept_double_forcecast_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); - sm.def("accept_double_c_style_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); - sm.def("accept_double_c_style_forcecast_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); - sm.def("accept_double_f_style_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); - sm.def("accept_double_f_style_forcecast_noconvert", - [](py::array_t) {}, - py::arg("a").noconvert()); + sm.def( + "accept_double", [](const py::array_t &) {}, py::arg("a")); + sm.def( + "accept_double_forcecast", + [](const py::array_t &) {}, + py::arg("a")); + sm.def( + "accept_double_c_style", + [](const py::array_t &) {}, + py::arg("a")); + sm.def( + "accept_double_c_style_forcecast", + [](const py::array_t &) {}, + py::arg("a")); + sm.def( + "accept_double_f_style", + [](const py::array_t &) {}, + py::arg("a")); + sm.def( + "accept_double_f_style_forcecast", + [](const py::array_t &) {}, + py::arg("a")); + sm.def( + "accept_double_noconvert", [](const py::array_t &) {}, "a"_a.noconvert()); + sm.def( + "accept_double_forcecast_noconvert", + [](const py::array_t &) {}, + "a"_a.noconvert()); + sm.def( + "accept_double_c_style_noconvert", + [](const py::array_t &) {}, + "a"_a.noconvert()); + sm.def( + "accept_double_c_style_forcecast_noconvert", + [](const py::array_t &) {}, + "a"_a.noconvert()); + sm.def( + "accept_double_f_style_noconvert", + [](const py::array_t &) {}, + "a"_a.noconvert()); + sm.def( + "accept_double_f_style_forcecast_noconvert", + [](const py::array_t &) {}, + "a"_a.noconvert()); + + // Check that types returns correct npy format descriptor + sm.def("test_fmt_desc_float", [](const py::array_t &) {}); + sm.def("test_fmt_desc_double", [](const py::array_t &) {}); + sm.def("test_fmt_desc_const_float", [](const py::array_t &) {}); + sm.def("test_fmt_desc_const_double", [](const py::array_t &) {}); } diff --git a/wrap/pybind11/tests/test_numpy_array.py b/wrap/pybind11/tests/test_numpy_array.py index a36e707c1..e4138f023 100644 --- a/wrap/pybind11/tests/test_numpy_array.py +++ b/wrap/pybind11/tests/test_numpy_array.py @@ -2,7 +2,6 @@ import pytest import env # noqa: F401 - from pybind11_tests import numpy_array as m np = pytest.importorskip("numpy") @@ -19,33 +18,36 @@ def test_dtypes(): print(check) assert check.numpy == check.pybind11, check if check.numpy.num != check.pybind11.num: - print("NOTE: typenum mismatch for {}: {} != {}".format( - check, check.numpy.num, check.pybind11.num)) + print( + "NOTE: typenum mismatch for {}: {} != {}".format( + check, check.numpy.num, check.pybind11.num + ) + ) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def arr(): - return np.array([[1, 2, 3], [4, 5, 6]], '=u2') + return np.array([[1, 2, 3], [4, 5, 6]], "=u2") def test_array_attributes(): - a = np.array(0, 'f8') + a = np.array(0, "f8") assert m.ndim(a) == 0 assert all(m.shape(a) == []) assert all(m.strides(a) == []) with pytest.raises(IndexError) as excinfo: m.shape(a, 0) - assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' + assert str(excinfo.value) == "invalid axis: 0 (ndim = 0)" with pytest.raises(IndexError) as excinfo: m.strides(a, 0) - assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' + assert str(excinfo.value) == "invalid axis: 0 (ndim = 0)" assert m.writeable(a) assert m.size(a) == 1 assert m.itemsize(a) == 8 assert m.nbytes(a) == 8 assert m.owndata(a) - a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() + a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view() a.flags.writeable = False assert m.ndim(a) == 2 assert all(m.shape(a) == [2, 3]) @@ -56,10 +58,10 @@ def test_array_attributes(): assert m.strides(a, 1) == 2 with pytest.raises(IndexError) as excinfo: m.shape(a, 2) - assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' + assert str(excinfo.value) == "invalid axis: 2 (ndim = 2)" with pytest.raises(IndexError) as excinfo: m.strides(a, 2) - assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' + assert str(excinfo.value) == "invalid axis: 2 (ndim = 2)" assert not m.writeable(a) assert m.size(a) == 6 assert m.itemsize(a) == 2 @@ -67,7 +69,9 @@ def test_array_attributes(): assert not m.owndata(a) -@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]) +@pytest.mark.parametrize( + "args, ret", [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)] +) def test_index_offset(arr, args, ret): assert m.index_at(arr, *args) == ret assert m.index_at_t(arr, *args) == ret @@ -76,31 +80,46 @@ def test_index_offset(arr, args, ret): def test_dim_check_fail(arr): - for func in (m.index_at, m.index_at_t, m.offset_at, m.offset_at_t, m.data, m.data_t, - m.mutate_data, m.mutate_data_t): + for func in ( + m.index_at, + m.index_at_t, + m.offset_at, + m.offset_at_t, + m.data, + m.data_t, + m.mutate_data, + m.mutate_data_t, + ): with pytest.raises(IndexError) as excinfo: func(arr, 1, 2, 3) - assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)' + assert str(excinfo.value) == "too many indices for an array: 3 (ndim = 2)" -@pytest.mark.parametrize('args, ret', - [([], [1, 2, 3, 4, 5, 6]), - ([1], [4, 5, 6]), - ([0, 1], [2, 3, 4, 5, 6]), - ([1, 2], [6])]) +@pytest.mark.parametrize( + "args, ret", + [ + ([], [1, 2, 3, 4, 5, 6]), + ([1], [4, 5, 6]), + ([0, 1], [2, 3, 4, 5, 6]), + ([1, 2], [6]), + ], +) def test_data(arr, args, ret): from sys import byteorder + assert all(m.data_t(arr, *args) == ret) - assert all(m.data(arr, *args)[(0 if byteorder == 'little' else 1)::2] == ret) - assert all(m.data(arr, *args)[(1 if byteorder == 'little' else 0)::2] == 0) + assert all(m.data(arr, *args)[(0 if byteorder == "little" else 1) :: 2] == ret) + assert all(m.data(arr, *args)[(1 if byteorder == "little" else 0) :: 2] == 0) -@pytest.mark.parametrize('dim', [0, 1, 3]) +@pytest.mark.parametrize("dim", [0, 1, 3]) def test_at_fail(arr, dim): for func in m.at_t, m.mutate_at_t: with pytest.raises(IndexError) as excinfo: func(arr, *([0] * dim)) - assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim) + assert str(excinfo.value) == "index dimension mismatch: {} (ndim = 2)".format( + dim + ) def test_at(arr): @@ -113,10 +132,14 @@ def test_at(arr): def test_mutate_readonly(arr): arr.flags.writeable = False - for func, args in (m.mutate_data, ()), (m.mutate_data_t, ()), (m.mutate_at_t, (0, 0)): + for func, args in ( + (m.mutate_data, ()), + (m.mutate_data_t, ()), + (m.mutate_at_t, (0, 0)), + ): with pytest.raises(ValueError) as excinfo: func(arr, *args) - assert str(excinfo.value) == 'array is not writeable' + assert str(excinfo.value) == "array is not writeable" def test_mutate_data(arr): @@ -134,14 +157,22 @@ def test_mutate_data(arr): def test_bounds_check(arr): - for func in (m.index_at, m.index_at_t, m.data, m.data_t, - m.mutate_data, m.mutate_data_t, m.at_t, m.mutate_at_t): + for func in ( + m.index_at, + m.index_at_t, + m.data, + m.data_t, + m.mutate_data, + m.mutate_data_t, + m.at_t, + m.mutate_at_t, + ): with pytest.raises(IndexError) as excinfo: func(arr, 2, 0) - assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2' + assert str(excinfo.value) == "index 2 is out of bounds for axis 0 with size 2" with pytest.raises(IndexError) as excinfo: func(arr, 0, 4) - assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' + assert str(excinfo.value) == "index 4 is out of bounds for axis 1 with size 3" def test_make_c_f_array(): @@ -163,10 +194,11 @@ def test_make_empty_shaped_array(): def test_wrap(): def assert_references(a, b, base=None): from distutils.version import LooseVersion + if base is None: base = a assert a is not b - assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] + assert a.__array_interface__["data"][0] == b.__array_interface__["data"][0] assert a.shape == b.shape assert a.strides == b.strides assert a.flags.c_contiguous == b.flags.c_contiguous @@ -189,12 +221,12 @@ def test_wrap(): a2 = m.wrap(a1) assert_references(a1, a2) - a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') + a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order="F") assert a1.flags.owndata and a1.base is None a2 = m.wrap(a1) assert_references(a1, a2) - a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') + a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order="C") a1.flags.writeable = False a2 = m.wrap(a1) assert_references(a1, a2) @@ -224,11 +256,14 @@ def test_numpy_view(capture): assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) del ac pytest.gc_collect() - assert capture == """ + assert ( + capture + == """ ArrayClass() ArrayClass::numpy_view() ArrayClass::numpy_view() """ + ) ac_view_1[0] = 4 ac_view_1[1] = 3 assert ac_view_2[0] == 4 @@ -238,9 +273,12 @@ def test_numpy_view(capture): del ac_view_2 pytest.gc_collect() pytest.gc_collect() - assert capture == """ + assert ( + capture + == """ ~ArrayClass() """ + ) def test_cast_numpy_int64_to_uint64(): @@ -271,20 +309,22 @@ def test_constructors(): def test_overload_resolution(msg): # Exact overload matches: - assert m.overloaded(np.array([1], dtype='float64')) == 'double' - assert m.overloaded(np.array([1], dtype='float32')) == 'float' - assert m.overloaded(np.array([1], dtype='ushort')) == 'unsigned short' - assert m.overloaded(np.array([1], dtype='intc')) == 'int' - assert m.overloaded(np.array([1], dtype='longlong')) == 'long long' - assert m.overloaded(np.array([1], dtype='complex')) == 'double complex' - assert m.overloaded(np.array([1], dtype='csingle')) == 'float complex' + assert m.overloaded(np.array([1], dtype="float64")) == "double" + assert m.overloaded(np.array([1], dtype="float32")) == "float" + assert m.overloaded(np.array([1], dtype="ushort")) == "unsigned short" + assert m.overloaded(np.array([1], dtype="intc")) == "int" + assert m.overloaded(np.array([1], dtype="longlong")) == "long long" + assert m.overloaded(np.array([1], dtype="complex")) == "double complex" + assert m.overloaded(np.array([1], dtype="csingle")) == "float complex" # No exact match, should call first convertible version: - assert m.overloaded(np.array([1], dtype='uint8')) == 'double' + assert m.overloaded(np.array([1], dtype="uint8")) == "double" with pytest.raises(TypeError) as excinfo: m.overloaded("not an array") - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ overloaded(): incompatible function arguments. The following argument types are supported: 1. (arg0: numpy.ndarray[numpy.float64]) -> str 2. (arg0: numpy.ndarray[numpy.float32]) -> str @@ -296,15 +336,16 @@ def test_overload_resolution(msg): Invoked with: 'not an array' """ + ) - assert m.overloaded2(np.array([1], dtype='float64')) == 'double' - assert m.overloaded2(np.array([1], dtype='float32')) == 'float' - assert m.overloaded2(np.array([1], dtype='complex64')) == 'float complex' - assert m.overloaded2(np.array([1], dtype='complex128')) == 'double complex' - assert m.overloaded2(np.array([1], dtype='float32')) == 'float' + assert m.overloaded2(np.array([1], dtype="float64")) == "double" + assert m.overloaded2(np.array([1], dtype="float32")) == "float" + assert m.overloaded2(np.array([1], dtype="complex64")) == "float complex" + assert m.overloaded2(np.array([1], dtype="complex128")) == "double complex" + assert m.overloaded2(np.array([1], dtype="float32")) == "float" - assert m.overloaded3(np.array([1], dtype='float64')) == 'double' - assert m.overloaded3(np.array([1], dtype='intc')) == 'int' + assert m.overloaded3(np.array([1], dtype="float64")) == "double" + assert m.overloaded3(np.array([1], dtype="intc")) == "int" expected_exc = """ overloaded3(): incompatible function arguments. The following argument types are supported: 1. (arg0: numpy.ndarray[numpy.int32]) -> str @@ -313,47 +354,49 @@ def test_overload_resolution(msg): Invoked with: """ with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='uintc')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1], dtype='uint32')) + m.overloaded3(np.array([1], dtype="uintc")) + assert msg(excinfo.value) == expected_exc + repr(np.array([1], dtype="uint32")) with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='float32')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1.], dtype='float32')) + m.overloaded3(np.array([1], dtype="float32")) + assert msg(excinfo.value) == expected_exc + repr(np.array([1.0], dtype="float32")) with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='complex')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1. + 0.j])) + m.overloaded3(np.array([1], dtype="complex")) + assert msg(excinfo.value) == expected_exc + repr(np.array([1.0 + 0.0j])) # Exact matches: - assert m.overloaded4(np.array([1], dtype='double')) == 'double' - assert m.overloaded4(np.array([1], dtype='longlong')) == 'long long' + assert m.overloaded4(np.array([1], dtype="double")) == "double" + assert m.overloaded4(np.array([1], dtype="longlong")) == "long long" # Non-exact matches requiring conversion. Since float to integer isn't a # save conversion, it should go to the double overload, but short can go to # either (and so should end up on the first-registered, the long long). - assert m.overloaded4(np.array([1], dtype='float32')) == 'double' - assert m.overloaded4(np.array([1], dtype='short')) == 'long long' + assert m.overloaded4(np.array([1], dtype="float32")) == "double" + assert m.overloaded4(np.array([1], dtype="short")) == "long long" - assert m.overloaded5(np.array([1], dtype='double')) == 'double' - assert m.overloaded5(np.array([1], dtype='uintc')) == 'unsigned int' - assert m.overloaded5(np.array([1], dtype='float32')) == 'unsigned int' + assert m.overloaded5(np.array([1], dtype="double")) == "double" + assert m.overloaded5(np.array([1], dtype="uintc")) == "unsigned int" + assert m.overloaded5(np.array([1], dtype="float32")) == "unsigned int" def test_greedy_string_overload(): """Tests fix for #685 - ndarray shouldn't go to std::string overload""" assert m.issue685("abc") == "string" - assert m.issue685(np.array([97, 98, 99], dtype='b')) == "array" + assert m.issue685(np.array([97, 98, 99], dtype="b")) == "array" assert m.issue685(123) == "other" def test_array_unchecked_fixed_dims(msg): - z1 = np.array([[1, 2], [3, 4]], dtype='float64') + z1 = np.array([[1, 2], [3, 4]], dtype="float64") m.proxy_add2(z1, 10) assert np.all(z1 == [[11, 12], [13, 14]]) with pytest.raises(ValueError) as excinfo: - m.proxy_add2(np.array([1., 2, 3]), 5.0) - assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2" + m.proxy_add2(np.array([1.0, 2, 3]), 5.0) + assert ( + msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2" + ) - expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') + expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype="int") assert np.all(m.proxy_init3(3.0) == expect_c) expect_f = np.transpose(expect_c) assert np.all(m.proxy_init3F(3.0) == expect_f) @@ -364,13 +407,16 @@ def test_array_unchecked_fixed_dims(msg): assert m.proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] assert m.proxy_auxiliaries2(z1) == m.array_auxiliaries2(z1) + assert m.proxy_auxiliaries1_const_ref(z1[0, :]) + assert m.proxy_auxiliaries2_const_ref(z1) -def test_array_unchecked_dyn_dims(msg): - z1 = np.array([[1, 2], [3, 4]], dtype='float64') + +def test_array_unchecked_dyn_dims(): + z1 = np.array([[1, 2], [3, 4]], dtype="float64") m.proxy_add2_dyn(z1, 10) assert np.all(z1 == [[11, 12], [13, 14]]) - expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') + expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype="int") assert np.all(m.proxy_init3_dyn(3.0) == expect_c) assert m.proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] @@ -380,15 +426,15 @@ def test_array_unchecked_dyn_dims(msg): def test_array_failure(): with pytest.raises(ValueError) as excinfo: m.array_fail_test() - assert str(excinfo.value) == 'cannot create a pybind11::array from a nullptr' + assert str(excinfo.value) == "cannot create a pybind11::array from a nullptr" with pytest.raises(ValueError) as excinfo: m.array_t_fail_test() - assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr' + assert str(excinfo.value) == "cannot create a pybind11::array_t from a nullptr" with pytest.raises(ValueError) as excinfo: m.array_fail_test_negative_size() - assert str(excinfo.value) == 'negative dimensions are not allowed' + assert str(excinfo.value) == "negative dimensions are not allowed" def test_initializer_list(): @@ -398,36 +444,76 @@ def test_initializer_list(): assert m.array_initializer_list4().shape == (1, 2, 3, 4) -def test_array_resize(msg): - a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64') +def test_array_resize(): + a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64") m.array_reshape2(a) - assert(a.size == 9) - assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + assert a.size == 9 + assert np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # total size change should succced with refcheck off m.array_resize3(a, 4, False) - assert(a.size == 64) + assert a.size == 64 # ... and fail with refcheck on try: m.array_resize3(a, 3, True) except ValueError as e: - assert(str(e).startswith("cannot resize an array")) + assert str(e).startswith("cannot resize an array") # transposed array doesn't own data b = a.transpose() try: m.array_resize3(b, 3, False) except ValueError as e: - assert(str(e).startswith("cannot resize this array: it does not own its data")) + assert str(e).startswith("cannot resize this array: it does not own its data") # ... but reshape should be fine m.array_reshape2(b) - assert(b.shape == (8, 8)) + assert b.shape == (8, 8) @pytest.mark.xfail("env.PYPY") -def test_array_create_and_resize(msg): +def test_array_create_and_resize(): a = m.create_and_resize(2) - assert(a.size == 4) - assert(np.all(a == 42.)) + assert a.size == 4 + assert np.all(a == 42.0) + + +def test_array_view(): + a = np.ones(100 * 4).astype("uint8") + a_float_view = m.array_view(a, "float32") + assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32 + + a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32 + assert a_int16_view.shape == (100 * 2,) + + +def test_array_view_invalid(): + a = np.ones(100 * 4).astype("uint8") + with pytest.raises(TypeError): + m.array_view(a, "deadly_dtype") + + +def test_reshape_initializer_list(): + a = np.arange(2 * 7 * 3) + 1 + x = m.reshape_initializer_list(a, 2, 7, 3) + assert x.shape == (2, 7, 3) + assert list(x[1][4]) == [34, 35, 36] + with pytest.raises(ValueError) as excinfo: + m.reshape_initializer_list(a, 1, 7, 3) + assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)" + + +def test_reshape_tuple(): + a = np.arange(3 * 7 * 2) + 1 + x = m.reshape_tuple(a, (3, 7, 2)) + assert x.shape == (3, 7, 2) + assert list(x[1][4]) == [23, 24] + y = m.reshape_tuple(x, (x.size,)) + assert y.shape == (42,) + with pytest.raises(ValueError) as excinfo: + m.reshape_tuple(a, (3, 7, 1)) + assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)" + with pytest.raises(ValueError) as excinfo: + m.reshape_tuple(a, ()) + assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()" def test_index_using_ellipsis(): @@ -435,17 +521,30 @@ def test_index_using_ellipsis(): assert a.shape == (6,) +@pytest.mark.parametrize( + "test_func", + [ + m.test_fmt_desc_float, + m.test_fmt_desc_double, + m.test_fmt_desc_const_float, + m.test_fmt_desc_const_double, + ], +) +def test_format_descriptors_for_floating_point_types(test_func): + assert "numpy.ndarray[numpy.float" in test_func.__doc__ + + @pytest.mark.parametrize("forcecast", [False, True]) -@pytest.mark.parametrize("contiguity", [None, 'C', 'F']) +@pytest.mark.parametrize("contiguity", [None, "C", "F"]) @pytest.mark.parametrize("noconvert", [False, True]) @pytest.mark.filterwarnings( "ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning" ) def test_argument_conversions(forcecast, contiguity, noconvert): function_name = "accept_double" - if contiguity == 'C': + if contiguity == "C": function_name += "_c_style" - elif contiguity == 'F': + elif contiguity == "F": function_name += "_f_style" if forcecast: function_name += "_forcecast" @@ -453,37 +552,39 @@ def test_argument_conversions(forcecast, contiguity, noconvert): function_name += "_noconvert" function = getattr(m, function_name) - for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]: - for order in ['C', 'F']: + for dtype in [np.dtype("float32"), np.dtype("float64"), np.dtype("complex128")]: + for order in ["C", "F"]: for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]: if not noconvert: # If noconvert is not passed, only complex128 needs to be truncated and # "cannot be safely obtained". So without `forcecast`, the argument shouldn't # be accepted. - should_raise = dtype.name == 'complex128' and not forcecast + should_raise = dtype.name == "complex128" and not forcecast else: # If noconvert is passed, only float64 and the matching order is accepted. # If at most one dimension has a size greater than 1, the array is also # trivially contiguous. trivially_contiguous = sum(1 for d in shape if d > 1) <= 1 - should_raise = ( - dtype.name != 'float64' or - (contiguity is not None and - contiguity != order and - not trivially_contiguous) + should_raise = dtype.name != "float64" or ( + contiguity is not None + and contiguity != order + and not trivially_contiguous ) array = np.zeros(shape, dtype=dtype, order=order) if not should_raise: function(array) else: - with pytest.raises(TypeError, match="incompatible function arguments"): + with pytest.raises( + TypeError, match="incompatible function arguments" + ): function(array) @pytest.mark.xfail("env.PYPY") def test_dtype_refcount_leak(): from sys import getrefcount + dtype = np.dtype(np.float_) a = np.array([1], dtype=dtype) before = getrefcount(dtype) diff --git a/wrap/pybind11/tests/test_numpy_dtypes.cpp b/wrap/pybind11/tests/test_numpy_dtypes.cpp index 467e0253f..bf4f4cee7 100644 --- a/wrap/pybind11/tests/test_numpy_dtypes.cpp +++ b/wrap/pybind11/tests/test_numpy_dtypes.cpp @@ -108,9 +108,11 @@ PYBIND11_PACKED(struct EnumStruct { std::ostream& operator<<(std::ostream& os, const StringStruct& v) { os << "a='"; - for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i]; + for (size_t i = 0; i < 3 && (v.a[i] != 0); i++) + os << v.a[i]; os << "',b='"; - for (size_t i = 0; i < 3 && v.b[i]; i++) os << v.b[i]; + for (size_t i = 0; i < 3 && (v.b[i] != 0); i++) + os << v.b[i]; return os << "'"; } @@ -146,11 +148,13 @@ py::array mkarray_via_buffer(size_t n) { 1, { n }, { sizeof(T) })); } -#define SET_TEST_VALS(s, i) do { \ - s.bool_ = (i) % 2 != 0; \ - s.uint_ = (uint32_t) (i); \ - s.float_ = (float) (i) * 1.5f; \ - s.ldbl_ = (long double) (i) * -2.5L; } while (0) +#define SET_TEST_VALS(s, i) \ + do { \ + (s).bool_ = (i) % 2 != 0; \ + (s).uint_ = (uint32_t) (i); \ + (s).float_ = (float) (i) *1.5f; \ + (s).ldbl_ = (long double) (i) * -2.5L; \ + } while (0) template py::array_t create_recarray(size_t n) { @@ -168,7 +172,7 @@ py::list print_recarray(py::array_t arr) { const auto req = arr.request(); const auto ptr = static_cast(req.ptr); auto l = py::list(); - for (ssize_t i = 0; i < req.size; i++) { + for (py::ssize_t i = 0; i < req.size; i++) { std::stringstream ss; ss << ptr[i]; l.append(py::str(ss.str())); @@ -180,8 +184,8 @@ py::array_t test_array_ctors(int i) { using arr_t = py::array_t; std::vector data { 1, 2, 3, 4, 5, 6 }; - std::vector shape { 3, 2 }; - std::vector strides { 8, 4 }; + std::vector shape { 3, 2 }; + std::vector strides { 8, 4 }; auto ptr = data.data(); auto vptr = (void *) ptr; @@ -255,11 +259,31 @@ struct A {}; struct B {}; TEST_SUBMODULE(numpy_dtypes, m) { - try { py::module::import("numpy"); } + try { py::module_::import("numpy"); } catch (...) { return; } // typeinfo may be registered before the dtype descriptor for scalar casts to work... - py::class_(m, "SimpleStruct"); + py::class_(m, "SimpleStruct") + // Explicit construct to ensure zero-valued initialization. + .def(py::init([]() { return SimpleStruct(); })) + .def_readwrite("bool_", &SimpleStruct::bool_) + .def_readwrite("uint_", &SimpleStruct::uint_) + .def_readwrite("float_", &SimpleStruct::float_) + .def_readwrite("ldbl_", &SimpleStruct::ldbl_) + .def("astuple", + [](const SimpleStruct &self) { + return py::make_tuple(self.bool_, self.uint_, self.float_, self.ldbl_); + }) + .def_static("fromtuple", [](const py::tuple &tup) { + if (py::len(tup) != 4) { + throw py::cast_error("Invalid size"); + } + return SimpleStruct{ + tup[0].cast(), + tup[1].cast(), + tup[2].cast(), + tup[3].cast()}; + }); PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); PYBIND11_NUMPY_DTYPE(SimpleStructReordered, bool_, uint_, float_, ldbl_); @@ -339,6 +363,14 @@ TEST_SUBMODULE(numpy_dtypes, m) { }); // test_dtype + std::vector dtype_names{ + "byte", "short", "intc", "int_", "longlong", + "ubyte", "ushort", "uintc", "uint", "ulonglong", + "half", "single", "double", "longdouble", + "csingle", "cdouble", "clongdouble", + "bool_", "datetime64", "timedelta64", "object_" + }; + m.def("print_dtypes", []() { py::list l; for (const py::handle &d : { @@ -357,6 +389,18 @@ TEST_SUBMODULE(numpy_dtypes, m) { return l; }); m.def("test_dtype_ctors", &test_dtype_ctors); + m.def("test_dtype_kind", [dtype_names]() { + py::list list; + for (auto& dt_name : dtype_names) + list.append(py::dtype(dt_name).kind()); + return list; + }); + m.def("test_dtype_char_", [dtype_names]() { + py::list list; + for (auto& dt_name : dtype_names) + list.append(py::dtype(dt_name).char_()); + return list; + }); m.def("test_dtype_methods", []() { py::list list; auto dt1 = py::dtype::of(); @@ -379,7 +423,7 @@ TEST_SUBMODULE(numpy_dtypes, m) { if (non_empty) { auto req = arr.request(); auto ptr = static_cast(req.ptr); - for (ssize_t i = 0; i < req.size * req.itemsize; i++) + for (py::ssize_t i = 0; i < req.size * req.itemsize; i++) static_cast(req.ptr)[i] = 0; ptr[1].a[0] = 'a'; ptr[1].b[0] = 'a'; ptr[2].a[0] = 'a'; ptr[2].b[0] = 'a'; @@ -462,10 +506,16 @@ TEST_SUBMODULE(numpy_dtypes, m) { m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); }); // test_scalar_conversion - m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; }); + auto f_simple = [](SimpleStruct s) { return s.uint_ * 10; }; + m.def("f_simple", f_simple); m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; }); m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; }); + // test_vectorize + m.def("f_simple_vectorized", py::vectorize(f_simple)); + auto f_simple_pass_thru = [](SimpleStruct s) { return s; }; + m.def("f_simple_pass_thru_vectorized", py::vectorize(f_simple_pass_thru)); + // test_register_dtype m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); }); diff --git a/wrap/pybind11/tests/test_numpy_dtypes.py b/wrap/pybind11/tests/test_numpy_dtypes.py index 417d6f1cf..06e578329 100644 --- a/wrap/pybind11/tests/test_numpy_dtypes.py +++ b/wrap/pybind11/tests/test_numpy_dtypes.py @@ -4,63 +4,82 @@ import re import pytest import env # noqa: F401 - from pybind11_tests import numpy_dtypes as m np = pytest.importorskip("numpy") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def simple_dtype(): - ld = np.dtype('longdouble') - return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'], - 'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)], - 'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]}) + ld = np.dtype("longdouble") + return np.dtype( + { + "names": ["bool_", "uint_", "float_", "ldbl_"], + "formats": ["?", "u4", "f4", "f{}".format(ld.itemsize)], + "offsets": [0, 4, 8, (16 if ld.alignment > 4 else 12)], + } + ) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def packed_dtype(): - return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')]) + return np.dtype([("bool_", "?"), ("uint_", "u4"), ("float_", "f4"), ("ldbl_", "g")]) def dt_fmt(): from sys import byteorder - e = '<' if byteorder == 'little' else '>' - return ("{{'names':['bool_','uint_','float_','ldbl_']," - " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}']," - " 'offsets':[0,4,8,{}], 'itemsize':{}}}") + + e = "<" if byteorder == "little" else ">" + return ( + "{{'names':['bool_','uint_','float_','ldbl_']," + " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}']," + " 'offsets':[0,4,8,{}], 'itemsize':{}}}" + ) def simple_dtype_fmt(): - ld = np.dtype('longdouble') + ld = np.dtype("longdouble") simple_ld_off = 12 + 4 * (ld.alignment > 4) return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize) def packed_dtype_fmt(): from sys import byteorder + return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format( - np.dtype('longdouble').itemsize, e='<' if byteorder == 'little' else '>') + np.dtype("longdouble").itemsize, e="<" if byteorder == "little" else ">" + ) def partial_ld_offset(): - return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * ( - np.dtype('longdouble').alignment > 8) + return ( + 12 + + 4 * (np.dtype("uint64").alignment > 4) + + 8 + + 8 * (np.dtype("longdouble").alignment > 8) + ) def partial_dtype_fmt(): - ld = np.dtype('longdouble') + ld = np.dtype("longdouble") partial_ld_off = partial_ld_offset() - return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize) + partial_size = partial_ld_off + ld.itemsize + partial_end_padding = partial_size % np.dtype("uint64").alignment + return dt_fmt().format( + ld.itemsize, partial_ld_off, partial_size + partial_end_padding + ) def partial_nested_fmt(): - ld = np.dtype('longdouble') + ld = np.dtype("longdouble") partial_nested_off = 8 + 8 * (ld.alignment > 8) partial_ld_off = partial_ld_offset() - partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize + partial_size = partial_ld_off + ld.itemsize + partial_end_padding = partial_size % np.dtype("uint64").alignment + partial_nested_size = partial_nested_off * 2 + partial_size + partial_end_padding return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format( - partial_dtype_fmt(), partial_nested_off, partial_nested_size) + partial_dtype_fmt(), partial_nested_off, partial_nested_size + ) def assert_equal(actual, expected_data, expected_dtype): @@ -70,15 +89,21 @@ def assert_equal(actual, expected_data, expected_dtype): def test_format_descriptors(): with pytest.raises(RuntimeError) as excinfo: m.get_format_unbound() - assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) + assert re.match( + "^NumPy type info missing for .*UnboundStruct.*$", str(excinfo.value) + ) - ld = np.dtype('longdouble') - ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char + ld = np.dtype("longdouble") + ldbl_fmt = ("4x" if ld.alignment > 4 else "") + ld.char ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}" - dbl = np.dtype('double') - partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" + - str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) + - "xg:ldbl_:}") + dbl = np.dtype("double") + end_padding = ld.itemsize % np.dtype("uint64").alignment + partial_fmt = ( + "^T{?:bool_:3xI:uint_:f:float_:" + + str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) + + "xg:ldbl_:" + + (str(end_padding) + "x}" if end_padding > 0 else "}") + ) nested_extra = str(max(8, ld.alignment)) assert m.print_format_descriptors() == [ ss_fmt, @@ -88,14 +113,15 @@ def test_format_descriptors(): "^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", "^T{3s:a:3s:b:}", "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", - '^T{q:e1:B:e2:}', - '^T{Zf:cflt:Zd:cdbl:}' + "^T{q:e1:B:e2:}", + "^T{Zf:cflt:Zd:cdbl:}", ] def test_dtype(simple_dtype): from sys import byteorder - e = '<' if byteorder == 'little' else '>' + + e = "<" if byteorder == "little" else ">" assert m.print_dtypes() == [ simple_dtype_fmt(), @@ -104,30 +130,63 @@ def test_dtype(simple_dtype): partial_dtype_fmt(), partial_nested_fmt(), "[('a', 'S3'), ('b', 'S3')]", - ("{{'names':['a','b','c','d'], " + - "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('" + e + "f4', (4, 2))], " + - "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e), + ( + "{{'names':['a','b','c','d'], " + + "'formats':[('S4', (3,)),('" + + e + + "i4', (2,)),('u1', (3,)),('" + + e + + "f4', (4, 2))], " + + "'offsets':[0,12,20,24], 'itemsize':56}}" + ).format(e=e), "[('e1', '" + e + "i8'), ('e2', 'u1')]", "[('x', 'i1'), ('y', '" + e + "u8')]", - "[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]" + "[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]", ] - d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'], - 'offsets': [1, 10], 'itemsize': 20}) - d2 = np.dtype([('a', 'i4'), ('b', 'f4')]) - assert m.test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'), - np.dtype('bool'), d1, d1, np.dtype('uint32'), d2] + d1 = np.dtype( + { + "names": ["a", "b"], + "formats": ["int32", "float64"], + "offsets": [1, 10], + "itemsize": 20, + } + ) + d2 = np.dtype([("a", "i4"), ("b", "f4")]) + assert m.test_dtype_ctors() == [ + np.dtype("int32"), + np.dtype("float64"), + np.dtype("bool"), + d1, + d1, + np.dtype("uint32"), + d2, + ] - assert m.test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True, - np.dtype('int32').itemsize, simple_dtype.itemsize] + assert m.test_dtype_methods() == [ + np.dtype("int32"), + simple_dtype, + False, + True, + np.dtype("int32").itemsize, + simple_dtype.itemsize, + ] - assert m.trailing_padding_dtype() == m.buffer_to_dtype(np.zeros(1, m.trailing_padding_dtype())) + assert m.trailing_padding_dtype() == m.buffer_to_dtype( + np.zeros(1, m.trailing_padding_dtype()) + ) + + assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO") + assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO") def test_recarray(simple_dtype, packed_dtype): elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)] - for func, dtype in [(m.create_rec_simple, simple_dtype), (m.create_rec_packed, packed_dtype)]: + for func, dtype in [ + (m.create_rec_simple, simple_dtype), + (m.create_rec_packed, packed_dtype), + ]: arr = func(0) assert arr.dtype == dtype assert_equal(arr, [], simple_dtype) @@ -138,20 +197,24 @@ def test_recarray(simple_dtype, packed_dtype): assert_equal(arr, elements, simple_dtype) assert_equal(arr, elements, packed_dtype) + # Show what recarray's look like in NumPy. + assert type(arr[0]) == np.void + assert type(arr[0].item()) == tuple + if dtype == simple_dtype: assert m.print_rec_simple(arr) == [ "s:0,0,0,-0", "s:1,1,1.5,-2.5", - "s:0,2,3,-5" + "s:0,2,3,-5", ] else: assert m.print_rec_packed(arr) == [ "p:0,0,0,-0", "p:1,1,1.5,-2.5", - "p:0,2,3,-5" + "p:0,2,3,-5", ] - nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)]) + nested_dtype = np.dtype([("a", simple_dtype), ("b", packed_dtype)]) arr = m.create_rec_nested(0) assert arr.dtype == nested_dtype @@ -159,33 +222,39 @@ def test_recarray(simple_dtype, packed_dtype): arr = m.create_rec_nested(3) assert arr.dtype == nested_dtype - assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)), - ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)), - ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype) + assert_equal( + arr, + [ + ((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)), + ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)), + ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5)), + ], + nested_dtype, + ) assert m.print_rec_nested(arr) == [ "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5", "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5", - "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5" + "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5", ] arr = m.create_rec_partial(3) assert str(arr.dtype) == partial_dtype_fmt() partial_dtype = arr.dtype - assert '' not in arr.dtype.fields + assert "" not in arr.dtype.fields assert partial_dtype.itemsize > simple_dtype.itemsize assert_equal(arr, elements, simple_dtype) assert_equal(arr, elements, packed_dtype) arr = m.create_rec_partial_nested(3) assert str(arr.dtype) == partial_nested_fmt() - assert '' not in arr.dtype.fields - assert '' not in arr.dtype.fields['a'][0].fields + assert "" not in arr.dtype.fields + assert "" not in arr.dtype.fields["a"][0].fields assert arr.dtype.itemsize > partial_dtype.itemsize - np.testing.assert_equal(arr['a'], m.create_rec_partial(3)) + np.testing.assert_equal(arr["a"], m.create_rec_partial(3)) def test_array_constructors(): - data = np.arange(1, 7, dtype='int32') + data = np.arange(1, 7, dtype="int32") for i in range(8): np.testing.assert_array_equal(m.test_array_ctors(10 + i), data.reshape((3, 2))) np.testing.assert_array_equal(m.test_array_ctors(20 + i), data.reshape((3, 2))) @@ -201,82 +270,92 @@ def test_string_array(): "a='',b=''", "a='a',b='a'", "a='ab',b='ab'", - "a='abc',b='abc'" + "a='abc',b='abc'", ] dtype = arr.dtype - assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc'] - assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc'] + assert arr["a"].tolist() == [b"", b"a", b"ab", b"abc"] + assert arr["b"].tolist() == [b"", b"a", b"ab", b"abc"] arr = m.create_string_array(False) assert dtype == arr.dtype def test_array_array(): from sys import byteorder - e = '<' if byteorder == 'little' else '>' + + e = "<" if byteorder == "little" else ">" arr = m.create_array_array(3) assert str(arr.dtype) == ( - "{{'names':['a','b','c','d'], " + - "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " + - "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e) + "{{'names':['a','b','c','d'], " + + "'formats':[('S4', (3,)),('" + + e + + "i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " + + "'offsets':[0,12,20,24], 'itemsize':56}}" + ).format(e=e) assert m.print_array_array(arr) == [ - "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," + - "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}", - "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," + - "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}", - "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," + - "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}", + "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," + + "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}", + "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," + + "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}", + "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," + + "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}", ] - assert arr['a'].tolist() == [[b'ABCD', b'KLMN', b'UVWX'], - [b'WXYZ', b'GHIJ', b'QRST'], - [b'STUV', b'CDEF', b'MNOP']] - assert arr['b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]] + assert arr["a"].tolist() == [ + [b"ABCD", b"KLMN", b"UVWX"], + [b"WXYZ", b"GHIJ", b"QRST"], + [b"STUV", b"CDEF", b"MNOP"], + ] + assert arr["b"].tolist() == [[0, 1], [1000, 1001], [2000, 2001]] assert m.create_array_array(0).dtype == arr.dtype def test_enum_array(): from sys import byteorder - e = '<' if byteorder == 'little' else '>' + + e = "<" if byteorder == "little" else ">" arr = m.create_enum_array(3) dtype = arr.dtype - assert dtype == np.dtype([('e1', e + 'i8'), ('e2', 'u1')]) - assert m.print_enum_array(arr) == [ - "e1=A,e2=X", - "e1=B,e2=Y", - "e1=A,e2=X" - ] - assert arr['e1'].tolist() == [-1, 1, -1] - assert arr['e2'].tolist() == [1, 2, 1] + assert dtype == np.dtype([("e1", e + "i8"), ("e2", "u1")]) + assert m.print_enum_array(arr) == ["e1=A,e2=X", "e1=B,e2=Y", "e1=A,e2=X"] + assert arr["e1"].tolist() == [-1, 1, -1] + assert arr["e2"].tolist() == [1, 2, 1] assert m.create_enum_array(0).dtype == dtype def test_complex_array(): from sys import byteorder - e = '<' if byteorder == 'little' else '>' + + e = "<" if byteorder == "little" else ">" arr = m.create_complex_array(3) dtype = arr.dtype - assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')]) + assert dtype == np.dtype([("cflt", e + "c8"), ("cdbl", e + "c16")]) assert m.print_complex_array(arr) == [ "c:(0,0.25),(0.5,0.75)", "c:(1,1.25),(1.5,1.75)", - "c:(2,2.25),(2.5,2.75)" + "c:(2,2.25),(2.5,2.75)", ] - assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j] - assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j] + assert arr["cflt"].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j] + assert arr["cdbl"].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j] assert m.create_complex_array(0).dtype == dtype def test_signature(doc): - assert doc(m.create_rec_nested) == \ - "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" + assert ( + doc(m.create_rec_nested) + == "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" + ) def test_scalar_conversion(): n = 3 - arrays = [m.create_rec_simple(n), m.create_rec_packed(n), - m.create_rec_nested(n), m.create_enum_array(n)] + arrays = [ + m.create_rec_simple(n), + m.create_rec_packed(n), + m.create_rec_nested(n), + m.create_enum_array(n), + ] funcs = [m.f_simple, m.f_packed, m.f_nested] for i, func in enumerate(funcs): @@ -286,18 +365,68 @@ def test_scalar_conversion(): else: with pytest.raises(TypeError) as excinfo: func(arr[0]) - assert 'incompatible function arguments' in str(excinfo.value) + assert "incompatible function arguments" in str(excinfo.value) + + +def test_vectorize(): + n = 3 + array = m.create_rec_simple(n) + values = m.f_simple_vectorized(array) + np.testing.assert_array_equal(values, [0, 10, 20]) + array_2 = m.f_simple_pass_thru_vectorized(array) + np.testing.assert_array_equal(array, array_2) + + +def test_cls_and_dtype_conversion(simple_dtype): + s = m.SimpleStruct() + assert s.astuple() == (False, 0, 0.0, 0.0) + assert m.SimpleStruct.fromtuple(s.astuple()).astuple() == s.astuple() + + s.uint_ = 2 + assert m.f_simple(s) == 20 + + # Try as recarray of shape==(1,). + s_recarray = np.array([(False, 2, 0.0, 0.0)], dtype=simple_dtype) + # Show that this will work for vectorized case. + np.testing.assert_array_equal(m.f_simple_vectorized(s_recarray), [20]) + + # Show as a scalar that inherits from np.generic. + s_scalar = s_recarray[0] + assert isinstance(s_scalar, np.void) + assert m.f_simple(s_scalar) == 20 + + # Show that an *array* scalar (np.ndarray.shape == ()) does not convert. + # More specifically, conversion to SimpleStruct is not implicit. + s_recarray_scalar = s_recarray.reshape(()) + assert isinstance(s_recarray_scalar, np.ndarray) + assert s_recarray_scalar.dtype == simple_dtype + with pytest.raises(TypeError) as excinfo: + m.f_simple(s_recarray_scalar) + assert "incompatible function arguments" in str(excinfo.value) + # Explicitly convert to m.SimpleStruct. + assert m.f_simple(m.SimpleStruct.fromtuple(s_recarray_scalar.item())) == 20 + + # Show that an array of dtype=object does *not* convert. + s_array_object = np.array([s]) + assert s_array_object.dtype == object + with pytest.raises(TypeError) as excinfo: + m.f_simple_vectorized(s_array_object) + assert "incompatible function arguments" in str(excinfo.value) + # Explicitly convert to `np.array(..., dtype=simple_dtype)` + s_array = np.array([s.astuple()], dtype=simple_dtype) + np.testing.assert_array_equal(m.f_simple_vectorized(s_array), [20]) def test_register_dtype(): with pytest.raises(RuntimeError) as excinfo: m.register_dtype() - assert 'dtype is already registered' in str(excinfo.value) + assert "dtype is already registered" in str(excinfo.value) @pytest.mark.xfail("env.PYPY") def test_str_leak(): from sys import getrefcount + fmt = "f4" pytest.gc_collect() start = getrefcount(fmt) diff --git a/wrap/pybind11/tests/test_numpy_vectorize.cpp b/wrap/pybind11/tests/test_numpy_vectorize.cpp index e76e462cb..eb5281fb1 100644 --- a/wrap/pybind11/tests/test_numpy_vectorize.cpp +++ b/wrap/pybind11/tests/test_numpy_vectorize.cpp @@ -11,13 +11,15 @@ #include "pybind11_tests.h" #include +#include + double my_func(int x, float y, double z) { py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z)); return (float) x*y*z; } TEST_SUBMODULE(numpy_vectorize, m) { - try { py::module::import("numpy"); } + try { py::module_::import("numpy"); } catch (...) { return; } // test_vectorize, test_docs, test_array_collapse @@ -25,11 +27,10 @@ TEST_SUBMODULE(numpy_vectorize, m) { m.def("vectorized_func", py::vectorize(my_func)); // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization) - m.def("vectorized_func2", - [](py::array_t x, py::array_t y, float z) { - return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y); - } - ); + m.def("vectorized_func2", [](py::array_t x, py::array_t y, float z) { + return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(std::move(x), + std::move(y)); + }); // Vectorize a complex-valued function m.def("vectorized_func3", py::vectorize( @@ -38,29 +39,40 @@ TEST_SUBMODULE(numpy_vectorize, m) { // test_type_selection // NumPy function which only accepts specific data types - m.def("selective_func", [](py::array_t) { return "Int branch taken."; }); - m.def("selective_func", [](py::array_t) { return "Float branch taken."; }); - m.def("selective_func", [](py::array_t, py::array::c_style>) { return "Complex float branch taken."; }); - + // A lot of these no lints could be replaced with const refs, and probably should at some point. + m.def("selective_func", + [](const py::array_t &) { return "Int branch taken."; }); + m.def("selective_func", + [](const py::array_t &) { return "Float branch taken."; }); + m.def("selective_func", [](const py::array_t, py::array::c_style> &) { + return "Complex float branch taken."; + }); // test_passthrough_arguments // Passthrough test: references and non-pod types should be automatically passed through (in the // function definition below, only `b`, `d`, and `g` are vectorized): struct NonPODClass { - NonPODClass(int v) : value{v} {} + explicit NonPODClass(int v) : value{v} {} int value; }; - py::class_(m, "NonPODClass").def(py::init()); - m.def("vec_passthrough", py::vectorize( - [](double *a, double b, py::array_t c, const int &d, int &e, NonPODClass f, const double g) { - return *a + b + c.at(0) + d + e + f.value + g; - } - )); + py::class_(m, "NonPODClass") + .def(py::init()) + .def_readwrite("value", &NonPODClass::value); + m.def("vec_passthrough", + py::vectorize([](const double *a, + double b, + // Changing this broke things + // NOLINTNEXTLINE(performance-unnecessary-value-param) + py::array_t c, + const int &d, + int &e, + NonPODClass f, + const double g) { return *a + b + c.at(0) + d + e + f.value + g; })); // test_method_vectorization struct VectorizeTestClass { - VectorizeTestClass(int v) : value{v} {}; - float method(int x, float y) { return y + (float) (x + value); } + explicit VectorizeTestClass(int v) : value{v} {}; + float method(int x, float y) const { return y + (float) (x + value); } int value = 0; }; py::class_ vtc(m, "VectorizeTestClass"); @@ -76,14 +88,16 @@ TEST_SUBMODULE(numpy_vectorize, m) { .value("f_trivial", py::detail::broadcast_trivial::f_trivial) .value("c_trivial", py::detail::broadcast_trivial::c_trivial) .value("non_trivial", py::detail::broadcast_trivial::non_trivial); - m.def("vectorized_is_trivial", []( - py::array_t arg1, - py::array_t arg2, - py::array_t arg3 - ) { - ssize_t ndim; - std::vector shape; - std::array buffers {{ arg1.request(), arg2.request(), arg3.request() }}; - return py::detail::broadcast(buffers, ndim, shape); - }); + m.def("vectorized_is_trivial", + [](const py::array_t &arg1, + const py::array_t &arg2, + const py::array_t &arg3) { + py::ssize_t ndim = 0; + std::vector shape; + std::array buffers{ + {arg1.request(), arg2.request(), arg3.request()}}; + return py::detail::broadcast(buffers, ndim, shape); + }); + + m.def("add_to", py::vectorize([](NonPODClass& x, int a) { x.value += a; })); } diff --git a/wrap/pybind11/tests/test_numpy_vectorize.py b/wrap/pybind11/tests/test_numpy_vectorize.py index 54e44cd8d..de5c9a607 100644 --- a/wrap/pybind11/tests/test_numpy_vectorize.py +++ b/wrap/pybind11/tests/test_numpy_vectorize.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pytest + from pybind11_tests import numpy_vectorize as m np = pytest.importorskip("numpy") @@ -17,28 +18,40 @@ def test_vectorize(capture): assert capture == "my_func(x:int=1, y:float=2, z:float=3)" with capture: assert np.allclose(f(np.array([1, 3]), np.array([2, 4]), 3), [6, 36]) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=3) my_func(x:int=3, y:float=4, z:float=3) """ + ) with capture: - a = np.array([[1, 2], [3, 4]], order='F') - b = np.array([[10, 20], [30, 40]], order='F') + a = np.array([[1, 2], [3, 4]], order="F") + b = np.array([[10, 20], [30, 40]], order="F") c = 3 result = f(a, b, c) assert np.allclose(result, a * b * c) assert result.flags.f_contiguous # All inputs are F order and full or singletons, so we the result is in col-major order: - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=10, z:float=3) my_func(x:int=3, y:float=30, z:float=3) my_func(x:int=2, y:float=20, z:float=3) my_func(x:int=4, y:float=40, z:float=3) """ + ) with capture: - a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3 + a, b, c = ( + np.array([[1, 3, 5], [7, 9, 11]]), + np.array([[2, 4, 6], [8, 10, 12]]), + 3, + ) assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=3) my_func(x:int=3, y:float=4, z:float=3) my_func(x:int=5, y:float=6, z:float=3) @@ -46,10 +59,13 @@ def test_vectorize(capture): my_func(x:int=9, y:float=10, z:float=3) my_func(x:int=11, y:float=12, z:float=3) """ + ) with capture: a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2 assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=2) my_func(x:int=2, y:float=3, z:float=2) my_func(x:int=3, y:float=4, z:float=2) @@ -57,10 +73,13 @@ def test_vectorize(capture): my_func(x:int=5, y:float=3, z:float=2) my_func(x:int=6, y:float=4, z:float=2) """ + ) with capture: a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2 assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=2) my_func(x:int=2, y:float=2, z:float=2) my_func(x:int=3, y:float=2, z:float=2) @@ -68,10 +87,17 @@ def test_vectorize(capture): my_func(x:int=5, y:float=3, z:float=2) my_func(x:int=6, y:float=3, z:float=2) """ + ) with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F'), np.array([[2], [3]]), 2 + a, b, c = ( + np.array([[1, 2, 3], [4, 5, 6]], order="F"), + np.array([[2], [3]]), + 2, + ) assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=2) my_func(x:int=2, y:float=2, z:float=2) my_func(x:int=3, y:float=2, z:float=2) @@ -79,36 +105,53 @@ def test_vectorize(capture): my_func(x:int=5, y:float=3, z:float=2) my_func(x:int=6, y:float=3, z:float=2) """ + ) with capture: a, b, c = np.array([[1, 2, 3], [4, 5, 6]])[::, ::2], np.array([[2], [3]]), 2 assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=2) my_func(x:int=3, y:float=2, z:float=2) my_func(x:int=4, y:float=3, z:float=2) my_func(x:int=6, y:float=3, z:float=2) """ + ) with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F')[::, ::2], np.array([[2], [3]]), 2 + a, b, c = ( + np.array([[1, 2, 3], [4, 5, 6]], order="F")[::, ::2], + np.array([[2], [3]]), + 2, + ) assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ + assert ( + capture + == """ my_func(x:int=1, y:float=2, z:float=2) my_func(x:int=3, y:float=2, z:float=2) my_func(x:int=4, y:float=3, z:float=2) my_func(x:int=6, y:float=3, z:float=2) """ + ) def test_type_selection(): assert m.selective_func(np.array([1], dtype=np.int32)) == "Int branch taken." assert m.selective_func(np.array([1.0], dtype=np.float32)) == "Float branch taken." - assert m.selective_func(np.array([1.0j], dtype=np.complex64)) == "Complex float branch taken." + assert ( + m.selective_func(np.array([1.0j], dtype=np.complex64)) + == "Complex float branch taken." + ) def test_docs(doc): - assert doc(m.vectorized_func) == """ + assert ( + doc(m.vectorized_func) + == """ vectorized_func(arg0: numpy.ndarray[numpy.int32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.float64]) -> object """ # noqa: E501 line too long + ) def test_trivial_broadcasting(): @@ -116,16 +159,24 @@ def test_trivial_broadcasting(): assert vectorized_is_trivial(1, 2, 3) == trivial.c_trivial assert vectorized_is_trivial(np.array(1), np.array(2), 3) == trivial.c_trivial - assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) == trivial.c_trivial + assert ( + vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) + == trivial.c_trivial + ) assert trivial.c_trivial == vectorized_is_trivial( - np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3) - assert vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) == trivial.non_trivial - assert vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) == trivial.non_trivial - z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype='int32') - z2 = np.array(z1, dtype='float32') - z3 = np.array(z1, dtype='float64') + np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3 + ) + assert ( + vectorized_is_trivial(np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) + == trivial.non_trivial + ) + assert ( + vectorized_is_trivial(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) + == trivial.non_trivial + ) + z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype="int32") + z2 = np.array(z1, dtype="float32") + z3 = np.array(z1, dtype="float64") assert vectorized_is_trivial(z1, z2, z3) == trivial.c_trivial assert vectorized_is_trivial(1, z2, z3) == trivial.c_trivial assert vectorized_is_trivial(z1, 1, z3) == trivial.c_trivial @@ -135,7 +186,7 @@ def test_trivial_broadcasting(): assert vectorized_is_trivial(1, 1, z3[::2, ::2]) == trivial.non_trivial assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4]) == trivial.c_trivial - y1 = np.array(z1, order='F') + y1 = np.array(z1, order="F") y2 = np.array(y1) y3 = np.array(y1) assert vectorized_is_trivial(y1, y2, y3) == trivial.f_trivial @@ -156,30 +207,41 @@ def test_trivial_broadcasting(): def test_passthrough_arguments(doc): assert doc(m.vec_passthrough) == ( - "vec_passthrough(" + ", ".join([ - "arg0: float", - "arg1: numpy.ndarray[numpy.float64]", - "arg2: numpy.ndarray[numpy.float64]", - "arg3: numpy.ndarray[numpy.int32]", - "arg4: int", - "arg5: m.numpy_vectorize.NonPODClass", - "arg6: numpy.ndarray[numpy.float64]"]) + ") -> object") + "vec_passthrough(" + + ", ".join( + [ + "arg0: float", + "arg1: numpy.ndarray[numpy.float64]", + "arg2: numpy.ndarray[numpy.float64]", + "arg3: numpy.ndarray[numpy.int32]", + "arg4: int", + "arg5: m.numpy_vectorize.NonPODClass", + "arg6: numpy.ndarray[numpy.float64]", + ] + ) + + ") -> object" + ) - b = np.array([[10, 20, 30]], dtype='float64') + b = np.array([[10, 20, 30]], dtype="float64") c = np.array([100, 200]) # NOT a vectorized argument - d = np.array([[1000], [2000], [3000]], dtype='int') - g = np.array([[1000000, 2000000, 3000000]], dtype='int') # requires casting + d = np.array([[1000], [2000], [3000]], dtype="int") + g = np.array([[1000000, 2000000, 3000000]], dtype="int") # requires casting assert np.all( - m.vec_passthrough(1, b, c, d, 10000, m.NonPODClass(100000), g) == - np.array([[1111111, 2111121, 3111131], - [1112111, 2112121, 3112131], - [1113111, 2113121, 3113131]])) + m.vec_passthrough(1, b, c, d, 10000, m.NonPODClass(100000), g) + == np.array( + [ + [1111111, 2111121, 3111131], + [1112111, 2112121, 3112131], + [1113111, 2113121, 3113131], + ] + ) + ) def test_method_vectorization(): o = m.VectorizeTestClass(3) - x = np.array([1, 2], dtype='int') - y = np.array([[10], [20]], dtype='float32') + x = np.array([1, 2], dtype="int") + y = np.array([[10], [20]], dtype="float32") assert np.all(o.method(x, y) == [[14, 15], [24, 25]]) @@ -188,7 +250,18 @@ def test_array_collapse(): assert not isinstance(m.vectorized_func(np.array(1), 2, 3), np.ndarray) z = m.vectorized_func([1], 2, 3) assert isinstance(z, np.ndarray) - assert z.shape == (1, ) + assert z.shape == (1,) z = m.vectorized_func(1, [[[2]]], 3) assert isinstance(z, np.ndarray) assert z.shape == (1, 1, 1) + + +def test_vectorized_noreturn(): + x = m.NonPODClass(0) + assert x.value == 0 + m.add_to(x, [1, 2, 3, 4]) + assert x.value == 10 + m.add_to(x, 1) + assert x.value == 11 + m.add_to(x, [[1, 1], [2, 3]]) + assert x.value == 18 diff --git a/wrap/pybind11/tests/test_opaque_types.cpp b/wrap/pybind11/tests/test_opaque_types.cpp index 594c45a08..804de6d4f 100644 --- a/wrap/pybind11/tests/test_opaque_types.cpp +++ b/wrap/pybind11/tests/test_opaque_types.cpp @@ -44,7 +44,7 @@ TEST_SUBMODULE(opaque_types, m) { m.def("print_opaque_list", [](const StringList &l) { std::string ret = "Opaque list: ["; bool first = true; - for (auto entry : l) { + for (const auto &entry : l) { if (!first) ret += ", "; ret += entry; @@ -64,4 +64,10 @@ TEST_SUBMODULE(opaque_types, m) { result->push_back("some value"); return std::unique_ptr(result); }); + + // test unions + py::class_(m, "IntFloat") + .def(py::init<>()) + .def_readwrite("i", &IntFloat::i) + .def_readwrite("f", &IntFloat::f); } diff --git a/wrap/pybind11/tests/test_opaque_types.py b/wrap/pybind11/tests/test_opaque_types.py index 3f2392775..5495cb6b4 100644 --- a/wrap/pybind11/tests/test_opaque_types.py +++ b/wrap/pybind11/tests/test_opaque_types.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- import pytest -from pybind11_tests import opaque_types as m + from pybind11_tests import ConstructorStats, UserType +from pybind11_tests import opaque_types as m def test_string_list(): @@ -32,12 +33,15 @@ def test_pointers(msg): with pytest.raises(TypeError) as excinfo: m.get_void_ptr_value([1, 2, 3]) # This should not work - assert msg(excinfo.value) == """ + assert ( + msg(excinfo.value) + == """ get_void_ptr_value(): incompatible function arguments. The following argument types are supported: 1. (arg0: capsule) -> int Invoked with: [1, 2, 3] """ # noqa: E501 line too long + ) assert m.return_null_str() is None assert m.get_null_str_value(m.return_null_str()) is not None @@ -45,3 +49,11 @@ def test_pointers(msg): ptr = m.return_unique_ptr() assert "StringList" in repr(ptr) assert m.print_opaque_list(ptr) == "Opaque list: [some value]" + + +def test_unions(): + int_float_union = m.IntFloat() + int_float_union.i = 42 + assert int_float_union.i == 42 + int_float_union.f = 3.0 + assert int_float_union.f == 3.0 diff --git a/wrap/pybind11/tests/test_operator_overloading.cpp b/wrap/pybind11/tests/test_operator_overloading.cpp index d55495471..ffa059d5b 100644 --- a/wrap/pybind11/tests/test_operator_overloading.cpp +++ b/wrap/pybind11/tests/test_operator_overloading.cpp @@ -7,18 +7,28 @@ BSD-style license that can be found in the LICENSE file. */ -#include "pybind11_tests.h" #include "constructor_stats.h" -#include +#include "pybind11_tests.h" #include +#include +#include class Vector2 { public: Vector2(float x, float y) : x(x), y(y) { print_created(this, toString()); } Vector2(const Vector2 &v) : x(v.x), y(v.y) { print_copy_created(this); } - Vector2(Vector2 &&v) : x(v.x), y(v.y) { print_move_created(this); v.x = v.y = 0; } + Vector2(Vector2 &&v) noexcept : x(v.x), y(v.y) { + print_move_created(this); + v.x = v.y = 0; + } Vector2 &operator=(const Vector2 &v) { x = v.x; y = v.y; print_copy_assigned(this); return *this; } - Vector2 &operator=(Vector2 &&v) { x = v.x; y = v.y; v.x = v.y = 0; print_move_assigned(this); return *this; } + Vector2 &operator=(Vector2 &&v) noexcept { + x = v.x; + y = v.y; + v.x = v.y = 0; + print_move_assigned(this); + return *this; + } ~Vector2() { print_destroyed(this); } std::string toString() const { return "[" + std::to_string(x) + ", " + std::to_string(y) + "]"; } @@ -62,6 +72,12 @@ int operator+(const C2 &, const C2 &) { return 22; } int operator+(const C2 &, const C1 &) { return 21; } int operator+(const C1 &, const C2 &) { return 12; } +struct HashMe { + std::string member; +}; + +bool operator==(const HashMe &lhs, const HashMe &rhs) { return lhs.member == rhs.member; } + // Note: Specializing explicit within `namespace std { ... }` is done due to a // bug in GCC<7. If you are supporting compilers later than this, consider // specializing `using template<> struct std::hash<...>` in the global @@ -73,6 +89,14 @@ namespace std { // Not a good hash function, but easy to test size_t operator()(const Vector2 &) { return 4; } }; + + // HashMe has a hash function in C++ but no `__hash__` for Python. + template <> + struct hash { + std::size_t operator()(const HashMe &selector) const { + return std::hash()(selector.member); + } + }; } // namespace std // Not a good abs function, but easy to test. @@ -80,8 +104,8 @@ std::string abs(const Vector2&) { return "abs(Vector2)"; } -// MSVC warns about unknown pragmas, and warnings are errors. -#ifndef _MSC_VER +// MSVC & Intel warns about unknown pragmas, and warnings are errors. +#if !defined(_MSC_VER) && !defined(__INTEL_COMPILER) #pragma GCC diagnostic push // clang 7.0.0 and Apple LLVM 10.0.1 introduce `-Wself-assign-overloaded` to // `-Wall`, which is used here for overloading (e.g. `py::self += py::self `). @@ -89,7 +113,7 @@ std::string abs(const Vector2&) { // Taken from: https://github.com/RobotLocomotion/drake/commit/aaf84b46 // TODO(eric): This could be resolved using a function / functor (e.g. `py::self()`). #if defined(__APPLE__) && defined(__clang__) - #if (__clang_major__ >= 10) && (__clang_minor__ >= 0) && (__clang_patchlevel__ >= 1) + #if (__clang_major__ >= 10) #pragma GCC diagnostic ignored "-Wself-assign-overloaded" #endif #elif defined(__clang__) @@ -219,8 +243,12 @@ TEST_SUBMODULE(operators, m) { .def("__hash__", &Hashable::hash) .def(py::init()) .def(py::self == py::self); -} -#ifndef _MSC_VER + // define __eq__ but not __hash__ + py::class_(m, "HashMe").def(py::self == py::self); + + m.def("get_unhashable_HashMe_set", []() { return std::unordered_set{{"one"}}; }); +} +#if !defined(_MSC_VER) && !defined(__INTEL_COMPILER) #pragma GCC diagnostic pop #endif diff --git a/wrap/pybind11/tests/test_operator_overloading.py b/wrap/pybind11/tests/test_operator_overloading.py index 39e3aee27..8cf375b6d 100644 --- a/wrap/pybind11/tests/test_operator_overloading.py +++ b/wrap/pybind11/tests/test_operator_overloading.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from pybind11_tests import operators as m + +import env from pybind11_tests import ConstructorStats +from pybind11_tests import operators as m def test_operator_overloading(): @@ -56,23 +58,23 @@ def test_operator_overloading(): del v3 assert cstats.alive() == 0 assert cstats.values() == [ - '[1.000000, 2.000000]', - '[3.000000, -1.000000]', - '[1.000000, 2.000000]', - '[-3.000000, 1.000000]', - '[4.000000, 1.000000]', - '[-2.000000, 3.000000]', - '[-7.000000, -6.000000]', - '[9.000000, 10.000000]', - '[8.000000, 16.000000]', - '[0.125000, 0.250000]', - '[7.000000, 6.000000]', - '[9.000000, 10.000000]', - '[8.000000, 16.000000]', - '[8.000000, 4.000000]', - '[3.000000, -2.000000]', - '[3.000000, -0.500000]', - '[6.000000, -2.000000]', + "[1.000000, 2.000000]", + "[3.000000, -1.000000]", + "[1.000000, 2.000000]", + "[-3.000000, 1.000000]", + "[4.000000, 1.000000]", + "[-2.000000, 3.000000]", + "[-7.000000, -6.000000]", + "[9.000000, 10.000000]", + "[8.000000, 16.000000]", + "[0.125000, 0.250000]", + "[7.000000, 6.000000]", + "[9.000000, 10.000000]", + "[8.000000, 16.000000]", + "[8.000000, 4.000000]", + "[3.000000, -2.000000]", + "[3.000000, -0.500000]", + "[6.000000, -2.000000]", ] assert cstats.default_constructions == 0 assert cstats.copy_constructions == 0 @@ -134,8 +136,9 @@ def test_overriding_eq_reset_hash(): assert m.Comparable(15) is not m.Comparable(15) assert m.Comparable(15) == m.Comparable(15) - with pytest.raises(TypeError): - hash(m.Comparable(15)) # TypeError: unhashable type: 'm.Comparable' + with pytest.raises(TypeError) as excinfo: + hash(m.Comparable(15)) + assert str(excinfo.value).startswith("unhashable type:") for hashable in (m.Hashable, m.Hashable2): assert hashable(15) is not hashable(15) @@ -143,3 +146,10 @@ def test_overriding_eq_reset_hash(): assert hash(hashable(15)) == 15 assert hash(hashable(15)) == hash(hashable(15)) + + +def test_return_set_of_unhashable(): + with pytest.raises(TypeError) as excinfo: + m.get_unhashable_HashMe_set() + if not env.PY2: + assert str(excinfo.value.__cause__).startswith("unhashable type:") diff --git a/wrap/pybind11/tests/test_pickling.cpp b/wrap/pybind11/tests/test_pickling.cpp index 9dc63bda3..b77636dd1 100644 --- a/wrap/pybind11/tests/test_pickling.cpp +++ b/wrap/pybind11/tests/test_pickling.cpp @@ -1,7 +1,9 @@ +// clang-format off /* tests/test_pickling.cpp -- pickle support Copyright (c) 2016 Wenzel Jakob + Copyright (c) 2021 The Pybind Development Team. All rights reserved. Use of this source code is governed by a BSD-style license that can be found in the LICENSE file. @@ -9,11 +11,63 @@ #include "pybind11_tests.h" +// clang-format on + +#include +#include +#include + +namespace exercise_trampoline { + +struct SimpleBase { + int num = 0; + virtual ~SimpleBase() = default; + + // For compatibility with old clang versions: + SimpleBase() = default; + SimpleBase(const SimpleBase &) = default; +}; + +struct SimpleBaseTrampoline : SimpleBase {}; + +struct SimpleCppDerived : SimpleBase {}; + +void wrap(py::module m) { + py::class_(m, "SimpleBase") + .def(py::init<>()) + .def_readwrite("num", &SimpleBase::num) + .def(py::pickle( + [](const py::object &self) { + py::dict d; + if (py::hasattr(self, "__dict__")) + d = self.attr("__dict__"); + return py::make_tuple(self.attr("num"), d); + }, + [](const py::tuple &t) { + if (t.size() != 2) + throw std::runtime_error("Invalid state!"); + auto cpp_state = std::unique_ptr(new SimpleBaseTrampoline); + cpp_state->num = t[0].cast(); + auto py_state = t[1].cast(); + return std::make_pair(std::move(cpp_state), py_state); + })); + + m.def("make_SimpleCppDerivedAsBase", + []() { return std::unique_ptr(new SimpleCppDerived); }); + m.def("check_dynamic_cast_SimpleCppDerived", [](const SimpleBase *base_ptr) { + return dynamic_cast(base_ptr) != nullptr; + }); +} + +} // namespace exercise_trampoline + +// clang-format off + TEST_SUBMODULE(pickling, m) { // test_roundtrip class Pickleable { public: - Pickleable(const std::string &value) : m_value(value) { } + explicit Pickleable(const std::string &value) : m_value(value) { } const std::string &value() const { return m_value; } void setExtra1(int extra1) { m_extra1 = extra1; } @@ -31,7 +85,8 @@ TEST_SUBMODULE(pickling, m) { using Pickleable::Pickleable; }; - py::class_(m, "Pickleable") + py::class_ pyPickleable(m, "Pickleable"); + pyPickleable .def(py::init()) .def("value", &Pickleable::value) .def("extra1", &Pickleable::extra1) @@ -43,8 +98,9 @@ TEST_SUBMODULE(pickling, m) { .def("__getstate__", [](const Pickleable &p) { /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(p.value(), p.extra1(), p.extra2()); - }) - .def("__setstate__", [](Pickleable &p, py::tuple t) { + }); + ignoreOldStyleInitWarnings([&pyPickleable]() { + pyPickleable.def("__setstate__", [](Pickleable &p, const py::tuple &t) { if (t.size() != 3) throw std::runtime_error("Invalid state!"); /* Invoke the constructor (need to use in-place version) */ @@ -54,6 +110,7 @@ TEST_SUBMODULE(pickling, m) { p.setExtra1(t[1].cast()); p.setExtra2(t[2].cast()); }); + }); py::class_(m, "PickleableNew") .def(py::init()) @@ -61,7 +118,7 @@ TEST_SUBMODULE(pickling, m) { [](const PickleableNew &p) { return py::make_tuple(p.value(), p.extra1(), p.extra2()); }, - [](py::tuple t) { + [](const py::tuple &t) { if (t.size() != 3) throw std::runtime_error("Invalid state!"); auto p = PickleableNew(t[0].cast()); @@ -69,14 +126,13 @@ TEST_SUBMODULE(pickling, m) { p.setExtra1(t[1].cast()); p.setExtra2(t[2].cast()); return p; - } - )); + })); #if !defined(PYPY_VERSION) // test_roundtrip_with_dict class PickleableWithDict { public: - PickleableWithDict(const std::string &value) : value(value) { } + explicit PickleableWithDict(const std::string &value) : value(value) { } std::string value; int extra; @@ -87,19 +143,20 @@ TEST_SUBMODULE(pickling, m) { using PickleableWithDict::PickleableWithDict; }; - py::class_(m, "PickleableWithDict", py::dynamic_attr()) - .def(py::init()) + py::class_ pyPickleableWithDict(m, "PickleableWithDict", py::dynamic_attr()); + pyPickleableWithDict.def(py::init()) .def_readwrite("value", &PickleableWithDict::value) .def_readwrite("extra", &PickleableWithDict::extra) - .def("__getstate__", [](py::object self) { + .def("__getstate__", [](const py::object &self) { /* Also include __dict__ in state */ return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__")); - }) - .def("__setstate__", [](py::object self, py::tuple t) { + }); + ignoreOldStyleInitWarnings([&pyPickleableWithDict]() { + pyPickleableWithDict.def("__setstate__", [](const py::object &self, const py::tuple &t) { if (t.size() != 3) throw std::runtime_error("Invalid state!"); /* Cast and construct */ - auto& p = self.cast(); + auto &p = self.cast(); new (&p) PickleableWithDict(t[0].cast()); /* Assign C++ state */ @@ -108,11 +165,12 @@ TEST_SUBMODULE(pickling, m) { /* Assign Python state */ self.attr("__dict__") = t[2]; }); + }); py::class_(m, "PickleableWithDictNew") .def(py::init()) .def(py::pickle( - [](py::object self) { + [](const py::object &self) { return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__")); }, [](const py::tuple &t) { @@ -124,7 +182,8 @@ TEST_SUBMODULE(pickling, m) { auto py_state = t[2].cast(); return std::make_pair(cpp_state, py_state); - } - )); + })); #endif + + exercise_trampoline::wrap(m); } diff --git a/wrap/pybind11/tests/test_pickling.py b/wrap/pybind11/tests/test_pickling.py index 9aee70505..9f68f37dc 100644 --- a/wrap/pybind11/tests/test_pickling.py +++ b/wrap/pybind11/tests/test_pickling.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- import pytest -import env # noqa: F401 - +import env from pybind11_tests import pickling as m try: @@ -42,5 +41,42 @@ def test_roundtrip_with_dict(cls_name): def test_enum_pickle(): from pybind11_tests import enums as e + data = pickle.dumps(e.EOne, 2) assert e.EOne == pickle.loads(data) + + +# +# exercise_trampoline +# +class SimplePyDerived(m.SimpleBase): + pass + + +def test_roundtrip_simple_py_derived(): + p = SimplePyDerived() + p.num = 202 + p.stored_in_dict = 303 + data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL) + p2 = pickle.loads(data) + assert isinstance(p2, SimplePyDerived) + assert p2.num == 202 + assert p2.stored_in_dict == 303 + + +def test_roundtrip_simple_cpp_derived(): + p = m.make_SimpleCppDerivedAsBase() + assert m.check_dynamic_cast_SimpleCppDerived(p) + p.num = 404 + if not env.PYPY: + # To ensure that this unit test is not accidentally invalidated. + with pytest.raises(AttributeError): + # Mimics the `setstate` C++ implementation. + setattr(p, "__dict__", {}) # noqa: B010 + data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL) + p2 = pickle.loads(data) + assert isinstance(p2, m.SimpleBase) + assert p2.num == 404 + # Issue #3062: pickleable base C++ classes can incur object slicing + # if derived typeid is not registered with pybind11 + assert not m.check_dynamic_cast_SimpleCppDerived(p2) diff --git a/wrap/pybind11/tests/test_pytypes.cpp b/wrap/pybind11/tests/test_pytypes.cpp index 4ef1b9ff0..85cb98fcb 100644 --- a/wrap/pybind11/tests/test_pytypes.cpp +++ b/wrap/pybind11/tests/test_pytypes.cpp @@ -7,17 +7,28 @@ BSD-style license that can be found in the LICENSE file. */ +#include + #include "pybind11_tests.h" TEST_SUBMODULE(pytypes, m) { + // test_bool + m.def("get_bool", []{return py::bool_(false);}); // test_int m.def("get_int", []{return py::int_(0);}); // test_iterator m.def("get_iterator", []{return py::iterator();}); // test_iterable m.def("get_iterable", []{return py::iterable();}); + // test_float + m.def("get_float", []{return py::float_(0.0f);}); // test_list + m.def("list_no_args", []() { return py::list{}; }); + m.def("list_ssize_t", []() { return py::list{(py::ssize_t) 0}; }); + m.def("list_size_t", []() { return py::list{(py::size_t) 0}; }); + m.def("list_insert_ssize_t", [](py::list *l) { return l->insert((py::ssize_t) 1, 83); }); + m.def("list_insert_size_t", [](py::list *l) { return l->insert((py::size_t) 3, 57); }); m.def("get_list", []() { py::list list; list.append("value"); @@ -27,16 +38,14 @@ TEST_SUBMODULE(pytypes, m) { list.insert(2, "inserted-2"); return list; }); - m.def("print_list", [](py::list list) { + m.def("print_list", [](const py::list &list) { int index = 0; for (auto item : list) py::print("list item {}: {}"_s.format(index++, item)); }); // test_none m.def("get_none", []{return py::none();}); - m.def("print_none", [](py::none none) { - py::print("none: {}"_s.format(none)); - }); + m.def("print_none", [](const py::none &none) { py::print("none: {}"_s.format(none)); }); // test_set m.def("get_set", []() { @@ -46,20 +55,17 @@ TEST_SUBMODULE(pytypes, m) { set.add(std::string("key3")); return set; }); - m.def("print_set", [](py::set set) { + m.def("print_set", [](const py::set &set) { for (auto item : set) py::print("key:", item); }); - m.def("set_contains", [](py::set set, py::object key) { - return set.contains(key); - }); - m.def("set_contains", [](py::set set, const char* key) { - return set.contains(key); - }); + m.def("set_contains", + [](const py::set &set, const py::object &key) { return set.contains(key); }); + m.def("set_contains", [](const py::set &set, const char *key) { return set.contains(key); }); // test_dict m.def("get_dict", []() { return py::dict("key"_a="value"); }); - m.def("print_dict", [](py::dict dict) { + m.def("print_dict", [](const py::dict &dict) { for (auto item : dict) py::print("key: {}, value={}"_s.format(item.first, item.second)); }); @@ -68,19 +74,38 @@ TEST_SUBMODULE(pytypes, m) { auto d2 = py::dict("z"_a=3, **d1); return d2; }); - m.def("dict_contains", [](py::dict dict, py::object val) { - return dict.contains(val); - }); - m.def("dict_contains", [](py::dict dict, const char* val) { - return dict.contains(val); + m.def("dict_contains", + [](const py::dict &dict, py::object val) { return dict.contains(val); }); + m.def("dict_contains", + [](const py::dict &dict, const char *val) { return dict.contains(val); }); + + // test_tuple + m.def("tuple_no_args", []() { return py::tuple{}; }); + m.def("tuple_ssize_t", []() { return py::tuple{(py::ssize_t) 0}; }); + m.def("tuple_size_t", []() { return py::tuple{(py::size_t) 0}; }); + m.def("get_tuple", []() { return py::make_tuple(42, py::none(), "spam"); }); + +#if PY_VERSION_HEX >= 0x03030000 + // test_simple_namespace + m.def("get_simple_namespace", []() { + auto ns = py::module_::import("types").attr("SimpleNamespace")("attr"_a=42, "x"_a="foo", "wrong"_a=1); + py::delattr(ns, "wrong"); + py::setattr(ns, "right", py::int_(2)); + return ns; }); +#endif // test_str + m.def("str_from_char_ssize_t", []() { return py::str{"red", (py::ssize_t) 3}; }); + m.def("str_from_char_size_t", []() { return py::str{"blue", (py::size_t) 4}; }); m.def("str_from_string", []() { return py::str(std::string("baz")); }); m.def("str_from_bytes", []() { return py::str(py::bytes("boo", 3)); }); m.def("str_from_object", [](const py::object& obj) { return py::str(obj); }); m.def("repr_from_object", [](const py::object& obj) { return py::repr(obj); }); m.def("str_from_handle", [](py::handle h) { return py::str(h); }); + m.def("str_from_string_from_str", [](const py::str& obj) { + return py::str(static_cast(obj)); + }); m.def("str_format", []() { auto s1 = "{} + {} = {}"_s.format(1, 2, 3); @@ -89,9 +114,17 @@ TEST_SUBMODULE(pytypes, m) { }); // test_bytes + m.def("bytes_from_char_ssize_t", []() { return py::bytes{"green", (py::ssize_t) 5}; }); + m.def("bytes_from_char_size_t", []() { return py::bytes{"purple", (py::size_t) 6}; }); m.def("bytes_from_string", []() { return py::bytes(std::string("foo")); }); m.def("bytes_from_str", []() { return py::bytes(py::str("bar", 3)); }); + // test bytearray + m.def("bytearray_from_char_ssize_t", []() { return py::bytearray{"$%", (py::ssize_t) 2}; }); + m.def("bytearray_from_char_size_t", []() { return py::bytearray{"@$!", (py::size_t) 3}; }); + m.def("bytearray_from_string", []() { return py::bytearray(std::string("foo")); }); + m.def("bytearray_size", []() { return py::bytearray("foo").size(); }); + // test_capsule m.def("return_capsule_with_destructor", []() { py::print("creating capsule"); @@ -108,7 +141,7 @@ TEST_SUBMODULE(pytypes, m) { }); m.def("return_capsule_with_name_and_destructor", []() { - auto capsule = py::capsule((void *) 1234, "pointer type description", [](PyObject *ptr) { + auto capsule = py::capsule((void *) 12345, "pointer type description", [](PyObject *ptr) { if (ptr) { auto name = PyCapsule_GetName(ptr); py::print("destructing capsule ({}, '{}')"_s.format( @@ -116,19 +149,30 @@ TEST_SUBMODULE(pytypes, m) { )); } }); - void *contents = capsule; - py::print("created capsule ({}, '{}')"_s.format((size_t) contents, capsule.name())); + + capsule.set_pointer((void *) 1234); + + // Using get_pointer() + void* contents1 = static_cast(capsule); + void* contents2 = capsule.get_pointer(); + void* contents3 = capsule.get_pointer(); + + auto result1 = reinterpret_cast(contents1); + auto result2 = reinterpret_cast(contents2); + auto result3 = reinterpret_cast(contents3); + + py::print("created capsule ({}, '{}')"_s.format(result1 & result2 & result3, capsule.name())); return capsule; }); // test_accessors - m.def("accessor_api", [](py::object o) { + m.def("accessor_api", [](const py::object &o) { auto d = py::dict(); d["basic_attr"] = o.attr("basic_attr"); auto l = py::list(); - for (const auto &item : o.attr("begin_end")) { + for (auto item : o.attr("begin_end")) { l.append(item); } d["begin_end"] = l; @@ -163,7 +207,7 @@ TEST_SUBMODULE(pytypes, m) { return d; }); - m.def("tuple_accessor", [](py::tuple existing_t) { + m.def("tuple_accessor", [](const py::tuple &existing_t) { try { existing_t[0] = 1; } catch (const py::error_already_set &) { @@ -199,6 +243,7 @@ TEST_SUBMODULE(pytypes, m) { m.def("default_constructors", []() { return py::dict( "bytes"_a=py::bytes(), + "bytearray"_a=py::bytearray(), "str"_a=py::str(), "bool"_a=py::bool_(), "int"_a=py::int_(), @@ -210,9 +255,10 @@ TEST_SUBMODULE(pytypes, m) { ); }); - m.def("converting_constructors", [](py::dict d) { + m.def("converting_constructors", [](const py::dict &d) { return py::dict( "bytes"_a=py::bytes(d["bytes"]), + "bytearray"_a=py::bytearray(d["bytearray"]), "str"_a=py::str(d["str"]), "bool"_a=py::bool_(d["bool"]), "int"_a=py::int_(d["int"]), @@ -225,10 +271,11 @@ TEST_SUBMODULE(pytypes, m) { ); }); - m.def("cast_functions", [](py::dict d) { + m.def("cast_functions", [](const py::dict &d) { // When converting between Python types, obj.cast() should be the same as T(obj) return py::dict( "bytes"_a=d["bytes"].cast(), + "bytearray"_a=d["bytearray"].cast(), "str"_a=d["str"].cast(), "bool"_a=d["bool"].cast(), "int"_a=d["int"].cast(), @@ -241,7 +288,24 @@ TEST_SUBMODULE(pytypes, m) { ); }); - m.def("convert_to_pybind11_str", [](py::object o) { return py::str(o); }); + m.def("convert_to_pybind11_str", [](const py::object &o) { return py::str(o); }); + + m.def("nonconverting_constructor", + [](const std::string &type, py::object value, bool move) -> py::object { + if (type == "bytes") { + return move ? py::bytes(std::move(value)) : py::bytes(value); + } + if (type == "none") { + return move ? py::none(std::move(value)) : py::none(value); + } + if (type == "ellipsis") { + return move ? py::ellipsis(std::move(value)) : py::ellipsis(value); + } + if (type == "type") { + return move ? py::type(std::move(value)) : py::type(value); + } + throw std::runtime_error("Invalid type"); + }); m.def("get_implicit_casting", []() { py::dict d; @@ -289,7 +353,7 @@ TEST_SUBMODULE(pytypes, m) { py::print("no new line here", "end"_a=" -- "); py::print("next print"); - auto py_stderr = py::module::import("sys").attr("stderr"); + auto py_stderr = py::module_::import("sys").attr("stderr"); py::print("this goes to stderr", "file"_a=py_stderr); py::print("flush", "flush"_a=true); @@ -299,9 +363,9 @@ TEST_SUBMODULE(pytypes, m) { m.def("print_failure", []() { py::print(42, UnregisteredType()); }); - m.def("hash_function", [](py::object obj) { return py::hash(obj); }); + m.def("hash_function", [](py::object obj) { return py::hash(std::move(obj)); }); - m.def("test_number_protocol", [](py::object a, py::object b) { + m.def("test_number_protocol", [](const py::object &a, const py::object &b) { py::list l; l.append(a.equal(b)); l.append(a.not_equal(b)); @@ -321,9 +385,7 @@ TEST_SUBMODULE(pytypes, m) { return l; }); - m.def("test_list_slicing", [](py::list a) { - return a[py::slice(0, -1, 2)]; - }); + m.def("test_list_slicing", [](const py::list &a) { return a[py::slice(0, -1, 2)]; }); // See #2361 m.def("issue2361_str_implicit_copy_none", []() { @@ -335,13 +397,10 @@ TEST_SUBMODULE(pytypes, m) { return is_this_none; }); - m.def("test_memoryview_object", [](py::buffer b) { - return py::memoryview(b); - }); + m.def("test_memoryview_object", [](const py::buffer &b) { return py::memoryview(b); }); - m.def("test_memoryview_buffer_info", [](py::buffer b) { - return py::memoryview(b.request()); - }); + m.def("test_memoryview_buffer_info", + [](const py::buffer &b) { return py::memoryview(b.request()); }); m.def("test_memoryview_from_buffer", [](bool is_unsigned) { static const int16_t si16[] = { 3, 1, 4, 1, 5 }; @@ -349,9 +408,7 @@ TEST_SUBMODULE(pytypes, m) { if (is_unsigned) return py::memoryview::from_buffer( ui16, { 4 }, { sizeof(uint16_t) }); - else - return py::memoryview::from_buffer( - si16, { 5 }, { sizeof(int16_t) }); + return py::memoryview::from_buffer(si16, {5}, {sizeof(int16_t)}); }); m.def("test_memoryview_from_buffer_nativeformat", []() { @@ -380,7 +437,128 @@ TEST_SUBMODULE(pytypes, m) { m.def("test_memoryview_from_memory", []() { const char* buf = "\xff\xe1\xab\x37"; return py::memoryview::from_memory( - buf, static_cast(strlen(buf))); + buf, static_cast(strlen(buf))); }); #endif + + // test_builtin_functions + m.def("get_len", [](py::handle h) { return py::len(h); }); + +#ifdef PYBIND11_STR_LEGACY_PERMISSIVE + m.attr("PYBIND11_STR_LEGACY_PERMISSIVE") = true; +#endif + + m.def("isinstance_pybind11_bytes", + [](py::object o) { return py::isinstance(std::move(o)); }); + m.def("isinstance_pybind11_str", + [](py::object o) { return py::isinstance(std::move(o)); }); + + m.def("pass_to_pybind11_bytes", [](py::bytes b) { return py::len(std::move(b)); }); + m.def("pass_to_pybind11_str", [](py::str s) { return py::len(std::move(s)); }); + m.def("pass_to_std_string", [](const std::string &s) { return s.size(); }); + + // test_weakref + m.def("weakref_from_handle", + [](py::handle h) { return py::weakref(h); }); + m.def("weakref_from_handle_and_function", + [](py::handle h, py::function f) { return py::weakref(h, std::move(f)); }); + m.def("weakref_from_object", [](const py::object &o) { return py::weakref(o); }); + m.def("weakref_from_object_and_function", + [](py::object o, py::function f) { return py::weakref(std::move(o), std::move(f)); }); + +// See PR #3263 for background (https://github.com/pybind/pybind11/pull/3263): +// pytypes.h could be changed to enforce the "most correct" user code below, by removing +// `const` from iterator `reference` using type aliases, but that will break existing +// user code. +#if (defined(__APPLE__) && defined(__clang__)) || defined(PYPY_VERSION) +// This is "most correct" and enforced on these platforms. +# define PYBIND11_AUTO_IT auto it +#else +// This works on many platforms and is (unfortunately) reflective of existing user code. +// NOLINTNEXTLINE(bugprone-macro-parentheses) +# define PYBIND11_AUTO_IT auto &it +#endif + + m.def("tuple_iterator", []() { + auto tup = py::make_tuple(5, 7); + int tup_sum = 0; + for (PYBIND11_AUTO_IT : tup) { + tup_sum += it.cast(); + } + return tup_sum; + }); + + m.def("dict_iterator", []() { + py::dict dct; + dct[py::int_(3)] = 5; + dct[py::int_(7)] = 11; + int kv_sum = 0; + for (PYBIND11_AUTO_IT : dct) { + kv_sum += it.first.cast() * 100 + it.second.cast(); + } + return kv_sum; + }); + + m.def("passed_iterator", [](const py::iterator &py_it) { + int elem_sum = 0; + for (PYBIND11_AUTO_IT : py_it) { + elem_sum += it.cast(); + } + return elem_sum; + }); + +#undef PYBIND11_AUTO_IT + + // Tests below this line are for pybind11 IMPLEMENTATION DETAILS: + + m.def("sequence_item_get_ssize_t", [](const py::object &o) { + return py::detail::accessor_policies::sequence_item::get(o, (py::ssize_t) 1); + }); + m.def("sequence_item_set_ssize_t", [](const py::object &o) { + auto s = py::str{"peppa", 5}; + py::detail::accessor_policies::sequence_item::set(o, (py::ssize_t) 1, s); + }); + m.def("sequence_item_get_size_t", [](const py::object &o) { + return py::detail::accessor_policies::sequence_item::get(o, (py::size_t) 2); + }); + m.def("sequence_item_set_size_t", [](const py::object &o) { + auto s = py::str{"george", 6}; + py::detail::accessor_policies::sequence_item::set(o, (py::size_t) 2, s); + }); + m.def("list_item_get_ssize_t", [](const py::object &o) { + return py::detail::accessor_policies::list_item::get(o, (py::ssize_t) 3); + }); + m.def("list_item_set_ssize_t", [](const py::object &o) { + auto s = py::str{"rebecca", 7}; + py::detail::accessor_policies::list_item::set(o, (py::ssize_t) 3, s); + }); + m.def("list_item_get_size_t", [](const py::object &o) { + return py::detail::accessor_policies::list_item::get(o, (py::size_t) 4); + }); + m.def("list_item_set_size_t", [](const py::object &o) { + auto s = py::str{"richard", 7}; + py::detail::accessor_policies::list_item::set(o, (py::size_t) 4, s); + }); + m.def("tuple_item_get_ssize_t", [](const py::object &o) { + return py::detail::accessor_policies::tuple_item::get(o, (py::ssize_t) 5); + }); + m.def("tuple_item_set_ssize_t", []() { + auto s0 = py::str{"emely", 5}; + auto s1 = py::str{"edmond", 6}; + auto o = py::tuple{2}; + py::detail::accessor_policies::tuple_item::set(o, (py::ssize_t) 0, s0); + py::detail::accessor_policies::tuple_item::set(o, (py::ssize_t) 1, s1); + return o; + }); + m.def("tuple_item_get_size_t", [](const py::object &o) { + return py::detail::accessor_policies::tuple_item::get(o, (py::size_t) 6); + }); + m.def("tuple_item_set_size_t", []() { + auto s0 = py::str{"candy", 5}; + auto s1 = py::str{"cat", 3}; + auto o = py::tuple{2}; + py::detail::accessor_policies::tuple_item::set(o, (py::size_t) 1, s1); + py::detail::accessor_policies::tuple_item::set(o, (py::size_t) 0, s0); + return o; + }); } diff --git a/wrap/pybind11/tests/test_pytypes.py b/wrap/pybind11/tests/test_pytypes.py index 0618cd54c..2cd6c3f03 100644 --- a/wrap/pybind11/tests/test_pytypes.py +++ b/wrap/pybind11/tests/test_pytypes.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- from __future__ import division -import pytest + import sys -import env # noqa: F401 +import pytest -from pybind11_tests import pytypes as m +import env from pybind11_tests import debug_enabled +from pybind11_tests import pytypes as m + + +def test_bool(doc): + assert doc(m.get_bool) == "get_bool() -> bool" def test_int(doc): @@ -21,20 +26,36 @@ def test_iterable(doc): assert doc(m.get_iterable) == "get_iterable() -> Iterable" +def test_float(doc): + assert doc(m.get_float) == "get_float() -> float" + + def test_list(capture, doc): + assert m.list_no_args() == [] + assert m.list_ssize_t() == [] + assert m.list_size_t() == [] + lins = [1, 2] + m.list_insert_ssize_t(lins) + assert lins == [1, 83, 2] + m.list_insert_size_t(lins) + assert lins == [1, 83, 2, 57] + with capture: lst = m.get_list() assert lst == ["inserted-0", "overwritten", "inserted-2"] lst.append("value2") m.print_list(lst) - assert capture.unordered == """ + assert ( + capture.unordered + == """ Entry at position 0: value list item 0: inserted-0 list item 1: overwritten list item 2: inserted-2 list item 3: value2 """ + ) assert doc(m.get_list) == "get_list() -> list" assert doc(m.print_list) == "print_list(arg0: list) -> None" @@ -52,14 +73,17 @@ def test_set(capture, doc): with capture: s.add("key4") m.print_set(s) - assert capture.unordered == """ + assert ( + capture.unordered + == """ key: key1 key: key2 key: key3 key: key4 """ + ) - assert not m.set_contains(set([]), 42) + assert not m.set_contains(set(), 42) assert m.set_contains({42}, 42) assert m.set_contains({"foo"}, "foo") @@ -74,10 +98,13 @@ def test_dict(capture, doc): with capture: d["key2"] = "value2" m.print_dict(d) - assert capture.unordered == """ + assert ( + capture.unordered + == """ key: key, value=value key: key2, value=value2 """ + ) assert not m.dict_contains({}, 42) assert m.dict_contains({42: None}, 42) @@ -89,7 +116,25 @@ def test_dict(capture, doc): assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} +def test_tuple(): + assert m.tuple_no_args() == () + assert m.tuple_ssize_t() == () + assert m.tuple_size_t() == () + assert m.get_tuple() == (42, None, "spam") + + +@pytest.mark.skipif("env.PY2") +def test_simple_namespace(): + ns = m.get_simple_namespace() + assert ns.attr == 42 + assert ns.x == "foo" + assert ns.right == 2 + assert not hasattr(ns, "wrong") + + def test_str(doc): + assert m.str_from_char_ssize_t().encode().decode() == "red" + assert m.str_from_char_size_t().encode().decode() == "blue" assert m.str_from_string().encode().decode() == "baz" assert m.str_from_bytes().encode().decode() == "boo" @@ -111,18 +156,31 @@ def test_str(doc): assert s1 == s2 malformed_utf8 = b"\x80" - assert m.str_from_object(malformed_utf8) is malformed_utf8 # To be fixed; see #2380 + if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"): + assert m.str_from_object(malformed_utf8) is malformed_utf8 + elif env.PY2: + with pytest.raises(UnicodeDecodeError): + m.str_from_object(malformed_utf8) + else: + assert m.str_from_object(malformed_utf8) == "b'\\x80'" if env.PY2: - # with pytest.raises(UnicodeDecodeError): - # m.str_from_object(malformed_utf8) with pytest.raises(UnicodeDecodeError): m.str_from_handle(malformed_utf8) else: - # assert m.str_from_object(malformed_utf8) == "b'\\x80'" assert m.str_from_handle(malformed_utf8) == "b'\\x80'" + assert m.str_from_string_from_str("this is a str") == "this is a str" + ucs_surrogates_str = u"\udcc3" + if env.PY2: + assert u"\udcc3" == m.str_from_string_from_str(ucs_surrogates_str) + else: + with pytest.raises(UnicodeEncodeError): + m.str_from_string_from_str(ucs_surrogates_str) + def test_bytes(doc): + assert m.bytes_from_char_ssize_t().decode() == "green" + assert m.bytes_from_char_size_t().decode() == "purple" assert m.bytes_from_string().decode() == "foo" assert m.bytes_from_str().decode() == "bar" @@ -131,34 +189,50 @@ def test_bytes(doc): ) +def test_bytearray(doc): + assert m.bytearray_from_char_ssize_t().decode() == "$%" + assert m.bytearray_from_char_size_t().decode() == "@$!" + assert m.bytearray_from_string().decode() == "foo" + assert m.bytearray_size() == len("foo") + + def test_capsule(capture): pytest.gc_collect() with capture: a = m.return_capsule_with_destructor() del a pytest.gc_collect() - assert capture.unordered == """ + assert ( + capture.unordered + == """ creating capsule destructing capsule """ + ) with capture: a = m.return_capsule_with_destructor_2() del a pytest.gc_collect() - assert capture.unordered == """ + assert ( + capture.unordered + == """ creating capsule destructing capsule: 1234 """ + ) with capture: a = m.return_capsule_with_name_and_destructor() del a pytest.gc_collect() - assert capture.unordered == """ + assert ( + capture.unordered + == """ created capsule (1234, 'pointer type description') destructing capsule (1234, 'pointer type description') """ + ) def test_accessors(): @@ -202,7 +276,7 @@ def test_accessors(): def test_constructors(): """C++ default and converting constructors are equivalent to type calls in Python""" - types = [bytes, str, bool, int, float, tuple, list, dict, set] + types = [bytes, bytearray, str, bool, int, float, tuple, list, dict, set] expected = {t.__name__: t() for t in types} if env.PY2: # Note that bytes.__name__ == 'str' in Python 2. @@ -212,7 +286,8 @@ def test_constructors(): assert m.default_constructors() == expected data = { - bytes: b'41', # Currently no supported or working conversions. + bytes: b"41", # Currently no supported or working conversions. + bytearray: bytearray(b"41"), str: 42, bool: "Not empty", int: "42", @@ -221,14 +296,14 @@ def test_constructors(): list: range(3), dict: [("two", 2), ("one", 1), ("three", 3)], set: [4, 4, 5, 6, 6, 6], - memoryview: b'abc' + memoryview: b"abc", } inputs = {k.__name__: v for k, v in data.items()} expected = {k.__name__: k(v) for k, v in data.items()} if env.PY2: # Similar to the above. See comments above. - inputs["bytes"] = b'41' + inputs["bytes"] = b"41" inputs["str"] = 42 - expected["bytes"] = b'41' + expected["bytes"] = b"41" expected["str"] = u"42" assert m.converting_constructors(inputs) == expected @@ -245,16 +320,33 @@ def test_constructors(): assert noconv2[k] is expected[k] +def test_non_converting_constructors(): + non_converting_test_cases = [ + ("bytes", range(10)), + ("none", 42), + ("ellipsis", 42), + ("type", 42), + ] + for t, v in non_converting_test_cases: + for move in [True, False]: + with pytest.raises(TypeError) as excinfo: + m.nonconverting_constructor(t, v, move) + expected_error = "Object of type '{}' is not an instance of '{}'".format( + type(v).__name__, t + ) + assert str(excinfo.value) == expected_error + + def test_pybind11_str_raw_str(): # specifically to exercise pybind11::str::raw_str cvt = m.convert_to_pybind11_str assert cvt(u"Str") == u"Str" - assert cvt(b'Bytes') == u"Bytes" if env.PY2 else "b'Bytes'" + assert cvt(b"Bytes") == u"Bytes" if env.PY2 else "b'Bytes'" assert cvt(None) == u"None" assert cvt(False) == u"False" assert cvt(True) == u"True" assert cvt(42) == u"42" - assert cvt(2**65) == u"36893488147419103232" + assert cvt(2 ** 65) == u"36893488147419103232" assert cvt(-1.50) == u"-1.5" assert cvt(()) == u"()" assert cvt((18,)) == u"(18,)" @@ -268,30 +360,54 @@ def test_pybind11_str_raw_str(): valid_orig = u"DZ" valid_utf8 = valid_orig.encode("utf-8") valid_cvt = cvt(valid_utf8) - assert type(valid_cvt) == bytes # Probably surprising. - assert valid_cvt == b'\xc7\xb1' + if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"): + assert valid_cvt is valid_utf8 + else: + assert type(valid_cvt) is unicode if env.PY2 else str # noqa: F821 + if env.PY2: + assert valid_cvt == valid_orig + else: + assert valid_cvt == "b'\\xc7\\xb1'" - malformed_utf8 = b'\x80' - malformed_cvt = cvt(malformed_utf8) - assert type(malformed_cvt) == bytes # Probably surprising. - assert malformed_cvt == b'\x80' + malformed_utf8 = b"\x80" + if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"): + assert cvt(malformed_utf8) is malformed_utf8 + else: + if env.PY2: + with pytest.raises(UnicodeDecodeError): + cvt(malformed_utf8) + else: + malformed_cvt = cvt(malformed_utf8) + assert type(malformed_cvt) is str + assert malformed_cvt == "b'\\x80'" def test_implicit_casting(): """Tests implicit casting when assigning or appending to dicts and lists.""" z = m.get_implicit_casting() - assert z['d'] == { - 'char*_i1': 'abc', 'char*_i2': 'abc', 'char*_e': 'abc', 'char*_p': 'abc', - 'str_i1': 'str', 'str_i2': 'str1', 'str_e': 'str2', 'str_p': 'str3', - 'int_i1': 42, 'int_i2': 42, 'int_e': 43, 'int_p': 44 + assert z["d"] == { + "char*_i1": "abc", + "char*_i2": "abc", + "char*_e": "abc", + "char*_p": "abc", + "str_i1": "str", + "str_i2": "str1", + "str_e": "str2", + "str_p": "str3", + "int_i1": 42, + "int_i2": 42, + "int_e": 43, + "int_p": 44, } - assert z['l'] == [3, 6, 9, 12, 15] + assert z["l"] == [3, 6, 9, 12, 15] def test_print(capture): with capture: m.print_function() - assert capture == """ + assert ( + capture + == """ Hello, World! 1 2.0 three True -- multiple args *args-and-a-custom-separator @@ -299,14 +415,15 @@ def test_print(capture): flush py::print + str.format = this """ + ) assert capture.stderr == "this goes to stderr" with pytest.raises(RuntimeError) as excinfo: m.print_failure() - assert str(excinfo.value) == "make_tuple(): unable to convert " + ( - "argument of type 'UnregisteredType' to Python object" - if debug_enabled else - "arguments to Python object (compile in debug mode for details)" + assert str(excinfo.value) == "Unable to convert call argument " + ( + "'1' of type 'UnregisteredType' to Python object" + if debug_enabled + else "to Python object (compile in debug mode for details)" ) @@ -328,8 +445,23 @@ def test_hash(): def test_number_protocol(): for a, b in [(1, 1), (3, 5)]: - li = [a == b, a != b, a < b, a <= b, a > b, a >= b, a + b, - a - b, a * b, a / b, a | b, a & b, a ^ b, a >> b, a << b] + li = [ + a == b, + a != b, + a < b, + a <= b, + a > b, + a >= b, + a + b, + a - b, + a * b, + a / b, + a | b, + a & b, + a ^ b, + a >> b, + a << b, + ] assert m.test_number_protocol(a, b) == li @@ -343,16 +475,20 @@ def test_issue2361(): assert m.issue2361_str_implicit_copy_none() == "None" with pytest.raises(TypeError) as excinfo: assert m.issue2361_dict_implicit_copy_none() - assert "'NoneType' object is not iterable" in str(excinfo.value) + assert "NoneType" in str(excinfo.value) + assert "iterable" in str(excinfo.value) -@pytest.mark.parametrize('method, args, fmt, expected_view', [ - (m.test_memoryview_object, (b'red',), 'B', b'red'), - (m.test_memoryview_buffer_info, (b'green',), 'B', b'green'), - (m.test_memoryview_from_buffer, (False,), 'h', [3, 1, 4, 1, 5]), - (m.test_memoryview_from_buffer, (True,), 'H', [2, 7, 1, 8]), - (m.test_memoryview_from_buffer_nativeformat, (), '@i', [4, 7, 5]), -]) +@pytest.mark.parametrize( + "method, args, fmt, expected_view", + [ + (m.test_memoryview_object, (b"red",), "B", b"red"), + (m.test_memoryview_buffer_info, (b"green",), "B", b"green"), + (m.test_memoryview_from_buffer, (False,), "h", [3, 1, 4, 1, 5]), + (m.test_memoryview_from_buffer, (True,), "H", [2, 7, 1, 8]), + (m.test_memoryview_from_buffer_nativeformat, (), "@i", [4, 7, 5]), + ], +) def test_memoryview(method, args, fmt, expected_view): view = method(*args) assert isinstance(view, memoryview) @@ -361,17 +497,20 @@ def test_memoryview(method, args, fmt, expected_view): view_as_list = list(view) else: # Using max to pick non-zero byte (big-endian vs little-endian). - view_as_list = [max([ord(c) for c in s]) for s in view] + view_as_list = [max(ord(c) for c in s) for s in view] assert view_as_list == list(expected_view) @pytest.mark.xfail("env.PYPY", reason="getrefcount is not available") -@pytest.mark.parametrize('method', [ - m.test_memoryview_object, - m.test_memoryview_buffer_info, -]) +@pytest.mark.parametrize( + "method", + [ + m.test_memoryview_object, + m.test_memoryview_buffer_info, + ], +) def test_memoryview_refcount(method): - buf = b'\x0a\x0b\x0c\x0d' + buf = b"\x0a\x0b\x0c\x0d" ref_before = sys.getrefcount(buf) view = method(buf) ref_after = sys.getrefcount(buf) @@ -382,13 +521,13 @@ def test_memoryview_refcount(method): def test_memoryview_from_buffer_empty_shape(): view = m.test_memoryview_from_buffer_empty_shape() assert isinstance(view, memoryview) - assert view.format == 'B' + assert view.format == "B" if env.PY2: # Python 2 behavior is weird, but Python 3 (the future) is fine. # PyPy3 has #include +#include +#include + +#ifdef PYBIND11_HAS_OPTIONAL +#include +#endif // PYBIND11_HAS_OPTIONAL + template class NonZeroIterator { const T* ptr_; public: - NonZeroIterator(const T* ptr) : ptr_(ptr) {} + explicit NonZeroIterator(const T *ptr) : ptr_(ptr) {} const T& operator*() const { return *ptr_; } NonZeroIterator& operator++() { ++ptr_; return *this; } }; @@ -31,6 +38,40 @@ bool operator==(const NonZeroIterator>& it, const NonZeroSentine return !(*it).first || !(*it).second; } +/* Iterator where dereferencing returns prvalues instead of references. */ +template +class NonRefIterator { + const T* ptr_; +public: + explicit NonRefIterator(const T *ptr) : ptr_(ptr) {} + T operator*() const { return T(*ptr_); } + NonRefIterator& operator++() { ++ptr_; return *this; } + bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; } +}; + +class NonCopyableInt { +public: + explicit NonCopyableInt(int value) : value_(value) {} + NonCopyableInt(const NonCopyableInt &) = delete; + NonCopyableInt(NonCopyableInt &&other) noexcept : value_(other.value_) { + other.value_ = -1; // detect when an unwanted move occurs + } + NonCopyableInt &operator=(const NonCopyableInt &) = delete; + NonCopyableInt &operator=(NonCopyableInt &&other) noexcept { + value_ = other.value_; + other.value_ = -1; // detect when an unwanted move occurs + return *this; + } + int get() const { return value_; } + void set(int value) { value_ = value; } + ~NonCopyableInt() = default; +private: + int value_; +}; +using NonCopyableIntPair = std::pair; +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); + template py::list test_random_access_iterator(PythonType x) { if (x.size() < 5) @@ -76,32 +117,43 @@ TEST_SUBMODULE(sequences_and_iterators, m) { // test_sliceable class Sliceable{ public: - Sliceable(int n): size(n) {} - int start,stop,step; - int size; + explicit Sliceable(int n) : size(n) {} + int start, stop, step; + int size; }; - py::class_(m,"Sliceable") + py::class_(m, "Sliceable") .def(py::init()) - .def("__getitem__",[](const Sliceable &s, py::slice slice) { - ssize_t start, stop, step, slicelength; - if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - int istart = static_cast(start); - int istop = static_cast(stop); - int istep = static_cast(step); - return std::make_tuple(istart,istop,istep); - }) - ; + .def("__getitem__", [](const Sliceable &s, const py::slice &slice) { + py::ssize_t start = 0, stop = 0, step = 0, slicelength = 0; + if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + int istart = static_cast(start); + int istop = static_cast(stop); + int istep = static_cast(step); + return std::make_tuple(istart, istop, istep); + }); + + m.def("make_forward_slice_size_t", []() { return py::slice(0, -1, 1); }); + m.def("make_reversed_slice_object", []() { return py::slice(py::none(), py::none(), py::int_(-1)); }); +#ifdef PYBIND11_HAS_OPTIONAL + m.attr("has_optional") = true; + m.def("make_reversed_slice_size_t_optional_verbose", []() { return py::slice(std::nullopt, std::nullopt, -1); }); + // Warning: The following spelling may still compile if optional<> is not present and give wrong answers. + // Please use with caution. + m.def("make_reversed_slice_size_t_optional", []() { return py::slice({}, {}, -1); }); +#else + m.attr("has_optional") = false; +#endif // test_sequence class Sequence { public: - Sequence(size_t size) : m_size(size) { + explicit Sequence(size_t size) : m_size(size) { print_created(this, "of size", m_size); m_data = new float[size]; memset(m_data, 0, sizeof(float) * size); } - Sequence(const std::vector &value) : m_size(value.size()) { + explicit Sequence(const std::vector &value) : m_size(value.size()) { print_created(this, "of size", m_size, "from std::vector"); m_data = new float[m_size]; memcpy(m_data, &value[0], sizeof(float) * m_size); @@ -111,7 +163,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) { m_data = new float[m_size]; memcpy(m_data, s.m_data, sizeof(float)*m_size); } - Sequence(Sequence &&s) : m_size(s.m_size), m_data(s.m_data) { + Sequence(Sequence &&s) noexcept : m_size(s.m_size), m_data(s.m_data) { print_move_created(this); s.m_size = 0; s.m_data = nullptr; @@ -130,7 +182,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) { return *this; } - Sequence &operator=(Sequence &&s) { + Sequence &operator=(Sequence &&s) noexcept { if (&s != this) { delete[] m_data; m_size = s.m_size; @@ -179,43 +231,54 @@ TEST_SUBMODULE(sequences_and_iterators, m) { }; py::class_(m, "Sequence") .def(py::init()) - .def(py::init&>()) + .def(py::init &>()) /// Bare bones interface - .def("__getitem__", [](const Sequence &s, size_t i) { - if (i >= s.size()) throw py::index_error(); - return s[i]; - }) - .def("__setitem__", [](Sequence &s, size_t i, float v) { - if (i >= s.size()) throw py::index_error(); - s[i] = v; - }) + .def("__getitem__", + [](const Sequence &s, size_t i) { + if (i >= s.size()) + throw py::index_error(); + return s[i]; + }) + .def("__setitem__", + [](Sequence &s, size_t i, float v) { + if (i >= s.size()) + throw py::index_error(); + s[i] = v; + }) .def("__len__", &Sequence::size) /// Optional sequence protocol operations - .def("__iter__", [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); }, - py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) + .def( + "__iter__", + [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); }, + py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); }) .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); }) /// Slicing protocol (optional) - .def("__getitem__", [](const Sequence &s, py::slice slice) -> Sequence* { - size_t start, stop, step, slicelength; - if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - auto *seq = new Sequence(slicelength); - for (size_t i = 0; i < slicelength; ++i) { - (*seq)[i] = s[start]; start += step; - } - return seq; - }) - .def("__setitem__", [](Sequence &s, py::slice slice, const Sequence &value) { - size_t start, stop, step, slicelength; - if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - if (slicelength != value.size()) - throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); - for (size_t i = 0; i < slicelength; ++i) { - s[start] = value[i]; start += step; - } - }) + .def("__getitem__", + [](const Sequence &s, const py::slice &slice) -> Sequence * { + size_t start = 0, stop = 0, step = 0, slicelength = 0; + if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + auto *seq = new Sequence(slicelength); + for (size_t i = 0; i < slicelength; ++i) { + (*seq)[i] = s[start]; + start += step; + } + return seq; + }) + .def("__setitem__", + [](Sequence &s, const py::slice &slice, const Sequence &value) { + size_t start = 0, stop = 0, step = 0, slicelength = 0; + if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) + throw py::error_already_set(); + if (slicelength != value.size()) + throw std::runtime_error( + "Left and right hand size of slice assignment have different sizes!"); + for (size_t i = 0; i < slicelength; ++i) { + s[start] = value[i]; + start += step; + } + }) /// Comparisons .def(py::self == py::self) .def(py::self != py::self) @@ -228,11 +291,11 @@ TEST_SUBMODULE(sequences_and_iterators, m) { class StringMap { public: StringMap() = default; - StringMap(std::unordered_map init) + explicit StringMap(std::unordered_map init) : map(std::move(init)) {} - void set(std::string key, std::string val) { map[key] = val; } - std::string get(std::string key) const { return map.at(key); } + void set(const std::string &key, std::string val) { map[key] = std::move(val); } + std::string get(const std::string &key) const { return map.at(key); } size_t size() const { return map.size(); } private: std::unordered_map map; @@ -243,38 +306,117 @@ TEST_SUBMODULE(sequences_and_iterators, m) { py::class_(m, "StringMap") .def(py::init<>()) .def(py::init>()) - .def("__getitem__", [](const StringMap &map, std::string key) { - try { return map.get(key); } - catch (const std::out_of_range&) { - throw py::key_error("key '" + key + "' does not exist"); - } - }) + .def("__getitem__", + [](const StringMap &map, const std::string &key) { + try { + return map.get(key); + } catch (const std::out_of_range &) { + throw py::key_error("key '" + key + "' does not exist"); + } + }) .def("__setitem__", &StringMap::set) .def("__len__", &StringMap::size) - .def("__iter__", [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); }, - py::keep_alive<0, 1>()) - .def("items", [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); }, - py::keep_alive<0, 1>()) - ; + .def( + "__iter__", + [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); }, + py::keep_alive<0, 1>()) + .def( + "items", + [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); }, + py::keep_alive<0, 1>()) + .def( + "values", + [](const StringMap &map) { return py::make_value_iterator(map.begin(), map.end()); }, + py::keep_alive<0, 1>()); // test_generalized_iterators class IntPairs { public: - IntPairs(std::vector> data) : data_(std::move(data)) {} + explicit IntPairs(std::vector> data) : data_(std::move(data)) {} const std::pair* begin() const { return data_.data(); } + // .end() only required for py::make_iterator(self) overload + const std::pair* end() const { return data_.data() + data_.size(); } private: std::vector> data_; }; py::class_(m, "IntPairs") .def(py::init>>()) .def("nonzero", [](const IntPairs& s) { - return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) .def("nonzero_keys", [](const IntPairs& s) { return py::make_key_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); }, py::keep_alive<0, 1>()) + .def("nonzero_values", [](const IntPairs& s) { + return py::make_value_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); + }, py::keep_alive<0, 1>()) + + // test iterator that returns values instead of references + .def("nonref", [](const IntPairs& s) { + return py::make_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_keys", [](const IntPairs& s) { + return py::make_key_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + .def("nonref_values", [](const IntPairs& s) { + return py::make_value_iterator(NonRefIterator>(s.begin()), + NonRefIterator>(s.end())); + }, py::keep_alive<0, 1>()) + + // test single-argument make_iterator + .def("simple_iterator", [](IntPairs& self) { + return py::make_iterator(self); + }, py::keep_alive<0, 1>()) + .def("simple_keys", [](IntPairs& self) { + return py::make_key_iterator(self); + }, py::keep_alive<0, 1>()) + .def("simple_values", [](IntPairs& self) { + return py::make_value_iterator(self); + }, py::keep_alive<0, 1>()) + + // Test iterator with an Extra (doesn't do anything useful, so not used + // at runtime, but tests need to be able to compile with the correct + // overload. See PR #3293. + .def("_make_iterator_extras", [](IntPairs& self) { + return py::make_iterator(self, py::call_guard()); + }, py::keep_alive<0, 1>()) + .def("_make_key_extras", [](IntPairs& self) { + return py::make_key_iterator(self, py::call_guard()); + }, py::keep_alive<0, 1>()) + .def("_make_value_extras", [](IntPairs& self) { + return py::make_value_iterator(self, py::call_guard()); + }, py::keep_alive<0, 1>()) ; + // test_iterater_referencing + py::class_(m, "NonCopyableInt") + .def(py::init()) + .def("set", &NonCopyableInt::set) + .def("__int__", &NonCopyableInt::get) + ; + py::class_>(m, "VectorNonCopyableInt") + .def(py::init<>()) + .def("append", [](std::vector &vec, int value) { + vec.emplace_back(value); + }) + .def("__iter__", [](std::vector &vec) { + return py::make_iterator(vec.begin(), vec.end()); + }) + ; + py::class_>(m, "VectorNonCopyableIntPair") + .def(py::init<>()) + .def("append", [](std::vector &vec, const std::pair &value) { + vec.emplace_back(NonCopyableInt(value.first), NonCopyableInt(value.second)); + }) + .def("keys", [](std::vector &vec) { + return py::make_key_iterator(vec.begin(), vec.end()); + }) + .def("values", [](std::vector &vec) { + return py::make_value_iterator(vec.begin(), vec.end()); + }) + ; #if 0 // Obsolete: special data structure for exposing custom iterator types to python @@ -304,7 +446,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) { #endif // test_python_iterator_in_cpp - m.def("object_to_list", [](py::object o) { + m.def("object_to_list", [](const py::object &o) { auto l = py::list(); for (auto item : o) { l.append(item); @@ -322,22 +464,22 @@ TEST_SUBMODULE(sequences_and_iterators, m) { }); // test_sequence_length: check that Python sequences can be converted to py::sequence. - m.def("sequence_length", [](py::sequence seq) { return seq.size(); }); + m.def("sequence_length", [](const py::sequence &seq) { return seq.size(); }); // Make sure that py::iterator works with std algorithms - m.def("count_none", [](py::object o) { + m.def("count_none", [](const py::object &o) { return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); }); - m.def("find_none", [](py::object o) { + m.def("find_none", [](const py::object &o) { auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); return it->is_none(); }); - m.def("count_nonzeros", [](py::dict d) { - return std::count_if(d.begin(), d.end(), [](std::pair p) { - return p.second.cast() != 0; - }); + m.def("count_nonzeros", [](const py::dict &d) { + return std::count_if(d.begin(), d.end(), [](std::pair p) { + return p.second.cast() != 0; + }); }); m.def("tuple_iterator", &test_random_access_iterator); diff --git a/wrap/pybind11/tests/test_sequences_and_iterators.py b/wrap/pybind11/tests/test_sequences_and_iterators.py index 8f6c0c4bb..6985918a1 100644 --- a/wrap/pybind11/tests/test_sequences_and_iterators.py +++ b/wrap/pybind11/tests/test_sequences_and_iterators.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- import pytest -from pybind11_tests import sequences_and_iterators as m + from pybind11_tests import ConstructorStats +from pybind11_tests import sequences_and_iterators as m def isclose(a, b, rel_tol=1e-05, abs_tol=0.0): @@ -10,7 +11,20 @@ def isclose(a, b, rel_tol=1e-05, abs_tol=0.0): def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0): - return all(isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list)) + return all( + isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list) + ) + + +def test_slice_constructors(): + assert m.make_forward_slice_size_t() == slice(0, -1, 1) + assert m.make_reversed_slice_object() == slice(None, None, -1) + + +@pytest.mark.skipif(not m.has_optional, reason="no ") +def test_slice_constructors_explicit_optional(): + assert m.make_reversed_slice_size_t_optional() == slice(None, None, -1) + assert m.make_reversed_slice_size_t_optional_verbose() == slice(None, None, -1) def test_generalized_iterators(): @@ -22,6 +36,10 @@ def test_generalized_iterators(): assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1] assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == [] + assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_values()) == [2, 4] + assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_values()) == [2] + assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_values()) == [] + # __next__ must continue to raise StopIteration it = m.IntPairs([(0, 0)]).nonzero() for _ in range(3): @@ -34,6 +52,47 @@ def test_generalized_iterators(): next(it) +def test_nonref_iterators(): + pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)]) + assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)] + assert list(pairs.nonref_keys()) == [1, 3, 0] + assert list(pairs.nonref_values()) == [2, 4, 5] + + +def test_generalized_iterators_simple(): + assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [ + (1, 2), + (3, 4), + (0, 5), + ] + assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_keys()) == [1, 3, 0] + assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_values()) == [2, 4, 5] + + +def test_iterator_referencing(): + """Test that iterators reference rather than copy their referents.""" + vec = m.VectorNonCopyableInt() + vec.append(3) + vec.append(5) + assert [int(x) for x in vec] == [3, 5] + # Increment everything to make sure the referents can be mutated + for x in vec: + x.set(int(x) + 1) + assert [int(x) for x in vec] == [4, 6] + + vec = m.VectorNonCopyableIntPair() + vec.append([3, 4]) + vec.append([5, 7]) + assert [int(x) for x in vec.keys()] == [3, 5] + assert [int(x) for x in vec.values()] == [4, 7] + for x in vec.keys(): + x.set(int(x) + 1) + for x in vec.values(): + x.set(int(x) + 10) + assert [int(x) for x in vec.keys()] == [4, 6] + assert [int(x) for x in vec.values()] == [14, 17] + + def test_sliceable(): sliceable = m.Sliceable(100) assert sliceable[::] == (0, 100, 1) @@ -51,7 +110,7 @@ def test_sequence(): cstats = ConstructorStats.get(m.Sequence) s = m.Sequence(5) - assert cstats.values() == ['of size', '5'] + assert cstats.values() == ["of size", "5"] assert "Sequence" in repr(s) assert len(s) == 5 @@ -62,16 +121,16 @@ def test_sequence(): assert isclose(s[0], 12.34) and isclose(s[3], 56.78) rev = reversed(s) - assert cstats.values() == ['of size', '5'] + assert cstats.values() == ["of size", "5"] rev2 = s[::-1] - assert cstats.values() == ['of size', '5'] + assert cstats.values() == ["of size", "5"] it = iter(m.Sequence(0)) for _ in range(3): # __next__ must continue to raise StopIteration with pytest.raises(StopIteration): next(it) - assert cstats.values() == ['of size', '0'] + assert cstats.values() == ["of size", "0"] expected = [0, 56.78, 0, 0, 12.34] assert allclose(rev, expected) @@ -79,7 +138,7 @@ def test_sequence(): assert rev == rev2 rev[0::2] = m.Sequence([2.0, 2.0, 2.0]) - assert cstats.values() == ['of size', '3', 'from std::vector'] + assert cstats.values() == ["of size", "3", "from std::vector"] assert allclose(rev, [2, 56.78, 2, 0, 2]) @@ -102,11 +161,12 @@ def test_sequence(): def test_sequence_length(): - """#2076: Exception raised by len(arg) should be propagated """ + """#2076: Exception raised by len(arg) should be propagated""" + class BadLen(RuntimeError): pass - class SequenceLike(): + class SequenceLike: def __getitem__(self, i): return None @@ -121,21 +181,22 @@ def test_sequence_length(): def test_map_iterator(): - sm = m.StringMap({'hi': 'bye', 'black': 'white'}) - assert sm['hi'] == 'bye' + sm = m.StringMap({"hi": "bye", "black": "white"}) + assert sm["hi"] == "bye" assert len(sm) == 2 - assert sm['black'] == 'white' + assert sm["black"] == "white" with pytest.raises(KeyError): - assert sm['orange'] - sm['orange'] = 'banana' - assert sm['orange'] == 'banana' + assert sm["orange"] + sm["orange"] = "banana" + assert sm["orange"] == "banana" - expected = {'hi': 'bye', 'black': 'white', 'orange': 'banana'} + expected = {"hi": "bye", "black": "white", "orange": "banana"} for k in sm: assert sm[k] == expected[k] for k, v in sm.items(): assert v == expected[k] + assert list(sm.values()) == [expected[k] for k in sm] it = iter(m.StringMap({})) for _ in range(3): # __next__ must continue to raise StopIteration @@ -179,11 +240,12 @@ def test_iterator_passthrough(): """#181: iterator passthrough did not compile""" from pybind11_tests.sequences_and_iterators import iterator_passthrough - assert list(iterator_passthrough(iter([3, 5, 7, 9, 11, 13, 15]))) == [3, 5, 7, 9, 11, 13, 15] + values = [3, 5, 7, 9, 11, 13, 15] + assert list(iterator_passthrough(iter(values))) == values def test_iterator_rvp(): - """#388: Can't make iterators via make_iterator() with different r/v policies """ + """#388: Can't make iterators via make_iterator() with different r/v policies""" import pybind11_tests.sequences_and_iterators as m assert list(m.make_iterator_1()) == [1, 2, 3] diff --git a/wrap/pybind11/tests/test_smart_ptr.cpp b/wrap/pybind11/tests/test_smart_ptr.cpp index 60c2e692e..94f04330a 100644 --- a/wrap/pybind11/tests/test_smart_ptr.cpp +++ b/wrap/pybind11/tests/test_smart_ptr.cpp @@ -8,30 +8,14 @@ BSD-style license that can be found in the LICENSE file. */ -#if defined(_MSC_VER) && _MSC_VER < 1910 -# pragma warning(disable: 4702) // unreachable code in system header +#if defined(_MSC_VER) && _MSC_VER < 1910 // VS 2015's MSVC +# pragma warning(disable: 4702) // unreachable code in system header (xatomic.h(382)) #endif #include "pybind11_tests.h" #include "object.h" -// Make pybind aware of the ref-counted wrapper type (s): - -// ref is a wrapper for 'Object' which uses intrusive reference counting -// It is always possible to construct a ref from an Object* pointer without -// possible inconsistencies, hence the 'true' argument at the end. -PYBIND11_DECLARE_HOLDER_TYPE(T, ref, true); -// Make pybind11 aware of the non-standard getter member function -namespace pybind11 { namespace detail { - template - struct holder_helper> { - static const T *get(const ref &p) { return p.get_ptr(); } - }; -} // namespace detail -} // namespace pybind11 - -// The following is not required anymore for std::shared_ptr, but it should compile without error: -PYBIND11_DECLARE_HOLDER_TYPE(T, std::shared_ptr); +namespace { // This is just a wrapper around unique_ptr, but with extra fields to deliberately bloat up the // holder size to trigger the non-simple-layout internal instance layout for single inheritance with @@ -40,21 +24,19 @@ template class huge_unique_ptr { std::unique_ptr ptr; uint64_t padding[10]; public: - huge_unique_ptr(T *p) : ptr(p) {}; + explicit huge_unique_ptr(T *p) : ptr(p) {} T *get() { return ptr.get(); } }; -PYBIND11_DECLARE_HOLDER_TYPE(T, huge_unique_ptr); // Simple custom holder that works like unique_ptr template class custom_unique_ptr { std::unique_ptr impl; public: - custom_unique_ptr(T* p) : impl(p) { } + explicit custom_unique_ptr(T *p) : impl(p) {} T* get() const { return impl.get(); } T* release_ptr() { return impl.release(); } }; -PYBIND11_DECLARE_HOLDER_TYPE(T, custom_unique_ptr); // Simple custom holder that works like shared_ptr and has operator& overload // To obtain address of an instance of this holder pybind should use std::addressof @@ -64,11 +46,10 @@ class shared_ptr_with_addressof_operator { std::shared_ptr impl; public: shared_ptr_with_addressof_operator( ) = default; - shared_ptr_with_addressof_operator(T* p) : impl(p) { } + explicit shared_ptr_with_addressof_operator(T *p) : impl(p) {} T* get() const { return impl.get(); } T** operator&() { throw std::logic_error("Call of overloaded operator& is not expected"); } }; -PYBIND11_DECLARE_HOLDER_TYPE(T, shared_ptr_with_addressof_operator); // Simple custom holder that works like unique_ptr and has operator& overload // To obtain address of an instance of this holder pybind should use std::addressof @@ -78,15 +59,226 @@ class unique_ptr_with_addressof_operator { std::unique_ptr impl; public: unique_ptr_with_addressof_operator() = default; - unique_ptr_with_addressof_operator(T* p) : impl(p) { } + explicit unique_ptr_with_addressof_operator(T *p) : impl(p) {} T* get() const { return impl.get(); } T* release_ptr() { return impl.release(); } T** operator&() { throw std::logic_error("Call of overloaded operator& is not expected"); } }; + +// Custom object with builtin reference counting (see 'object.h' for the implementation) +class MyObject1 : public Object { +public: + explicit MyObject1(int value) : value(value) { print_created(this, toString()); } + std::string toString() const override { return "MyObject1[" + std::to_string(value) + "]"; } +protected: + ~MyObject1() override { print_destroyed(this); } +private: + int value; +}; + +// Object managed by a std::shared_ptr<> +class MyObject2 { +public: + MyObject2(const MyObject2 &) = default; + explicit MyObject2(int value) : value(value) { print_created(this, toString()); } + std::string toString() const { return "MyObject2[" + std::to_string(value) + "]"; } + virtual ~MyObject2() { print_destroyed(this); } +private: + int value; +}; + +// Object managed by a std::shared_ptr<>, additionally derives from std::enable_shared_from_this<> +class MyObject3 : public std::enable_shared_from_this { +public: + MyObject3(const MyObject3 &) = default; + explicit MyObject3(int value) : value(value) { print_created(this, toString()); } + std::string toString() const { return "MyObject3[" + std::to_string(value) + "]"; } + virtual ~MyObject3() { print_destroyed(this); } +private: + int value; +}; + +// test_unique_nodelete +// Object with a private destructor +class MyObject4; +std::unordered_set myobject4_instances; +class MyObject4 { +public: + explicit MyObject4(int value) : value{value} { + print_created(this); + myobject4_instances.insert(this); + } + int value; + + static void cleanupAllInstances() { + auto tmp = std::move(myobject4_instances); + myobject4_instances.clear(); + for (auto o : tmp) + delete o; + } +private: + ~MyObject4() { + myobject4_instances.erase(this); + print_destroyed(this); + } +}; + +// test_unique_deleter +// Object with std::unique_ptr where D is not matching the base class +// Object with a protected destructor +class MyObject4a; +std::unordered_set myobject4a_instances; +class MyObject4a { +public: + explicit MyObject4a(int i) { + value = i; + print_created(this); + myobject4a_instances.insert(this); + }; + int value; + + static void cleanupAllInstances() { + auto tmp = std::move(myobject4a_instances); + myobject4a_instances.clear(); + for (auto o : tmp) + delete o; + } +protected: + virtual ~MyObject4a() { + myobject4a_instances.erase(this); + print_destroyed(this); + } +}; + +// Object derived but with public destructor and no Deleter in default holder +class MyObject4b : public MyObject4a { +public: + explicit MyObject4b(int i) : MyObject4a(i) { print_created(this); } + ~MyObject4b() override { print_destroyed(this); } +}; + +// test_large_holder +class MyObject5 { // managed by huge_unique_ptr +public: + explicit MyObject5(int value) : value{value} { print_created(this); } + ~MyObject5() { print_destroyed(this); } + int value; +}; + +// test_shared_ptr_and_references +struct SharedPtrRef { + struct A { + A() { print_created(this); } + A(const A &) { print_copy_created(this); } + A(A &&) noexcept { print_move_created(this); } + ~A() { print_destroyed(this); } + }; + + A value = {}; + std::shared_ptr shared = std::make_shared(); +}; + +// test_shared_ptr_from_this_and_references +struct SharedFromThisRef { + struct B : std::enable_shared_from_this { + B() { print_created(this); } + // NOLINTNEXTLINE(bugprone-copy-constructor-init) + B(const B &) : std::enable_shared_from_this() { print_copy_created(this); } + B(B &&) noexcept : std::enable_shared_from_this() { print_move_created(this); } + ~B() { print_destroyed(this); } + }; + + B value = {}; + std::shared_ptr shared = std::make_shared(); +}; + +// Issue #865: shared_from_this doesn't work with virtual inheritance +struct SharedFromThisVBase : std::enable_shared_from_this { + SharedFromThisVBase() = default; + SharedFromThisVBase(const SharedFromThisVBase &) = default; + virtual ~SharedFromThisVBase() = default; +}; +struct SharedFromThisVirt : virtual SharedFromThisVBase {}; + +// test_move_only_holder +struct C { + C() { print_created(this); } + ~C() { print_destroyed(this); } +}; + +// test_holder_with_addressof_operator +struct TypeForHolderWithAddressOf { + TypeForHolderWithAddressOf() { print_created(this); } + TypeForHolderWithAddressOf(const TypeForHolderWithAddressOf &) { print_copy_created(this); } + TypeForHolderWithAddressOf(TypeForHolderWithAddressOf &&) noexcept { + print_move_created(this); + } + ~TypeForHolderWithAddressOf() { print_destroyed(this); } + std::string toString() const { + return "TypeForHolderWithAddressOf[" + std::to_string(value) + "]"; + } + int value = 42; +}; + +// test_move_only_holder_with_addressof_operator +struct TypeForMoveOnlyHolderWithAddressOf { + explicit TypeForMoveOnlyHolderWithAddressOf(int value) : value{value} { print_created(this); } + ~TypeForMoveOnlyHolderWithAddressOf() { print_destroyed(this); } + std::string toString() const { + return "MoveOnlyHolderWithAddressOf[" + std::to_string(value) + "]"; + } + int value; +}; + +// test_smart_ptr_from_default +struct HeldByDefaultHolder { }; + +// test_shared_ptr_gc +// #187: issue involving std::shared_ptr<> return value policy & garbage collection +struct ElementBase { + virtual ~ElementBase() = default; /* Force creation of virtual table */ + ElementBase() = default; + ElementBase(const ElementBase&) = delete; +}; + +struct ElementA : ElementBase { + explicit ElementA(int v) : v(v) {} + int value() const { return v; } + int v; +}; + +struct ElementList { + void add(const std::shared_ptr &e) { l.push_back(e); } + std::vector> l; +}; + +} // namespace + +// ref is a wrapper for 'Object' which uses intrusive reference counting +// It is always possible to construct a ref from an Object* pointer without +// possible inconsistencies, hence the 'true' argument at the end. +// Make pybind11 aware of the non-standard getter member function +namespace pybind11 { namespace detail { + template + struct holder_helper> { + static const T *get(const ref &p) { return p.get_ptr(); } + }; +} // namespace detail +} // namespace pybind11 + +// Make pybind aware of the ref-counted wrapper type (s): +PYBIND11_DECLARE_HOLDER_TYPE(T, ref, true); +// The following is not required anymore for std::shared_ptr, but it should compile without error: +PYBIND11_DECLARE_HOLDER_TYPE(T, std::shared_ptr); +PYBIND11_DECLARE_HOLDER_TYPE(T, huge_unique_ptr); +PYBIND11_DECLARE_HOLDER_TYPE(T, custom_unique_ptr); +PYBIND11_DECLARE_HOLDER_TYPE(T, shared_ptr_with_addressof_operator); PYBIND11_DECLARE_HOLDER_TYPE(T, unique_ptr_with_addressof_operator); - TEST_SUBMODULE(smart_ptr, m) { + // Please do not interleave `struct` and `class` definitions with bindings code, + // but implement `struct`s and `class`es in the anonymous namespace above. + // This helps keeping the smart_holder branch in sync with master. // test_smart_ptr @@ -94,24 +286,14 @@ TEST_SUBMODULE(smart_ptr, m) { py::class_> obj(m, "Object"); obj.def("getRefCount", &Object::getRefCount); - // Custom object with builtin reference counting (see 'object.h' for the implementation) - class MyObject1 : public Object { - public: - MyObject1(int value) : value(value) { print_created(this, toString()); } - std::string toString() const override { return "MyObject1[" + std::to_string(value) + "]"; } - protected: - ~MyObject1() override { print_destroyed(this); } - private: - int value; - }; py::class_>(m, "MyObject1", obj) .def(py::init()); py::implicitly_convertible(); m.def("make_object_1", []() -> Object * { return new MyObject1(1); }); - m.def("make_object_2", []() -> ref { return new MyObject1(2); }); + m.def("make_object_2", []() -> ref { return ref(new MyObject1(2)); }); m.def("make_myobject1_1", []() -> MyObject1 * { return new MyObject1(4); }); - m.def("make_myobject1_2", []() -> ref { return new MyObject1(5); }); + m.def("make_myobject1_2", []() -> ref { return ref(new MyObject1(5)); }); m.def("print_object_1", [](const Object *obj) { py::print(obj->toString()); }); m.def("print_object_2", [](ref obj) { py::print(obj->toString()); }); m.def("print_object_3", [](const ref &obj) { py::print(obj->toString()); }); @@ -124,48 +306,29 @@ TEST_SUBMODULE(smart_ptr, m) { // Expose constructor stats for the ref type m.def("cstats_ref", &ConstructorStats::get); - - // Object managed by a std::shared_ptr<> - class MyObject2 { - public: - MyObject2(const MyObject2 &) = default; - MyObject2(int value) : value(value) { print_created(this, toString()); } - std::string toString() const { return "MyObject2[" + std::to_string(value) + "]"; } - virtual ~MyObject2() { print_destroyed(this); } - private: - int value; - }; py::class_>(m, "MyObject2") .def(py::init()); m.def("make_myobject2_1", []() { return new MyObject2(6); }); m.def("make_myobject2_2", []() { return std::make_shared(7); }); m.def("print_myobject2_1", [](const MyObject2 *obj) { py::print(obj->toString()); }); + // NOLINTNEXTLINE(performance-unnecessary-value-param) m.def("print_myobject2_2", [](std::shared_ptr obj) { py::print(obj->toString()); }); m.def("print_myobject2_3", [](const std::shared_ptr &obj) { py::print(obj->toString()); }); m.def("print_myobject2_4", [](const std::shared_ptr *obj) { py::print((*obj)->toString()); }); - // Object managed by a std::shared_ptr<>, additionally derives from std::enable_shared_from_this<> - class MyObject3 : public std::enable_shared_from_this { - public: - MyObject3(const MyObject3 &) = default; - MyObject3(int value) : value(value) { print_created(this, toString()); } - std::string toString() const { return "MyObject3[" + std::to_string(value) + "]"; } - virtual ~MyObject3() { print_destroyed(this); } - private: - int value; - }; py::class_>(m, "MyObject3") .def(py::init()); m.def("make_myobject3_1", []() { return new MyObject3(8); }); m.def("make_myobject3_2", []() { return std::make_shared(9); }); m.def("print_myobject3_1", [](const MyObject3 *obj) { py::print(obj->toString()); }); + // NOLINTNEXTLINE(performance-unnecessary-value-param) m.def("print_myobject3_2", [](std::shared_ptr obj) { py::print(obj->toString()); }); m.def("print_myobject3_3", [](const std::shared_ptr &obj) { py::print(obj->toString()); }); m.def("print_myobject3_4", [](const std::shared_ptr *obj) { py::print((*obj)->toString()); }); // test_smart_ptr_refcounting m.def("test_object1_refcounting", []() { - ref o = new MyObject1(0); + auto o = ref(new MyObject1(0)); bool good = o->getRefCount() == 1; py::object o2 = py::cast(o, py::return_value_policy::reference); // always request (partial) ownership for objects with intrusive @@ -175,155 +338,88 @@ TEST_SUBMODULE(smart_ptr, m) { }); // test_unique_nodelete - // Object with a private destructor - class MyObject4 { - public: - MyObject4(int value) : value{value} { print_created(this); } - int value; - private: - ~MyObject4() { print_destroyed(this); } - }; py::class_>(m, "MyObject4") .def(py::init()) - .def_readwrite("value", &MyObject4::value); + .def_readwrite("value", &MyObject4::value) + .def_static("cleanup_all_instances", &MyObject4::cleanupAllInstances); // test_unique_deleter - // Object with std::unique_ptr where D is not matching the base class - // Object with a protected destructor - class MyObject4a { - public: - MyObject4a(int i) { - value = i; - print_created(this); - }; - int value; - protected: - virtual ~MyObject4a() { print_destroyed(this); } - }; py::class_>(m, "MyObject4a") .def(py::init()) - .def_readwrite("value", &MyObject4a::value); + .def_readwrite("value", &MyObject4a::value) + .def_static("cleanup_all_instances", &MyObject4a::cleanupAllInstances); - // Object derived but with public destructor and no Deleter in default holder - class MyObject4b : public MyObject4a { - public: - MyObject4b(int i) : MyObject4a(i) { print_created(this); } - ~MyObject4b() override { print_destroyed(this); } - }; - py::class_(m, "MyObject4b") + py::class_>(m, "MyObject4b") .def(py::init()); // test_large_holder - class MyObject5 { // managed by huge_unique_ptr - public: - MyObject5(int value) : value{value} { print_created(this); } - ~MyObject5() { print_destroyed(this); } - int value; - }; py::class_>(m, "MyObject5") .def(py::init()) .def_readwrite("value", &MyObject5::value); // test_shared_ptr_and_references - struct SharedPtrRef { - struct A { - A() { print_created(this); } - A(const A &) { print_copy_created(this); } - A(A &&) { print_move_created(this); } - ~A() { print_destroyed(this); } - }; - - A value = {}; - std::shared_ptr shared = std::make_shared(); - }; using A = SharedPtrRef::A; py::class_>(m, "A"); - py::class_(m, "SharedPtrRef") + py::class_>(m, "SharedPtrRef") .def(py::init<>()) .def_readonly("ref", &SharedPtrRef::value) - .def_property_readonly("copy", [](const SharedPtrRef &s) { return s.value; }, - py::return_value_policy::copy) + .def_property_readonly( + "copy", [](const SharedPtrRef &s) { return s.value; }, py::return_value_policy::copy) .def_readonly("holder_ref", &SharedPtrRef::shared) - .def_property_readonly("holder_copy", [](const SharedPtrRef &s) { return s.shared; }, - py::return_value_policy::copy) + .def_property_readonly( + "holder_copy", + [](const SharedPtrRef &s) { return s.shared; }, + py::return_value_policy::copy) .def("set_ref", [](SharedPtrRef &, const A &) { return true; }) + // NOLINTNEXTLINE(performance-unnecessary-value-param) .def("set_holder", [](SharedPtrRef &, std::shared_ptr) { return true; }); // test_shared_ptr_from_this_and_references - struct SharedFromThisRef { - struct B : std::enable_shared_from_this { - B() { print_created(this); } - B(const B &) : std::enable_shared_from_this() { print_copy_created(this); } - B(B &&) : std::enable_shared_from_this() { print_move_created(this); } - ~B() { print_destroyed(this); } - }; - - B value = {}; - std::shared_ptr shared = std::make_shared(); - }; using B = SharedFromThisRef::B; py::class_>(m, "B"); - py::class_(m, "SharedFromThisRef") + py::class_>(m, "SharedFromThisRef") .def(py::init<>()) .def_readonly("bad_wp", &SharedFromThisRef::value) - .def_property_readonly("ref", [](const SharedFromThisRef &s) -> const B & { return *s.shared; }) - .def_property_readonly("copy", [](const SharedFromThisRef &s) { return s.value; }, - py::return_value_policy::copy) + .def_property_readonly("ref", + [](const SharedFromThisRef &s) -> const B & { return *s.shared; }) + .def_property_readonly( + "copy", + [](const SharedFromThisRef &s) { return s.value; }, + py::return_value_policy::copy) .def_readonly("holder_ref", &SharedFromThisRef::shared) - .def_property_readonly("holder_copy", [](const SharedFromThisRef &s) { return s.shared; }, - py::return_value_policy::copy) + .def_property_readonly( + "holder_copy", + [](const SharedFromThisRef &s) { return s.shared; }, + py::return_value_policy::copy) .def("set_ref", [](SharedFromThisRef &, const B &) { return true; }) + // NOLINTNEXTLINE(performance-unnecessary-value-param) .def("set_holder", [](SharedFromThisRef &, std::shared_ptr) { return true; }); // Issue #865: shared_from_this doesn't work with virtual inheritance - struct SharedFromThisVBase : std::enable_shared_from_this { - SharedFromThisVBase() = default; - SharedFromThisVBase(const SharedFromThisVBase &) = default; - virtual ~SharedFromThisVBase() = default; - }; - struct SharedFromThisVirt : virtual SharedFromThisVBase {}; static std::shared_ptr sft(new SharedFromThisVirt()); py::class_>(m, "SharedFromThisVirt") .def_static("get", []() { return sft.get(); }); // test_move_only_holder - struct C { - C() { print_created(this); } - ~C() { print_destroyed(this); } - }; py::class_>(m, "TypeWithMoveOnlyHolder") .def_static("make", []() { return custom_unique_ptr(new C); }) .def_static("make_as_object", []() { return py::cast(custom_unique_ptr(new C)); }); // test_holder_with_addressof_operator - struct TypeForHolderWithAddressOf { - TypeForHolderWithAddressOf() { print_created(this); } - TypeForHolderWithAddressOf(const TypeForHolderWithAddressOf &) { print_copy_created(this); } - TypeForHolderWithAddressOf(TypeForHolderWithAddressOf &&) { print_move_created(this); } - ~TypeForHolderWithAddressOf() { print_destroyed(this); } - std::string toString() const { - return "TypeForHolderWithAddressOf[" + std::to_string(value) + "]"; - } - int value = 42; - }; using HolderWithAddressOf = shared_ptr_with_addressof_operator; py::class_(m, "TypeForHolderWithAddressOf") .def_static("make", []() { return HolderWithAddressOf(new TypeForHolderWithAddressOf); }) .def("get", [](const HolderWithAddressOf &self) { return self.get(); }) - .def("print_object_1", [](const TypeForHolderWithAddressOf *obj) { py::print(obj->toString()); }) + .def("print_object_1", + [](const TypeForHolderWithAddressOf *obj) { py::print(obj->toString()); }) + // NOLINTNEXTLINE(performance-unnecessary-value-param) .def("print_object_2", [](HolderWithAddressOf obj) { py::print(obj.get()->toString()); }) - .def("print_object_3", [](const HolderWithAddressOf &obj) { py::print(obj.get()->toString()); }) - .def("print_object_4", [](const HolderWithAddressOf *obj) { py::print((*obj).get()->toString()); }); + .def("print_object_3", + [](const HolderWithAddressOf &obj) { py::print(obj.get()->toString()); }) + .def("print_object_4", + [](const HolderWithAddressOf *obj) { py::print((*obj).get()->toString()); }); // test_move_only_holder_with_addressof_operator - struct TypeForMoveOnlyHolderWithAddressOf { - TypeForMoveOnlyHolderWithAddressOf(int value) : value{value} { print_created(this); } - ~TypeForMoveOnlyHolderWithAddressOf() { print_destroyed(this); } - std::string toString() const { - return "MoveOnlyHolderWithAddressOf[" + std::to_string(value) + "]"; - } - int value; - }; using MoveOnlyHolderWithAddressOf = unique_ptr_with_addressof_operator; py::class_(m, "TypeForMoveOnlyHolderWithAddressOf") .def_static("make", []() { return MoveOnlyHolderWithAddressOf(new TypeForMoveOnlyHolderWithAddressOf(0)); }) @@ -331,33 +427,19 @@ TEST_SUBMODULE(smart_ptr, m) { .def("print_object", [](const TypeForMoveOnlyHolderWithAddressOf *obj) { py::print(obj->toString()); }); // test_smart_ptr_from_default - struct HeldByDefaultHolder { }; - py::class_(m, "HeldByDefaultHolder") + py::class_>(m, "HeldByDefaultHolder") .def(py::init<>()) + // NOLINTNEXTLINE(performance-unnecessary-value-param) .def_static("load_shared_ptr", [](std::shared_ptr) {}); // test_shared_ptr_gc // #187: issue involving std::shared_ptr<> return value policy & garbage collection - struct ElementBase { - virtual ~ElementBase() = default; /* Force creation of virtual table */ - ElementBase() = default; - ElementBase(const ElementBase&) = delete; - }; py::class_>(m, "ElementBase"); - struct ElementA : ElementBase { - ElementA(int v) : v(v) { } - int value() { return v; } - int v; - }; py::class_>(m, "ElementA") .def(py::init()) .def("value", &ElementA::value); - struct ElementList { - void add(std::shared_ptr e) { l.push_back(e); } - std::vector> l; - }; py::class_>(m, "ElementList") .def(py::init<>()) .def("add", &ElementList::add) diff --git a/wrap/pybind11/tests/test_smart_ptr.py b/wrap/pybind11/tests/test_smart_ptr.py index 0b1ca45b5..85f61a322 100644 --- a/wrap/pybind11/tests/test_smart_ptr.py +++ b/wrap/pybind11/tests/test_smart_ptr.py @@ -7,7 +7,9 @@ from pybind11_tests import ConstructorStats # noqa: E402 def test_smart_ptr(capture): # Object1 - for i, o in enumerate([m.make_object_1(), m.make_object_2(), m.MyObject1(3)], start=1): + for i, o in enumerate( + [m.make_object_1(), m.make_object_2(), m.MyObject1(3)], start=1 + ): assert o.getRefCount() == 1 with capture: m.print_object_1(o) @@ -16,8 +18,9 @@ def test_smart_ptr(capture): m.print_object_4(o) assert capture == "MyObject1[{i}]\n".format(i=i) * 4 - for i, o in enumerate([m.make_myobject1_1(), m.make_myobject1_2(), m.MyObject1(6), 7], - start=4): + for i, o in enumerate( + [m.make_myobject1_1(), m.make_myobject1_2(), m.MyObject1(6), 7], start=4 + ): print(o) with capture: if not isinstance(o, int): @@ -29,11 +32,15 @@ def test_smart_ptr(capture): m.print_myobject1_2(o) m.print_myobject1_3(o) m.print_myobject1_4(o) - assert capture == "MyObject1[{i}]\n".format(i=i) * (4 if isinstance(o, int) else 8) + + times = 4 if isinstance(o, int) else 8 + assert capture == "MyObject1[{i}]\n".format(i=i) * times cstats = ConstructorStats.get(m.MyObject1) assert cstats.alive() == 0 - expected_values = ['MyObject1[{}]'.format(i) for i in range(1, 7)] + ['MyObject1[7]'] * 4 + expected_values = ["MyObject1[{}]".format(i) for i in range(1, 7)] + [ + "MyObject1[7]" + ] * 4 assert cstats.values() == expected_values assert cstats.default_constructions == 0 assert cstats.copy_constructions == 0 @@ -42,7 +49,9 @@ def test_smart_ptr(capture): assert cstats.move_assignments == 0 # Object2 - for i, o in zip([8, 6, 7], [m.MyObject2(8), m.make_myobject2_1(), m.make_myobject2_2()]): + for i, o in zip( + [8, 6, 7], [m.MyObject2(8), m.make_myobject2_1(), m.make_myobject2_2()] + ): print(o) with capture: m.print_myobject2_1(o) @@ -55,7 +64,7 @@ def test_smart_ptr(capture): assert cstats.alive() == 1 o = None assert cstats.alive() == 0 - assert cstats.values() == ['MyObject2[8]', 'MyObject2[6]', 'MyObject2[7]'] + assert cstats.values() == ["MyObject2[8]", "MyObject2[6]", "MyObject2[7]"] assert cstats.default_constructions == 0 assert cstats.copy_constructions == 0 # assert cstats.move_constructions >= 0 # Doesn't invoke any @@ -63,7 +72,9 @@ def test_smart_ptr(capture): assert cstats.move_assignments == 0 # Object3 - for i, o in zip([9, 8, 9], [m.MyObject3(9), m.make_myobject3_1(), m.make_myobject3_2()]): + for i, o in zip( + [9, 8, 9], [m.MyObject3(9), m.make_myobject3_1(), m.make_myobject3_2()] + ): print(o) with capture: m.print_myobject3_1(o) @@ -76,7 +87,7 @@ def test_smart_ptr(capture): assert cstats.alive() == 1 o = None assert cstats.alive() == 0 - assert cstats.values() == ['MyObject3[9]', 'MyObject3[8]', 'MyObject3[9]'] + assert cstats.values() == ["MyObject3[9]", "MyObject3[8]", "MyObject3[9]"] assert cstats.default_constructions == 0 assert cstats.copy_constructions == 0 # assert cstats.move_constructions >= 0 # Doesn't invoke any @@ -96,7 +107,7 @@ def test_smart_ptr(capture): # ref<> cstats = m.cstats_ref() assert cstats.alive() == 0 - assert cstats.values() == ['from pointer'] * 10 + assert cstats.values() == ["from pointer"] * 10 assert cstats.default_constructions == 30 assert cstats.copy_constructions == 12 # assert cstats.move_constructions >= 0 # Doesn't invoke any @@ -114,7 +125,9 @@ def test_unique_nodelete(): cstats = ConstructorStats.get(m.MyObject4) assert cstats.alive() == 1 del o - assert cstats.alive() == 1 # Leak, but that's intentional + assert cstats.alive() == 1 + m.MyObject4.cleanup_all_instances() + assert cstats.alive() == 0 def test_unique_nodelete4a(): @@ -123,19 +136,25 @@ def test_unique_nodelete4a(): cstats = ConstructorStats.get(m.MyObject4a) assert cstats.alive() == 1 del o - assert cstats.alive() == 1 # Leak, but that's intentional + assert cstats.alive() == 1 + m.MyObject4a.cleanup_all_instances() + assert cstats.alive() == 0 def test_unique_deleter(): + m.MyObject4a(0) o = m.MyObject4b(23) assert o.value == 23 cstats4a = ConstructorStats.get(m.MyObject4a) - assert cstats4a.alive() == 2 # Two because of previous test + assert cstats4a.alive() == 2 cstats4b = ConstructorStats.get(m.MyObject4b) assert cstats4b.alive() == 1 del o - assert cstats4a.alive() == 1 # Should now only be one leftover from previous test + assert cstats4a.alive() == 1 # Should now only be one leftover assert cstats4b.alive() == 0 # Should be deleted + m.MyObject4a.cleanup_all_instances() + assert cstats4a.alive() == 0 + assert cstats4b.alive() == 0 def test_large_holder(): @@ -186,7 +205,9 @@ def test_shared_ptr_from_this_and_references(): ref = s.ref # init_holder_helper(holder_ptr=false, owned=false, bad_wp=false) assert stats.alive() == 2 assert s.set_ref(ref) - assert s.set_holder(ref) # std::enable_shared_from_this can create a holder from a reference + assert s.set_holder( + ref + ) # std::enable_shared_from_this can create a holder from a reference bad_wp = s.bad_wp # init_holder_helper(holder_ptr=false, owned=false, bad_wp=true) assert stats.alive() == 2 @@ -200,12 +221,16 @@ def test_shared_ptr_from_this_and_references(): assert s.set_ref(copy) assert s.set_holder(copy) - holder_ref = s.holder_ref # init_holder_helper(holder_ptr=true, owned=false, bad_wp=false) + holder_ref = ( + s.holder_ref + ) # init_holder_helper(holder_ptr=true, owned=false, bad_wp=false) assert stats.alive() == 3 assert s.set_ref(holder_ref) assert s.set_holder(holder_ref) - holder_copy = s.holder_copy # init_holder_helper(holder_ptr=true, owned=true, bad_wp=false) + holder_copy = ( + s.holder_copy + ) # init_holder_helper(holder_ptr=true, owned=true, bad_wp=false) assert stats.alive() == 3 assert s.set_ref(holder_copy) assert s.set_holder(holder_copy) @@ -277,8 +302,10 @@ def test_smart_ptr_from_default(): instance = m.HeldByDefaultHolder() with pytest.raises(RuntimeError) as excinfo: m.HeldByDefaultHolder.load_shared_ptr(instance) - assert "Unable to load a custom holder type from a " \ - "default-holder instance" in str(excinfo.value) + assert ( + "Unable to load a custom holder type from a " + "default-holder instance" in str(excinfo.value) + ) def test_shared_ptr_gc(): diff --git a/wrap/pybind11/tests/test_stl.cpp b/wrap/pybind11/tests/test_stl.cpp index 059016277..bc5c6553a 100644 --- a/wrap/pybind11/tests/test_stl.cpp +++ b/wrap/pybind11/tests/test_stl.cpp @@ -11,9 +11,26 @@ #include "constructor_stats.h" #include +#ifndef PYBIND11_HAS_FILESYSTEM_IS_OPTIONAL +#define PYBIND11_HAS_FILESYSTEM_IS_OPTIONAL +#endif +#include + #include #include +#if defined(PYBIND11_TEST_BOOST) +#include + +namespace pybind11 { namespace detail { +template +struct type_caster> : optional_caster> {}; + +template <> +struct type_caster : void_caster {}; +}} // namespace pybind11::detail +#endif + // Test with `std::variant` in C++17 mode, or with `boost::variant` in C++11/14 #if defined(PYBIND11_HAS_VARIANT) using std::variant; @@ -40,7 +57,8 @@ PYBIND11_MAKE_OPAQUE(std::vector>); /// Issue #528: templated constructor struct TplCtorClass { - template TplCtorClass(const T &) { } + template + explicit TplCtorClass(const T &) {} bool operator==(const TplCtorClass &) const { return true; } }; @@ -53,7 +71,8 @@ namespace std { template