diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 3f5701281..0855dbc21 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -75,7 +75,7 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DGTSAM_UNSTABLE_BUILD_PYTHON=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_PYTHON_VERSION=$PYTHON_VERSION \ -DPYTHON_EXECUTABLE:FILEPATH=$(which $PYTHON) \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=OFF \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=OFF \ -DCMAKE_INSTALL_PREFIX=$GITHUB_WORKSPACE/gtsam_install @@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ make -j2 install cd $GITHUB_WORKSPACE/build/python -$PYTHON setup.py install --user --prefix= +$PYTHON -m pip install --user . cd $GITHUB_WORKSPACE/python/gtsam/tests $PYTHON -m unittest discover -v diff --git a/.github/scripts/unix.sh b/.github/scripts/unix.sh index 9689d346c..d890577b6 100644 --- a/.github/scripts/unix.sh +++ b/.github/scripts/unix.sh @@ -64,7 +64,7 @@ function configure() -DGTSAM_BUILD_UNSTABLE=${GTSAM_BUILD_UNSTABLE:-ON} \ -DGTSAM_WITH_TBB=${GTSAM_WITH_TBB:-OFF} \ -DGTSAM_BUILD_EXAMPLES_ALWAYS=${GTSAM_BUILD_EXAMPLES_ALWAYS:-ON} \ - -DGTSAM_ALLOW_DEPRECATED_SINCE_V41=${GTSAM_ALLOW_DEPRECATED_SINCE_V41:-OFF} \ + -DGTSAM_ALLOW_DEPRECATED_SINCE_V42=${GTSAM_ALLOW_DEPRECATED_SINCE_V42:-OFF} \ -DGTSAM_USE_QUATERNIONS=${GTSAM_USE_QUATERNIONS:-OFF} \ -DGTSAM_ROT3_EXPMAP=${GTSAM_ROT3_EXPMAP:-ON} \ -DGTSAM_POSE3_EXPMAP=${GTSAM_POSE3_EXPMAP:-ON} \ diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index f52e5eec3..7b13b6646 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -15,7 +15,7 @@ jobs: BOOST_VERSION: 1.67.0 strategy: - fail-fast: false + fail-fast: true matrix: # Github Actions requires a single row to be added to the build matrix. # See https://help.github.com/en/articles/workflow-syntax-for-github-actions. diff --git a/.github/workflows/build-special.yml b/.github/workflows/build-special.yml index 647b9c0f1..d357b9a34 100644 --- a/.github/workflows/build-special.yml +++ b/.github/workflows/build-special.yml @@ -110,7 +110,7 @@ jobs: - name: Set Allow Deprecated Flag if: matrix.flag == 'deprecated' run: | - echo "GTSAM_ALLOW_DEPRECATED_SINCE_V41=ON" >> $GITHUB_ENV + echo "GTSAM_ALLOW_DEPRECATED_SINCE_V42=ON" >> $GITHUB_ENV echo "Allow deprecated since version 4.1" - name: Set Use Quaternions Flag diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 5dfdcd013..a9e794b3f 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -26,7 +26,11 @@ jobs: windows-2019-cl, ] - build_type: [Debug, Release] + build_type: [ + Debug, + #TODO(Varun) The release build takes over 2.5 hours, need to figure out why. + # Release + ] build_unstable: [ON] include: #TODO This build fails, need to understand why. @@ -90,13 +94,18 @@ jobs: - name: Checkout uses: actions/checkout@v2 - - name: Build + - name: Configuration run: | cmake -E remove_directory build cmake -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib" - cmake --build build --config ${{ matrix.build_type }} --target gtsam - cmake --build build --config ${{ matrix.build_type }} --target gtsam_unstable - cmake --build build --config ${{ matrix.build_type }} --target wrap - cmake --build build --config ${{ matrix.build_type }} --target check.base - cmake --build build --config ${{ matrix.build_type }} --target check.base_unstable - cmake --build build --config ${{ matrix.build_type }} --target check.linear + + - name: Build + run: | + # Since Visual Studio is a multi-generator, we need to use --config + # https://stackoverflow.com/a/24470998/1236990 + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam + cmake --build build -j 4 --config ${{ matrix.build_type }} --target gtsam_unstable + cmake --build build -j 4 --config ${{ matrix.build_type }} --target wrap + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.base_unstable + cmake --build build -j 4 --config ${{ matrix.build_type }} --target check.linear diff --git a/.gitignore b/.gitignore index cde059767..e6e38132f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .idea *.pyc *.DS_Store +*.swp /examples/Data/dubrovnik-3-7-pre-rewritten.txt /examples/Data/pose2example-rewritten.txt /examples/Data/pose3example-rewritten.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index b8480867e..a79e812ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,18 @@ endif() # Set the version number for the library set (GTSAM_VERSION_MAJOR 4) -set (GTSAM_VERSION_MINOR 1) +set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) +set (GTSAM_PRERELEASE_VERSION "a4") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") -set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) +if (${GTSAM_VERSION_PATCH} EQUAL 0) + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") +else() + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") +endif() +message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") + set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) @@ -87,6 +93,13 @@ if(GTSAM_BUILD_PYTHON OR GTSAM_INSTALL_MATLAB_TOOLBOX) CACHE STRING "The Python version to use for wrapping") # Set the include directory for matlab.h set(GTWRAP_INCLUDE_NAME "wrap") + + # Copy matlab.h to the correct folder. + configure_file(${PROJECT_SOURCE_DIR}/wrap/matlab.h + ${PROJECT_BINARY_DIR}/wrap/matlab.h COPYONLY) + # Add the include directories so that matlab.h can be found + include_directories("${PROJECT_BINARY_DIR}" "${GTSAM_EIGEN_INCLUDE_FOR_BUILD}") + add_subdirectory(wrap) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/wrap/cmake") endif() diff --git a/README.md b/README.md index 046132301..52ac0a5d8 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ **Important Note** -As of August 1 2020, the `develop` branch is officially in "Pre 4.1" mode, and features deprecated in 4.0 have been removed. Please use the last [4.0.3 release](https://github.com/borglab/gtsam/releases/tag/4.0.3) if you need those features. +As of Dec 2021, the `develop` branch is officially in "Pre 4.2" mode. A great new feature we will be adding in 4.2 is *hybrid inference* a la DCSLAM (Kevin Doherty et al) and we envision several API-breaking changes will happen in the discrete folder. -However, most are easily converted and can be tracked down (in 4.0.3) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4`. +In addition, features deprecated in 4.1 will be removed. Please use the last [4.1.1 release](https://github.com/borglab/gtsam/releases/tag/4.1.1) if you need those features. However, most (not all, unfortunately) are easily converted and can be tracked down (in 4.1.1) by disabling the cmake flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42`. ## What is GTSAM? @@ -57,7 +57,7 @@ GTSAM 4 introduces several new features, most notably Expressions and a Python t GTSAM 4 also deprecated some legacy functionality and wrongly named methods. If you are on a 4.0.X release, you can define the flag `GTSAM_ALLOW_DEPRECATED_SINCE_V4` to use the deprecated methods. -GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V41` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. +GTSAM 4.1 added a new pybind wrapper, and **removed** the deprecated functionality. There is a flag `GTSAM_ALLOW_DEPRECATED_SINCE_V42` for newly deprecated methods since the 4.1 release, which is on by default, allowing anyone to just pull version 4.1 and compile. ## Wrappers diff --git a/Using-GTSAM-EXPORT.md b/Using-GTSAM-EXPORT.md index cae1d499c..faeebc97f 100644 --- a/Using-GTSAM-EXPORT.md +++ b/Using-GTSAM-EXPORT.md @@ -29,7 +29,7 @@ Rule #1 doesn't seem very bad, until you combine it with rule #2 ***Compiler Rule #2*** Anything declared in a header file is not included in a DLL. -When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. LieMatrix) cannot use `GTSAM_EXPORT` in its definition. If LieMatrix is defined with `GTSAM_EXPORT`, then the compiler _must_ find LieMatrix in a DLL. Because LieMatrix is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. +When these two rules are combined, you get some very confusing results. For example, a class which is completely defined in a header (e.g. Foo) cannot use `GTSAM_EXPORT` in its definition. If Foo is defined with `GTSAM_EXPORT`, then the compiler _must_ find Foo in a DLL. Because Foo is a header-only class, however, it can't find it, leading to a very confusing "I can't find this symbol" type of error. Note that the linker says it can't find the symbol even though the compiler found the header file that completely defines the class. Also note that when a class that you want to export inherits from another class that is not exportable, this can cause significant issues. According to this [MSVC Warning page](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4275?view=vs-2019), it may not strictly be a rule, but we have seen several linker errors when a class that is defined with `GTSAM_EXPORT` extended an Eigen class. In general, it appears that any inheritance of non-exportable class by an exportable class is a bad idea. diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 64c239f39..7c8f8533f 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -25,7 +25,7 @@ option(GTSAM_WITH_EIGEN_MKL_OPENMP "Eigen, when using Intel MKL, will a option(GTSAM_THROW_CHEIRALITY_EXCEPTION "Throw exception when a triangulated point is behind a camera" ON) option(GTSAM_BUILD_PYTHON "Enable/Disable building & installation of Python module with pybind11" OFF) option(GTSAM_INSTALL_MATLAB_TOOLBOX "Enable/Disable installation of matlab toolbox" OFF) -option(GTSAM_ALLOW_DEPRECATED_SINCE_V41 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) +option(GTSAM_ALLOW_DEPRECATED_SINCE_V42 "Allow use of methods/functions deprecated in GTSAM 4.1" ON) option(GTSAM_SUPPORT_NESTED_DISSECTION "Support Metis-based nested dissection" ON) option(GTSAM_TANGENT_PREINTEGRATION "Use new ImuFactor with integration on tangent space" ON) option(GTSAM_SLOW_BUT_CORRECT_BETWEENFACTOR "Use the slower but correct version of BetweenFactor" OFF) diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index ad6ac5c5c..43ee5b57b 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -86,7 +86,7 @@ print_enabled_config(${GTSAM_USE_QUATERNIONS} "Quaternions as defaul print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency checking ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") -print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V41} "Allow features deprecated in GTSAM 4.1") +print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V42} "Allow features deprecated in GTSAM 4.1") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") diff --git a/doc/Doxyfile.in b/doc/Doxyfile.in index fd7f4e5f6..12193d0be 100644 --- a/doc/Doxyfile.in +++ b/doc/Doxyfile.in @@ -1188,7 +1188,7 @@ USE_MATHJAX = YES # MathJax, but it is strongly recommended to install a local copy of MathJax # before deployment. -MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest +# MATHJAX_RELPATH = https://cdn.mathjax.org/mathjax/latest # The MATHJAX_EXTENSIONS tag can be used to specify one or MathJax extension # names that should be enabled during MathJax rendering. diff --git a/doc/gtsam.lyx b/doc/gtsam.lyx index a5adc2b60..29d03cd35 100644 --- a/doc/gtsam.lyx +++ b/doc/gtsam.lyx @@ -1,5 +1,5 @@ -#LyX 2.2 created this file. For more info see http://www.lyx.org/ -\lyxformat 508 +#LyX 2.3 created this file. For more info see http://www.lyx.org/ +\lyxformat 544 \begin_document \begin_header \save_transient_properties true @@ -62,6 +62,8 @@ \font_osf false \font_sf_scale 100 100 \font_tt_scale 100 100 +\use_microtype false +\use_dash_ligatures true \graphics default \default_output_format default \output_sync 0 @@ -91,6 +93,7 @@ \suppress_date false \justification true \use_refstyle 0 +\use_minted 0 \index Index \shortcut idx \color #008000 @@ -105,7 +108,10 @@ \tocdepth 3 \paragraph_separation indent \paragraph_indentation default -\quotes_language english +\is_math_indent 0 +\math_numbering_side default +\quotes_style english +\dynamic_quotes 0 \papercolumns 1 \papersides 1 \paperpagestyle default @@ -168,6 +174,7 @@ Factor graphs \begin_inset CommandInset citation LatexCommand citep key "Koller09book" +literal "true" \end_inset @@ -270,6 +277,7 @@ Let us start with a one-page primer on factor graphs, which in no way replaces \begin_inset CommandInset citation LatexCommand citet key "Kschischang01it" +literal "true" \end_inset @@ -277,6 +285,7 @@ key "Kschischang01it" \begin_inset CommandInset citation LatexCommand citet key "Loeliger04spm" +literal "true" \end_inset @@ -1321,6 +1330,7 @@ r in a pre-existing map, or indeed the presence of absence of ceiling lights \begin_inset CommandInset citation LatexCommand citet key "Dellaert99b" +literal "true" \end_inset @@ -1542,6 +1552,7 @@ which is done on line 12. \begin_inset CommandInset citation LatexCommand citealt key "Dellaert06ijrr" +literal "true" \end_inset @@ -1936,8 +1947,8 @@ reference "fig:CompareMarginals" \end_inset -, where I show the marginals on position as covariance ellipses that contain - 68.26% of all probability mass. +, where I show the marginals on position as 5-sigma covariance ellipses + that contain 99.9996% of all probability mass. For the odometry marginals, it is immediately apparent from the figure that (1) the uncertainty on pose keeps growing, and (2) the uncertainty on angular odometry translates into increasing uncertainty on y. @@ -1992,6 +2003,7 @@ PoseSLAM \begin_inset CommandInset citation LatexCommand citep key "DurrantWhyte06ram" +literal "true" \end_inset @@ -2190,9 +2202,9 @@ reference "fig:example" \end_inset , along with covariance ellipses shown in green. - These covariance ellipses in 2D indicate the marginal over position, over - all possible orientations, and show the area which contain 68.26% of the - probability mass (in 1D this would correspond to one standard deviation). + These 5-sigma covariance ellipses in 2D indicate the marginal over position, + over all possible orientations, and show the area which contain 99.9996% + of the probability mass. The graph shows in a clear manner that the uncertainty on pose \begin_inset Formula $x_{5}$ \end_inset @@ -3076,6 +3088,7 @@ reference "fig:Victoria-1" \begin_inset CommandInset citation LatexCommand citep key "Kaess09ras" +literal "true" \end_inset @@ -3088,6 +3101,7 @@ key "Kaess09ras" \begin_inset CommandInset citation LatexCommand citep key "Kaess08tro" +literal "true" \end_inset @@ -3355,6 +3369,7 @@ iSAM \begin_inset CommandInset citation LatexCommand citet key "Kaess08tro,Kaess12ijrr" +literal "true" \end_inset @@ -3606,6 +3621,7 @@ subgraph preconditioning \begin_inset CommandInset citation LatexCommand citet key "Dellaert10iros,Jian11iccv" +literal "true" \end_inset @@ -3638,6 +3654,7 @@ Visual Odometry \begin_inset CommandInset citation LatexCommand citet key "Nister04cvpr2" +literal "true" \end_inset @@ -3661,6 +3678,7 @@ Visual SLAM \begin_inset CommandInset citation LatexCommand citet key "Davison03iccv" +literal "true" \end_inset @@ -3711,6 +3729,7 @@ Filtering \begin_inset CommandInset citation LatexCommand citep key "Smith87b" +literal "true" \end_inset diff --git a/doc/gtsam.pdf b/doc/gtsam.pdf index c6a39a79c..d4cb8908f 100644 Binary files a/doc/gtsam.pdf and b/doc/gtsam.pdf differ diff --git a/doc/math.lyx b/doc/math.lyx index 4ee89a9cc..86ed2b220 100644 --- a/doc/math.lyx +++ b/doc/math.lyx @@ -2668,7 +2668,7 @@ reference "eq:pushforward" \begin{eqnarray*} \varphi(a)e^{\yhat} & = & \varphi(ae^{\xhat})\\ a^{-1}e^{\yhat} & = & \left(ae^{\xhat}\right)^{-1}\\ -e^{\yhat} & = & -ae^{\xhat}a^{-1}\\ +e^{\yhat} & = & ae^{-\xhat}a^{-1}\\ \yhat & = & -\Ad a\xhat \end{eqnarray*} @@ -3003,8 +3003,8 @@ between \begin_inset Formula \begin{align} \varphi(g,h)e^{\yhat} & =\varphi(ge^{\xhat},h)\nonumber \\ -g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=-e^{\xhat}g^{-1}h\nonumber \\ -e^{\yhat} & =-\left(h^{-1}g\right)e^{\xhat}\left(h^{-1}g\right)^{-1}=-\exp\Ad{\left(h^{-1}g\right)}\xhat\nonumber \\ +g^{-1}he^{\yhat} & =\left(ge^{\xhat}\right)^{-1}h=e^{-\xhat}g^{-1}h\nonumber \\ +e^{\yhat} & =\left(h^{-1}g\right)e^{-\xhat}\left(h^{-1}g\right)^{-1}=\exp\Ad{\left(h^{-1}g\right)}(-\xhat)\nonumber \\ \yhat & =-\Ad{\left(h^{-1}g\right)}\xhat=-\Ad{\varphi\left(h,g\right)}\xhat\label{eq:Dbetween1} \end{align} @@ -6674,7 +6674,7 @@ One representation of a line is through 2 vectors \begin_inset Formula $d$ \end_inset - points from the orgin to the closest point on the line. + points from the origin to the closest point on the line. \end_layout \begin_layout Standard diff --git a/doc/math.pdf b/doc/math.pdf index 40980354e..71533e1e8 100644 Binary files a/doc/math.pdf and b/doc/math.pdf differ diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp index 5dca116c3..dfd7beb63 100644 --- a/examples/DiscreteBayesNetExample.cpp +++ b/examples/DiscreteBayesNetExample.cpp @@ -53,11 +53,10 @@ int main(int argc, char **argv) { // Create solver and eliminate Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); - DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); + auto mpe = fg.optimize(); + GTSAM_PRINT(mpe); // We can also build a Bayes tree (directed junction tree). // The elimination order above will do fine: @@ -69,15 +68,15 @@ int main(int argc, char **argv) { fg.add(Dyspnea, "0 1"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues mpe2 = chordal2->optimize(); - GTSAM_PRINT(*mpe2); + auto mpe2 = fg.optimize(); + GTSAM_PRINT(mpe2); // We can also sample from it + DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal2->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal->sample(); + GTSAM_PRINT(sample); } return 0; } diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 121df4bef..88904001a 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -33,11 +33,11 @@ using namespace gtsam; int main(int argc, char **argv) { // Define keys and a print function Key C(1), S(2), R(3), W(4); - auto print = [=](DiscreteFactor::sharedValues values) { - cout << boolalpha << "Cloudy = " << static_cast((*values)[C]) - << " Sprinkler = " << static_cast((*values)[S]) - << " Rain = " << boolalpha << static_cast((*values)[R]) - << " WetGrass = " << static_cast((*values)[W]) << endl; + auto print = [=](const DiscreteFactor::Values& values) { + cout << boolalpha << "Cloudy = " << static_cast(values.at(C)) + << " Sprinkler = " << static_cast(values.at(S)) + << " Rain = " << boolalpha << static_cast(values.at(R)) + << " WetGrass = " << static_cast(values.at(W)) << endl; }; // We assume binary state variables @@ -85,7 +85,7 @@ int main(int argc, char **argv) { } // "Most Probable Explanation", i.e., configuration with largest value - DiscreteFactor::sharedValues mpe = graph.eliminateSequential()->optimize(); + auto mpe = graph.optimize(); cout << "\nMost Probable Explanation (MPE):" << endl; print(mpe); @@ -96,8 +96,7 @@ int main(int argc, char **argv) { graph.add(Cloudy, "1 0"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues mpe_with_evidence = chordal->optimize(); + auto mpe_with_evidence = graph.optimize(); cout << "\nMPE given C=0:" << endl; print(mpe_with_evidence); @@ -110,10 +109,11 @@ int main(int argc, char **argv) { cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] << endl; - // We can also sample from it + // We can also sample from the eliminated graph + DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - DiscreteFactor::sharedValues sample = chordal->sample(); + auto sample = chordal->sample(); print(sample); } return 0; diff --git a/examples/FisheyeExample.cpp b/examples/FisheyeExample.cpp index 223149299..fc0aed0d7 100644 --- a/examples/FisheyeExample.cpp +++ b/examples/FisheyeExample.cpp @@ -122,8 +122,7 @@ int main(int argc, char *argv[]) { std::cout << "initial error=" << graph.error(initialEstimate) << std::endl; std::cout << "final error=" << graph.error(result) << std::endl; - std::ofstream os("examples/vio_batch.dot"); - graph.saveGraph(os, result); + graph.saveGraph("examples/vio_batch.dot", result); return 0; } diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp index ee861e381..3a7673001 100644 --- a/examples/HMMExample.cpp +++ b/examples/HMMExample.cpp @@ -59,21 +59,21 @@ int main(int argc, char **argv) { // Convert to factor graph DiscreteFactorGraph factorGraph(hmm); + // Do max-prodcut + auto mpe = factorGraph.optimize(); + GTSAM_PRINT(mpe); + // Create solver and eliminate // This will create a DAG ordered with arrow of time reversed DiscreteBayesNet::shared_ptr chordal = factorGraph.eliminateSequential(ordering); chordal->print("Eliminated"); - // solve - DiscreteFactor::sharedValues mpe = chordal->optimize(); - GTSAM_PRINT(*mpe); - // We can also sample from it cout << "\n10 samples:" << endl; for (size_t k = 0; k < 10; k++) { - DiscreteFactor::sharedValues sample = chordal->sample(); - GTSAM_PRINT(*sample); + auto sample = chordal->sample(); + GTSAM_PRINT(sample); } // Or compute the marginals. This re-eliminates the FG into a Bayes tree diff --git a/examples/Pose2SLAMExample_graphviz.cpp b/examples/Pose2SLAMExample_graphviz.cpp index 27d556725..a8768e2b8 100644 --- a/examples/Pose2SLAMExample_graphviz.cpp +++ b/examples/Pose2SLAMExample_graphviz.cpp @@ -60,11 +60,10 @@ int main(int argc, char** argv) { // save factor graph as graphviz dot file // Render to PDF using "fdp Pose2SLAMExample.dot -Tpdf > graph.pdf" - ofstream os("Pose2SLAMExample.dot"); - graph.saveGraph(os, result); + graph.saveGraph("Pose2SLAMExample.dot", result); // Also print out to console - graph.saveGraph(cout, result); + graph.dot(cout, result); return 0; } diff --git a/examples/UGM_chain.cpp b/examples/UGM_chain.cpp index 3a885a844..ad21af9fa 100644 --- a/examples/UGM_chain.cpp +++ b/examples/UGM_chain.cpp @@ -68,10 +68,9 @@ int main(int argc, char** argv) { << graph.size() << " factors (Unary+Edge)."; // "Decoding", i.e., configuration with largest value - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n"); + // Uses max-product. + auto optimalDecoding = graph.optimize(); + optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); // "Inference" Computing marginals for each node // Here we'll make use of DiscreteMarginals class, which makes use of diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 27a6205a3..bc6a41317 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -50,8 +50,8 @@ int main(int argc, char** argv) { // Print the UGM distribution cout << "\nUGM distribution:" << endl; - vector allPosbValues = cartesianProduct( - Cathy & Heather & Mark & Allison); + auto allPosbValues = + DiscreteValues::CartesianProduct(Cathy & Heather & Mark & Allison); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteFactor::Values values = allPosbValues[i]; double prodPot = graph(values); @@ -61,10 +61,9 @@ int main(int argc, char** argv) { } // "Decoding", i.e., configuration with largest value (MPE) - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues optimalDecoding = chordal->optimize(); - optimalDecoding->print("\noptimalDecoding"); + // Uses max-product + auto optimalDecoding = graph.optimize(); + GTSAM_PRINT(optimalDecoding); // "Inference" Computing marginals cout << "\nComputing Node Marginals .." << endl; diff --git a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h index 667ef09dc..9db32744e 100644 --- a/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h +++ b/gtsam/3rdparty/Eigen/Eigen/src/Core/TriangularMatrix.h @@ -440,7 +440,7 @@ template class TriangularViewImpl<_Mat EIGEN_DEVICE_FUNC void lazyAssign(const TriangularBase& other); - /** \deprecated */ + /** @deprecated */ template EIGEN_DEVICE_FUNC void lazyAssign(const MatrixBase& other); @@ -523,7 +523,7 @@ template class TriangularViewImpl<_Mat call_assignment(derived(), other.const_cast_derived(), internal::swap_assign_op()); } - /** \deprecated + /** @deprecated * Shortcut for \code (*this).swap(other.triangularView<(*this)::Mode>()) \endcode */ template EIGEN_DEVICE_FUNC diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index 535d60eb1..a293c6ec2 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -15,7 +15,7 @@ set (gtsam_subdirs sam sfm slam - navigation + navigation ) set(gtsam_srcs) diff --git a/gtsam/base/CMakeLists.txt b/gtsam/base/CMakeLists.txt index 99984e7b3..66d3ec721 100644 --- a/gtsam/base/CMakeLists.txt +++ b/gtsam/base/CMakeLists.txt @@ -5,8 +5,5 @@ install(FILES ${base_headers} DESTINATION include/gtsam/base) file(GLOB base_headers_tree "treeTraversal/*.h") install(FILES ${base_headers_tree} DESTINATION include/gtsam/base/treeTraversal) -file(GLOB deprecated_headers "deprecated/*.h") -install(FILES ${deprecated_headers} DESTINATION include/gtsam/base/deprecated) - # Build tests add_subdirectory(tests) diff --git a/gtsam/base/LieMatrix.h b/gtsam/base/LieMatrix.h deleted file mode 100644 index 210bdcc73..000000000 --- a/gtsam/base/LieMatrix.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief External deprecation warning, see deprecated/LieMatrix.h for details - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieMatrix.h is deprecated. Please use Eigen::Matrix instead.") -#else -#warning "LieMatrix.h is deprecated. Please use Eigen::Matrix instead." -#endif - -#include "gtsam/base/deprecated/LieMatrix.h" diff --git a/gtsam/base/LieScalar.h b/gtsam/base/LieScalar.h deleted file mode 100644 index e159ffa87..000000000 --- a/gtsam/base/LieScalar.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieScalar.h - * @brief External deprecation warning, see deprecated/LieScalar.h for details - * @author Kai Ni - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieScalar.h is deprecated. Please use double/float instead.") -#else - #warning "LieScalar.h is deprecated. Please use double/float instead." -#endif - -#include diff --git a/gtsam/base/LieVector.h b/gtsam/base/LieVector.h deleted file mode 100644 index a7491d804..000000000 --- a/gtsam/base/LieVector.h +++ /dev/null @@ -1,26 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief Deprecation warning for LieVector, see deprecated/LieVector.h for details. - * @author Paul Drews - */ - -#pragma once - -#ifdef _MSC_VER -#pragma message("LieVector.h is deprecated. Please use Eigen::Vector instead.") -#else -#warning "LieVector.h is deprecated. Please use Eigen::Vector instead." -#endif - -#include diff --git a/gtsam/base/TestableAssertions.h b/gtsam/base/TestableAssertions.h index 0e6e1c276..e5bd34d19 100644 --- a/gtsam/base/TestableAssertions.h +++ b/gtsam/base/TestableAssertions.h @@ -80,12 +80,13 @@ bool assert_equal(const V& expected, const boost::optional& actual, do return assert_equal(expected, *actual, tol); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * Version of assert_equals to work with vectors - * \deprecated: use container equals instead + * @deprecated: use container equals instead */ template -bool assert_equal(const std::vector& expected, const std::vector& actual, double tol = 1e-9) { +bool GTSAM_DEPRECATED assert_equal(const std::vector& expected, const std::vector& actual, double tol = 1e-9) { bool match = true; if (expected.size() != actual.size()) match = false; @@ -108,6 +109,7 @@ bool assert_equal(const std::vector& expected, const std::vector& actual, } return true; } +#endif /** * Function for comparing maps of testable->testable diff --git a/gtsam/base/Vector.h b/gtsam/base/Vector.h index 35c68c4b4..36dc2288d 100644 --- a/gtsam/base/Vector.h +++ b/gtsam/base/Vector.h @@ -203,18 +203,19 @@ inline double inner_prod(const V1 &a, const V2& b) { return a.dot(b); } +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * BLAS Level 1 scal: x <- alpha*x - * \deprecated: use operators instead + * @deprecated: use operators instead */ -inline void scal(double alpha, Vector& x) { x *= alpha; } +inline void GTSAM_DEPRECATED scal(double alpha, Vector& x) { x *= alpha; } /** * BLAS Level 1 axpy: y <- alpha*x + y - * \deprecated: use operators instead + * @deprecated: use operators instead */ template -inline void axpy(double alpha, const V1& x, V2& y) { +inline void GTSAM_DEPRECATED axpy(double alpha, const V1& x, V2& y) { assert (y.size()==x.size()); y += alpha * x; } @@ -222,6 +223,7 @@ inline void axpy(double alpha, const Vector& x, SubVector y) { assert (y.size()==x.size()); y += alpha * x; } +#endif /** * house(x,j) computes HouseHolder vector v and scaling factor beta diff --git a/gtsam/base/base.i b/gtsam/base/base.i index d9c51fbe8..9838f97d3 100644 --- a/gtsam/base/base.i +++ b/gtsam/base/base.i @@ -38,7 +38,7 @@ class DSFMap { DSFMap(); KEY find(const KEY& key) const; void merge(const KEY& x, const KEY& y); - std::map sets(); + std::map sets(); }; class IndexPairSet { diff --git a/gtsam/base/deprecated/LieMatrix.h b/gtsam/base/deprecated/LieMatrix.h deleted file mode 100644 index a3d0a4328..000000000 --- a/gtsam/base/deprecated/LieMatrix.h +++ /dev/null @@ -1,152 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieMatrix.h - * @brief A wrapper around Matrix providing Lie compatibility - * @author Richard Roberts and Alex Cunningham - */ - -#pragma once - -#include - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieMatrix, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieMatrix : public Matrix { - - /// @name Constructors - /// @{ - enum { dimension = Eigen::Dynamic }; - - /** default constructor - only for serialize */ - LieMatrix() {} - - /** initialize from a normal matrix */ - LieMatrix(const Matrix& v) : Matrix(v) {} - - template - LieMatrix(const M& v) : Matrix(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieMatrix(const Eigen::Matrix& v) : Matrix(v) {} -#endif - - /** constructor with size and initial data, row order ! */ - LieMatrix(size_t m, size_t n, const double* const data) : - Matrix(Eigen::Map(data, m, n)) {} - - /// @} - /// @name Testable interface - /// @{ - - /** print @param s optional string naming the object */ - void print(const std::string& name = "") const { - gtsam::print(matrix(), name); - } - /** equality up to tolerance */ - inline bool equals(const LieMatrix& expected, double tol=1e-5) const { - return gtsam::equal_with_abs_tol(matrix(), expected.matrix(), tol); - } - - /// @} - /// @name Standard Interface - /// @{ - - /** get the underlying matrix */ - inline Matrix matrix() const { - return static_cast(*this); - } - - /// @} - - /// @name Group - /// @{ - LieMatrix compose(const LieMatrix& q) { return (*this)+q;} - LieMatrix between(const LieMatrix& q) { return q-(*this);} - LieMatrix inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieMatrix& q) { return between(q).vector();} - LieMatrix retract(const Vector& v) {return compose(LieMatrix(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieMatrix& p) {return p.vector();} - static LieMatrix Expmap(const Vector& v) { return LieMatrix(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** Returns dimensionality of the tangent space */ - inline size_t dim() const { return size(); } - - /** Convert to vector, is done row-wise - TODO why? */ - inline Vector vector() const { - Vector result(size()); - typedef Eigen::Matrix RowMajor; - Eigen::Map(&result(0), rows(), cols()) = *this; - return result; - } - - /** identity - NOTE: no known size at compile time - so zero length */ - inline static LieMatrix identity() { - throw std::runtime_error("LieMatrix::identity(): Don't use this function"); - return LieMatrix(); - } - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Matrix", - boost::serialization::base_object(*this)); - - } - -}; - - -template<> -struct traits : public internal::VectorSpace { - - // Override Retract, as the default version does not know how to initialize - static LieMatrix Retract(const LieMatrix& origin, const TangentVector& v, - ChartJacobian H1 = boost::none, ChartJacobian H2 = boost::none) { - if (H1) *H1 = Eye(origin); - if (H2) *H2 = Eye(origin); - typedef const Eigen::Matrix RowMajor; - return origin + Eigen::Map(&v(0), origin.rows(), origin.cols()); - } - -}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieScalar.h b/gtsam/base/deprecated/LieScalar.h deleted file mode 100644 index 6c9a5f766..000000000 --- a/gtsam/base/deprecated/LieScalar.h +++ /dev/null @@ -1,88 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieScalar.h - * @brief A wrapper around scalar providing Lie compatibility - * @author Kai Ni - */ - -#pragma once - -#include -#include -#include - -namespace gtsam { - - /** - * @deprecated: LieScalar, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ - struct LieScalar { - - enum { dimension = 1 }; - - /** default constructor */ - LieScalar() : d_(0.0) {} - - /** wrap a double */ - /*explicit*/ LieScalar(double d) : d_(d) {} - - /** access the underlying value */ - double value() const { return d_; } - - /** Automatic conversion to underlying value */ - operator double() const { return d_; } - - /** convert vector */ - Vector1 vector() const { Vector1 v; v< - struct traits : public internal::ScalarTraits {}; - -} // \namespace gtsam diff --git a/gtsam/base/deprecated/LieVector.h b/gtsam/base/deprecated/LieVector.h deleted file mode 100644 index 745189c3d..000000000 --- a/gtsam/base/deprecated/LieVector.h +++ /dev/null @@ -1,121 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file LieVector.h - * @brief A wrapper around vector providing Lie compatibility - * @author Alex Cunningham - */ - -#pragma once - -#include -#include - -namespace gtsam { - -/** - * @deprecated: LieVector, LieVector and LieMatrix are obsolete in GTSAM 4.0 as - * we can directly add double, Vector, and Matrix into values now, because of - * gtsam::traits. - */ -struct LieVector : public Vector { - - enum { dimension = Eigen::Dynamic }; - - /** default constructor - should be unnecessary */ - LieVector() {} - - /** initialize from a normal vector */ - LieVector(const Vector& v) : Vector(v) {} - - template - LieVector(const V& v) : Vector(v) {} - -// Currently TMP constructor causes ICE on MSVS 2013 -#if (_MSC_VER < 1800) - /** initialize from a fixed size normal vector */ - template - LieVector(const Eigen::Matrix& v) : Vector(v) {} -#endif - - /** wrap a double */ - LieVector(double d) : Vector((Vector(1) << d).finished()) {} - - /** constructor with size and initial data, row order ! */ - LieVector(size_t m, const double* const data) : Vector(m) { - for (size_t i = 0; i < m; i++) (*this)(i) = data[i]; - } - - /// @name Testable - /// @{ - void print(const std::string& name="") const { - gtsam::print(vector(), name); - } - bool equals(const LieVector& expected, double tol=1e-5) const { - return gtsam::equal(vector(), expected.vector(), tol); - } - /// @} - - /// @name Group - /// @{ - LieVector compose(const LieVector& q) { return (*this)+q;} - LieVector between(const LieVector& q) { return q-(*this);} - LieVector inverse() { return -(*this);} - /// @} - - /// @name Manifold - /// @{ - Vector localCoordinates(const LieVector& q) { return between(q).vector();} - LieVector retract(const Vector& v) {return compose(LieVector(v));} - /// @} - - /// @name Lie Group - /// @{ - static Vector Logmap(const LieVector& p) {return p.vector();} - static LieVector Expmap(const Vector& v) { return LieVector(v);} - /// @} - - /// @name VectorSpace requirements - /// @{ - - /** get the underlying vector */ - Vector vector() const { - return static_cast(*this); - } - - /** Returns dimensionality of the tangent space */ - size_t dim() const { return this->size(); } - - /** identity - NOTE: no known size at compile time - so zero length */ - static LieVector identity() { - throw std::runtime_error("LieVector::identity(): Don't use this function"); - return LieVector(); - } - - /// @} - -private: - - // Serialization function - friend class boost::serialization::access; - template - void serialize(Archive & ar, const unsigned int /*version*/) { - ar & boost::serialization::make_nvp("Vector", - boost::serialization::base_object(*this)); - } -}; - - -template<> -struct traits : public internal::VectorSpace {}; - -} // \namespace gtsam diff --git a/gtsam/base/serialization.h b/gtsam/base/serialization.h index f589ecc5e..24355c684 100644 --- a/gtsam/base/serialization.h +++ b/gtsam/base/serialization.h @@ -19,8 +19,9 @@ #pragma once -#include +#include #include +#include #include // includes for standard serialization types @@ -40,6 +41,17 @@ #include #include +// Workaround a bug in GCC >= 7 and C++17 +// ref. https://gitlab.com/libeigen/eigen/-/issues/1676 +#ifdef __GNUC__ +#if __GNUC__ >= 7 && __cplusplus >= 201703L +namespace boost { namespace serialization { struct U; } } +namespace Eigen { namespace internal { +template<> struct traits {enum {Flags=0};}; +} } +#endif +#endif + namespace gtsam { /** @name Standard serialization diff --git a/gtsam/base/tests/testLieMatrix.cpp b/gtsam/base/tests/testLieMatrix.cpp deleted file mode 100644 index 8c68bf8a0..000000000 --- a/gtsam/base/tests/testLieMatrix.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieMatrix.cpp - * @author Richard Roberts - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieMatrix) -GTSAM_CONCEPT_LIE_INST(LieMatrix) - -/* ************************************************************************* */ -TEST( LieMatrix, construction ) { - Matrix m = (Matrix(2,2) << 1.0,2.0, 3.0,4.0).finished(); - LieMatrix lie1(m), lie2(m); - - EXPECT(traits::GetDimension(m) == 4); - EXPECT(assert_equal(m, lie1.matrix())); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( LieMatrix, other_constructors ) { - Matrix init = (Matrix(2,2) << 10.0,20.0, 30.0,40.0).finished(); - LieMatrix exp(init); - double data[] = {10,30,20,40}; - LieMatrix b(2,2,data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -TEST(LieMatrix, retract) { - LieMatrix init((Matrix(2,2) << 1.0,2.0,3.0,4.0).finished()); - Vector update = (Vector(4) << 3.0, 4.0, 6.0, 7.0).finished(); - - LieMatrix expected((Matrix(2,2) << 4.0, 6.0, 9.0, 11.0).finished()); - LieMatrix actual = traits::Retract(init,update); - - EXPECT(assert_equal(expected, actual)); - - Vector expectedUpdate = update; - Vector actualUpdate = traits::Local(init,actual); - - EXPECT(assert_equal(expectedUpdate, actualUpdate)); - - Vector expectedLogmap = (Vector(4) << 1, 2, 3, 4).finished(); - Vector actualLogmap = traits::Logmap(LieMatrix((Matrix(2,2) << 1.0, 2.0, 3.0, 4.0).finished())); - EXPECT(assert_equal(expectedLogmap, actualLogmap)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ - - diff --git a/gtsam/base/tests/testLieScalar.cpp b/gtsam/base/tests/testLieScalar.cpp deleted file mode 100644 index 74f5e0d41..000000000 --- a/gtsam/base/tests/testLieScalar.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieScalar.cpp - * @author Kai Ni - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieScalar) -GTSAM_CONCEPT_LIE_INST(LieScalar) - -const double tol=1e-9; - -//****************************************************************************** -TEST(LieScalar , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieScalar , Invariants) { - LieScalar lie1(2), lie2(3); - CHECK(check_group_invariants(lie1, lie2)); - CHECK(check_manifold_invariants(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, construction ) { - double d = 2.; - LieScalar lie1(d), lie2(d); - - EXPECT_DOUBLES_EQUAL(2., lie1.value(),tol); - EXPECT_DOUBLES_EQUAL(2., lie2.value(),tol); - EXPECT(traits::dimension == 1); - EXPECT(assert_equal(lie1, lie2)); -} - -/* ************************************************************************* */ -TEST( testLieScalar, localCoordinates ) { - LieScalar lie1(1.), lie2(3.); - - Vector1 actual = traits::Local(lie1, lie2); - EXPECT( assert_equal((Vector)(Vector(1) << 2).finished(), actual)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/base/tests/testLieVector.cpp b/gtsam/base/tests/testLieVector.cpp deleted file mode 100644 index 76c4fc490..000000000 --- a/gtsam/base/tests/testLieVector.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testLieVector.cpp - * @author Alex Cunningham - */ - -#include -#include -#include -#include - -using namespace gtsam; - -GTSAM_CONCEPT_TESTABLE_INST(LieVector) -GTSAM_CONCEPT_LIE_INST(LieVector) - -//****************************************************************************** -TEST(LieVector , Concept) { - BOOST_CONCEPT_ASSERT((IsGroup)); - BOOST_CONCEPT_ASSERT((IsManifold)); - BOOST_CONCEPT_ASSERT((IsLieGroup)); -} - -//****************************************************************************** -TEST(LieVector , Invariants) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - check_manifold_invariants(lie1, lie2); -} - -//****************************************************************************** -TEST( testLieVector, construction ) { - Vector v = Vector3(1.0, 2.0, 3.0); - LieVector lie1(v), lie2(v); - - EXPECT(lie1.dim() == 3); - EXPECT(assert_equal(v, lie1.vector())); - EXPECT(assert_equal(lie1, lie2)); -} - -//****************************************************************************** -TEST( testLieVector, other_constructors ) { - Vector init = Vector2(10.0, 20.0); - LieVector exp(init); - double data[] = { 10, 20 }; - LieVector b(2, data); - EXPECT(assert_equal(exp, b)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ - diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index a7c218705..7802f27e1 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -173,7 +173,7 @@ TEST(Matrix, stack ) { Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished(); Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished(); - Matrix AB = stack(2, &A, &B); + Matrix AB = gtsam::stack(2, &A, &B); Matrix C(5, 2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -187,7 +187,7 @@ TEST(Matrix, stack ) std::vector matrices; matrices.push_back(A); matrices.push_back(B); - Matrix AB2 = stack(matrices); + Matrix AB2 = gtsam::stack(matrices); EQUALITY(C,AB2); } diff --git a/gtsam/base/tests/testTestableAssertions.cpp b/gtsam/base/tests/testTestableAssertions.cpp deleted file mode 100644 index 305aa7ca9..000000000 --- a/gtsam/base/tests/testTestableAssertions.cpp +++ /dev/null @@ -1,35 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file testTestableAssertions - * @author Alex Cunningham - */ - -#include -#include -#include - -using namespace gtsam; - -/* ************************************************************************* */ -TEST( testTestableAssertions, optional ) { - typedef boost::optional OptionalScalar; - LieScalar x(1.0); - OptionalScalar ox(x), dummy = boost::none; - EXPECT(assert_equal(ox, ox)); - EXPECT(assert_equal(x, ox)); - EXPECT(assert_equal(dummy, dummy)); -} - -/* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ diff --git a/gtsam/base/tests/testVector.cpp b/gtsam/base/tests/testVector.cpp index bd715e3cb..c87732b09 100644 --- a/gtsam/base/tests/testVector.cpp +++ b/gtsam/base/tests/testVector.cpp @@ -220,8 +220,8 @@ TEST(Vector, axpy ) Vector x = Vector3(10., 20., 30.); Vector y0 = Vector3(2.0, 5.0, 6.0); Vector y1 = y0, y2 = y0; - axpy(0.1,x,y1); - axpy(0.1,x,y2.head(3)); + y1 += 0.1 * x; + y2.head(3) += 0.1 * x; Vector expected = Vector3(3.0, 7.0, 9.0); EXPECT(assert_equal(expected,y1)); EXPECT(assert_equal(expected,Vector(y2))); diff --git a/gtsam/base/types.h b/gtsam/base/types.h index aaada3cee..a0d24f1a6 100644 --- a/gtsam/base/types.h +++ b/gtsam/base/types.h @@ -34,6 +34,14 @@ #include #endif +#if defined(__GNUC__) || defined(__clang__) +#define GTSAM_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define GTSAM_DEPRECATED __declspec(deprecated) +#else +#define GTSAM_DEPRECATED +#endif + #ifdef GTSAM_USE_EIGEN_MKL_OPENMP #include #endif diff --git a/gtsam/base/utilities.cpp b/gtsam/base/utilities.cpp new file mode 100644 index 000000000..189156c91 --- /dev/null +++ b/gtsam/base/utilities.cpp @@ -0,0 +1,13 @@ +#include + +namespace gtsam { + +std::string RedirectCout::str() const { + return ssBuffer_.str(); +} + +RedirectCout::~RedirectCout() { + std::cout.rdbuf(coutBuffer_); +} + +} diff --git a/gtsam/base/utilities.h b/gtsam/base/utilities.h index 8eb5617a8..d9b92b8aa 100644 --- a/gtsam/base/utilities.h +++ b/gtsam/base/utilities.h @@ -1,5 +1,9 @@ #pragma once +#include +#include +#include + namespace gtsam { /** * For Python __str__(). @@ -12,14 +16,10 @@ struct RedirectCout { RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {} /// return the string - std::string str() const { - return ssBuffer_.str(); - } + std::string str() const; /// destructor -- redirect stdout buffer to its original buffer - ~RedirectCout() { - std::cout.rdbuf(coutBuffer_); - } + ~RedirectCout(); private: std::stringstream ssBuffer_; diff --git a/gtsam/basis/ParameterMatrix.h b/gtsam/basis/ParameterMatrix.h index df2d9f62e..eddcbfeae 100644 --- a/gtsam/basis/ParameterMatrix.h +++ b/gtsam/basis/ParameterMatrix.h @@ -153,7 +153,7 @@ class ParameterMatrix { return matrix_ * other; } - /// @name Vector Space requirements, following LieMatrix + /// @name Vector Space requirements /// @{ /** diff --git a/gtsam/basis/basis.i b/gtsam/basis/basis.i index 8f06fd2e1..c9c027438 100644 --- a/gtsam/basis/basis.i +++ b/gtsam/basis/basis.i @@ -140,7 +140,7 @@ class FitBasis { static gtsam::GaussianFactorGraph::shared_ptr LinearGraph( const std::map& sequence, const gtsam::noiseModel::Base* model, size_t N); - Parameters parameters() const; + This::Parameters parameters() const; }; } // namespace gtsam diff --git a/gtsam/config.h.in b/gtsam/config.h.in index e7623c52b..d47329a62 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -70,7 +70,7 @@ #cmakedefine GTSAM_THROW_CHEIRALITY_EXCEPTION // Make sure dependent projects that want it can see deprecated functions -#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#cmakedefine GTSAM_ALLOW_DEPRECATED_SINCE_V42 // Support Metis-based nested dissection #cmakedefine GTSAM_SUPPORT_NESTED_DISSECTION diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..828f0b1a2 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -18,8 +18,13 @@ #pragma once +#include #include +#include +#include +#include +#include namespace gtsam { /** @@ -27,21 +32,28 @@ namespace gtsam { * Just has some nice constructors and some syntactic sugar * TODO: consider eliminating this class altogether? */ - template - class AlgebraicDecisionTree: public DecisionTree { + template + class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } - public: - - typedef DecisionTree Super; + public: + using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { - static inline double zero() { - return 0.0; - } - static inline double one() { - return 1.0; - } + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } static inline double add(const double& a, const double& b) { return a + b; } @@ -54,63 +66,68 @@ namespace gtsam { static inline double div(const double& a, const double& b) { return a / b; } - static inline double id(const double& x) { - return x; - } + static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : - Super(1.0) { - } + AlgebraicDecisionTree() : Base(1.0) {} - AlgebraicDecisionTree(const Super& add) : - Super(add) { - } + // Explicitly non-explicit constructor + AlgebraicDecisionTree(const Base& add) : Base(add) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const L& label, double y1, double y2) : - Super(label, y1, y2) { - } + AlgebraicDecisionTree(const L& label, double y1, double y2) + : Base(label, y1, y2) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : - Super(labelC, y1, y2) { - } + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, + double y2) + : Base(labelC, y1, y2) {} /** Create from keys and vector table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::vector& ys) { + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); std::copy(std::istream_iterator(iss), - std::istream_iterator(), std::back_inserter(ys)); + std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ - template - AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Super(nullptr) { + template + AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) + : Base(nullptr) { this->root_ = compose(begin, end, label); } - /** Convert */ - template + /** + * Convert labels from type M to type L. + * + * @param other: The AlgebraicDecisionTree with label type M to convert. + * @param map: Map from label type M to label type L. + */ + template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + // Functor for label conversion so we can use `convertFrom`. + std::function L_of_M = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = this->template convertFrom(other.root_, L_of_M, op); } /** sum */ @@ -134,12 +151,31 @@ namespace gtsam { } /** sum out variable */ - AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const { return this->combine(labelC, &Ring::add); } - }; -// AlgebraicDecisionTree + /// print method customized to value type `double`. + void print(const std::string& s, + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { + auto valueFormatter = [](const double& v) { + return (boost::format("%4.4g") % v).str(); + }; + Base::print(s, labelFormatter, valueFormatter); + } -} -// namespace gtsam + /// Equality method customized to value type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const { + // lambda for comparison of two doubles upto some tolerance. + auto compare = [tol](double a, double b) { + return std::abs(a - b) < tol; + }; + return Base::equals(other, compare); + } + }; + +template +struct traits> + : public Testable> {}; +} // namespace gtsam diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 3665d6dfa..cdbf0a2e9 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -19,32 +19,30 @@ #pragma once #include -#include #include - +#include +#include namespace gtsam { - /** - * An assignment from labels to value index (size_t). - * Assigns to each label a value. Implemented as a simple map. - * A discrete factor takes an Assignment and returns a value. - */ - template - class Assignment: public std::map { - public: - void print(const std::string& s = "Assignment: ") const { - std::cout << s << ": "; - for(const typename Assignment::value_type& keyValue: *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; - std::cout << std::endl; - } - - bool equals(const Assignment& other, double tol = 1e-9) const { - return (*this == other); - } - }; //Assignment +/** + * An assignment from labels to value index (size_t). + * Assigns to each label a value. Implemented as a simple map. + * A discrete factor takes an Assignment and returns a value. + */ +template +class Assignment : public std::map { + public: + void print(const std::string& s = "Assignment: ") const { + std::cout << s << ": "; + for (const typename Assignment::value_type& keyValue : *this) + std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + std::cout << std::endl; + } + bool equals(const Assignment& other, double tol = 1e-9) const { + return (*this == other); + } /** * @brief Get Cartesian product consisting all possible configurations @@ -58,29 +56,28 @@ namespace gtsam { * variables with each having cardinalities 4, we get 4096 possible * configurations!! */ - template - std::vector > cartesianProduct( - const std::vector >& keys) { - std::vector > allPossValues; - Assignment values; + template > + static std::vector CartesianProduct( + const std::vector>& keys) { + std::vector allPossValues; + Derived values; typedef std::pair DiscreteKey; - for(const DiscreteKey& key: keys) - values[key.first] = 0; //Initialize from 0 + for (const DiscreteKey& key : keys) + values[key.first] = 0; // Initialize from 0 while (1) { allPossValues.push_back(values); size_t j = 0; for (j = 0; j < keys.size(); j++) { L idx = keys[j].first; values[idx]++; - if (values[idx] < keys[j].second) - break; - //Wrap condition + if (values[idx] < keys[j].second) break; + // Wrap condition values[idx] = 0; } - if (j == keys.size()) - break; + if (j == keys.size()) break; } return allPossValues; } +}; // Assignment -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebf..01c7b689c 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,42 +20,45 @@ #pragma once #include -#include +#include +#include #include +#include +#include #include #include -#include -using boost::assign::operator+=; +#include #include -#include - -#include #include #include +#include +#include +#include #include +#include +#include + +using boost::assign::operator+=; namespace gtsam { - /*********************************************************************************/ + /****************************************************************************/ // Node - /*********************************************************************************/ + /****************************************************************************/ #ifdef DT_DEBUG_MEMORY template int DecisionTree::Node::nrNodes = 0; #endif - /*********************************************************************************/ + /****************************************************************************/ // Leaf - /*********************************************************************************/ - template - class DecisionTree::Leaf: public DecisionTree::Node { - + /****************************************************************************/ + template + struct DecisionTree::Leaf : public DecisionTree::Node { /** constant stored in this leaf */ Y constant_; - public: - /** Constructor from constant */ Leaf(const Y& constant) : constant_(constant) {} @@ -76,23 +79,26 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Leaf* other = dynamic_cast (&q); + bool equals(const Node& q, const CompareFunc& compare) const override { + const Leaf* other = dynamic_cast(&q); if (!other) return false; - return std::abs(double(this->constant_ - other->constant_)) < tol; + return compare(this->constant_, other->constant_); } /** print */ - void print(const std::string& s) const override { - bool showZero = true; - if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { + std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; } - /** to graphviz file */ - void dot(std::ostream& os, bool showZero) const override { - if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" - << boost::format("%4.2g") % constant_ - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + /** Write graphviz format to stream `os`. */ + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { + std::string value = valueFormatter(constant_); + if (showZero || value.compare("0")) + os << "\"" << this->id() << "\" [label=\"" << value + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; } /** evaluate */ @@ -117,13 +123,13 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL return h; } // If second argument is a Choice node, call it's apply with leaf as second NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - return fC.apply_fC_op_gL(*this, op); // operand order back to normal + return fC.apply_fC_op_gL(*this, op); // operand order back to normal } /** choose a branch, create new memory ! */ @@ -132,32 +138,30 @@ namespace gtsam { } bool isLeaf() const override { return true; } + }; // Leaf - }; // Leaf - - /*********************************************************************************/ + /****************************************************************************/ // Choice - /*********************************************************************************/ + /****************************************************************************/ template - class DecisionTree::Choice: public DecisionTree::Node { - + struct DecisionTree::Choice: public DecisionTree::Node { /** the label of the variable on which we split */ L label_; /** The children of this Choice node. */ std::vector branches_; - private: + private: /** incremental allSame */ size_t allSame_; - typedef boost::shared_ptr ChoicePtr; - - public: + using ChoicePtr = boost::shared_ptr; + public: ~Choice() override { #ifdef DT_DEBUG_MEMORY - std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; + std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() + << std::std::endl; #endif } @@ -168,7 +172,8 @@ namespace gtsam { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; assert(f0->isLeaf()); - NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + NodePtr newLeaf( + new Leaf(boost::dynamic_pointer_cast(f0)->constant())); return newLeaf; } else #endif @@ -188,7 +193,6 @@ namespace gtsam { */ Choice(const Choice& f, const Choice& g, const Binary& op) : allSame_(true) { - // Choose what to do based on label if (f.label() > g.label()) { // f higher than g @@ -236,32 +240,38 @@ namespace gtsam { } /** print (as a tree) */ - void print(const std::string& s) const override { + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; - // std::cout << this << ","; - std::cout << label_ << ") " << std::endl; + std::cout << labelFormatter(label_) << ") " << std::endl; for (size_t i = 0; i < branches_.size(); i++) - branches_[i]->print((boost::format("%s %d") % s % i).str()); + branches_[i]->print((boost::format("%s %d") % s % i).str(), + labelFormatter, valueFormatter); } /** output to graphviz (as a a graph) */ - void dot(std::ostream& os, bool showZero) const override { + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { - const Leaf* leaf = dynamic_cast (branch.get()); - if (leaf && !leaf->constant()) continue; + const Leaf* leaf = dynamic_cast(branch.get()); + if (leaf && valueFormatter(leaf->constant()).compare("0")) continue; } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; - branch->dot(os, showZero); + branch->dot(os, labelFormatter, valueFormatter, showZero); } } @@ -275,15 +285,16 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + /** equality */ + bool equals(const Node& q, const CompareFunc& compare) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), compare))) + return false; return true; } @@ -307,15 +318,13 @@ namespace gtsam { */ Choice(const L& label, const Choice& f, const Unary& op) : label_(label), allSame_(true) { - - branches_.reserve(f.branches_.size()); // reserve space - for (const NodePtr& branch: f.branches_) - push_back(branch->apply(op)); + branches_.reserve(f.branches_.size()); // reserve space + for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); } /** apply unary operator */ NodePtr apply(const Unary& op) const override { - boost::shared_ptr r(new Choice(label_, *this, op)); + auto r = boost::make_shared(label_, *this, op); return Unique(r); } @@ -330,44 +339,42 @@ namespace gtsam { // If second argument of binary op is Leaf node, recurse on branches NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(NodePtr branch: branches_) - h->push_back(fL.apply_f_op_g(*branch, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); return Unique(h); } // If second argument of binary op is Choice, call constructor NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - boost::shared_ptr h(new Choice(fC, *this, op)); + auto h = boost::make_shared(fC, *this, op); return Unique(h); } // If second argument of binary op is Leaf template NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { - boost::shared_ptr h(new Choice(label(), nrChoices())); - for(const NodePtr& branch: branches_) - h->push_back(branch->apply_f_op_g(gL, op)); + auto h = boost::make_shared(label(), nrChoices()); + for (auto&& branch : branches_) + h->push_back(branch->apply_f_op_g(gL, op)); return Unique(h); } /** choose a branch, recursively */ NodePtr choose(const L& label, size_t index) const override { - if (label_ == label) - return branches_[index]; // choose branch + if (label_ == label) return branches_[index]; // choose branch // second case, not label of interest, just recurse - boost::shared_ptr r(new Choice(label_, branches_.size())); - for(const NodePtr& branch: branches_) - r->push_back(branch->choose(label, index)); + auto r = boost::make_shared(label_, branches_.size()); + for (auto&& branch : branches_) + r->push_back(branch->choose(label, index)); return Unique(r); } + }; // Choice - }; // Choice - - /*********************************************************************************/ + /****************************************************************************/ // DecisionTree - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree() { } @@ -377,37 +384,36 @@ namespace gtsam { root_(root) { } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const Y& y) { root_ = NodePtr(new Leaf(y)); } - /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const L& label, const Y& y1, const Y& y2) { - boost::shared_ptr a(new Choice(label, 2)); + /****************************************************************************/ + template + DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { + auto a = boost::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); root_ = Choice::Unique(a); } - /*********************************************************************************/ - template - DecisionTree::DecisionTree(// - const LabelC& labelC, const Y& y1, const Y& y2) { + /****************************************************************************/ + template + DecisionTree::DecisionTree(const LabelC& labelC, const Y& y1, + const Y& y2) { if (labelC.second != 2) throw std::invalid_argument( "DecisionTree: binary constructor called with non-binary label"); - boost::shared_ptr a(new Choice(labelC.first, 2)); + auto a = boost::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); a->push_back(l1); a->push_back(l2); root_ = Choice::Unique(a); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::vector& ys) { @@ -415,29 +421,28 @@ namespace gtsam { root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::string& table) { - // Convert std::string to values of type Y std::vector ys; std::istringstream iss(table); copy(std::istream_iterator(iss), std::istream_iterator(), - back_inserter(ys)); + back_inserter(ys)); // now call recursive Create root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { root_ = compose(begin, end, label); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { @@ -446,24 +451,35 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } - /*********************************************************************************/ - template - template - DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); + /****************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + Func Y_of_X) { + // Define functor for identity mapping of node label. + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(other.root_, L_of_L, Y_of_X); } - /*********************************************************************************/ - // Called by two constructors above. - // Takes a label and a corresponding range of decision trees, and creates a new - // decision tree. However, the order of the labels needs to be respected, so we - // cannot just create a root Choice node on the label: if the label is not the - // highest label, we need to do a complicated and expensive recursive call. - template template - typename DecisionTree::NodePtr DecisionTree::compose(Iterator begin, - Iterator end, const L& label) const { + /****************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, Func Y_of_X) { + auto L_of_M = [&map](const M& label) -> L { return map.at(label); }; + root_ = convertFrom(other.root_, L_of_M, Y_of_X); + } + /****************************************************************************/ + // Called by two constructors above. + // Takes a label and a corresponding range of decision trees, and creates a + // new decision tree. However, the order of the labels needs to be respected, + // so we cannot just create a root Choice node on the label: if the label is + // not the highest label, we need a complicated/ expensive recursive call. + template + template + typename DecisionTree::NodePtr DecisionTree::compose( + Iterator begin, Iterator end, const L& label) const { // find highest label among branches boost::optional highestLabel; size_t nrChoices = 0; @@ -480,13 +496,14 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { - boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + auto choiceOnLabel = boost::make_shared(label, end - begin); for (Iterator it = begin; it != end; it++) choiceOnLabel->push_back(it->root_); return Choice::Unique(choiceOnLabel); } else { // Set up a new choice on the highest label - boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, nrChoices)); + auto choiceOnHighestLabel = + boost::make_shared(*highestLabel, nrChoices); // now, for all possible values of highestLabel for (size_t index = 0; index < nrChoices; index++) { // make a new set of functions for composing by iterating over the given @@ -505,7 +522,7 @@ namespace gtsam { } } - /*********************************************************************************/ + /****************************************************************************/ // "create" is a bit of a complicated thing, but very useful. // It takes a range of labels and a corresponding range of values, // and creates a decision tree, as follows: @@ -530,7 +547,6 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) const { - // get crucial counts size_t nrChoices = begin->second; size_t size = endY - beginY; @@ -542,10 +558,14 @@ namespace gtsam { // Create a simple choice node with values as leaves. if (size != nrChoices) { std::cout << "Trying to create DD on " << begin->first << std::endl; - std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; + std::cout << boost::format( + "DecisionTree::create: expected %d values but got %d " + "instead") % + nrChoices % size + << std::endl; throw std::invalid_argument("DecisionTree::create invalid argument"); } - boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + auto choice = boost::make_shared(begin->first, endY - beginY); for (ValueIt y = beginY; y != endY; y++) choice->push_back(NodePtr(new Leaf(*y))); return Choice::Unique(choice); @@ -558,56 +578,140 @@ namespace gtsam { size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = create(labelC, end, beginY, beginY + split); - functions += DecisionTree(f); + functions.emplace_back(f); } return compose(functions.begin(), functions.end(), begin->first); } - /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { + /****************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const { + using LY = DecisionTree; - typedef DecisionTree MX; - typedef typename MX::Leaf MXLeaf; - typedef typename MX::Choice MXChoice; - typedef typename MX::NodePtr MXNodePtr; - typedef DecisionTree LY; - - // ugliness below because apparently we can't have templated virtual functions - // If leaf, apply unary conversion "op" and create a unique leaf - const MXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + // ugliness below because apparently we can't have templated virtual + // functions If leaf, apply unary conversion "op" and create a unique leaf + using MXLeaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(f)) + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + using MXChoice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( - "DecisionTree::Convert: Invalid NodePtr"); + "DecisionTree::convertFrom: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); - functions += converted; + for (auto&& branch : choice->branches()) { + functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } return LY::compose(functions.begin(), functions.end(), newLabel); } - /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + /****************************************************************************/ + // Functor performing depth-first visit without Assignment argument. + template + struct Visit { + using F = std::function; + explicit Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr"); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); } - template - void DecisionTree::print(const std::string& s) const { - root_->print(s); + /****************************************************************************/ + // Functor performing depth-first visit with Assignment argument. + template + struct VisitWith { + using Choices = Assignment; + using F = std::function; + explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(choices, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); + for (size_t i = 0; i < choice->nrChoices(); i++) { + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /****************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /****************************************************************************/ + // labels is just done with a visit + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& choices, const Y&) { + for (auto&& kv : choices) unique.insert(kv.first); + }; + visitWith(f); + return unique; + } + +/****************************************************************************/ + template + bool DecisionTree::equals(const DecisionTree& other, + const CompareFunc& compare) const { + return root_->equals(*other.root_, compare); + } + + template + void DecisionTree::print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const { + root_->print(s, labelFormatter, valueFormatter); } template @@ -622,13 +726,23 @@ namespace gtsam { template DecisionTree DecisionTree::apply(const Unary& op) const { + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } return DecisionTree(root_->apply(op)); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { + // It is unclear what should happen if either tree is empty: + if (empty() || g.empty()) { + throw std::runtime_error( + "DecisionTree::apply(binary op) undefined for empty trees."); + } // apply the operaton on the root of both diagrams NodePtr h = root_->apply_f_op_g(*g.root_, op); // create a new class with the resulting root "h" @@ -636,7 +750,7 @@ namespace gtsam { return result; } - /*********************************************************************************/ + /****************************************************************************/ // The way this works: // We have an ADT, picture it as a tree. // At a certain depth, we have a branch on "label". @@ -656,25 +770,40 @@ namespace gtsam { return result; } - /*********************************************************************************/ - template - void DecisionTree::dot(std::ostream& os, bool showZero) const { + /****************************************************************************/ + template + void DecisionTree::dot(std::ostream& os, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { os << "digraph G {\n"; - root_->dot(os, showZero); + root_->dot(os, labelFormatter, valueFormatter, showZero); os << " [ordering=out]}" << std::endl; } - template - void DecisionTree::dot(const std::string& name, bool showZero) const { + template + void DecisionTree::dot(const std::string& name, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { std::ofstream os((name + ".dot").c_str()); - dot(os, showZero); - int result = system( - ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); - if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + dot(os, labelFormatter, valueFormatter, showZero); + int result = + system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null") + .c_str()); + if (result == -1) + throw std::runtime_error("DecisionTree::dot system call failed"); + } -/*********************************************************************************/ - -} // namespace gtsam + template + std::string DecisionTree::dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const { + std::stringstream ss; + dot(ss, labelFormatter, valueFormatter, showZero); + return ss.str(); + } +/******************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0..d655756b8 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,12 +19,17 @@ #pragma once +#include #include #include #include #include #include +#include +#include +#include +#include #include namespace gtsam { @@ -36,24 +41,31 @@ namespace gtsam { */ template class DecisionTree { + protected: + /// Default method for comparison of two objects of type Y. + static bool DefaultCompare(const Y& a, const Y& b) { + return a == b; + } - public: + public: + using LabelFormatter = std::function; + using ValueFormatter = std::function; + using CompareFunc = std::function; /** Handy typedefs for unary and binary function types */ - typedef std::function Unary; - typedef std::function Binary; + using Unary = std::function; + using Binary = std::function; /** A label annotated with cardinality */ - typedef std::pair LabelC; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ - class Leaf; - class Choice; + struct Leaf; + struct Choice; /** ------------------------ Node base class --------------------------- */ - class Node { - public: - typedef boost::shared_ptr Ptr; + struct Node { + using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -62,14 +74,16 @@ namespace gtsam { // Constructor Node() { #ifdef DT_DEBUG_MEMORY - std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); + std::cout << ++nrNodes << " constructed " << id() << std::endl; + std::cout.flush(); #endif } // Destructor virtual ~Node() { #ifdef DT_DEBUG_MEMORY - std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); + std::cout << --nrNodes << " destructed " << id() << std::endl; + std::cout.flush(); #endif } @@ -77,11 +91,16 @@ namespace gtsam { const void* id() const { return this; } // everything else is virtual, no documentation here as internal - virtual void print(const std::string& s = "") const = 0; - virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual void print(const std::string& s, + const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const = 0; + virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, const CompareFunc& compare = + &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -92,35 +111,44 @@ namespace gtsam { }; /** ------------------------ Node base class --------------------------- */ - public: - + public: /** A function is a shared pointer to the root of a DT */ - typedef typename Node::Ptr NodePtr; + using NodePtr = typename Node::Ptr; - /* a DecisionTree just contains the root */ + /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; - protected: - - /** Internal recursive function to create from keys, cardinalities, and Y values */ + protected: + /** Internal recursive function to create from keys, cardinalities, + * and Y values + */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - - /** Default constructor */ - DecisionTree(); - - public: + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param L_of_M Functor to convert from label type M to type L. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; + public: /// @name Standard Constructors /// @{ + /** Default constructor (for serialization) */ + DecisionTree(); + /** Create a constant */ - DecisionTree(const Y& y); + explicit DecisionTree(const Y& y); /** Create a new leaf function splitting on a variable */ DecisionTree(const L& label, const Y& y1, const Y& y2); @@ -139,23 +167,50 @@ namespace gtsam { DecisionTree(Iterator begin, Iterator end, const L& label); /** Create DecisionTree from two others */ - DecisionTree(const L& label, // - const DecisionTree& f0, const DecisionTree& f1); + DecisionTree(const L& label, const DecisionTree& f0, + const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + /** + * @brief Convert from a different value type. + * + * @tparam X The previous value type. + * @param other The DecisionTree to convert from. + * @param Y_of_X Functor to convert from value type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, Func Y_of_X); + + /** + * @brief Convert from a different value type X to value type Y, also transate + * labels via map from type M to L. + * + * @tparam M Previous label type. + * @tparam X Previous value type. + * @param other The decision tree to convert. + * @param L_of_M Map from label type M to type L. + * @param Y_of_X Functor to convert from type X to type Y. + */ + template + DecisionTree(const DecisionTree& other, const std::map& map, + Func Y_of_X); /// @} /// @name Testable /// @{ - /** GTSAM-style print */ - void print(const std::string& s = "DecisionTree") const; + /** + * @brief GTSAM-style print + * + * @param s Prefix string. + * @param labelFormatter Functor to format the node label. + * @param valueFormatter Functor to format the node value. + */ + void print(const std::string& s, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, + const CompareFunc& compare = &DefaultCompare) const; /// @} /// @name Standard Interface @@ -165,12 +220,66 @@ namespace gtsam { virtual ~DecisionTree() { } + /// Check if tree is empty. + bool empty() const { return !root_; } + /** equality */ bool operator==(const DecisionTree& q) const; /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking a value. + * + * @note Due to pruning, leaves might not exhaust choices. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking an assignment and a value. + * + * @note Due to pruning, leaves might not exhaust choices. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + * @note Due to pruning, leaves might not exhaust choices. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all unique labels as a set. */ + std::set labels() const; + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; @@ -185,7 +294,8 @@ namespace gtsam { } /** combine subtrees on key with binary operation "op" */ - DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; + DecisionTree combine(const L& label, size_t cardinality, + const Binary& op) const; /** combine with LabelC for convenience */ DecisionTree combine(const LabelC& labelC, const Binary& op) const { @@ -193,38 +303,61 @@ namespace gtsam { } /** output to graphviz format, stream version */ - void dot(std::ostream& os, bool showZero = true) const; + void dot(std::ostream& os, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; /** output to graphviz format, open a file */ - void dot(const std::string& name, bool showZero = true) const; + void dot(const std::string& name, const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const LabelFormatter& labelFormatter, + const ValueFormatter& valueFormatter, + bool showZero = true) const; /// @name Advanced Interface /// @{ // internal use only - DecisionTree(const NodePtr& root); + explicit DecisionTree(const NodePtr& root); // internal use only template NodePtr compose(Iterator begin, Iterator end, const L& label) const; /// @} - - }; // DecisionTree + }; // DecisionTree /** free versions of apply */ - template + /// Apply unary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const typename DecisionTree::Unary& op) { return f.apply(op); } - template + /// Apply binary operator `op` to DecisionTree `f`. + template DecisionTree apply(const DecisionTree& f, const DecisionTree& g, const typename DecisionTree::Binary& op) { return f.apply(g, op); } -} // namespace gtsam + /** + * @brief unzip a DecisionTree with `std::pair` values. + * + * @param input the DecisionTree with `(T1,T2)` values. + * @return a pair of DecisionTree on T1 and T2, respectively. + */ + template + std::pair, DecisionTree > unzip( + const DecisionTree >& input) { + return std::make_pair( + DecisionTree(input, [](std::pair i) { return i.first; }), + DecisionTree(input, + [](std::pair i) { return i.second; })); + } + +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index b7b9d7034..ef4cc48f6 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -17,74 +17,90 @@ * @author Frank Dellaert */ +#include #include #include -#include #include +#include +#include using namespace std; namespace gtsam { - /* ******************************************************************************** */ - DecisionTreeFactor::DecisionTreeFactor() { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor() {} - /* ******************************************************************************** */ + /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) : - DiscreteFactor(keys.indices()), Potentials(keys, potentials) { - } + const ADT& potentials) + : DiscreteFactor(keys.indices()), + ADT(potentials), + cardinalities_(keys.cardinalities()) {} - /* *************************************************************************/ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), Potentials(c) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) + : DiscreteFactor(c.keys()), + AlgebraicDecisionTree(c), + cardinalities_(c.cardinalities_) {} - /* ************************************************************************* */ - bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { - if(!dynamic_cast(&other)) { + /* ************************************************************************ */ + bool DecisionTreeFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { return false; - } - else { - const DecisionTreeFactor& f(static_cast(other)); - return Potentials::equals(f, tol); + } else { + const auto& f(static_cast(other)); + return ADT::equals(f, tol); } } - /* ************************************************************************* */ + /* ************************************************************************ */ + double DecisionTreeFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ************************************************************************ */ void DecisionTreeFactor::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { cout << s; - Potentials::print("Potentials:",formatter); + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + ADT::print("", formatter); } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { - map cs; // new cardinalities + ADT::Binary op) const { + map cs; // new cardinalities // make unique key-cardinality map - for(Key j: keys()) cs[j] = cardinality(j); - for(Key j: f.keys()) cs[j] = f.cardinality(j); + for (Key j : keys()) cs[j] = cardinality(j); + for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for(const std::pair& key: cs) - keys.push_back(key); + for (const std::pair& key : cs) keys.push_back(key); // apply operand ADT result = ADT::apply(f, op); // Make a new factor return DecisionTreeFactor(keys, result); } - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, - ADT::Binary op) const { - - if (nrFrontals > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % nrFrontals % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + size_t nrFrontals, ADT::Binary op) const { + if (nrFrontals > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + nrFrontals % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -98,20 +114,21 @@ namespace gtsam { DiscreteKeys dkeys; for (; i < keys().size(); i++) { Key j = keys()[i]; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, - ADT::Binary op) const { - - if (frontalKeys.size() > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % frontalKeys.size() % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + const Ordering& frontalKeys, ADT::Binary op) const { + if (frontalKeys.size() > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + frontalKeys.size() % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -122,17 +139,155 @@ namespace gtsam { } // create new factor, note we collect keys that are not in frontalKeys - // TODO: why do we need this??? result should contain correct keys!!! + // TODO(frank): why do we need this??? result should contain correct keys!!! DiscreteKeys dkeys; for (i = 0; i < keys().size(); i++) { Key j = keys()[i]; - // TODO: inefficient! - if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) + // TODO(frank): inefficient! + if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != + frontalKeys.end()) continue; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } -/* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************ */ + std::vector> DecisionTreeFactor::enumerate() + const { + // Get all possible assignments + std::vector> pairs; + for (auto& key : keys()) { + pairs.emplace_back(key, cardinalities_.at(key)); + } + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + + /* ************************************************************************ */ + DiscreteKeys DecisionTreeFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; + } + + /* ************************************************************************ */ + static std::string valueFormatter(const double& v) { + return (boost::format("%4.2g") % v).str(); + } + + /** output to graphviz format, stream version */ + void DecisionTreeFactor::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(os, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format, open a file */ + void DecisionTreeFactor::dot(const std::string& name, + const KeyFormatter& keyFormatter, + bool showZero) const { + ADT::dot(name, keyFormatter, valueFormatter, showZero); + } + + /** output to graphviz format string */ + std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, + bool showZero) const { + return ADT::dot(keyFormatter, valueFormatter, showZero); + } + + // Print out header. + /* ************************************************************************ */ + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header. + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; + + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << "|"; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; + } + ss << kv.second << "|\n"; + } + return ss.str(); + } + + /* ************************************************************************ */ + string DecisionTreeFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + auto rows = enumerate(); + for (const auto& kv : rows) { + ss << " "; + auto assignment = kv.first; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << ""; + } + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << DiscreteValues::Translate(names, key, index) << "" << kv.second << "
\n
"; + return ss.str(); + } + + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const vector& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} + + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const string& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} + + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index d1696a281..91fa7c484 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,15 +18,18 @@ #pragma once +#include #include -#include +#include #include +#include #include - -#include -#include +#include #include +#include +#include +#include namespace gtsam { @@ -35,34 +38,46 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials { - - public: - + class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor, + public AlgebraicDecisionTree { + public: // typedefs needed to play nice with gtsam typedef DecisionTreeFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class + typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; + typedef AlgebraicDecisionTree ADT; - public: + protected: + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ /** Default constructor for I/O */ DecisionTreeFactor(); - /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from Indices and (string or doubles) */ - template - DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : - DiscreteFactor(keys.indices()), Potentials(keys, table) { - } + /** Constructor from doubles */ + DecisionTreeFactor(const DiscreteKeys& keys, + const std::vector& table); + + /** Constructor from string */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); + + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} /** Construct from a DiscreteConditional type */ - DecisionTreeFactor(const DiscreteConditional& c); + explicit DecisionTreeFactor(const DiscreteConditional& c); /// @} /// @name Testable @@ -72,7 +87,8 @@ namespace gtsam { bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; // print - void print(const std::string& s = "DecisionTreeFactor:\n", + void print( + const std::string& s = "DecisionTreeFactor:\n", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// @} @@ -80,8 +96,8 @@ namespace gtsam { /// @{ /// Value is just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { - return Potentials::operator()(values); + double operator()(const DiscreteValues& values) const override { + return ADT::operator()(values); } /// multiply two factors @@ -89,15 +105,17 @@ namespace gtsam { return apply(f, ADT::Ring::mul); } + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); } /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - return *this; - } + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values shared_ptr sum(size_t nrFrontals) const { @@ -109,11 +127,16 @@ namespace gtsam { return combine(keys, ADT::Ring::add); } - /// Create new factor by maximizing over all values with the same separator values + /// Create new factor by maximizing over all values with the same separator. shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, ADT::Ring::max); } + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, ADT::Ring::max); + } + /// @} /// @name Advanced Interface /// @{ @@ -121,14 +144,14 @@ namespace gtsam { /** * Apply binary operator (*this) "op" f * @param f the second argument for op - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree */ DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; /** * Combine frontal variables using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; @@ -136,37 +159,60 @@ namespace gtsam { /** * Combine frontal variables in an Ordering using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; -// /** -// * @brief Permutes the keys in Potentials and DiscreteFactor -// * -// * This re-implements the permuteWithInverse() in both Potentials -// * and DiscreteFactor by doing both of them together. -// */ -// -// void permuteWithInverse(const Permutation& inversePermutation){ -// DiscreteFactor::permuteWithInverse(inversePermutation); -// Potentials::permuteWithInverse(inversePermutation); -// } -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) { -// DiscreteFactor::reduceWithInverse(inverseReduction); -// Potentials::reduceWithInverse(inverseReduction); -// } + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; /// @} -}; -// DecisionTreeFactor + /// @name Wrapper support + /// @{ + + /** output to graphviz format, stream version */ + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** output to graphviz format string */ + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// @} + }; // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 84a80c565..ccc52585e 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -25,51 +25,78 @@ namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool DiscreteBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ -// void DiscreteBayesNet::add_front(const Signature& s) { -// push_front(boost::make_shared(s)); -// } - - /* ************************************************************************* */ - void DiscreteBayesNet::add(const Signature& s) { - push_back(boost::make_shared(s)); - } - - /* ************************************************************************* */ - double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const { - // evaluate all conditionals and multiply - double result = 1.0; - for(DiscreteConditional::shared_ptr conditional: *this) - result *= (*conditional)(values); - return result; - } - - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const { - // solve each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->solveInPlace(*result); - return result; - } - - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteBayesNet::sample() const { - // sample each node in turn in topological sort order (parents first) - DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); - for (auto conditional: boost::adaptors::reverse(*this)) - conditional->sampleInPlace(*result); - return result; - } +// Instantiate base class +template class FactorGraph; /* ************************************************************************* */ -} // namespace +bool DiscreteBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); +} + +/* ************************************************************************* */ +double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { + // evaluate all conditionals and multiply + double result = 1.0; + for (const DiscreteConditional::shared_ptr& conditional : *this) + result *= (*conditional)(values); + return result; +} + +/* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +DiscreteValues DiscreteBayesNet::optimize() const { + DiscreteValues result; + return optimize(result); +} + +DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { + // solve each node in turn in topological sort order (parents first) +#ifdef _MSC_VER +#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!") +#else +#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!" +#endif + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->solveInPlace(&result); + return result; +} +#endif + +/* ************************************************************************* */ +DiscreteValues DiscreteBayesNet::sample() const { + DiscreteValues result; + return sample(result); +} + +DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { + // sample each node in turn in topological sort order (parents first) + for (auto conditional : boost::adaptors::reverse(*this)) + conditional->sampleInPlace(&result); + return result; +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->markdown(keyFormatter, names) << endl; + return ss.str(); +} + +/* *********************************************************************** */ +std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteBayesNet of size " << size() << "

"; + for (const DiscreteConditional::shared_ptr& conditional : *this) + ss << conditional->html(keyFormatter, names) << endl; + return ss.str(); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index d5ba30584..df94d6908 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -13,25 +13,31 @@ * @file DiscreteBayesNet.h * @date Feb 15, 2011 * @author Duy-Nguyen Ta + * @author Frank dellaert */ #pragma once -#include -#include -#include +#include +#include #include #include -#include + +#include +#include +#include +#include +#include namespace gtsam { -/** A Bayes net made from linear-Discrete densities */ - class GTSAM_EXPORT DiscreteBayesNet: public BayesNet - { - public: - - typedef FactorGraph Base; +/** + * A Bayes net made from discrete conditional distributions. + * @addtogroup discrete + */ +class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { + public: + typedef BayesNet Base; typedef DiscreteBayesNet This; typedef DiscreteConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -40,20 +46,24 @@ namespace gtsam { /// @name Standard Constructors /// @{ - /** Construct empty factor graph */ + /// Construct empty Bayes net. DiscreteBayesNet() {} /** Construct from iterator over conditionals */ - template - DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit DiscreteBayesNet(const CONTAINER& conditionals) + : Base(conditionals) {} - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - DiscreteBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + DiscreteBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~DiscreteBayesNet() {} @@ -71,26 +81,73 @@ namespace gtsam { /// @name Standard Interface /// @{ + // Add inherited versions of add. + using Base::add; + + /** Add a DiscreteDistribution using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } + /** Add a DiscreteCondtional */ - void add(const Signature& s); + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + //** evaluate for given DiscreteValues */ + double evaluate(const DiscreteValues & values) const; -// /** Add a DiscreteCondtional in front, when listing parents first*/ -// GTSAM_EXPORT void add_front(const Signature& s); - - //** evaluate for given Values */ - double evaluate(const DiscreteConditional::Values & values) const; + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues & values) const { + return evaluate(values); + } /** - * Solve the DiscreteBayesNet by back-substitution - */ - DiscreteFactor::sharedValues optimize() const; + * @brief do ancestral sampling + * + * Assumes the Bayes net is reverse topologically sorted, i.e. last + * conditional will be sampled first. If the Bayes net resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return a sampled value for all variables. + */ + DiscreteValues sample() const; - /** Do ancestral sampling */ - DiscreteFactor::sharedValues sample() const; + /** + * @brief do ancestral sampling, given certain variables. + * + * Assumes the Bayes net is reverse topologically sorted *and* that the + * Bayes net does not contain any conditionals for the given values. + * + * @return given values extended with sampled value for all other variables. + */ + DiscreteValues sample(DiscreteValues given) const; + + ///@} + /// @name Wrapper support + /// @{ + + /// Render as markdown tables. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// Render as html tables. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; ///@} - private: +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + + DiscreteValues GTSAM_DEPRECATED optimize() const; + DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const; + /// @} +#endif + + private: /** Serialization function */ friend class boost::serialization::access; template diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 990d10dbe..139292eee 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -31,7 +31,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTreeClique::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { // evaluate all conditionals and multiply double result = (*conditional_)(values); for (const auto& child : children) { @@ -47,7 +47,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteBayesTree::evaluate( - const DiscreteConditional::Values& values) const { + const DiscreteValues& values) const { double result = 1.0; for (const auto& root : roots_) { result *= root->evaluate(values); @@ -55,8 +55,40 @@ namespace gtsam { return result; } -} // \namespace gtsam - - + /* **************************************************************************/ + std::string DiscreteBayesTree::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + std::string DiscreteBayesTree::html( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteBayesTree of size " << nodes_.size() + << "

"; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << clique->conditional()->html(keyFormatter, names); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 29da5817e..809ce9c83 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -57,8 +57,8 @@ class GTSAM_EXPORT DiscreteBayesTreeClique conditional_->printSignature(s, formatter); } - //** evaluate conditional probability of subtree for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate conditional probability of subtree for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; }; /* ************************************************************************* */ @@ -72,14 +72,35 @@ class GTSAM_EXPORT DiscreteBayesTree typedef DiscreteBayesTree This; typedef boost::shared_ptr shared_ptr; + /// @name Standard interface + /// @{ /** Default constructor, creates an empty Bayes tree */ DiscreteBayesTree() {} /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; - //** evaluate probability for given Values */ - double evaluate(const DiscreteConditional::Values& values) const; + //** evaluate probability for given DiscreteValues */ + double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } + + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown tables. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// Render as html tables. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index ac7c58405..06b2856f8 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -16,57 +16,119 @@ * @author Frank Dellaert */ +#include +#include #include #include #include -#include -#include - -#include #include +#include #include +#include #include #include +#include #include using namespace std; - +using std::pair; +using std::stringstream; +using std::vector; namespace gtsam { // Instantiate base class -template class Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ -DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} + +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; @@ -79,122 +141,196 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; - Potentials::print(""); + cout << "):\n"; + ADT::print("", formatter); cout << endl; } -/* ******************************************************************************** */ +/* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) + double tol) const { + if (!dynamic_cast(&other)) { return false; - else { - const DecisionTreeFactor& f( - static_cast(other)); + } else { + const DecisionTreeFactor& f(static_cast(other)); return DecisionTreeFactor::equals(f, tol); } } -/* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const { - ADT pFS(*this); - Key j; size_t value; - for(Key key: parents()) { +/* ************************************************************************** */ +DiscreteConditional::ADT DiscreteConditional::choose( + const DiscreteValues& given, bool forceComplete) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteConditional::ADT adt(*this); + size_t value; + for (Key j : parents()) { try { - j = (key); - value = parentsValues.at(j); - pFS = pFS.choose(j, value); - } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; - parentsValues.print("parentsValues: "); - // pFS.print("pFS: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; + value = given.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + if (forceComplete) { + given.print("parentsValues: "); + throw runtime_error( + "DiscreteConditional::choose: parent value missing"); + } + } } - - return pFS; + return adt; } -/* ******************************************************************************** */ -void DiscreteConditional::solveInPlace(Values& values) const { - // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(values); // P(F|S=parentsValues) +/* ************************************************************************** */ +DiscreteConditional::shared_ptr DiscreteConditional::choose( + const DiscreteValues& given) const { + ADT adt = choose(given, false); // P(F|S=given) + + // Collect all keys not in given. + DiscreteKeys dKeys; + for (Key j : frontals()) { + dKeys.emplace_back(j, this->cardinality(j)); + } + for (size_t i = nrFrontals(); i < size(); i++) { + Key j = keys_[i]; + if (given.count(j) == 0) { + dKeys.emplace_back(j, this->cardinality(j)); + } + } + return boost::make_shared(nrFrontals(), dKeys, adt); +} + +/* ************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + } + } + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); +} + +/* ****************************************************************************/ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t parent_value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], parent_value); + return likelihood(values); +} + +/* ************************************************************************** */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +void DiscreteConditional::solveInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize - Values mpe; + DiscreteValues mpe; double maxP = 0; - DiscreteKeys keys; - for(Key idx: frontals()) { - DiscreteKey dk(idx, cardinality(idx)); - keys & dk; - } // Get all Possible Configurations - vector allPosbValues = cartesianProduct(keys); + const auto allPosbValues = frontalAssignments(); - // Find the MPE - for(Values& frontalVals: allPosbValues) { - double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better if (pValueS > maxP) { maxP = pValueS; mpe = frontalVals; } } - //set values (inPlace) to mpe - for(Key j: frontals()) { - values[j] = mpe[j]; + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; } } -/* ******************************************************************************** */ -void DiscreteConditional::sampleInPlace(Values& values) const { - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - size_t sampled = sample(values); // Sample variable - values[j] = sampled; // store result in partial solution -} - -/* ******************************************************************************** */ -size_t DiscreteConditional::solve(const Values& parentsValues) const { - - // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) +/* ************************************************************************** */ +size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining - // TODO, only works for one key now, seems horribly slow this way - size_t mpe = 0; - Values frontals; + size_t max = 0; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { frontals[j] = value; double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update solution if better + if (pValueS > maxP) { + maxP = pValueS; + max = value; + } + } + return max; +} +#endif + +/* ************************************************************************** */ +size_t DiscreteConditional::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + assert(nrParents() == 0); + DiscreteValues frontals; + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = (*this)(frontals); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; - mpe = value; + maxValue = value; } } - return mpe; + return maxValue; } -/* ******************************************************************************** */ -size_t DiscreteConditional::sample(const Values& parentsValues) const { +/* ************************************************************************** */ +void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + size_t sampled = sample(*values); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution +} + +/* ************************************************************************** */ +size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); - Values frontals; + DiscreteValues frontals; for (size_t value = 0; value < nj; value++) { frontals[key] = value; p[value] = pFS(frontals); // P(F=value|S=parentsValues) @@ -206,6 +342,174 @@ size_t DiscreteConditional::sample(const Values& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} -}// namespace +/* ************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + +/* ************************************************************************* */ +vector DiscreteConditional::frontalAssignments() const { + vector> pairs; + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************* */ +vector DiscreteConditional::allAssignments() const { + vector> pairs; + for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key)); + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************* */ +// Print out signature. +static void streamSignature(const DiscreteConditional& conditional, + const KeyFormatter& keyFormatter, + stringstream* ss) { + *ss << "P("; + bool first = true; + for (Key key : conditional.frontals()) { + if (!first) *ss << ","; + *ss << keyFormatter(key); + first = false; + } + if (conditional.nrParents() > 0) { + *ss << "|"; + bool first = true; + for (Key parent : conditional.parents()) { + if (!first) *ss << ","; + *ss << keyFormatter(parent); + first = false; + } + } + *ss << "):"; +} + +/* ************************************************************************* */ +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << " *"; + streamSignature(*this, keyFormatter, &ss); + ss << "*\n" << std::endl; + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << DecisionTreeFactor::markdown(keyFormatter, names); + return ss.str(); + } + + // Print out header. + ss << "|"; + for (Key parent : parents()) { + ss << "*" << keyFormatter(parent) << "*|"; + } + + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index); + } + ss << "|"; + } + ss << "\n"; + + // Print out separator with alignment hints. + ss << "|"; + size_t n = frontalAssignments.size(); + for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; + ss << "\n"; + + // Print out all rows. + size_t count = 0; + for (const auto& a : allAssignments()) { + if (count == 0) { + ss << "|"; + for (auto&& it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index) << "|"; + } + } + ss << operator()(a) << "|"; + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + return ss.str(); +} + +/* ************************************************************************ */ +string DiscreteConditional::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + ss << "
\n

"; + streamSignature(*this, keyFormatter, &ss); + ss << "

\n"; + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << DecisionTreeFactor::html(keyFormatter, names); + return ss.str(); + } + + // Print out preamble. + ss << "\n \n"; + + // Print out header row. + ss << " "; + for (Key parent : parents()) { + ss << ""; + } + auto frontalAssignments = this->frontalAssignments(); + for (const auto& a : frontalAssignments) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Output all rows, one per assignment: + size_t count = 0, n = frontalAssignments.size(); + for (const auto& a : allAssignments()) { + if (count == 0) { + ss << " "; + for (auto&& it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << ""; + } + } + ss << ""; // value + count = (count + 1) % n; + if (count == 0) ss << "\n"; + } + + // Finish up + ss << " \n
" << keyFormatter(parent) << ""; + for (auto&& it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << DiscreteValues::Translate(names, *it, index); + } + ss << "
" << DiscreteValues::Translate(names, *it, index) << "" << operator()(a) << "
\n
"; + return ss.str(); +} + +/* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8299fab2c..48d94a383 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -21,10 +21,11 @@ #include #include #include -#include -#include +#include +#include #include +#include namespace gtsam { @@ -32,59 +33,109 @@ namespace gtsam { * Discrete Conditional Density * Derives from DecisionTreeFactor */ -class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, - public Conditional { - -public: +class GTSAM_EXPORT DiscreteConditional + : public DecisionTreeFactor, + public Conditional { + public: // typedefs needed to play nice with gtsam - typedef DiscreteConditional This; ///< Typedef to this class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class - typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class - typedef Conditional BaseConditional; ///< Typedef to our conditional base class + typedef DiscreteConditional This; ///< Typedef to this class + typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class + typedef Conditional + BaseConditional; ///< Typedef to our conditional base class - /** A map from keys to values.. - * TODO: Again, do we need this??? */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ - DiscreteConditional() { - } + /// Default constructor needed for serialization. + DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); - - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal); - - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ - DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys); + explicit DiscreteConditional(const Signature& signature); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the parents - * of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The string is parsed into a Signature::Table. + * + * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteConditional(Signature(key, parents, spec)) {} + + /// No-parent specialization; can also use DiscreteDistribution. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal); + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys); + + /** + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; + + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; /// @} /// @name Testable /// @{ /// GTSAM-style print - void print(const std::string& s = "Discrete Conditional: ", + void print( + const std::string& s = "Discrete Conditional: ", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// GTSAM-style equals @@ -102,68 +153,95 @@ public: } /// Evaluate, just look up in AlgebraicDecisonTree - double operator()(const Values& values) const override { - return Potentials::operator()(values); + double operator()(const DiscreteValues& values) const override { + return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const Assignment& parentsValues) const; - /** - * solve a conditional - * @param parentsValues Known values of the parents - * @return MPE value of the child (1 frontal variable). + * @brief restrict to given *parent* values. + * + * Note: does not need be complete set. Examples: + * + * P(C|D,E) + . -> P(C|D,E) + * P(C|D,E) + E -> P(C|D) + * P(C|D,E) + D -> P(C|E) + * P(C|D,E) + D,E -> P(C) + * P(C|D,E) + C -> error! + * + * @return a shared_ptr to a new DiscreteConditional */ - size_t solve(const Values& parentsValues) const; + shared_ptr choose(const DiscreteValues& given) const; + + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; /** * sample * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const Values& parentsValues) const; + size_t sample(const DiscreteValues& parentsValues) const; + + /// Single parent version. + size_t sample(size_t parent_value) const; + + /// Zero parent version. + size_t sample() const; + + /** + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). + */ + size_t argmax() const; /// @} /// @name Advanced Interface /// @{ - /// solve a conditional, in place - void solveInPlace(Values& parentsValues) const; - /// sample in place, stores result in partial solution - void sampleInPlace(Values& parentsValues) const; + void sampleInPlace(DiscreteValues* parentsValues) const; + + /// Return all assignments for frontal variables. + std::vector frontalAssignments() const; + + /// Return all assignments for frontal *and* parent variables. + std::vector allAssignments() const; + + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /// Render as html table. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const; + void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; + /// @} +#endif + + protected: + /// Internal version of choose + DiscreteConditional::ADT choose(const DiscreteValues& given, + bool forceComplete) const; }; // DiscreteConditional // traits -template<> struct traits : public Testable {}; - -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - -} // gtsam +template <> +struct traits : public Testable {}; +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.cpp b/gtsam/discrete/DiscreteDistribution.cpp new file mode 100644 index 000000000..739771470 --- /dev/null +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -0,0 +1,52 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteDistribution.cpp + * @date December 2021 + * @author Frank Dellaert + */ + +#include + +#include + +namespace gtsam { + +void DiscreteDistribution::print(const std::string& s, + const KeyFormatter& formatter) const { + Base::print(s, formatter); +} + +double DiscreteDistribution::operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); +} + +std::vector DiscreteDistribution::pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscreteDistribution::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h new file mode 100644 index 000000000..c5147dbc1 --- /dev/null +++ b/gtsam/discrete/DiscreteDistribution.h @@ -0,0 +1,107 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteDistribution.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscreteDistribution() {} + + /// Constructor from factor. + explicit DiscreteDistribution(const DecisionTreeFactor& f) + : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscreteDistribution P(D % "3/2"); + */ + explicit DiscreteDistribution(const Signature& s) : Base(s) {} + + /** + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). + * + * Example: DiscreteDistribution P(D, {0.4, 0.6}); + */ + DiscreteDistribution(const DiscreteKey& key, const std::vector& spec) + : DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {} + + /** + * Construct from key and a string specifying the probability mass function + * (PMF). + * + * Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9"); + */ + DiscreteDistribution(const DiscreteKey& key, const std::string& spec) + : DiscreteDistribution(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard interface + /// @{ + + /// Evaluate given a single value. + double operator()(size_t value) const; + + /// We also want to keep the Base version, taking DiscreteValues: + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); + + /// Return entire probability mass function. + std::vector pmf() const; + + /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); } + /// @} +#endif +}; +// DiscreteDistribution + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d2..08309e2e1 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -17,11 +17,59 @@ * @author Frank Dellaert */ +#include #include +#include +#include + using namespace std; namespace gtsam { /* ************************************************************************* */ -} // namespace gtsam +std::vector expNormalize(const std::vector& logProbs) { + double maxLogProb = -std::numeric_limits::infinity(); + for (size_t i = 0; i < logProbs.size(); i++) { + double logProb = logProbs[i]; + if ((logProb != std::numeric_limits::infinity()) && + logProb > maxLogProb) { + maxLogProb = logProb; + } + } + + // After computing the max = "Z" of the log probabilities L_i, we compute + // the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z). + double total = 0.0; + for (size_t i = 0; i < logProbs.size(); i++) { + double probPrime = exp(logProbs[i] - maxLogProb); + total += probPrime; + } + double logTotal = log(total); + + // Now we compute the (normalized) probability (for each i): + // p_i = exp(L_i - Z - log S) + double checkNormalization = 0.0; + std::vector probs; + for (size_t i = 0; i < logProbs.size(); i++) { + double prob = exp(logProbs[i] - maxLogProb - logTotal); + probs.push_back(prob); + checkNormalization += prob; + } + + // Numerical tolerance for floating point comparisons + double tol = 1e-9; + + if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) { + std::string errMsg = + std::string("expNormalize failed to normalize probabilities. ") + + std::string("Expected normalization constant = 1.0. Got value: ") + + std::to_string(checkNormalization) + + std::string( + "\n This could have resulted from numerical overflow/underflow."); + throw std::logic_error(errMsg); + } + return probs; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6b0919507..212ade8cf 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -18,10 +18,11 @@ #pragma once -#include +#include #include #include +#include namespace gtsam { class DecisionTreeFactor; @@ -40,18 +41,7 @@ public: typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class typedef Factor Base; ///< Our base class - /** A map from keys to values - * TODO: Do we need this? Should we just use gtsam::Values? - * We just need another special DiscreteValue to represent labels, - * However, all other Lie's operators are undefined in this class. - * The good thing is we can have a Hybrid graph of discrete/continuous variables - * together.. - * Another good thing is we don't need to have the special DiscreteKey which stores - * cardinality of a Discrete variable. It should be handled naturally in - * the new class DiscreteValue, as the varible's type (domain) - */ - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Values = DiscreteValues; ///< backwards compatibility public: @@ -84,27 +74,72 @@ public: Base::print(s, formatter); } - /** Test whether the factor is empty */ - virtual bool empty() const { return size() == 0; } - /// @} /// @name Standard Interface /// @{ /// Find value for given assignment of values to variables - virtual double operator()(const Values&) const = 0; + virtual double operator()(const DiscreteValues&) const = 0; /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = DiscreteValues::Names; + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + virtual std::string markdown( + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + virtual std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; + /// @} }; // DiscreteFactor // traits template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; + + +/** + * @brief Normalize a set of log probabilities. + * + * Normalizing a set of log probabilities in a numerically stable way is + * tricky. To avoid overflow/underflow issues, we compute the largest + * (finite) log probability and subtract it from each log probability before + * normalizing. This comes from the observation that if: + * p_i = exp(L_i) / ( sum_j exp(L_j) ), + * Then, + * p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)), + * = exp(L_i - Z) / ( sum_j exp(L_j - Z) ) + * + * Setting Z = max_j L_j, we can avoid numerical issues that arise when all + * of the (unnormalized) log probabilities are either very large or very + * small. + */ +std::vector expNormalize(const std::vector &logProbs); + }// namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index e41968d6b..ebcac445c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,18 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include +#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -41,11 +44,25 @@ namespace gtsam { /* ************************************************************************* */ KeySet DiscreteFactorGraph::keys() const { KeySet keys; - for(const sharedFactor& factor: *this) - if (factor) keys.insert(factor->begin(), factor->end()); + for (const sharedFactor& factor : *this) { + if (factor) keys.insert(factor->begin(), factor->end()); + } return keys; } + /* ************************************************************************* */ + DiscreteKeys DiscreteFactorGraph::discreteKeys() const { + DiscreteKeys result; + for (auto&& factor : *this) { + if (auto p = boost::dynamic_pointer_cast(factor)) { + DiscreteKeys factor_keys = p->discreteKeys(); + result.insert(result.end(), factor_keys.begin(), factor_keys.end()); + } + } + + return result; + } + /* ************************************************************************* */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; @@ -56,7 +73,7 @@ namespace gtsam { /* ************************************************************************* */ double DiscreteFactorGraph::operator()( - const DiscreteFactor::Values &values) const { + const DiscreteValues &values) const { double product = 1.0; for( const sharedFactor& factor: factors_ ) product *= (*factor)(values); @@ -64,7 +81,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -93,22 +110,99 @@ namespace gtsam { // } // } - /* ************************************************************************* */ - DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const - { - gttic(DiscreteFactorGraph_optimize); - return BaseEliminateable::eliminateSequential()->optimize(); - } - - /* ************************************************************************* */ + /* ************************************************************************ */ + // Alternate eliminate function for MPE std::pair // - EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; - for(const DiscreteFactor::shared_ptr& factor: factors) - product = (*factor) * product; + for (auto&& factor : factors) product = (*factor) * product; + gttoc(product); + + // max out frontals, this is the factor on the separator + gttic(max); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + gttoc(max); + + // Ordering keys for the conditional so that frontalKeys are really in front + DiscreteKeys orderedKeys; + for (auto&& key : frontalKeys) + orderedKeys.emplace_back(key, product.cardinality(key)); + for (auto&& key : max->keys()) + orderedKeys.emplace_back(key, product.cardinality(key)); + + // Make lookup with product + gttic(lookup); + size_t nrFrontals = frontalKeys.size(); + auto lookup = boost::make_shared(nrFrontals, + orderedKeys, product); + gttoc(lookup); + + return std::make_pair( + boost::dynamic_pointer_cast(lookup), max); + } + + /* ************************************************************************ */ + // sumProduct is just an alias for regular eliminateSequential. + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(orderingType); + return *bayesNet; + } + + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = eliminateSequential(ordering); + return *bayesNet; + } + + /* ************************************************************************ */ + // The max-product solution below is a bit clunky: the elimination machinery + // does not allow for differently *typed* versions of elimination, so we + // eliminate into a Bayes Net using the special eliminate function above, and + // then create the DiscreteLookupDAG after the fact, in linear time. + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = eliminateSequential(orderingType, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + /* ************************************************************************ */ + DiscreteValues DiscreteFactorGraph::optimize( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(orderingType); + return dag.argmax(); + } + + DiscreteValues DiscreteFactorGraph::optimize( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(ordering); + return dag.argmax(); + } + + /* ************************************************************************ */ + std::pair // + EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product; + for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator @@ -118,17 +212,46 @@ namespace gtsam { // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), + sum->keys().end()); // now divide product/sum to get conditional gttic(divide); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); + auto conditional = + boost::make_shared(product, *sum, orderedKeys); gttoc(divide); - return std::make_pair(cond, sum); + return std::make_pair(conditional, sum); } -/* ************************************************************************* */ -} // namespace + /* ************************************************************************ */ + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "factor " << i << ":\n"; + ss << factors_[i]->markdown(keyFormatter, names) << endl; + } + return ss.str(); + } + /* ************************************************************************ */ + string DiscreteFactorGraph::html(const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { + using std::endl; + std::stringstream ss; + ss << "

DiscreteFactorGraph of size " << size() << "

"; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "

factor " << i << ":

"; + ss << factors_[i]->html(keyFormatter, names) << endl; + } + return ss.str(); + } + + /* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index f39adc9a8..f962b1802 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -18,19 +18,22 @@ #pragma once -#include -#include -#include #include -#include +#include +#include +#include +#include #include + #include +#include +#include +#include namespace gtsam { // Forward declarations class DiscreteFactorGraph; -class DiscreteFactor; class DiscreteConditional; class DiscreteBayesNet; class DiscreteEliminationTree; @@ -62,33 +65,35 @@ template<> struct EliminationTraits * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * Factor == DiscreteFactor */ -class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph, -public EliminateableFactorGraph { -public: +class GTSAM_EXPORT DiscreteFactorGraph + : public FactorGraph, + public EliminateableFactorGraph { + public: + using This = DiscreteFactorGraph; ///< this class + using Base = FactorGraph; ///< base factor graph type + using BaseEliminateable = + EliminateableFactorGraph; ///< for elimination + using shared_ptr = boost::shared_ptr; ///< shared_ptr to This - typedef DiscreteFactorGraph This; ///< Typedef to this class - typedef FactorGraph Base; ///< Typedef to base factor graph type - typedef EliminateableFactorGraph BaseEliminateable; ///< Typedef to base elimination class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + using Values = DiscreteValues; ///< backwards compatibility - /** A map from keys to values */ - typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + using Indices = KeyVector; ///> map from keys to values /** Default constructor */ DiscreteFactorGraph() {} /** Construct from iterator over factors */ - template - DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} + template + DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) + : Base(firstFactor, lastFactor) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template + template explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template + /** Implicit copy/downcast constructor to override explicit template container + * constructor */ + template DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} /// Destructor @@ -101,57 +106,111 @@ public: /// @} - template - void add(const DiscreteKey& j, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j); - push_back(boost::make_shared(keys, table)); - } - - template - void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j1); - keys.push_back(j2); - push_back(boost::make_shared(keys, table)); - } - - /** add shared discreteFactor immediately from arguments */ - template - void add(const DiscreteKeys& keys, SOURCE table) { - push_back(boost::make_shared(keys, table)); + /** Add a decision-tree factor */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); } /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; + /// Return the DiscreteKeys in this factor graph. + DiscreteKeys discreteKeys() const; + /** return product of all factors as a single factor */ DecisionTreeFactor product() const; - /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ - double operator()(const DiscreteFactor::Values & values) const; + /** + * Evaluates the factor graph given values, returns the joint probability of + * the factor graph given specific instantiation of values + */ + double operator()(const DiscreteValues& values) const; /// print void print( const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** Solve the factor graph by performing variable elimination in COLAMD order using - * the dense elimination function specified in \c function, - * followed by back-substitution resulting from elimination. Is equivalent - * to calling graph.eliminateSequential()->optimize(). */ - DiscreteFactor::sharedValues optimize() const; + /** + * @brief Implement the sum-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct( + OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Implement the sum-product algorithm + * + * @param ordering + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct(const Ordering& ordering) const; -// /** Permute the variables in the factors */ -// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); -// -// /** Apply a reduction, which is a remapping of variable indices. */ -// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /** + * @brief Implement the max-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteLookupDAG DAG with lookup tables + */ + DiscreteLookupDAG maxProduct( + OptionalOrderingType orderingType = boost::none) const; -}; // \ DiscreteFactorGraph + /** + * @brief Implement the max-product algorithm + * + * @param ordering + * @return DiscreteLookupDAG `DAG with lookup tables + */ + DiscreteLookupDAG maxProduct(const Ordering& ordering) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param orderingType + * @return DiscreteValues : MPE + */ + DiscreteValues optimize( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param ordering + * @return DiscreteValues : MPE + */ + DiscreteValues optimize(const Ordering& ordering) const; + + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown tables + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /** + * @brief Render as html tables + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; + + /// @} +}; // \ DiscreteFactorGraph /// traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -} // \ namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 5ddad22b0..121d61103 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -33,16 +33,13 @@ namespace gtsam { KeyVector DiscreteKeys::indices() const { KeyVector js; - for(const DiscreteKey& key: *this) - js.push_back(key.first); + for (const DiscreteKey& key : *this) js.push_back(key.first); return js; } - map DiscreteKeys::cardinalities() const { - map cs; - cs.insert(begin(),end()); -// for(const DiscreteKey& key: *this) -// cs.insert(key); + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(), end()); return cs; } diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index c041c7e8e..ce0c56dbe 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -28,21 +28,26 @@ namespace gtsam { /** - * Key type for discrete conditionals - * Includes name and cardinality + * Key type for discrete variables. + * Includes Key and cardinality. */ - typedef std::pair DiscreteKey; + using DiscreteKey = std::pair; /// DiscreteKeys is a set of keys that can be assembled using the & operator - struct DiscreteKeys: public std::vector { + struct GTSAM_EXPORT DiscreteKeys: public std::vector { - /// Default constructor - DiscreteKeys() { - } + // Forward all constructors. + using std::vector::vector; + + /// Constructor for serialization + DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { - push_back(key); + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } + + /// Construct from cardinalities. + explicit DiscreteKeys(std::map cardinalities) { + for (auto&& kv : cardinalities) emplace_back(kv); } /// Construct from a vector of keys @@ -51,13 +56,13 @@ namespace gtsam { } /// Construct from cardinalities with default names - GTSAM_EXPORT DiscreteKeys(const std::vector& cs); + DiscreteKeys(const std::vector& cs); /// Return a vector of indices - GTSAM_EXPORT KeyVector indices() const; + KeyVector indices() const; /// Return a map from index to cardinality - GTSAM_EXPORT std::map cardinalities() const; + std::map cardinalities() const; /// Add a key (non-const!) DiscreteKeys& operator&(const DiscreteKey& key) { @@ -67,5 +72,5 @@ namespace gtsam { }; // DiscreteKeys /// Create a list from two keys - GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); } diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp new file mode 100644 index 000000000..d96b38b0e --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -0,0 +1,127 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.cpp + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + double maxP = 0; + DiscreteValues frontals; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + +/* ************************************************************************** */ +DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( + const DiscreteBayesNet& bayesNet) { + DiscreteLookupDAG dag; + for (auto&& conditional : bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; +} + +DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h new file mode 100644 index 000000000..8cb651f28 --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.h + * @date January, 2022 + * @author Frank dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +class DiscreteBayesNet; + +/** + * @brief DiscreteLookupTable table for max-product + * + * Inherits from discrete conditional for convenience, but is not normalized. + * Is used in the max-product algorithm. + */ +class DiscreteLookupTable : public DiscreteConditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a orted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; +}; + +/** A DAG made from lookup tables, as defined above. */ +class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = DiscreteLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + DiscreteLookupDAG() {} + + /// Create from BayesNet with LookupTables + static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); + + /// Destructor + virtual ~DiscreteLookupDAG() {} + + /// @} + + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This& bn, double tol = 1e-9) const; + + /// @} + + /// @name Standard Interface + /// @{ + + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + /** + * @brief argmax by back-substitution, optionally given certain variables. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index c2a188e08..a2207a10b 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -29,7 +29,7 @@ namespace gtsam { /** * A class for computing marginals of variables in a DiscreteFactorGraph */ - class DiscreteMarginals { +class GTSAM_EXPORT DiscreteMarginals { protected: @@ -37,6 +37,8 @@ namespace gtsam { public: + DiscreteMarginals() {} + /** Construct a marginals class. * @param graph The factor graph defining the full joint density on all variables. */ @@ -64,7 +66,7 @@ namespace gtsam { //Create result Vector vResult(key.second); for (size_t state = 0; state < key.second ; ++ state) { - DiscreteFactor::Values values; + DiscreteValues values; values[key.first] = state; vResult(state) = (*marginalFactor)(values); } diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp new file mode 100644 index 000000000..5d0c8dd3d --- /dev/null +++ b/gtsam/discrete/DiscreteValues.cpp @@ -0,0 +1,97 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.cpp + * @date January, 2022 + * @author Frank Dellaert + */ + +#include + +#include + +using std::cout; +using std::endl; +using std::string; +using std::stringstream; + +namespace gtsam { + +void DiscreteValues::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << s << ": "; + for (auto&& kv : *this) + cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")"; + cout << endl; +} + +string DiscreteValues::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +string DiscreteValues::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out header and separator with alignment hints. + ss << "|Variable|value|\n|:-:|:-:|\n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << "|" << keyFormatter(kv.first) << "|" + << Translate(names, kv.first, kv.second) << "|\n"; + } + + return ss.str(); +} + +string DiscreteValues::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; + + // Print out preamble. + ss << "
\n\n \n"; + + // Print out header row. + ss << " \n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (const auto& kv : *this) { + ss << " "; + ss << ""; + ss << "\n"; + } + ss << " \n
Variablevalue
" << keyFormatter(kv.first) << "" + << Translate(names, kv.first, kv.second) << "
\n
"; + return ss.str(); +} + +string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter, + const DiscreteValues::Names& names) { + return values.markdown(keyFormatter, names); +} + +string html(const DiscreteValues& values, const KeyFormatter& keyFormatter, + const DiscreteValues::Names& names) { + return values.html(keyFormatter, names); +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h new file mode 100644 index 000000000..81997a783 --- /dev/null +++ b/gtsam/discrete/DiscreteValues.h @@ -0,0 +1,106 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteValues.h + * @date Dec 13, 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** A map from keys to values + * TODO(dellaert): Do we need this? Should we just use gtsam::DiscreteValues? + * We just need another special DiscreteValue to represent labels, + * However, all other Lie's operators are undefined in this class. + * The good thing is we can have a Hybrid graph of discrete/continuous variables + * together.. + * Another good thing is we don't need to have the special DiscreteKey which + * stores cardinality of a Discrete variable. It should be handled naturally in + * the new class DiscreteValue, as the variable's type (domain) + */ +class DiscreteValues : public Assignment { + public: + using Base = Assignment; // base class + + using Assignment::Assignment; // all constructors + + // Define the implicit default constructor. + DiscreteValues() = default; + + // Construct from assignment. + explicit DiscreteValues(const Base& a) : Base(a) {} + + void print(const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + static std::vector CartesianProduct( + const DiscreteKeys& keys) { + return Base::CartesianProduct(keys); + } + + /// @name Wrapper support + /// @{ + + /// Translation table from values to strings. + using Names = std::map>; + + /// Translate an integer index value for given key to a string. + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Output as a markdown table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string markdown output. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const; + + /** + * @brief Output as a html table. + * + * @param keyFormatter function that formats keys. + * @param names translation table for values. + * @return string html output. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const; + + /// @} +}; + +/// Free version of markdown. +std::string markdown(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); + +/// Free version of html. +std::string html(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp deleted file mode 100644 index 331a76c13..000000000 --- a/gtsam/discrete/Potentials.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.cpp - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#include -#include - -#include - -#include - -using namespace std; - -namespace gtsam { - -// explicit instantiation -template class DecisionTree; -template class AlgebraicDecisionTree; - -/* ************************************************************************* */ -double Potentials::safe_div(const double& a, const double& b) { - // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); -} - -/* ******************************************************************************** - */ -Potentials::Potentials() : ADT(1.0) {} - -/* ******************************************************************************** - */ -Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) - : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} - -/* ************************************************************************* */ -bool Potentials::equals(const Potentials& other, double tol) const { - return ADT::equals(other, tol); -} - -/* ************************************************************************* */ -void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: {"; - for (const std::pair& key : cardinalities_) - cout << formatter(key.first) << ":" << key.second << ", "; - cout << "}" << endl; - ADT::print(" "); -} -// -// /* ************************************************************************* */ -// template -// void Potentials::remapIndices(const P& remapping) { -// // Permute the _cardinalities (TODO: Inefficient Consider Improving) -// DiscreteKeys keys; -// map ordering; -// -// // Get the original keys from cardinalities_ -// for(const DiscreteKey& key: cardinalities_) -// keys & key; -// -// // Perform Permutation -// for(DiscreteKey& key: keys) { -// ordering[key.first] = remapping[key.first]; -// key.first = ordering[key.first]; -// } -// -// // Change *this -// AlgebraicDecisionTree permuted((*this), ordering); -// *this = permuted; -// cardinalities_ = keys.cardinalities(); -// } -// -// /* ************************************************************************* */ -// void Potentials::permuteWithInverse(const Permutation& inversePermutation) { -// remapIndices(inversePermutation); -// } -// -// /* ************************************************************************* */ -// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) { -// remapIndices(inverseReduction); -// } - - /* ************************************************************************* */ - -} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h deleted file mode 100644 index 1078b4c61..000000000 --- a/gtsam/discrete/Potentials.h +++ /dev/null @@ -1,97 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.h - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace gtsam { - - /** - * A base class for both DiscreteFactor and DiscreteConditional - */ - class Potentials: public AlgebraicDecisionTree { - - public: - - typedef AlgebraicDecisionTree ADT; - - protected: - - /// Cardinality for each key, used in combine - std::map cardinalities_; - - /** Constructor from ColumnIndex, and ADT */ - Potentials(const ADT& potentials) : - ADT(potentials) { - } - - // Safe division for probabilities - GTSAM_EXPORT static double safe_div(const double& a, const double& b); - -// // Apply either a permutation or a reduction -// template -// void remapIndices(const P& remapping); - - public: - - /** Default constructor for I/O */ - GTSAM_EXPORT Potentials(); - - /** Constructor from Indices and ADT */ - GTSAM_EXPORT Potentials(const DiscreteKeys& keys, const ADT& decisionTree); - - /** Constructor from Indices and (string or doubles) */ - template - Potentials(const DiscreteKeys& keys, SOURCE table) : - ADT(keys, table), cardinalities_(keys.cardinalities()) { - } - - // Testable - GTSAM_EXPORT bool equals(const Potentials& other, double tol = 1e-9) const; - GTSAM_EXPORT void print(const std::string& s = "Potentials: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const; - - size_t cardinality(Key j) const { return cardinalities_.at(j);} - -// /** -// * @brief Permutes the keys in Potentials -// * -// * This permutes the Indices and performs necessary re-ordering of ADD. -// * This is virtual so that derived types e.g. DecisionTreeFactor can -// * re-implement it. -// */ -// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation); -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction); - - }; // Potentials - -// traits -template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; - - -} // namespace gtsam diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp index 94b160a29..146555898 100644 --- a/gtsam/discrete/Signature.cpp +++ b/gtsam/discrete/Signature.cpp @@ -38,19 +38,7 @@ namespace gtsam { using boost::phoenix::push_back; // Special rows, true and false - Signature::Row createF() { - Signature::Row r(2); - r[0] = 1; - r[1] = 0; - return r; - } - Signature::Row createT() { - Signature::Row r(2); - r[0] = 0; - r[1] = 1; - return r; - } - Signature::Row T = createT(), F = createF(); + Signature::Row F{1, 0}, T{0, 1}; // Special tables (inefficient, but do we care for user input?) Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { @@ -69,40 +57,13 @@ namespace gtsam { table = or_ | and_ | rows; or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; - rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + rows = +(row | true_ | false_); row = qi::double_ >> +("/" >> qi::double_); true_ = qi::lit("T")[qi::_val = T]; false_ = qi::lit("F")[qi::_val = F]; } } grammar; - // Create simpler parsing function to avoid the issue of only parsing a single row - bool parse_table(const string& spec, Signature::Table& table) { - // check for OR, AND on whole phrase - It f = spec.begin(), l = spec.end(); - if (qi::parse(f, l, - qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) || - qi::parse(f, l, - qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)])) - return true; - - // tokenize into separate rows - istringstream iss(spec); - string token; - while (iss >> token) { - Signature::Row values; - It tf = token.begin(), tl = token.end(); - bool r = qi::parse(tf, tl, - qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) | - qi::lit("T")[ph::ref(values) = T] | - qi::lit("F")[ph::ref(values) = F] ); - if (!r) - return false; - table.push_back(values); - } - - return true; - } } // \namespace parser ostream& operator <<(ostream &os, const Signature::Row &row) { @@ -118,6 +79,18 @@ namespace gtsam { return os; } + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table) + : key_(key), parents_(parents) { + operator=(table); + } + + Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : key_(key), parents_(parents) { + operator=(spec); + } + Signature::Signature(const DiscreteKey& key) : key_(key) { } @@ -166,14 +139,11 @@ namespace gtsam { Signature& Signature::operator=(const string& spec) { spec_.reset(spec); Table table; - // NOTE: using simpler parse function to ensure boost back compatibility -// parser::It f = spec.begin(), l = spec.end(); - bool success = // -// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar - parser::parse_table(spec, table); + parser::It f = spec.begin(), l = spec.end(); + bool success = + qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); if (success) { - for(Row& row: table) - normalize(row); + for (Row& row : table) normalize(row); table_.reset(table); } return *this; diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h index 6c59b5bff..ff83caa53 100644 --- a/gtsam/discrete/Signature.h +++ b/gtsam/discrete/Signature.h @@ -30,7 +30,7 @@ namespace gtsam { * The format is (Key % string) for nodes with no parents, * and (Key | Key, Key = string) for nodes with parents. * - * The string specifies a conditional probability spec in the 00 01 10 11 order. + * The string specifies a conditional probability table in 00 01 10 11 order. * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... * * For example, given the following keys @@ -45,9 +45,9 @@ namespace gtsam { * T|A = "99/1 95/5" * L|S = "99/1 90/10" * B|S = "70/30 40/60" - * E|T,L = "F F F 1" + * (E|T,L) = "F F F 1" * X|E = "95/5 2/98" - * D|E,B = "9/1 2/8 3/7 1/9" + * (D|E,B) = "9/1 2/8 3/7 1/9" */ class GTSAM_EXPORT Signature { @@ -72,45 +72,73 @@ namespace gtsam { boost::optional table_; public: + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. + * + * Example: + * Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + * Signature sig(D, {E, B}, table); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const Table& table); - /** Constructor from DiscreteKey */ - Signature(const DiscreteKey& key); + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The first string is parsed to add a key and parents. The second string + * parses into a table. + * + * Example (same CPT as above): + * Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + Signature(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec); - /** the variable key */ - const DiscreteKey& key() const { - return key_; - } + /** + * Construct from a single DiscreteKey. + * + * The resulting signature has no parents or CPT table. Typical use then + * either adds parents with | and , operators below, or assigns a table with + * operator=(). + */ + Signature(const DiscreteKey& key); - /** the parent keys */ - const DiscreteKeys& parents() const { - return parents_; - } + /** the variable key */ + const DiscreteKey& key() const { return key_; } - /** All keys, with variable key first */ - DiscreteKeys discreteKeys() const; + /** the parent keys */ + const DiscreteKeys& parents() const { return parents_; } - /** All key indices, with variable key first */ - KeyVector indices() const; + /** All keys, with variable key first */ + DiscreteKeys discreteKeys() const; - // the CPT as parsed, if successful - const boost::optional
& table() const { - return table_; - } + /** All key indices, with variable key first */ + KeyVector indices() const; - // the CPT as a vector of doubles, with key's values most rapidly changing - std::vector cpt() const; + // the CPT as parsed, if successful + const boost::optional
& table() const { return table_; } - /** Add a parent */ - Signature& operator,(const DiscreteKey& parent); + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; - /** Add the CPT spec - Fails in boost 1.40 */ - Signature& operator=(const std::string& spec); + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); - /** Add the CPT spec directly as a table */ - Signature& operator=(const Table& table); + /** Add the CPT spec */ + Signature& operator=(const std::string& spec); - /** provide streaming */ - GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s); + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os, + const Signature& s); }; /** @@ -122,7 +150,6 @@ namespace gtsam { /** * Helper function to create Signature objects * example: Signature s(D % "99/1"); - * Uses string parser, which requires BOOST 1.42 or higher */ GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i new file mode 100644 index 000000000..56e7248a3 --- /dev/null +++ b/gtsam/discrete/discrete.i @@ -0,0 +1,299 @@ +//************************************************************************* +// discrete +//************************************************************************* + +namespace gtsam { + + +#include +class DiscreteKey {}; + +class DiscreteKeys { + DiscreteKeys(); + size_t size() const; + bool empty() const; + gtsam::DiscreteKey at(size_t n) const; + void push_back(const gtsam::DiscreteKey& point_pair); +}; + +// DiscreteValues is added in specializations/discrete.h as a std::map +string markdown( + const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +string markdown(const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter, + std::map> names); +string html( + const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +string html(const gtsam::DiscreteValues& values, + const gtsam::KeyFormatter& keyFormatter, + std::map> names); + +#include +class DiscreteFactor { + void print(string s = "DiscreteFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + double operator()(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { + DecisionTreeFactor(); + + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + + void print(string s = "DecisionTreeFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + bool showZero = true) const; + std::vector> enumerate() const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +virtual class DiscreteConditional : gtsam::DecisionTreeFactor { + DiscreteConditional(); + DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal); + DiscreteConditional(const gtsam::DecisionTreeFactor& joint, + const gtsam::DecisionTreeFactor& marginal, + const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; + DiscreteConditional marginal(gtsam::Key key) const; + void print(string s = "Discrete Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + gtsam::Key firstFrontalKey() const; + size_t nrFrontals() const; + size_t nrParents() const; + void printSignature( + string s = "Discrete Conditional: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; + size_t sample(const gtsam::DiscreteValues& parentsValues) const; + size_t sample(size_t value) const; + size_t sample() const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +virtual class DiscreteDistribution : gtsam::DiscreteConditional { + DiscreteDistribution(); + DiscreteDistribution(const gtsam::DecisionTreeFactor& f); + DiscreteDistribution(const gtsam::DiscreteKey& key, string spec); + DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector spec); + void print(string s = "Discrete Prior\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; + size_t argmax() const; +}; + +#include +class DiscreteBayesNet { + DiscreteBayesNet(); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); + void add(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteConditional* at(size_t i) const; + void print(string s = "DiscreteBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues sample() const; + gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +class DiscreteBayesTreeClique { + DiscreteBayesTreeClique(); + DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); + const gtsam::DiscreteConditional* conditional() const; + bool isRoot() const; + void printSignature( + const string& s = "Clique: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + +class DiscreteBayesTree { + DiscreteBayesTree(); + void print(string s = "DiscreteBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const DiscreteBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + double operator()(const gtsam::DiscreteValues& values) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +#include +class DiscreteLookupDAG { + DiscreteLookupDAG(); + void push_back(const gtsam::DiscreteLookupTable* table); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteLookupTable* at(size_t i) const; + void print(string s = "DiscreteLookupDAG\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + gtsam::DiscreteValues argmax() const; + gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const; +}; + +#include +class DiscreteFactorGraph { + DiscreteFactorGraph(); + DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); + + // Building the graph + void push_back(const gtsam::DiscreteFactor* factor); + void push_back(const gtsam::DiscreteConditional* conditional); + void push_back(const gtsam::DiscreteFactorGraph& graph); + void push_back(const gtsam::DiscreteBayesNet& bayesNet); + void push_back(const gtsam::DiscreteBayesTree& bayesTree); + void add(const gtsam::DiscreteKey& j, string spec); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKeys& keys, string spec); + void add(const std::vector& keys, string spec); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; + + gtsam::DecisionTreeFactor product() const; + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues optimize() const; + + gtsam::DiscreteBayesNet sumProduct(); + gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); + + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); + + gtsam::DiscreteBayesNet eliminateSequential(); + gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + std::pair + eliminatePartialSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree eliminateMultifrontal(); + gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); + std::pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; + string html(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + string html(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; +}; + +} // namespace gtsam diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index be720dbca..9d130a1f6 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -17,37 +17,39 @@ */ #include -#include // make sure we have traits +#include // make sure we have traits +#include // headers first to make sure no missing headers //#define DT_NO_PRUNING #include -#include // for convert only +#include // for convert only #define DISABLE_TIMING -#include #include #include +#include using namespace boost::assign; #include -#include #include +#include using namespace std; using namespace gtsam; -/* ******************************************************************************** */ +/* ************************************************************************** */ typedef AlgebraicDecisionTree ADT; // traits namespace gtsam { -template<> struct traits : public Testable {}; -} +template <> +struct traits : public Testable {}; +} // namespace gtsam #define DISABLE_DOT -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -62,8 +64,8 @@ void dot(const T&f, const string& filename) { // If second argument of binary op is Leaf template - typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( - Cache& cache, const Leaf& gL, Mul op) const { + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const { Ptr h(new Choice(label(), cardinality())); for(const NodePtr& branch: branches_) h->push_back(branch->apply_f_op_g(cache, gL, op)); @@ -71,9 +73,9 @@ void dot(const T&f, const string& filename) { } */ -/* ******************************************************************************** */ +/* ************************************************************************** */ // instrumented operators -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t muls = 0, adds = 0; double elapsed; void resetCounts() { @@ -82,8 +84,9 @@ void resetCounts() { } void printCounts(const string& s) { #ifndef DISABLE_TIMING - cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds - % (1000 * elapsed) << endl; + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds % + (1000 * elapsed) + << endl; #endif resetCounts(); } @@ -96,12 +99,11 @@ double add_(const double& a, const double& b) { return a + b; } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test ADT -TEST(ADT, example3) -{ +TEST(ADT, example3) { // Create labels - DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2); // Literals ADT a(A, 0.5, 0.5); @@ -113,38 +115,37 @@ TEST(ADT, example3) ADT cnotb = c * notb; dot(cnotb, "ADT-cnotb"); -// a.print("a: "); -// cnotb.print("cnotb: "); + // a.print("a: "); + // cnotb.print("cnotb: "); ADT acnotb = a * cnotb; -// acnotb.print("acnotb: "); -// acnotb.printCache("acnotb Cache:"); + // acnotb.print("acnotb: "); + // acnotb.printCache("acnotb Cache:"); dot(acnotb, "ADT-acnotb"); - ADT big = apply(apply(d, note, &mul), acnotb, &add_); dot(big, "ADT-big"); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Asia Bayes Network -/* ******************************************************************************** */ +/* ************************************************************************** */ /** Convert Signature into CPT */ ADT create(const Signature& signature) { ADT p(signature.discreteKeys(), signature.cpt()); static size_t count = 0; const DiscreteKey& key = signature.key(); - string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); - dot(p, dotfile); + string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, DOTfile); return p; } /* ************************************************************************* */ // test Asia Joint -TEST(ADT, joint) -{ - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); +TEST(ADT, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); resetCounts(); gttic_(asiaCPTs); @@ -203,10 +204,9 @@ TEST(ADT, joint) /* ************************************************************************* */ // test Inference with joint -TEST(ADT, inference) -{ - DiscreteKey A(0,2), D(1,2),// - B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); +TEST(ADT, inference) { + DiscreteKey A(0, 2), D(1, 2), // + B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); resetCounts(); gttic_(infCPTs); @@ -243,7 +243,7 @@ TEST(ADT, inference) dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Joint-Product-ASTLBEXD"); - EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering + EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); elapsed = asiaProdNode->secs() + asiaProdNode->wall(); @@ -270,9 +270,8 @@ TEST(ADT, inference) } /* ************************************************************************* */ -TEST(ADT, factor_graph) -{ - DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); +TEST(ADT, factor_graph) { + DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); resetCounts(); gttic_(createCPTs); @@ -402,50 +401,49 @@ TEST(ADT, factor_graph) /* ************************************************************************* */ // test equality -TEST(ADT, equality_noparser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_noparser) { + DiscreteKey A(0, 2), B(1, 2); Signature::Table tableA, tableB; Signature::Row rA, rB; - rA += 80, 20; rB += 60, 40; - tableA += rA; tableB += rB; + rA += 80, 20; + rB += 60, 40; + tableA += rA; + tableB += rB; // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } /* ************************************************************************* */ // test equality -TEST(ADT, equality_parser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_parser) { + DiscreteKey A(0, 2), B(1, 2); // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1 == pA2); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); ADT pAB1 = apply(pA1, pB, &mul); ADT pAB2 = apply(pB, pA1, &mul); - EXPECT(pAB2 == pAB1); + EXPECT(pAB2.equals(pAB1)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Factor graph construction // test constructor from strings -TEST(ADT, constructor) -{ - DiscreteKey v0(0,2), v1(1,3); - Assignment x00, x01, x02, x10, x11, x12; +TEST(ADT, constructor) { + DiscreteKey v0(0, 2), v1(1, 3); + DiscreteValues x00, x01, x02, x10, x11, x12; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x02[0] = 0, x02[1] = 2; @@ -469,13 +467,12 @@ TEST(ADT, constructor) EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); - DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2); vector table(5 * 4 * 3 * 2); double x = 0; - for(double& t: table) - t = x++; + for (double& t : table) t = x++; ADT f3(z0 & z1 & z2 & z3, table); - Assignment assignment; + DiscreteValues assignment; assignment[0] = 0; assignment[1] = 0; assignment[2] = 0; @@ -486,9 +483,8 @@ TEST(ADT, constructor) /* ************************************************************************* */ // test conversion to integer indices // Only works if DiscreteKeys are binary, as size_t has binary cardinality! -TEST(ADT, conversion) -{ - DiscreteKey X(0,2), Y(1,2); +TEST(ADT, conversion) { + DiscreteKey X(0, 2), Y(1, 2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); dot(fDiscreteKey, "conversion-f1"); @@ -501,7 +497,7 @@ TEST(ADT, conversion) // f2.print("f2"); dot(fIndexKey, "conversion-f2"); - Assignment x00, x01, x02, x10, x11, x12; + DiscreteValues x00, x01, x02, x10, x11, x12; x00[5] = 0, x00[2] = 0; x01[5] = 0, x01[2] = 1; x10[5] = 1, x10[2] = 0; @@ -512,11 +508,10 @@ TEST(ADT, conversion) EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test operations in elimination -TEST(ADT, elimination) -{ - DiscreteKey A(0,2), B(1,3), C(2,2); +TEST(ADT, elimination) { + DiscreteKey A(0, 2), B(1, 3), C(2, 2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); dot(f1, "elimination-f1"); @@ -524,60 +519,58 @@ TEST(ADT, elimination) // sum out lower key ADT actualSum = f1.sum(C); ADT expectedSum(A & B, "3 7 11 9 6 10"); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // - 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } { // sum out lower 2 keys ADT actualSum = f1.sum(C).sum(B); ADT expectedSum(A, 21, 25); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // - 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test non-commutative op -TEST(ADT, div) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, div) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 8, 16); ADT b(B, 2, 4); - ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 - ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_b_div_a, b / a)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test zero shortcut -TEST(ADT, zero) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, zero) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 0, 1); ADT notb(B, 1, 0); ADT anotb = a * notb; // GTSAM_PRINT(anotb); - Assignment x00, x01, x10, x11; + DiscreteValues x00, x01, x10, x11; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; x10[0] = 1, x10[1] = 0; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..dbfb2dc40 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -24,60 +24,98 @@ using namespace boost::assign; #include #include -//#define DT_DEBUG_MEMORY -//#define DT_NO_PRUNING +// #define DT_DEBUG_MEMORY +// #define DT_NO_PRUNING #define DISABLE_DOT #include using namespace std; using namespace gtsam; -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } -#define DOT(x)(dot(x,#x)) +#define DOT(x) (dot(x, #x)) -struct Crazy { int a; double b; }; -typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) +struct Crazy { + int a; + double b; +}; -// traits -namespace gtsam { -template<> struct traits : public Testable {}; -} - -/* ******************************************************************************** */ -// Test string labels and int range -/* ******************************************************************************** */ - -typedef DecisionTree DT; - -// traits -namespace gtsam { -template<> struct traits
: public Testable
{}; -} - -struct Ring { - static inline int zero() { - return 0; +struct CrazyDecisionTree : public DecisionTree { + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const Crazy& v) { + return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); + }; + DecisionTree::print("", keyFormatter, valueFormatter); } - static inline int one() { - return 1; - } - static inline int add(const int& a, const int& b) { - return a + b; - } - static inline int mul(const int& a, const int& b) { - return a * b; + /// Equality method customized to Crazy node type + bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { + auto compare = [tol](const Crazy& v, const Crazy& w) { + return v.a == w.a && std::abs(v.b - w.b) < tol; + }; + return DecisionTree::equals(other, compare); } }; -/* ******************************************************************************** */ +// traits +namespace gtsam { +template <> +struct traits : public Testable {}; +} // namespace gtsam + +GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) + +/* ************************************************************************** */ +// Test string labels and int range +/* ************************************************************************** */ + +struct DT : public DecisionTree { + using Base = DecisionTree; + using DecisionTree::DecisionTree; + DT() = default; + + DT(const Base& dt) : Base(dt) {} + + /// print to stdout + void print(const std::string& s = "") const { + auto keyFormatter = [](const std::string& s) { return s; }; + auto valueFormatter = [](const int& v) { + return (boost::format("%d") % v).str(); + }; + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); + } +}; + +// traits +namespace gtsam { +template <> +struct traits
: public Testable
{}; +} // namespace gtsam + +GTSAM_CONCEPT_TESTABLE_INST(DT) + +struct Ring { + static inline int zero() { return 0; } + static inline int one() { return 1; } + static inline int id(const int& a) { return a; } + static inline int add(const int& a, const int& b) { return a + b; } + static inline int mul(const int& a, const int& b) { return a * b; } +}; + +/* ************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -88,54 +126,62 @@ TEST(DT, example) x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; + // empty + DT empty; + // A DT a(A, 0, 5); - LONGS_EQUAL(0,a(x00)) - LONGS_EQUAL(5,a(x10)) + LONGS_EQUAL(0, a(x00)) + LONGS_EQUAL(5, a(x10)) DOT(a); // pruned DT p(A, 2, 2); - LONGS_EQUAL(2,p(x00)) - LONGS_EQUAL(2,p(x10)) + LONGS_EQUAL(2, p(x00)) + LONGS_EQUAL(2, p(x10)) DOT(p); // \neg B DT notb(B, 5, 0); - LONGS_EQUAL(5,notb(x00)) - LONGS_EQUAL(5,notb(x10)) + LONGS_EQUAL(5, notb(x00)) + LONGS_EQUAL(5, notb(x10)) DOT(notb); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,anotb(x00)) - LONGS_EQUAL(0,anotb(x01)) - LONGS_EQUAL(25,anotb(x10)) - LONGS_EQUAL(0,anotb(x11)) + LONGS_EQUAL(0, anotb(x00)) + LONGS_EQUAL(0, anotb(x01)) + LONGS_EQUAL(25, anotb(x10)) + LONGS_EQUAL(0, anotb(x11)) DOT(anotb); // check pruning DT pnotb = apply(p, notb, &Ring::mul); - LONGS_EQUAL(10,pnotb(x00)) - LONGS_EQUAL( 0,pnotb(x01)) - LONGS_EQUAL(10,pnotb(x10)) - LONGS_EQUAL( 0,pnotb(x11)) + LONGS_EQUAL(10, pnotb(x00)) + LONGS_EQUAL(0, pnotb(x01)) + LONGS_EQUAL(10, pnotb(x10)) + LONGS_EQUAL(0, pnotb(x11)) DOT(pnotb); // check pruning DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); - LONGS_EQUAL(0,zeros(x00)) - LONGS_EQUAL(0,zeros(x01)) - LONGS_EQUAL(0,zeros(x10)) - LONGS_EQUAL(0,zeros(x11)) + LONGS_EQUAL(0, zeros(x00)) + LONGS_EQUAL(0, zeros(x01)) + LONGS_EQUAL(0, zeros(x10)) + LONGS_EQUAL(0, zeros(x11)) DOT(zeros); // apply, two nodes, in switched order DT notba = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,notba(x00)) - LONGS_EQUAL(0,notba(x01)) - LONGS_EQUAL(25,notba(x10)) - LONGS_EQUAL(0,notba(x11)) + LONGS_EQUAL(0, notba(x00)) + LONGS_EQUAL(0, notba(x01)) + LONGS_EQUAL(25, notba(x10)) + LONGS_EQUAL(0, notba(x11)) DOT(notba); // Test choose 0 @@ -150,10 +196,10 @@ TEST(DT, example) // apply, two nodes at same level DT a_and_a = apply(a, a, &Ring::mul); - LONGS_EQUAL(0,a_and_a(x00)) - LONGS_EQUAL(0,a_and_a(x01)) - LONGS_EQUAL(25,a_and_a(x10)) - LONGS_EQUAL(25,a_and_a(x11)) + LONGS_EQUAL(0, a_and_a(x00)) + LONGS_EQUAL(0, a_and_a(x01)) + LONGS_EQUAL(25, a_and_a(x10)) + LONGS_EQUAL(25, a_and_a(x11)) DOT(a_and_a); // create a function on C @@ -165,27 +211,42 @@ TEST(DT, example) // mul notba with C DT notbac = apply(notba, c, &Ring::mul); - LONGS_EQUAL(125,notbac(x101)) + LONGS_EQUAL(125, notbac(x101)) DOT(notbac); // mul now in different order DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); - LONGS_EQUAL(125,acnotb(x101)) + LONGS_EQUAL(125, acnotb(x101)) DOT(acnotb); } -/* ******************************************************************************** */ -// test Conversion -enum Label { - U, V, X, Y, Z -}; -typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; +/* ************************************************************************** */ +// test Conversion of values +bool bool_of_int(const int& y) { return y != 0; }; +typedef DecisionTree StringBoolTree; + +TEST(DecisionTree, ConvertValuesOnly) { + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + StringBoolTree f2(f1, bool_of_int); + + // Check a value + Assignment x00; + x00["A"] = 0, x00["B"] = 0; + EXPECT(!f2(x00)); } -TEST(DT, conversion) -{ +/* ************************************************************************** */ +// test Conversion of both values and labels. +enum Label { U, V, X, Y, Z }; +typedef DecisionTree LabelBoolTree; + +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -196,12 +257,9 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); - // f1.print("f1"); - // f2.print("f2"); + LabelBoolTree f2(f1, ordering, &bool_of_int); - // create a value + // Check some values Assignment
\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
ABvalue
Zero-1
Zero+2
One-3
One+4
Two-5
Two+6
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.html(keyFormatter, names); + EXPECT(actual == expected); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 2b440e5a0..cfc9c1bb5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -38,21 +38,26 @@ using namespace boost::assign; using namespace std; using namespace gtsam; +static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), + LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); + +using ADT = AlgebraicDecisionTree; + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); auto prior = boost::make_shared(Parent % "6/4"); - CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), - (Potentials::ADT)*prior)); + CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), + (ADT)*prior)); bayesNet.push_back(prior); auto conditional = boost::make_shared(Child | Parent = "7/3 8/2"); EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); - Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); - CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); + ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); + CHECK(assert_equal(expected, (ADT)*conditional)); bayesNet.push_back(conditional); DiscreteFactorGraph fg(bayesNet); @@ -71,11 +76,9 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), - Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); - asia.add(Asia % "99/1"); - asia.add(Smoking % "50/50"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); @@ -103,39 +106,26 @@ TEST(DiscreteBayesNet, Asia) { DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); - // solve - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 0); - EXPECT(assert_equal(expectedMPE, *actualMPE)); - // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); fg.add(Dyspnea, "0 1"); // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize(); - DiscreteFactor::Values expectedMPE2; - insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 1); - EXPECT(assert_equal(expectedMPE2, *actualMPE2)); + EXPECT(assert_equal(expected2, *chordal->back())); // now sample from it - DiscreteFactor::Values expectedSample; + DiscreteValues expectedSample; SETDEBUG("DiscreteConditional::sample", false); insert(expectedSample)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 1)( Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 1)( LungCancer.first, 1)(Bronchitis.first, 0); - DiscreteFactor::sharedValues actualSample = chordal2->sample(); - EXPECT(assert_equal(expectedSample, *actualSample)); + auto actualSample = chordal2->sample(); + EXPECT(assert_equal(expectedSample, actualSample)); } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Sugar) { +TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); DiscreteBayesNet bn; @@ -149,6 +139,61 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) { bn.add(C | S = "1/1/2 5/2/3"); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Dot) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking % "50/50"); + + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + string actual = fragment.dot(); + cout << actual << endl; + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var3[label=\"3\"];\n" + " var4[label=\"4\"];\n" + " var5[label=\"5\"];\n" + " var6[label=\"6\"];\n" + "\n" + " var3->var5\n" + " var6->var5\n" + " var4->var6\n" + " var0->var3\n" + "}"); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteBayesNet, markdown) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking | Asia = "8/2 7/3"); + + string expected = + "`DiscreteBayesNet` of size 2\n" + "\n" + " *P(Asia):*\n\n" + "|Asia|value|\n" + "|:-:|:-:|\n" + "|0|0.99|\n" + "|1|0.01|\n" + "\n" + " *P(Smoking|Asia):*\n\n" + "|*Asia*|0|1|\n" + "|:-:|:-:|:-:|\n" + "|0|0.8|0.2|\n" + "|1|0.7|0.3|\n\n"; + auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + string actual = fragment.markdown(formatter); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index ecf485036..26356be3d 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -26,88 +26,101 @@ using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; - -static bool debug = false; +static constexpr bool debug = false; /* ************************************************************************* */ - -TEST_UNSAFE(DiscreteBayesTree, ThinTree) { - const int nrNodes = 15; - const size_t nrStates = 2; - - // define variables - vector key; - for (int i = 0; i < nrNodes; i++) { - DiscreteKey key_i(i, nrStates); - key.push_back(key_i); - } - - // create a thin-tree Bayesnet, a la Jean-Guillaume +struct TestFixture { + vector keys; DiscreteBayesNet bayesNet; - bayesNet.add(key[14] % "1/3"); + boost::shared_ptr bayesTree; - bayesNet.add(key[13] | key[14] = "1/3 3/1"); - bayesNet.add(key[12] | key[14] = "3/1 3/1"); + /** + * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student), + * and then create the Bayes tree from it. + */ + TestFixture() { + // Define variables. + for (int i = 0; i < 15; i++) { + DiscreteKey key_i(i, 2); + keys.push_back(key_i); + } - bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); - bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); + // Create thin-tree Bayesnet. + bayesNet.add(keys[14] % "1/3"); - bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); + bayesNet.add(keys[13] | keys[14] = "1/3 3/1"); + bayesNet.add(keys[12] | keys[14] = "3/1 3/1"); - bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); - bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); + bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4"); + bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1"); + + bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1"); + + bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1"); + bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1"); + + // Create a BayesTree out of the Bayes net. + bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); + } +}; + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, ThinTree) { + const TestFixture self; + const auto& keys = self.keys; if (debug) { - GTSAM_PRINT(bayesNet); - bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); + GTSAM_PRINT(self.bayesNet); + self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } // create a BayesTree out of a Bayes net - auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); if (debug) { - GTSAM_PRINT(*bayesTree); - bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); + GTSAM_PRINT(*self.bayesTree); + self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } // Check frontals and parents for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { - auto clique_i = (*bayesTree)[i]; + auto clique_i = (*self.bayesTree)[i]; EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = bayesTree->roots().front(); + auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & - key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); + auto allPosbValues = DiscreteValues::CartesianProduct( + keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] & + keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] & + keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double expected = bayesNet.evaluate(x); - double actual = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double expected = self.bayesNet.evaluate(x); + double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } - // Calculate all some marginals for Values==all1 + // Calculate all some marginals for DiscreteValues==all1 Vector marginals = Vector::Zero(15); double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; - double px = bayesTree->evaluate(x); + DiscreteValues x = allPosbValues[i]; + double px = self.bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; if (x[12] && x[14]) { @@ -138,49 +151,49 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { } } } - DiscreteFactor::Values all1 = allPosbValues.back(); + DiscreteValues all1 = allPosbValues.back(); // check separator marginal P(S0) - auto clique = (*bayesTree)[0]; + auto clique = (*self.bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - clique = (*bayesTree)[11]; + clique = (*self.bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - clique = (*bayesTree)[8]; + clique = (*self.bayesTree)[8]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - clique = (*bayesTree)[2]; + clique = (*self.bayesTree)[2]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - clique = (*bayesTree)[0]; + clique = (*self.bayesTree)[0]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); for (auto clique : cliques) { DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { @@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } @@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) - actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); // Check joint P(1, 2) - actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); // Check joint P(2, 4) - actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 5) - actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 6) - actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 11) - actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, Dot) { + const TestFixture self; + string actual = self.bayesTree->dot(); + EXPECT(actual == + "digraph G{\n" + "0[label=\"13,11,6,7\"];\n" + "0->1\n" + "1[label=\"14 : 11,13\"];\n" + "1->2\n" + "2[label=\"9,12 : 14\"];\n" + "2->3\n" + "3[label=\"3 : 9,12\"];\n" + "2->4\n" + "4[label=\"2 : 9,12\"];\n" + "2->5\n" + "5[label=\"8 : 12,14\"];\n" + "5->6\n" + "6[label=\"1 : 8,12\"];\n" + "5->7\n" + "7[label=\"0 : 8,12\"];\n" + "1->8\n" + "8[label=\"10 : 13,14\"];\n" + "8->9\n" + "9[label=\"5 : 10,13\"];\n" + "8->10\n" + "10[label=\"4 : 10,13\"];\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3ac3ffc9e..13a34dd19 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -10,10 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * @file testDecisionTreeFactor.cpp + * @file testDiscreteConditional.cpp * @brief unit tests for DiscreteConditional * @author Duy-Nguyen Ta - * @date Feb 14, 2011 + * @author Frank dellaert + * @date Feb 14, 2011 */ #include @@ -24,31 +25,30 @@ using namespace boost::assign; #include #include #include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditional, constructors) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! +TEST(DiscreteConditional, constructors) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! + + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = "1/1 2/3 1/4"); - EXPECT(expected1); - EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); - EXPECT(expected1->endParents() == expected1->end()); - EXPECT(expected1->endFrontals() == expected1->beginParents()); - DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,50 +61,314 @@ TEST(DiscreteConditional, constructors_alt_interface) { r2 += 2.0, 3.0; r3 += 1.0, 4.0; table += r1, r2, r3; - auto actual1 = boost::make_shared(X | Y = table); - EXPECT(actual1); + DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); - EXPECT(assert_equal(expected1, *actual1, 1e-9)); + EXPECT(assert_equal(expected1, actual1, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { - DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional actual(2, factor); - auto expected = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(*expected, actual, 1e-5)); +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + // And for good measure: + EXPECT(assert_equal(expected, actual)); + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT(actualA.frontals() == KeyVector{1}); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT(actualB.frontals() == KeyVector{0}); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); +} + +/* ************************************************************************* */ +// Check calculation of marginals in case branches are pruned +TEST(DiscreteConditional, marginals2) { + DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen! + DiscreteConditional conditional(A | B = "2/2 3/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + GTSAM_PRINT(pAB); + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "8/4"); + EXPECT(assert_equal(pA, actualA)); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); +} + +/* ************************************************************************* */ +TEST(DiscreteConditional, likelihood) { + DiscreteKey X(0, 2), Y(1, 3); + DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); + + auto actual0 = conditional.likelihood(0); + DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); + EXPECT(assert_equal(expected0, *actual0, 1e-9)); + + auto actual1 = conditional.likelihood(1); + DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); +} + +/* ************************************************************************* */ +// Check choose on P(C|D,E) +TEST(DiscreteConditional, choose) { + DiscreteKey C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // Case 1: no given values: no-op + DiscreteValues given; + auto actual1 = C_given_DE.choose(given); + EXPECT(assert_equal(C_given_DE, *actual1, 1e-9)); + + // Case 2: 1 given value + given[D.first] = 1; + auto actual2 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual2->nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual2->nrParents()); + DiscreteConditional expected2(C | E = "1/1 1/4"); + EXPECT(assert_equal(expected2, *actual2, 1e-9)); + + // Case 2: 2 given values + given[E.first] = 0; + auto actual3 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual3->nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual3->nrParents()); + DiscreteConditional expected3(C % "1/1"); + EXPECT(assert_equal(expected3, *actual3, 1e-9)); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents. +TEST(DiscreteConditional, markdown_prior) { + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1):*\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|0|0.2|\n" + "|1|0.4|\n" + "|2|0.4|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, no parents + names. +TEST(DiscreteConditional, markdown_prior_names) { + Symbol x1('x', 1); + DiscreteKey A(x1, 3); + DiscreteConditional conditional(A % "1/2/2"); + string expected = + " *P(x1):*\n\n" + "|x1|value|\n" + "|:-:|:-:|\n" + "|A0|0.2|\n" + "|A1|0.4|\n" + "|A2|0.4|\n"; + DecisionTreeFactor::Names names{{x1, {"A0", "A1", "A2"}}}; + string actual = conditional.markdown(DefaultKeyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, multivalued. +TEST(DiscreteConditional, markdown_multivalued) { + DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5); + DiscreteConditional conditional( + A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); + string expected = + " *P(a1|b1):*\n\n" + "|*b1*|0|1|2|\n" + "|:-:|:-:|:-:|:-:|\n" + "|0|0.02|0.88|0.1|\n" + "|1|0.02|0.2|0.78|\n" + "|2|0.33|0.33|0.34|\n" + "|3|0.33|0.33|0.34|\n" + "|4|0.95|0.02|0.03|\n"; + string actual = conditional.markdown(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected, two parents + names. +TEST(DiscreteConditional, markdown) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + " *P(A|B,C):*\n\n" + "|*B*|*C*|T|F|\n" + "|:-:|:-:|:-:|:-:|\n" + "|-|Zero|0|1|\n" + "|-|One|0.25|0.75|\n" + "|-|Two|0.5|0.5|\n" + "|+|Zero|0.75|0.25|\n" + "|+|One|0|1|\n" + "|+|Two|1|0|\n"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.markdown(formatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation looks as expected, two parents + names. +TEST(DiscreteConditional, html) { + DiscreteKey A(2, 2), B(1, 2), C(0, 3); + DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); + string expected = + "
\n" + "

P(A|B,C):

\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
BCTF
-Zero01
-One0.250.75
-Two0.50.5
+Zero0.750.25
+One01
+Two10
\n" + "
"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.html(formatter, names); + EXPECT(actual == expected); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp new file mode 100644 index 000000000..d88b510f8 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -0,0 +1,88 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testDiscreteDistribution.cpp + * @brief unit tests for DiscreteDistribution + * @author Frank dellaert + * @date December 2021 + */ + +#include +#include +#include + +using namespace gtsam; + +static const DiscreteKey X(0, 2); + +/* ************************************************************************* */ +TEST(DiscreteDistribution, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscreteDistribution expected(f); + + DiscreteDistribution actual(X % "2/3"); + EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual.nrParents()); + EXPECT(assert_equal(expected, actual, 1e-9)); + + const std::vector pmf{0.4, 0.6}; + DiscreteDistribution actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteDistribution prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, operator) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); + EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, pmf) { + DiscreteDistribution prior(X % "2/3"); + std::vector expected{0.4, 0.6}; + EXPECT(prior.pmf() == expected); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, sample) { + DiscreteDistribution prior(X % "2/3"); + prior.sample(); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, argmax) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_LONGS_EQUAL(prior.argmax(), 1); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 1defd5acf..0a7d869ec 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -30,8 +30,8 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { - DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); +TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { + DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteFactorGraph graph; graph.add(AI, "1 0 0 1"); @@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); -// graph.print("Graph: "); - DecisionTreeFactor product = graph.product(); - DecisionTreeFactor::shared_ptr sum = product.sum(1); -// sum->print("Debug SUM: "); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); - -// cond->print("marginal:"); - -// pair result = EliminateDiscrete(graph, 1); -// result.first->print("BayesNet: "); -// result.second->print("New factor: "); -// - Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3); - DiscreteEliminationTree eliminationTree(graph, ordering); -// eliminationTree.print("Elimination tree: "); - eliminationTree.eliminate(EliminateDiscrete); -// solver.optimize(); -// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate(); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); + EXPECT(assert_equal(mpe, actualMPE)); } /* ************************************************************************* */ @@ -81,8 +67,8 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { graph.add(P2, "0.9 0.6"); graph.add(P1 & P2, "4 1 10 4"); - // Instantiate Values - DiscreteFactor::Values values; + // Instantiate DiscreteValues + DiscreteValues values; values[0] = 1; values[1] = 1; @@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, test) -{ +TEST(DiscreteFactorGraph, test) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); // A simple factor graph (A)-fAC-(C)-fBC-(B) // with smoothness priors @@ -127,77 +112,124 @@ TEST( DiscreteFactorGraph, test) graph.add(C & B, "3 1 1 3"); // Test EliminateDiscrete - // FIXME: apparently Eliminate returns a conditional rather than a net Ordering frontalKeys; frontalKeys += Key(0); DiscreteConditional::shared_ptr conditional; DecisionTreeFactor::shared_ptr newFactor; boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); - // Check Bayes net + // Check Conditional CHECK(conditional); - DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - // cout << signature << endl; DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); - expected.add(signature); // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); EXPECT(assert_equal(expectedFactor, *newFactor)); - // add conditionals to complete expected Bayes net - expected.add(B | A = "5/3 3/5"); - expected.add(A % "1/1"); - // GTSAM_PRINT(expected); - - // Test elimination tree + // Test using elimination tree Ordering ordering; ordering += Key(0), Key(1), Key(2); DiscreteEliminationTree etree(graph, ordering); DiscreteBayesNet::shared_ptr actual; DiscreteFactorGraph::shared_ptr remainingGraph; boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); - EXPECT(assert_equal(expected, *actual)); -// // Test solver -// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); -// EXPECT(assert_equal(expected, *actual2)); + // Check Bayes net + DiscreteBayesNet expectedBayesNet; + expectedBayesNet.add(signature); + expectedBayesNet.add(B | A = "5/3 3/5"); + expectedBayesNet.add(A % "1/1"); + EXPECT(assert_equal(expectedBayesNet, *actual)); - // Test optimization - DiscreteFactor::Values expectedValues; - insert(expectedValues)(0, 0)(1, 0)(2, 0); - DiscreteFactor::sharedValues actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, *actualValues)); + // Test eliminateSequential + DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); + EXPECT(assert_equal(expectedBayesNet, *actual2)); + + // Test mpe + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Test sumProduct alias with all orderings: + auto mpeProbability = expectedBayesNet(mpe); + EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + + // Using custom ordering + DiscreteBayesNet bayesNet = graph.sumProduct(ordering); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + auto bayesNet = graph.sumProduct(orderingType); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + } } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE) -{ +TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) { // Declare a bunch of keys - DiscreteKey C(0,2), A(1,2), B(2,2); + DiscreteKey C(0, 2), A(1, 2), B(2, 2); // Create Factor graph DiscreteFactorGraph graph; graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // graph.product().print(); - // DiscreteSequentialSolver(graph).eliminate()->print(); - DiscreteFactor::sharedValues actualMPE = graph.optimize(); + // Created expected MPE + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 1)(2, 1); - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, *actualMPE)); + // Do max-product with different orderings + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + DiscreteLookupDAG dag = graph.maxProduct(orderingType); + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); + auto actualMPE2 = graph.optimize(); // all in one + EXPECT(assert_equal(mpe, actualMPE2)); + } } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -{ +TEST(DiscreteFactorGraph, marginalIsNotMPE) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + // Optimize on BayesNet maximizes marginal, then the conditional marginals: + auto notOptimal = bayesNet.optimize(); + EXPECT(graph(notOptimal) < graph(mpe)); + EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +#endif +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { // The factor graph in Darwiche09book, page 244 - DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); + DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2); // Create Factor graph DiscreteFactorGraph graph; @@ -206,53 +238,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); - graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) - //graph.product().print("Darwiche-product"); - // graph.product().potentials().dot("Darwiche-product"); - // DiscreteSequentialSolver(graph).eliminate()->print(); + graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) - DiscreteFactor::Values expectedMPE; - insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); + DiscreteValues mpe; + insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1); + EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression + // You can check visually by printing product: + // graph.product().print("Darwiche-product"); - // Use the solver machinery. - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - DiscreteFactor::sharedValues actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, *actualMPE)); -// DiscreteConditional::shared_ptr root = chordal->back(); -// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); - - // Let us create the Bayes tree here, just for fun, because we don't use it now -// typedef JunctionTreeOrdered JT; -// GenericMultifrontalSolver solver(graph); -// BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); -//// bayesTree->print("Bayes Tree"); -// EXPECT_LONGS_EQUAL(2,bayesTree->size()); + // Check MPE. + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4); - DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); - // bayesTree->print("Bayes Tree"); - EXPECT_LONGS_EQUAL(2,bayesTree->size()); - -#ifdef OLD -// Create the elimination tree manually -VariableIndexOrdered structure(graph); -typedef EliminationTreeOrdered ETree; -ETree::shared_ptr eTree = ETree::Create(graph, structure); -//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<"); - -// eliminate normally and check solution -DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); -// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -DiscreteFactor::sharedValues actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, *actualMPE)); - -// Approximate and check solution -// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); -// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<"); -// EXPECT(assert_equal(expectedMPE, *actualMPE)); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4); + auto chordal = graph.eliminateSequential(ordering); + EXPECT_LONGS_EQUAL(5, chordal->size()); +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + auto notOptimal = chordal->optimize(); // not MPE ! + EXPECT(graph(notOptimal) < graph(mpe)); #endif + + // Let us create the Bayes tree here, just for fun, because we don't use it + DiscreteBayesTree::shared_ptr bayesTree = + graph.eliminateMultifrontal(ordering); + // bayesTree->print("Bayes Tree"); + EXPECT_LONGS_EQUAL(2, bayesTree->size()); } + #ifdef OLD /* ************************************************************************* */ @@ -359,6 +373,100 @@ cout << unicorns; } #endif +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, Dot) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string actual = graph.dot(); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var1[label=\"1\"];\n" + " var2[label=\"2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var0--factor0;\n" + " var1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " var0--factor1;\n" + " var2--factor1;\n" + "}\n"; + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, DotWithNames) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.dot(formatter); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varC[label=\"C\"];\n" + " varA[label=\"A\"];\n" + " varB[label=\"B\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varC--factor0;\n" + " varA--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varC--factor1;\n" + " varB--factor1;\n" + "}\n"; + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteFactorGraph, markdown) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string expected = + "`DiscreteFactorGraph` of size 2\n" + "\n" + "factor 0:\n" + "|C|A|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.2|\n" + "|0|1|0.8|\n" + "|1|0|0.3|\n" + "|1|1|0.7|\n" + "\n" + "factor 1:\n" + "|C|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.1|\n" + "|0|1|0.9|\n" + "|1|0|0.4|\n" + "|1|1|0.6|\n\n"; + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph.markdown(formatter); + EXPECT(actual == expected); + + // Make sure values are correctly displayed. + DiscreteValues values; + values[0] = 1; + values[1] = 0; + EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteLookupDAG.cpp b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp new file mode 100644 index 000000000..04b859780 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteLookupDAG.cpp + * + * @date January, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using namespace gtsam; +using namespace boost::assign; + +/* ************************************************************************* */ +TEST(DiscreteLookupDAG, argmax) { + using ADT = AlgebraicDecisionTree; + + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create lookup table corresponding to "marginalIsNotMPE" in testDFG. + DiscreteLookupDAG dag; + + ADT adtB(DiscreteKeys{B, A}, std::vector{0.5, 1. / 3, 0.5, 2. / 3}); + dag.add(1, DiscreteKeys{B, A}, adtB); + + ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19)); + dag.add(1, DiscreteKeys{A}, adtA); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // check: + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index e1eb92af3..3208f81c5 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -47,7 +47,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(Cathy.first); - DiscreteFactor::Values values; + DiscreteValues values; values[Cathy.first] = 0; EXPECT_DOUBLES_EQUAL( 0.359631, (*actualC)(values), 1e-6); @@ -94,7 +94,7 @@ TEST_UNSAFE( DiscreteMarginals, UGM_chain ) { DiscreteMarginals marginals(graph); DiscreteFactor::shared_ptr actualC = marginals(key[2].first); - DiscreteFactor::Values values; + DiscreteValues values; values[key[2].first] = 0; EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4); @@ -164,11 +164,11 @@ TEST_UNSAFE(DiscreteMarginals, truss2) { graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8"); // Calculate the marginals by brute force - vector allPosbValues = - cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]); + auto allPosbValues = DiscreteValues::CartesianProduct( + key[0] & key[1] & key[2] & key[3] & key[4]); Vector T = Z_5x1, F = Z_5x1; for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteFactor::Values x = allPosbValues[i]; + DiscreteValues x = allPosbValues[i]; double px = graph(x); for (size_t j = 0; j < 5; j++) if (x[j]) diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp new file mode 100644 index 000000000..c8a1fa168 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -0,0 +1,76 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteValues.cpp + * + * @date Jan, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +#include +using namespace boost::assign; + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DiscreteValues, markdownWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "|Variable|value|\n" + "|:-:|:-:|\n" + "|B|-|\n" + "|A|One|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +// Check html representation with a value formatter. +TEST(DiscreteValues, htmlWithValueFormatter) { + DiscreteValues values; + values[12] = 1; // A + values[5] = 0; // B + string expected = + "
\n" + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
Variablevalue
B-
AOne
\n" + "
"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DiscreteValues::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; + string actual = values.html(keyFormatter, names); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testSignature.cpp b/gtsam/discrete/tests/testSignature.cpp index 049c455f7..737bd8aef 100644 --- a/gtsam/discrete/tests/testSignature.cpp +++ b/gtsam/discrete/tests/testSignature.cpp @@ -32,22 +32,27 @@ DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); /* ************************************************************************* */ TEST(testSignature, simple_conditional) { - Signature sig(X | Y = "1/1 2/3 1/4"); + Signature sig(X, {Y}, "1/1 2/3 1/4"); + CHECK(sig.table()); Signature::Table table = *sig.table(); vector row[3]{{0.5, 0.5}, {0.4, 0.6}, {0.2, 0.8}}; + LONGS_EQUAL(3, table.size()); CHECK(row[0] == table[0]); CHECK(row[1] == table[1]); CHECK(row[2] == table[2]); - DiscreteKey actKey = sig.key(); - LONGS_EQUAL(X.first, actKey.first); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + CHECK(sig.key() == X); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); + + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); } /* ************************************************************************* */ @@ -60,16 +65,56 @@ TEST(testSignature, simple_conditional_nonparser) { table += row1, row2, row3; Signature sig(X | Y = table); - DiscreteKey actKey = sig.key(); - EXPECT_LONGS_EQUAL(X.first, actKey.first); + CHECK(sig.key() == X); - DiscreteKeys actKeys = sig.discreteKeys(); - LONGS_EQUAL(2, actKeys.size()); - LONGS_EQUAL(X.first, actKeys.front().first); - LONGS_EQUAL(Y.first, actKeys.back().first); + DiscreteKeys keys = sig.discreteKeys(); + LONGS_EQUAL(2, keys.size()); + CHECK(keys[0] == X); + CHECK(keys[1] == Y); - vector actCpt = sig.cpt(); - EXPECT_LONGS_EQUAL(6, actCpt.size()); + DiscreteKeys parents = sig.parents(); + LONGS_EQUAL(1, parents.size()); + CHECK(parents[0] == Y); + + EXPECT_LONGS_EQUAL(6, sig.cpt().size()); +} + +/* ************************************************************************* */ +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), D(7, 2); + +// Make sure we can create all signatures for Asia network with constructor. +TEST(testSignature, all_examples) { + DiscreteKey X(6, 2); + Signature a(A, {}, "99/1"); + Signature s(S, {}, "50/50"); + Signature t(T, {A}, "99/1 95/5"); + Signature l(L, {S}, "99/1 90/10"); + Signature b(B, {S}, "70/30 40/60"); + Signature e(E, {T, L}, "F F F 1"); + Signature x(X, {E}, "95/5 2/98"); +} + +// Make sure we can create all signatures for Asia network with operator magic. +TEST(testSignature, all_examples_magic) { + DiscreteKey X(6, 2); + Signature a(A % "99/1"); + Signature s(S % "50/50"); + Signature t(T | A = "99/1 95/5"); + Signature l(L | S = "99/1 90/10"); + Signature b(B | S = "70/30 40/60"); + Signature e((E | T, L) = "F F F 1"); + Signature x(X | E = "95/5 2/98"); +} + +// Check example from docs. +TEST(testSignature, doxygen_example) { + Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}}; + Signature d1(D, {E, B}, table); + Signature d2((D | E, B) = "9/1 2/8 3/7 1/9"); + Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9"); + EXPECT(*(d1.table()) == table); + EXPECT(*(d2.table()) == table); + EXPECT(*(d3.table()) == table); } /* ************************************************************************* */ diff --git a/gtsam/geometry/Cal3.h b/gtsam/geometry/Cal3.h index 08ce4c1e6..1690615dd 100644 --- a/gtsam/geometry/Cal3.h +++ b/gtsam/geometry/Cal3.h @@ -170,9 +170,9 @@ class GTSAM_EXPORT Cal3 { return K; } -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** @deprecated The following function has been deprecated, use K above */ - Matrix3 matrix() const { return K(); } + Matrix3 GTSAM_DEPRECATED matrix() const { return K(); } #endif /// Return inverted calibration matrix inv(K) diff --git a/gtsam/geometry/Cal3Bundler.h b/gtsam/geometry/Cal3Bundler.h index 0d7c1be9d..b240603fc 100644 --- a/gtsam/geometry/Cal3Bundler.h +++ b/gtsam/geometry/Cal3Bundler.h @@ -97,12 +97,12 @@ class GTSAM_EXPORT Cal3Bundler : public Cal3 { Vector3 vector() const; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// get parameter u0 - inline double u0() const { return u0_; } + inline double GTSAM_DEPRECATED u0() const { return u0_; } /// get parameter v0 - inline double v0() const { return v0_; } + inline double GTSAM_DEPRECATED v0() const { return v0_; } #endif /** diff --git a/gtsam/geometry/Cal3Fisheye.cpp b/gtsam/geometry/Cal3Fisheye.cpp index 52d475d5d..fd2c7ab65 100644 --- a/gtsam/geometry/Cal3Fisheye.cpp +++ b/gtsam/geometry/Cal3Fisheye.cpp @@ -46,9 +46,9 @@ double Cal3Fisheye::Scaling(double r) { /* ************************************************************************* */ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1, OptionalJacobian<2, 2> H2) const { - const double xi = p.x(), yi = p.y(); + const double xi = p.x(), yi = p.y(), zi = 1; const double r2 = xi * xi + yi * yi, r = sqrt(r2); - const double t = atan(r); + const double t = atan2(r, zi); const double t2 = t * t, t4 = t2 * t2, t6 = t2 * t4, t8 = t4 * t4; Vector5 K, T; K << 1, k1_, k2_, k3_, k4_; @@ -76,28 +76,32 @@ Point2 Cal3Fisheye::uncalibrate(const Point2& p, OptionalJacobian<2, 9> H1, // Derivative for points in intrinsic coords (2 by 2) if (H2) { - const double dtd_dt = - 1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8; - const double dt_dr = 1 / (1 + r2); - const double rinv = 1 / r; - const double dr_dxi = xi * rinv; - const double dr_dyi = yi * rinv; - const double dtd_dxi = dtd_dt * dt_dr * dr_dxi; - const double dtd_dyi = dtd_dt * dt_dr * dr_dyi; + if (r2==0) { + *H2 = DK; + } else { + const double dtd_dt = + 1 + 3 * k1_ * t2 + 5 * k2_ * t4 + 7 * k3_ * t6 + 9 * k4_ * t8; + const double R2 = r2 + zi*zi; + const double dt_dr = zi / R2; + const double rinv = 1 / r; + const double dr_dxi = xi * rinv; + const double dr_dyi = yi * rinv; + const double dtd_dr = dtd_dt * dt_dr; + + const double c2 = dr_dxi * dr_dxi; + const double s2 = dr_dyi * dr_dyi; + const double cs = dr_dxi * dr_dyi; - const double td = t * K.dot(T); - const double rrinv = 1 / r2; - const double dxd_dxi = - dtd_dxi * dr_dxi + td * rinv - td * xi * rrinv * dr_dxi; - const double dxd_dyi = dtd_dyi * dr_dxi - td * xi * rrinv * dr_dyi; - const double dyd_dxi = dtd_dxi * dr_dyi - td * yi * rrinv * dr_dxi; - const double dyd_dyi = - dtd_dyi * dr_dyi + td * rinv - td * yi * rrinv * dr_dyi; + const double dxd_dxi = dtd_dr * c2 + s * (1 - c2); + const double dxd_dyi = (dtd_dr - s) * cs; + const double dyd_dxi = dxd_dyi; + const double dyd_dyi = dtd_dr * s2 + s * (1 - s2); - Matrix2 DR; - DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi; + Matrix2 DR; + DR << dxd_dxi, dxd_dyi, dyd_dxi, dyd_dyi; - *H2 = DK * DR; + *H2 = DK * DR; + } } return uv; diff --git a/gtsam/geometry/PinholeCamera.h b/gtsam/geometry/PinholeCamera.h index c1f0b6b3f..61e9f0909 100644 --- a/gtsam/geometry/PinholeCamera.h +++ b/gtsam/geometry/PinholeCamera.h @@ -312,6 +312,16 @@ public: return range(camera.pose(), Dcamera, Dother); } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return K_.K() * PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_.fx());; + } + private: /** Serialization function */ diff --git a/gtsam/geometry/PinholePose.h b/gtsam/geometry/PinholePose.h index 7a0b08227..b4999af7c 100644 --- a/gtsam/geometry/PinholePose.h +++ b/gtsam/geometry/PinholePose.h @@ -121,6 +121,13 @@ public: return _project(pw, Dpose, Dpoint, Dcal); } + /// project a 3D point from world coordinates into the image + Point2 reprojectionError(const Point3& pw, const Point2& measured, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none, + OptionalJacobian<2, DimK> Dcal = boost::none) const { + return Point2(_project(pw, Dpose, Dpoint, Dcal) - measured); + } + /// project a point at infinity from world coordinates into the image Point2 project(const Unit3& pw, OptionalJacobian<2, 6> Dpose = boost::none, OptionalJacobian<2, 2> Dpoint = boost::none, @@ -159,7 +166,6 @@ public: return result; } - /// backproject a 2-dimensional point to a 3-dimensional point at infinity Unit3 backprojectPointAtInfinity(const Point2& p) const { const Point2 pn = calibration().calibrate(p); @@ -410,6 +416,16 @@ public: return PinholePose(); // assumes that the default constructor is valid } + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + Matrix34 P = Matrix34(PinholeBase::pose().inverse().matrix().block(0, 0, 3, 4)); + return K_->K() * P; + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } /// @} private: diff --git a/gtsam/geometry/Quaternion.h b/gtsam/geometry/Quaternion.h index 1557a09db..2ef47d58e 100644 --- a/gtsam/geometry/Quaternion.h +++ b/gtsam/geometry/Quaternion.h @@ -117,13 +117,23 @@ struct traits { omega = (-8. / 3. - 2. / 3. * qw) * q.vec(); } else { // Normal, away from zero case - _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); - // Important: convert to [-pi,pi] to keep error continuous - if (angle > M_PI) - angle -= twoPi; - else if (angle < -M_PI) - angle += twoPi; - omega = (angle / s) * q.vec(); + if (qw > 0) { + _Scalar angle = 2 * acos(qw), s = sqrt(1 - qw * qw); + // Important: convert to [-pi,pi] to keep error continuous + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * q.vec(); + } else { + // Make sure that we are using a canonical quaternion with w > 0 + _Scalar angle = 2 * acos(-qw), s = sqrt(1 - qw * qw); + if (angle > M_PI) + angle -= twoPi; + else if (angle < -M_PI) + angle += twoPi; + omega = (angle / s) * -q.vec(); + } } if(H) *H = SO3::LogmapDerivative(omega.template cast()); diff --git a/gtsam/geometry/Rot3.h b/gtsam/geometry/Rot3.h index abd74e063..18bd88b52 100644 --- a/gtsam/geometry/Rot3.h +++ b/gtsam/geometry/Rot3.h @@ -49,16 +49,14 @@ namespace gtsam { - /** - * @brief A 3D rotation represented as a rotation matrix if the preprocessor - * symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it - * is defined. - * @addtogroup geometry - * \nosubgrouping - */ - class GTSAM_EXPORT Rot3 : public LieGroup { - - private: +/** + * @brief Rot3 is a 3D rotation represented as a rotation matrix if the + * preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion + * if it is defined. + * @addtogroup geometry + */ +class GTSAM_EXPORT Rot3 : public LieGroup { + private: #ifdef GTSAM_USE_QUATERNIONS /** Internal Eigen Quaternion */ @@ -67,8 +65,7 @@ namespace gtsam { SO3 rot_; #endif - public: - + public: /// @name Constructors and named constructors /// @{ @@ -83,7 +80,7 @@ namespace gtsam { */ Rot3(const Point3& col1, const Point3& col2, const Point3& col3); - /** constructor from a rotation matrix, as doubles in *row-major* order !!! */ + /// Construct from a rotation matrix, as doubles in *row-major* order !!! Rot3(double R11, double R12, double R13, double R21, double R22, double R23, double R31, double R32, double R33); @@ -567,6 +564,9 @@ namespace gtsam { #endif }; + /// std::vector of Rot3s, mainly for wrapper + using Rot3Vector = std::vector >; + /** * [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R * and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx' @@ -585,5 +585,6 @@ namespace gtsam { template<> struct traits : public internal::LieGroup {}; -} + +} // namespace gtsam diff --git a/gtsam/geometry/Similarity3.cpp b/gtsam/geometry/Similarity3.cpp index fcaf0c874..e8d6e7510 100644 --- a/gtsam/geometry/Similarity3.cpp +++ b/gtsam/geometry/Similarity3.cpp @@ -40,8 +40,10 @@ static Point3Pairs subtractCentroids(const Point3Pairs &abPointPairs, } /// Form inner products x and y and calculate scale. -static const double calculateScale(const Point3Pairs &d_abPointPairs, - const Rot3 &aRb) { +// We force the scale to be a non-negative quantity +// (see Section 10.1 of https://ethaneade.com/lie_groups.pdf) +static double calculateScale(const Point3Pairs &d_abPointPairs, + const Rot3 &aRb) { double x = 0, y = 0; Point3 da, db; for (const Point3Pair& d_abPair : d_abPointPairs) { @@ -50,7 +52,7 @@ static const double calculateScale(const Point3Pairs &d_abPointPairs, y += da.transpose() * da_prime; x += da_prime.transpose() * da_prime; } - const double s = y / x; + const double s = std::fabs(y / x); return s; } diff --git a/gtsam/geometry/SimpleCamera.cpp b/gtsam/geometry/SimpleCamera.cpp index d1a5ed330..be6a010b2 100644 --- a/gtsam/geometry/SimpleCamera.cpp +++ b/gtsam/geometry/SimpleCamera.cpp @@ -21,8 +21,8 @@ namespace gtsam { -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 - SimpleCamera simpleCamera(const Matrix34& P) { +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + SimpleCamera GTSAM_DEPRECATED simpleCamera(const Matrix34& P) { // P = [A|a] = s K cRw [I|-T], with s the unknown scale Matrix3 A = P.topLeftCorner(3, 3); diff --git a/gtsam/geometry/SimpleCamera.h b/gtsam/geometry/SimpleCamera.h index 5ff6b9816..f0776c2e2 100644 --- a/gtsam/geometry/SimpleCamera.h +++ b/gtsam/geometry/SimpleCamera.h @@ -37,7 +37,7 @@ namespace gtsam { using PinholeCameraCal3Unified = gtsam::PinholeCamera; using PinholeCameraCal3Fisheye = gtsam::PinholeCamera; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** * @deprecated: SimpleCamera for backwards compatability with GTSAM 3.x * Use PinholeCameraCal3_S2 instead diff --git a/gtsam/geometry/SphericalCamera.cpp b/gtsam/geometry/SphericalCamera.cpp new file mode 100644 index 000000000..58a29dc09 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.cpp @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include + +using namespace std; + +namespace gtsam { + +/* ************************************************************************* */ +bool SphericalCamera::equals(const SphericalCamera& camera, double tol) const { + return pose_.equals(camera.pose(), tol); +} + +/* ************************************************************************* */ +void SphericalCamera::print(const string& s) const { pose_.print(s + ".pose"); } + +/* ************************************************************************* */ +pair SphericalCamera::projectSafe(const Point3& pw) const { + const Point3 pc = pose().transformTo(pw); + Unit3 pu = Unit3::FromPoint3(pc); + return make_pair(pu, pc.norm() > 1e-8); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Point3& pw, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + Matrix36 Dtf_pose; + Matrix3 Dtf_point; // calculated by transformTo if needed + const Point3 pc = + pose().transformTo(pw, Dpose ? &Dtf_pose : 0, Dpoint ? &Dtf_point : 0); + + if (pc.norm() <= 1e-8) throw("point cannot be at the center of the camera"); + + Matrix23 Dunit; // calculated by FromPoint3 if needed + Unit3 pu = Unit3::FromPoint3(Point3(pc), Dpoint ? &Dunit : 0); + + if (Dpose) *Dpose = Dunit * Dtf_pose; // 2x3 * 3x6 = 2x6 + if (Dpoint) *Dpoint = Dunit * Dtf_point; // 2x3 * 3x3 = 2x3 + return pu; +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 2> Dpoint) const { + Matrix23 Dtf_rot; + Matrix2 Dtf_point; // calculated by transformTo if needed + const Unit3 pu = pose().rotation().unrotate(pwu, Dpose ? &Dtf_rot : 0, + Dpoint ? &Dtf_point : 0); + + if (Dpose) + *Dpose << Dtf_rot, Matrix::Zero(2, 3); // 2x6 (translation part is zero) + if (Dpoint) *Dpoint = Dtf_point; // 2x2 + return pu; +} + +/* ************************************************************************* */ +Point3 SphericalCamera::backproject(const Unit3& pu, const double depth) const { + return pose().transformFrom(depth * pu); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::backprojectPointAtInfinity(const Unit3& p) const { + return pose().rotation().rotate(p); +} + +/* ************************************************************************* */ +Unit3 SphericalCamera::project(const Point3& point, + OptionalJacobian<2, 6> Dcamera, + OptionalJacobian<2, 3> Dpoint) const { + return project2(point, Dcamera, Dpoint); +} + +/* ************************************************************************* */ +Vector2 SphericalCamera::reprojectionError( + const Point3& point, const Unit3& measured, OptionalJacobian<2, 6> Dpose, + OptionalJacobian<2, 3> Dpoint) const { + // project point + if (Dpose || Dpoint) { + Matrix26 H_project_pose; + Matrix23 H_project_point; + Matrix22 H_error; + Unit3 projected = project2(point, H_project_pose, H_project_point); + Vector2 error = measured.errorVector(projected, boost::none, H_error); + if (Dpose) *Dpose = H_error * H_project_pose; + if (Dpoint) *Dpoint = H_error * H_project_point; + return error; + } else { + return measured.errorVector(project2(point, Dpose, Dpoint)); + } +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/geometry/SphericalCamera.h b/gtsam/geometry/SphericalCamera.h new file mode 100644 index 000000000..4880423d3 --- /dev/null +++ b/gtsam/geometry/SphericalCamera.h @@ -0,0 +1,241 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +/** + * Empty calibration. Only needed to play well with other cameras + * (e.g., when templating functions wrt cameras), since other cameras + * have constuctors in the form ‘camera(pose,calibration)’ + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT EmptyCal { + public: + enum { dimension = 0 }; + EmptyCal() {} + virtual ~EmptyCal() = default; + using shared_ptr = boost::shared_ptr; + + /// return DOF, dimensionality of tangent space + inline static size_t Dim() { return dimension; } + + void print(const std::string& s) const { + std::cout << "empty calibration: " << s << std::endl; + } + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "EmptyCal", boost::serialization::base_object(*this)); + } +}; + +/** + * A spherical camera class that has a Pose3 and measures bearing vectors. + * The camera has an ‘Empty’ calibration and the only 6 dof are the pose + * @addtogroup geometry + * \nosubgrouping + */ +class GTSAM_EXPORT SphericalCamera { + public: + enum { dimension = 6 }; + + using Measurement = Unit3; + using MeasurementVector = std::vector; + using CalibrationType = EmptyCal; + + private: + Pose3 pose_; ///< 3D pose of camera + + protected: + EmptyCal::shared_ptr emptyCal_; + + public: + /// @} + /// @name Standard Constructors + /// @{ + + /// Default constructor + SphericalCamera() + : pose_(Pose3::identity()), emptyCal_(boost::make_shared()) {} + + /// Constructor with pose + explicit SphericalCamera(const Pose3& pose) + : pose_(pose), emptyCal_(boost::make_shared()) {} + + /// Constructor with empty intrinsics (needed for smart factors) + explicit SphericalCamera(const Pose3& pose, + const EmptyCal::shared_ptr& cal) + : pose_(pose), emptyCal_(cal) {} + + /// @} + /// @name Advanced Constructors + /// @{ + explicit SphericalCamera(const Vector& v) : pose_(Pose3::Expmap(v)) {} + + /// Default destructor + virtual ~SphericalCamera() = default; + + /// return shared pointer to calibration + const EmptyCal::shared_ptr& sharedCalibration() const { + return emptyCal_; + } + + /// return calibration + const EmptyCal& calibration() const { return *emptyCal_; } + + /// @} + /// @name Testable + /// @{ + + /// assert equality up to a tolerance + bool equals(const SphericalCamera& camera, double tol = 1e-9) const; + + /// print + virtual void print(const std::string& s = "SphericalCamera") const; + + /// @} + /// @name Standard Interface + /// @{ + + /// return pose, constant version + const Pose3& pose() const { return pose_; } + + /// get rotation + const Rot3& rotation() const { return pose_.rotation(); } + + /// get translation + const Point3& translation() const { return pose_.translation(); } + + // /// return pose, with derivative + // const Pose3& getPose(OptionalJacobian<6, 6> H) const; + + /// @} + /// @name Transformations and measurement functions + /// @{ + + /// Project a point into the image and check depth + std::pair projectSafe(const Point3& pw) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Point3& pw, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D direction in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project2(const Unit3& pwu, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 2> Dpoint = boost::none) const; + + /// backproject a 2-dimensional point to a 3-dimensional point at given depth + Point3 backproject(const Unit3& p, const double depth) const; + + /// backproject point at infinity + Unit3 backprojectPointAtInfinity(const Unit3& p) const; + + /** Project point into the image + * (note: there is no CheiralityException for a spherical camera) + * @param point 3D point in world coordinates + * @return the intrinsic coordinates of the projected point + */ + Unit3 project(const Point3& point, OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + + /** Compute reprojection error for a given 3D point in world coordinates + * @param point 3D point in world coordinates + * @return the tangent space error between the projection and the measurement + */ + Vector2 reprojectionError(const Point3& point, const Unit3& measured, + OptionalJacobian<2, 6> Dpose = boost::none, + OptionalJacobian<2, 3> Dpoint = boost::none) const; + /// @} + + /// move a cameras according to d + SphericalCamera retract(const Vector6& d) const { + return SphericalCamera(pose().retract(d)); + } + + /// return canonical coordinate + Vector6 localCoordinates(const SphericalCamera& p) const { + return pose().localCoordinates(p.pose()); + } + + /// for Canonical + static SphericalCamera identity() { + return SphericalCamera( + Pose3::identity()); // assumes that the default constructor is valid + } + + /// for Linear Triangulation + Matrix34 cameraProjectionMatrix() const { + return Matrix34(pose_.inverse().matrix().block(0, 0, 3, 4)); + } + + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension, 1>::Constant(0.0); + } + + /// @deprecated + size_t dim() const { return 6; } + + /// @deprecated + static size_t Dim() { return 6; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(pose_); + } + + public: + GTSAM_MAKE_ALIGNED_OPERATOR_NEW +}; +// end of class SphericalCamera + +template <> +struct traits : public internal::LieGroup {}; + +template <> +struct traits : public internal::LieGroup {}; + +} // namespace gtsam diff --git a/gtsam/geometry/StereoCamera.h b/gtsam/geometry/StereoCamera.h index 3b5bdaefc..c53fc11c9 100644 --- a/gtsam/geometry/StereoCamera.h +++ b/gtsam/geometry/StereoCamera.h @@ -170,6 +170,11 @@ public: OptionalJacobian<3, 3> H2 = boost::none, OptionalJacobian<3, 0> H3 = boost::none) const; + /// for Nonlinear Triangulation + Vector defaultErrorWhenTriangulatingBehindCamera() const { + return Eigen::Matrix::dimension,1>::Constant(2.0 * K_->fx());; + } + /// @} private: diff --git a/gtsam/geometry/geometry.i b/gtsam/geometry/geometry.i index a40951d3e..1e42966f8 100644 --- a/gtsam/geometry/geometry.i +++ b/gtsam/geometry/geometry.i @@ -27,9 +27,6 @@ class Point2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point2Pairs { @@ -104,9 +101,6 @@ class StereoPoint2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -131,9 +125,6 @@ class Point3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Point3Pairs { @@ -191,9 +182,6 @@ class Rot2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -372,9 +360,6 @@ class Rot3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -433,9 +418,6 @@ class Pose2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; boost::optional align(const gtsam::Point2Pairs& pairs); @@ -502,9 +484,6 @@ class Pose3 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; class Pose3Pairs { @@ -547,9 +526,6 @@ class Unit3 { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::Unit3& expected, double tol) const; }; @@ -611,9 +587,6 @@ class Cal3_S2 { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -642,9 +615,6 @@ virtual class Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -668,9 +638,6 @@ virtual class Cal3DS2 : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -705,9 +672,6 @@ virtual class Cal3Unified : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -750,9 +714,6 @@ class Cal3Fisheye { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -811,9 +772,6 @@ class Cal3Bundler { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -847,9 +805,6 @@ class CalibratedCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -889,9 +844,6 @@ class PinholeCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -962,9 +914,6 @@ class StereoCamera { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -974,27 +923,34 @@ class StereoCamera { gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3_S2* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3DS2* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::Pose3Vector& poses, gtsam::Cal3Bundler* sharedCal, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3_S2& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Bundler& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Fisheye& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulatePoint3(const gtsam::CameraSetCal3Unified& cameras, const gtsam::Point2Vector& measurements, - double rank_tol, bool optimize); + double rank_tol, bool optimize, + const gtsam::SharedNoiseModel& model = nullptr); gtsam::Point3 triangulateNonlinear(const gtsam::Pose3Vector& poses, gtsam::Cal3_S2* sharedCal, const gtsam::Point2Vector& measurements, diff --git a/gtsam/geometry/tests/testSimpleCamera.cpp b/gtsam/geometry/tests/testSimpleCamera.cpp index 18a25c553..173ccf05b 100644 --- a/gtsam/geometry/tests/testSimpleCamera.cpp +++ b/gtsam/geometry/tests/testSimpleCamera.cpp @@ -26,7 +26,7 @@ using namespace std; using namespace gtsam; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 static const Cal3_S2 K(625, 625, 0, 0, 0); diff --git a/gtsam/geometry/tests/testSphericalCamera.cpp b/gtsam/geometry/tests/testSphericalCamera.cpp new file mode 100644 index 000000000..4bc851f35 --- /dev/null +++ b/gtsam/geometry/tests/testSphericalCamera.cpp @@ -0,0 +1,163 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file SphericalCamera.h + * @brief Calibrated camera with spherical projection + * @date Aug 26, 2021 + * @author Luca Carlone + */ + +#include +#include +#include +#include + +#include +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +typedef SphericalCamera Camera; + +// static const Cal3_S2 K(625, 625, 0, 0, 0); +// +static const Pose3 pose(Rot3(Vector3(1, -1, -1).asDiagonal()), + Point3(0, 0, 0.5)); +static const Camera camera(pose); +// +static const Pose3 pose1(Rot3(), Point3(0, 1, 0.5)); +static const Camera camera1(pose1); + +static const Point3 point1(-0.08, -0.08, 0.0); +static const Point3 point2(-0.08, 0.08, 0.0); +static const Point3 point3(0.08, 0.08, 0.0); +static const Point3 point4(0.08, -0.08, 0.0); + +// manually computed in matlab +static const Unit3 bearing1(-0.156054862928174, 0.156054862928174, + 0.975342893301088); +static const Unit3 bearing2(-0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing3(0.156054862928174, -0.156054862928174, + 0.975342893301088); +static const Unit3 bearing4(0.156054862928174, 0.156054862928174, + 0.975342893301088); + +static double depth = 0.512640224719052; +/* ************************************************************************* */ +TEST(SphericalCamera, constructor) { + EXPECT(assert_equal(pose, camera.pose())); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, project) { + // expected from manual calculation in Matlab + EXPECT(assert_equal(camera.project(point1), bearing1)); + EXPECT(assert_equal(camera.project(point2), bearing2)); + EXPECT(assert_equal(camera.project(point3), bearing3)); + EXPECT(assert_equal(camera.project(point4), bearing4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject) { + EXPECT(assert_equal(camera.backproject(bearing1, depth), point1)); + EXPECT(assert_equal(camera.backproject(bearing2, depth), point2)); + EXPECT(assert_equal(camera.backproject(bearing3, depth), point3)); + EXPECT(assert_equal(camera.backproject(bearing4, depth), point4)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, backproject2) { + Point3 origin(0, 0, 0); + Rot3 rot(1., 0., 0., 0., 0., 1., 0., -1., 0.); // a camera1 looking down + Camera camera(Pose3(rot, origin)); + + Point3 actual = camera.backproject(Unit3(0, 0, 1), 1.); + Point3 expected(0., 1., 0.); + pair x = camera.projectSafe(expected); + + EXPECT(assert_equal(expected, actual)); + EXPECT(assert_equal(Unit3(0, 0, 1), x.first)); + EXPECT(x.second); +} + +/* ************************************************************************* */ +static Unit3 project3(const Pose3& pose, const Point3& point) { + return Camera(pose).project(point); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, Dproject) { + Matrix Dpose, Dpoint; + Unit3 result = camera.project(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose, point1); + Matrix numerical_point = numericalDerivative22(project3, pose, point1); + EXPECT(assert_equal(bearing1, result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +static Vector2 reprojectionError2(const Pose3& pose, const Point3& point, + const Unit3& measured) { + return Camera(pose).reprojectionError(point, measured); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError) { + Matrix Dpose, Dpoint; + Vector2 result = camera.reprojectionError(point1, bearing1, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing1); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing1); + EXPECT(assert_equal(Vector2(0.0, 0.0), result)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +TEST(SphericalCamera, reprojectionError_noisy) { + Matrix Dpose, Dpoint; + Unit3 bearing_noisy = bearing1.retract(Vector2(0.01, 0.05)); + Vector2 result = + camera.reprojectionError(point1, bearing_noisy, Dpose, Dpoint); + Matrix numerical_pose = + numericalDerivative31(reprojectionError2, pose, point1, bearing_noisy); + Matrix numerical_point = + numericalDerivative32(reprojectionError2, pose, point1, bearing_noisy); + EXPECT(assert_equal(Vector2(-0.050282, 0.00833482), result, 1e-5)); + EXPECT(assert_equal(numerical_pose, Dpose, 1e-7)); + EXPECT(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +// Add a test with more arbitrary rotation +TEST(SphericalCamera, Dproject2) { + static const Pose3 pose1(Rot3::Ypr(0.1, -0.1, 0.4), Point3(0, 0, -10)); + static const Camera camera(pose1); + Matrix Dpose, Dpoint; + camera.project2(point1, Dpose, Dpoint); + Matrix numerical_pose = numericalDerivative21(project3, pose1, point1); + Matrix numerical_point = numericalDerivative22(project3, pose1, point1); + CHECK(assert_equal(numerical_pose, Dpose, 1e-7)); + CHECK(assert_equal(numerical_point, Dpoint, 1e-7)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/geometry/tests/testTriangulation.cpp b/gtsam/geometry/tests/testTriangulation.cpp index 4f71a48da..3a09f49bc 100644 --- a/gtsam/geometry/tests/testTriangulation.cpp +++ b/gtsam/geometry/tests/testTriangulation.cpp @@ -10,22 +10,23 @@ * -------------------------------------------------------------------------- */ /** - * testTriangulation.cpp - * - * Created on: July 30th, 2013 - * Author: cbeall3 + * @file testTriangulation.cpp + * @brief triangulation utilities + * @date July 30th, 2013 + * @author Chris Beall (cbeall3) + * @author Luca Carlone */ -#include -#include -#include -#include -#include -#include -#include -#include #include - +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -36,7 +37,7 @@ using namespace boost::assign; // Some common constants -static const boost::shared_ptr sharedCal = // +static const boost::shared_ptr sharedCal = // boost::make_shared(1500, 1200, 0, 640, 480); // Looking along X-axis, 1 meter above ground plane (x-y) @@ -57,8 +58,7 @@ Point2 z2 = camera2.project(landmark); //****************************************************************************** // Simple test with a well-behaved two camera situation -TEST( triangulation, twoPoses) { - +TEST(triangulation, twoPoses) { vector poses; Point2Vector measurements; @@ -69,36 +69,36 @@ TEST( triangulation, twoPoses) { // 1. Test simple DLT, perfect in no noise situation bool optimize = false; - boost::optional actual1 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual1 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual1, 1e-7)); // 2. test with optimization on, same answer optimize = true; - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual2, 1e-7)); - // 3. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 3. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); optimize = false; - boost::optional actual3 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual3 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual3, 1e-4)); // 4. Now with optimization on optimize = true; - boost::optional actual4 = // - triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); + boost::optional actual4 = // + triangulatePoint3(poses, sharedCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19814), *actual4, 1e-4)); } //****************************************************************************** // Similar, but now with Bundler calibration -TEST( triangulation, twoPosesBundler) { - - boost::shared_ptr bundlerCal = // +TEST(triangulation, twoPosesBundler) { + boost::shared_ptr bundlerCal = // boost::make_shared(1500, 0, 0, 640, 480); PinholeCamera camera1(pose1, *bundlerCal); PinholeCamera camera2(pose2, *bundlerCal); @@ -116,37 +116,38 @@ TEST( triangulation, twoPosesBundler) { bool optimize = true; double rank_tol = 1e-9; - boost::optional actual = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); + boost::optional actual = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); EXPECT(assert_equal(landmark, *actual, 1e-7)); // Add some noise and try again measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); + boost::optional actual2 = // + triangulatePoint3(poses, bundlerCal, measurements, rank_tol, optimize); EXPECT(assert_equal(Point3(4.995, 0.499167, 1.19847), *actual2, 1e-4)); } //****************************************************************************** -TEST( triangulation, fourPoses) { +TEST(triangulation, fourPoses) { vector poses; Point2Vector measurements; poses += pose1, pose2; measurements += z1, z2; - boost::optional actual = triangulatePoint3(poses, sharedCal, - measurements); + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -157,13 +158,13 @@ TEST( triangulation, fourPoses) { poses += pose3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(poses, sharedCal, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(poses, sharedCal, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(poses, - sharedCal, measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -176,13 +177,101 @@ TEST( triangulation, fourPoses) { poses += pose4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, fourPoses_distinct_Ks) { +TEST(triangulation, threePoses_robustNoiseModel) { + + Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); + PinholeCamera camera3(pose3, *sharedCal); + Point2 z3 = camera3.project(landmark); + + vector poses; + Point2Vector measurements; + poses += pose1, pose2, pose3; + measurements += z1, z2, z3; + + // noise free, so should give exactly the landmark + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); + EXPECT(assert_equal(landmark, *actual, 1e-2)); + + // Add outlier + measurements.at(0) += Point2(100, 120); // very large pixel noise! + + // now estimate does not match landmark + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); + // DLT is surprisingly robust, but still off (actual error is around 0.26m): + EXPECT( (landmark - *actual2).norm() >= 0.2); + EXPECT( (landmark - *actual2).norm() <= 0.5); + + // Again with nonlinear optimization + boost::optional actual3 = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); + // result from nonlinear (but non-robust optimization) is close to DLT and still off + EXPECT(assert_equal(*actual2, *actual3, 0.1)); + + // Again with nonlinear optimization, this time with robust loss + auto model = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2)); + boost::optional actual4 = triangulatePoint3( + poses, sharedCal, measurements, 1e-9, true, model); + // using the Huber loss we now have a quite small error!! nice! + EXPECT(assert_equal(landmark, *actual4, 0.05)); +} + +//****************************************************************************** +TEST(triangulation, fourPoses_robustNoiseModel) { + + Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); + PinholeCamera camera3(pose3, *sharedCal); + Point2 z3 = camera3.project(landmark); + + vector poses; + Point2Vector measurements; + poses += pose1, pose1, pose2, pose3; // 2 measurements from pose 1 + measurements += z1, z1, z2, z3; + + // noise free, so should give exactly the landmark + boost::optional actual = + triangulatePoint3(poses, sharedCal, measurements); + EXPECT(assert_equal(landmark, *actual, 1e-2)); + + // Add outlier + measurements.at(0) += Point2(100, 120); // very large pixel noise! + // add noise on other measurements: + measurements.at(1) += Point2(0.1, 0.2); // small noise + measurements.at(2) += Point2(0.2, 0.2); + measurements.at(3) += Point2(0.3, 0.1); + + // now estimate does not match landmark + boost::optional actual2 = // + triangulatePoint3(poses, sharedCal, measurements); + // DLT is surprisingly robust, but still off (actual error is around 0.17m): + EXPECT( (landmark - *actual2).norm() >= 0.1); + EXPECT( (landmark - *actual2).norm() <= 0.5); + + // Again with nonlinear optimization + boost::optional actual3 = + triangulatePoint3(poses, sharedCal, measurements, 1e-9, true); + // result from nonlinear (but non-robust optimization) is close to DLT and still off + EXPECT(assert_equal(*actual2, *actual3, 0.1)); + + // Again with nonlinear optimization, this time with robust loss + auto model = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), noiseModel::Unit::Create(2)); + boost::optional actual4 = triangulatePoint3( + poses, sharedCal, measurements, 1e-9, true, model); + // using the Huber loss we now have a quite small error!! nice! + EXPECT(assert_equal(landmark, *actual4, 0.05)); +} + +//****************************************************************************** +TEST(triangulation, fourPoses_distinct_Ks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -195,22 +284,23 @@ TEST( triangulation, fourPoses_distinct_Ks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - boost::optional actual = // - triangulatePoint3(cameras, measurements); + boost::optional actual = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual, 1e-2)); - // 2. Add some noise and try again: result should be ~ (4.995, 0.499167, 1.19814) + // 2. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) measurements.at(0) += Point2(0.1, 0.5); measurements.at(1) += Point2(-0.2, 0.3); - boost::optional actual2 = // - triangulatePoint3(cameras, measurements); + boost::optional actual2 = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *actual2, 1e-2)); // 3. Add a slightly rotated third camera above, again with measurement noise @@ -222,13 +312,13 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera3; measurements += z3 + Point2(0.1, -0.1); - boost::optional triangulated_3cameras = // - triangulatePoint3(cameras, measurements); + boost::optional triangulated_3cameras = // + triangulatePoint3(cameras, measurements); EXPECT(assert_equal(landmark, *triangulated_3cameras, 1e-2)); // Again with nonlinear optimization - boost::optional triangulated_3cameras_opt = triangulatePoint3(cameras, - measurements, 1e-9, true); + boost::optional triangulated_3cameras_opt = + triangulatePoint3(cameras, measurements, 1e-9, true); EXPECT(assert_equal(landmark, *triangulated_3cameras_opt, 1e-2)); // 4. Test failure: Add a 4th camera facing the wrong way @@ -241,13 +331,13 @@ TEST( triangulation, fourPoses_distinct_Ks) { cameras += camera4; measurements += Point2(400, 400); - CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), - TriangulationCheiralityException); + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements), + TriangulationCheiralityException); #endif } //****************************************************************************** -TEST( triangulation, outliersAndFarLandmarks) { +TEST(triangulation, outliersAndFarLandmarks) { Cal3_S2 K1(1500, 1200, 0, 640, 480); // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, K1); @@ -260,24 +350,29 @@ TEST( triangulation, outliersAndFarLandmarks) { Point2 z1 = camera1.project(landmark); Point2 z2 = camera2.project(landmark); - CameraSet > cameras; + CameraSet> cameras; Point2Vector measurements; cameras += camera1, camera2; measurements += z1, z2; - double landmarkDistanceThreshold = 10; // landmark is closer than that - TriangulationParameters params(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - TriangulationResult actual = triangulateSafe(cameras,measurements,params); + double landmarkDistanceThreshold = 10; // landmark is closer than that + TriangulationParameters params( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + TriangulationResult actual = triangulateSafe(cameras, measurements, params); EXPECT(assert_equal(landmark, *actual, 1e-2)); EXPECT(actual.valid()); - landmarkDistanceThreshold = 4; // landmark is farther than that - TriangulationParameters params2(1.0, false, landmarkDistanceThreshold); // all default except landmarkDistanceThreshold - actual = triangulateSafe(cameras,measurements,params2); + landmarkDistanceThreshold = 4; // landmark is farther than that + TriangulationParameters params2( + 1.0, false, landmarkDistanceThreshold); // all default except + // landmarkDistanceThreshold + actual = triangulateSafe(cameras, measurements, params2); EXPECT(actual.farPoint()); - // 3. Add a slightly rotated third camera above with a wrong measurement (OUTLIER) + // 3. Add a slightly rotated third camera above with a wrong measurement + // (OUTLIER) Pose3 pose3 = pose1 * Pose3(Rot3::Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -.1)); Cal3_S2 K3(700, 500, 0, 640, 480); PinholeCamera camera3(pose3, K3); @@ -286,21 +381,23 @@ TEST( triangulation, outliersAndFarLandmarks) { cameras += camera3; measurements += z3 + Point2(10, -10); - landmarkDistanceThreshold = 10; // landmark is closer than that - double outlierThreshold = 100; // loose, the outlier is going to pass - TriangulationParameters params3(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params3); + landmarkDistanceThreshold = 10; // landmark is closer than that + double outlierThreshold = 100; // loose, the outlier is going to pass + TriangulationParameters params3(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params3); EXPECT(actual.valid()); // now set stricter threshold for outlier rejection - outlierThreshold = 5; // tighter, the outlier is not going to pass - TriangulationParameters params4(1.0, false, landmarkDistanceThreshold,outlierThreshold); - actual = triangulateSafe(cameras,measurements,params4); + outlierThreshold = 5; // tighter, the outlier is not going to pass + TriangulationParameters params4(1.0, false, landmarkDistanceThreshold, + outlierThreshold); + actual = triangulateSafe(cameras, measurements, params4); EXPECT(actual.outlier()); } //****************************************************************************** -TEST( triangulation, twoIdenticalPoses) { +TEST(triangulation, twoIdenticalPoses) { // create first camera. Looking along X-axis, 1 meter above ground plane (x-y) PinholeCamera camera1(pose1, *sharedCal); @@ -313,12 +410,12 @@ TEST( triangulation, twoIdenticalPoses) { poses += pose1, pose1; measurements += z1, z1; - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, onePose) { +TEST(triangulation, onePose) { // we expect this test to fail with a TriangulationUnderconstrainedException // because there's only one camera observation @@ -326,28 +423,26 @@ TEST( triangulation, onePose) { Point2Vector measurements; poses += Pose3(); - measurements += Point2(0,0); + measurements += Point2(0, 0); - CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), - TriangulationUnderconstrainedException); + CHECK_EXCEPTION(triangulatePoint3(poses, sharedCal, measurements), + TriangulationUnderconstrainedException); } //****************************************************************************** -TEST( triangulation, StereotriangulateNonlinear ) { - - auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, 508.835, 0.0699612); +TEST(triangulation, StereotriangulateNonlinear) { + auto stereoK = boost::make_shared(1733.75, 1733.75, 0, 689.645, + 508.835, 0.0699612); // two camera poses m1, m2 Matrix4 m1, m2; - m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, - 0.592783835, -0.77156583, 0.230856632, 66.2186159, - 0.116517574, -0.201470143, -0.9725393, -4.28382528, - 0, 0, 0, 1; + m1 << 0.796888717, 0.603404026, -0.0295271487, 46.6673779, 0.592783835, + -0.77156583, 0.230856632, 66.2186159, 0.116517574, -0.201470143, + -0.9725393, -4.28382528, 0, 0, 0, 1; - m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, - -0.29277519, 0.947083213, 0.131587097, 65.843136, - -0.0206094928, 0.131334858, -0.991123524, -4.3525033, - 0, 0, 0, 1; + m2 << -0.955959025, -0.29288915, -0.0189328569, 45.7169799, -0.29277519, + 0.947083213, 0.131587097, 65.843136, -0.0206094928, 0.131334858, + -0.991123524, -4.3525033, 0, 0, 0, 1; typedef CameraSet Cameras; Cameras cameras; @@ -358,18 +453,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { measurements += StereoPoint2(226.936, 175.212, 424.469); measurements += StereoPoint2(339.571, 285.547, 669.973); - Point3 initial = Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 + Point3 initial = + Point3(46.0536958, 66.4621179, -6.56285929); // error: 96.5715555191 - Point3 actual = triangulateNonlinear(cameras, measurements, initial); + Point3 actual = triangulateNonlinear(cameras, measurements, initial); - Point3 expected(46.0484569, 66.4710686, -6.55046613); // error: 0.763510644187 + Point3 expected(46.0484569, 66.4710686, + -6.55046613); // error: 0.763510644187 EXPECT(assert_equal(expected, actual, 1e-4)); - // regular stereo factor comparison - expect very similar result as above { - typedef GenericStereoFactor StereoFactor; + typedef GenericStereoFactor StereoFactor; Values values; values.insert(Symbol('x', 1), Pose3(m1)); @@ -378,17 +474,19 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared(measurements[0], unit, Symbol('x',1), Symbol('l',1), stereoK); - graph.emplace_shared(measurements[1], unit, Symbol('x',2), Symbol('l',1), stereoK); + graph.emplace_shared(measurements[0], unit, Symbol('x', 1), + Symbol('l', 1), stereoK); + graph.emplace_shared(measurements[1], unit, Symbol('x', 2), + Symbol('l', 1), stereoK); const SharedDiagonal posePrior = noiseModel::Isotropic::Sigma(6, 1e-9); - graph.addPrior(Symbol('x',1), Pose3(m1), posePrior); - graph.addPrior(Symbol('x',2), Pose3(m2), posePrior); + graph.addPrior(Symbol('x', 1), Pose3(m1), posePrior); + graph.addPrior(Symbol('x', 2), Pose3(m2), posePrior); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use Triangulation Factor directly - expect same result as above @@ -399,13 +497,15 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - graph.emplace_shared >(cameras[0], measurements[0], unit, Symbol('l',1)); - graph.emplace_shared >(cameras[1], measurements[1], unit, Symbol('l',1)); + graph.emplace_shared>( + cameras[0], measurements[0], unit, Symbol('l', 1)); + graph.emplace_shared>( + cameras[1], measurements[1], unit, Symbol('l', 1)); LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } // use ExpressionFactor - expect same result as above @@ -416,11 +516,13 @@ TEST( triangulation, StereotriangulateNonlinear ) { NonlinearFactorGraph graph; static SharedNoiseModel unit(noiseModel::Unit::Create(3)); - Expression point_(Symbol('l',1)); + Expression point_(Symbol('l', 1)); Expression camera0_(cameras[0]); Expression camera1_(cameras[1]); - Expression project0_(camera0_, &StereoCamera::project2, point_); - Expression project1_(camera1_, &StereoCamera::project2, point_); + Expression project0_(camera0_, &StereoCamera::project2, + point_); + Expression project1_(camera1_, &StereoCamera::project2, + point_); graph.addExpressionFactor(unit, measurements[0], project0_); graph.addExpressionFactor(unit, measurements[1], project1_); @@ -428,10 +530,172 @@ TEST( triangulation, StereotriangulateNonlinear ) { LevenbergMarquardtOptimizer optimizer(graph, values); Values result = optimizer.optimize(); - EXPECT(assert_equal(expected, result.at(Symbol('l',1)), 1e-4)); + EXPECT(assert_equal(expected, result.at(Symbol('l', 1)), 1e-4)); } } +//****************************************************************************** +// Simple test with a well-behaved two camera situation +TEST(triangulation, twoPoses_sphericalCamera) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + SphericalCamera cam1(pose1); + SphericalCamera cam2(pose2); + Unit3 u1 = cam1.project(landmark); + Unit3 u2 = cam2.project(landmark); + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + // 1. Test linear triangulation via DLT + auto projection_matrices = projectionMatricesFromCameras(cameras); + Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 2. Test nonlinear triangulation + point = triangulateNonlinear(cameras, measurements, point); + EXPECT(assert_equal(landmark, point, 1e-7)); + + // 3. Test simple DLT, now within triangulatePoint3 + bool optimize = false; + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual1, 1e-7)); + + // 4. test with optimization on, same answer + optimize = true; + boost::optional actual2 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmark, *actual2, 1e-7)); + + // 5. Add some noise and try again: result should be ~ (4.995, + // 0.499167, 1.19814) + measurements.at(0) = + u1.retract(Vector2(0.01, 0.05)); // note: perturbation smaller for Unit3 + measurements.at(1) = u2.retract(Vector2(-0.02, 0.03)); + optimize = false; + boost::optional actual3 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654319, 1.48192), *actual3, 1e-3)); + + // 6. Now with optimization on + optimize = true; + boost::optional actual4 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(Point3(5.9432, 0.654334, 1.48192), *actual4, 1e-3)); +} + +//****************************************************************************** +TEST(triangulation, twoPoses_sphericalCamera_extremeFOV) { + vector poses; + std::vector measurements; + + // Project landmark into two cameras and triangulate + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(2.0, 0.0, 0.0)); // 2m in front of poseA + Point3 landmarkL( + 1.0, -1.0, + 0.0); // 1m to the right of both cameras, in front of poseA, behind poseB + SphericalCamera cam1(poseA); + SphericalCamera cam2(poseB); + Unit3 u1 = cam1.project(landmarkL); + Unit3 u2 = cam2.project(landmarkL); + + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, 1.0)), u1, + 1e-7)); // in front and to the right of PoseA + EXPECT(assert_equal(Unit3(Point3(1.0, 0.0, -1.0)), u2, + 1e-7)); // behind and to the right of PoseB + + poses += pose1, pose2; + measurements += u1, u2; + + CameraSet cameras; + cameras.push_back(cam1); + cameras.push_back(cam2); + + double rank_tol = 1e-9; + + { + // 1. Test simple DLT, when 1 point is behind spherical camera + bool optimize = false; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } + { + // 2. test with optimization on, same answer + bool optimize = true; +#ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION + CHECK_EXCEPTION(triangulatePoint3(cameras, measurements, + rank_tol, optimize), + TriangulationCheiralityException); +#else // otherwise project should not throw the exception + boost::optional actual1 = // + triangulatePoint3(cameras, measurements, rank_tol, + optimize); + EXPECT(assert_equal(landmarkL, *actual1, 1e-7)); +#endif + } +} + +//****************************************************************************** +TEST(triangulation, reprojectionError_cameraComparison) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Point3 landmarkL(5.0, 0.0, 0.0); // 1m in front of poseA + SphericalCamera sphericalCamera(poseA); + Unit3 u = sphericalCamera.project(landmarkL); + + static Cal3_S2::shared_ptr sharedK(new Cal3_S2(60, 640, 480)); + PinholePose pinholeCamera(poseA, sharedK); + Vector2 px = pinholeCamera.project(landmarkL); + + // add perturbation and compare error in both cameras + Vector2 px_noise(1.0, 2.0); // px perturbation vertically and horizontally + Vector2 measured_px = px + px_noise; + Vector2 measured_px_calibrated = sharedK->calibrate(measured_px); + Unit3 measured_u = + Unit3(measured_px_calibrated[0], measured_px_calibrated[1], 1.0); + Unit3 expected_measured_u = + Unit3(px_noise[0] / sharedK->fx(), px_noise[1] / sharedK->fy(), 1.0); + EXPECT(assert_equal(expected_measured_u, measured_u, 1e-7)); + + Vector2 actualErrorPinhole = + pinholeCamera.reprojectionError(landmarkL, measured_px); + Vector2 expectedErrorPinhole = Vector2(-px_noise[0], -px_noise[1]); + EXPECT(assert_equal(expectedErrorPinhole, actualErrorPinhole, + 1e-7)); //- sign due to definition of error + + Vector2 actualErrorSpherical = + sphericalCamera.reprojectionError(landmarkL, measured_u); + // expectedError: not easy to calculate, since it involves the unit3 basis + Vector2 expectedErrorSpherical(-0.00360842, 0.00180419); + EXPECT(assert_equal(expectedErrorSpherical, actualErrorSpherical, 1e-7)); +} + //****************************************************************************** int main() { TestResult tr; diff --git a/gtsam/geometry/triangulation.cpp b/gtsam/geometry/triangulation.cpp index a5d2e04cd..026afef24 100644 --- a/gtsam/geometry/triangulation.cpp +++ b/gtsam/geometry/triangulation.cpp @@ -53,15 +53,57 @@ Vector4 triangulateHomogeneousDLT( return v; } -Point3 triangulateDLT(const std::vector>& projection_matrices, +Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // number of cameras + size_t m = projection_matrices.size(); + + // Allocate DLT matrix + Matrix A = Matrix::Zero(m * 2, 4); + + for (size_t i = 0; i < m; i++) { + size_t row = i * 2; + const Matrix34& projection = projection_matrices.at(i); + const Point3& p = measurements.at(i).point3(); // to get access to x,y,z of the bearing vector + + // build system of equations + A.row(row) = p.x() * projection.row(2) - p.z() * projection.row(0); + A.row(row + 1) = p.y() * projection.row(2) - p.z() * projection.row(1); + } + int rank; + double error; + Vector v; + boost::tie(rank, error, v) = DLT(A, rank_tol); + + if (rank < 3) + throw(TriangulationUnderconstrainedException()); + + return v; +} + +Point3 triangulateDLT( + const std::vector>& projection_matrices, const Point2Vector& measurements, double rank_tol) { - Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, rank_tol); - + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); // Create 3D point from homogeneous coordinates return Point3(v.head<3>() / v[3]); } +Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol) { + + // contrary to previous triangulateDLT, this is now taking Unit3 inputs + Vector4 v = triangulateHomogeneousDLT(projection_matrices, measurements, + rank_tol); + // Create 3D point from homogeneous coordinates + return Point3(v.head<3>() / v[3]); +} + /// /** * Optimize for triangulation @@ -71,7 +113,7 @@ Point3 triangulateDLT(const std::vector #include #include +#include #include #include #include @@ -59,6 +61,18 @@ GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( const std::vector>& projection_matrices, const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * Same math as Hartley and Zisserman, 2nd Ed., page 312, but with unit-norm bearing vectors + * (contrarily to pinhole projection, the z entry is not assumed to be 1 as in Hartley and Zisserman) + * @param projection_matrices Projection matrices (K*P^-1) + * @param measurements Unit3 bearing measurements + * @param rank_tol SVD rank tolerance + * @return Triangulated point, in homogeneous coordinates + */ +GTSAM_EXPORT Vector4 triangulateHomogeneousDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, double rank_tol = 1e-9); + /** * DLT triangulation: See Hartley and Zisserman, 2nd Ed., page 312 * @param projection_matrices Projection matrices (K*P^-1) @@ -71,6 +85,14 @@ GTSAM_EXPORT Point3 triangulateDLT( const Point2Vector& measurements, double rank_tol = 1e-9); +/** + * overload of previous function to work with Unit3 (projected to canonical camera) + */ +GTSAM_EXPORT Point3 triangulateDLT( + const std::vector>& projection_matrices, + const std::vector& measurements, + double rank_tol = 1e-9); + /** * Create a factor graph with projection factors from poses and one calibration * @param poses Camera poses @@ -84,18 +106,18 @@ template std::pair triangulationGraph( const std::vector& poses, boost::shared_ptr sharedCal, const Point2Vector& measurements, Key landmarkKey, - const Point3& initialEstimate) { + const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { Values values; values.insert(landmarkKey, initialEstimate); // Initial landmark value NonlinearFactorGraph graph; static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); - static SharedNoiseModel prior_model(noiseModel::Isotropic::Sigma(6, 1e-6)); for (size_t i = 0; i < measurements.size(); i++) { const Pose3& pose_i = poses[i]; typedef PinholePose Camera; Camera camera_i(pose_i, sharedCal); graph.emplace_shared > // - (camera_i, measurements[i], unit2, landmarkKey); + (camera_i, measurements[i], model? model : unit2, landmarkKey); } return std::make_pair(graph, values); } @@ -113,7 +135,8 @@ template std::pair triangulationGraph( const CameraSet& cameras, const typename CAMERA::MeasurementVector& measurements, Key landmarkKey, - const Point3& initialEstimate) { + const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { Values values; values.insert(landmarkKey, initialEstimate); // Initial landmark value NonlinearFactorGraph graph; @@ -122,7 +145,7 @@ std::pair triangulationGraph( for (size_t i = 0; i < measurements.size(); i++) { const CAMERA& camera_i = cameras[i]; graph.emplace_shared > // - (camera_i, measurements[i], unit, landmarkKey); + (camera_i, measurements[i], model? model : unit, landmarkKey); } return std::make_pair(graph, values); } @@ -148,13 +171,14 @@ GTSAM_EXPORT Point3 optimize(const NonlinearFactorGraph& graph, template Point3 triangulateNonlinear(const std::vector& poses, boost::shared_ptr sharedCal, - const Point2Vector& measurements, const Point3& initialEstimate) { + const Point2Vector& measurements, const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { // Create a factor graph and initial values Values values; NonlinearFactorGraph graph; boost::tie(graph, values) = triangulationGraph // - (poses, sharedCal, measurements, Symbol('p', 0), initialEstimate); + (poses, sharedCal, measurements, Symbol('p', 0), initialEstimate, model); return optimize(graph, values, Symbol('p', 0)); } @@ -169,37 +193,39 @@ Point3 triangulateNonlinear(const std::vector& poses, template Point3 triangulateNonlinear( const CameraSet& cameras, - const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate) { + const typename CAMERA::MeasurementVector& measurements, const Point3& initialEstimate, + const SharedNoiseModel& model = nullptr) { // Create a factor graph and initial values Values values; NonlinearFactorGraph graph; boost::tie(graph, values) = triangulationGraph // - (cameras, measurements, Symbol('p', 0), initialEstimate); + (cameras, measurements, Symbol('p', 0), initialEstimate, model); return optimize(graph, values, Symbol('p', 0)); } -/** - * Create a 3*4 camera projection matrix from calibration and pose. - * Functor for partial application on calibration - * @param pose The camera pose - * @param cal The calibration - * @return Returns a Matrix34 - */ +template +std::vector> +projectionMatricesFromCameras(const CameraSet &cameras) { + std::vector> projection_matrices; + for (const CAMERA &camera: cameras) { + projection_matrices.push_back(camera.cameraProjectionMatrix()); + } + return projection_matrices; +} + +// overload, assuming pinholePose template -struct CameraProjectionMatrix { - CameraProjectionMatrix(const CALIBRATION& calibration) : - K_(calibration.K()) { +std::vector> projectionMatricesFromPoses( + const std::vector &poses, boost::shared_ptr sharedCal) { + std::vector> projection_matrices; + for (size_t i = 0; i < poses.size(); i++) { + PinholePose camera(poses.at(i), sharedCal); + projection_matrices.push_back(camera.cameraProjectionMatrix()); } - Matrix34 operator()(const Pose3& pose) const { - return K_ * (pose.inverse().matrix()).block<3, 4>(0, 0); - } -private: - const Matrix3 K_; -public: - GTSAM_MAKE_ALIGNED_OPERATOR_NEW -}; + return projection_matrices; +} /** * Function to triangulate 3D landmark point from an arbitrary number @@ -217,17 +243,15 @@ template Point3 triangulatePoint3(const std::vector& poses, boost::shared_ptr sharedCal, const Point2Vector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { assert(poses.size() == measurements.size()); if (poses.size() < 2) throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - CameraProjectionMatrix createP(*sharedCal); // partially apply - for(const Pose3& pose: poses) - projection_matrices.push_back(createP(pose)); + auto projection_matrices = projectionMatricesFromPoses(poses, sharedCal); // Triangulate linearly Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); @@ -235,7 +259,7 @@ Point3 triangulatePoint3(const std::vector& poses, // Then refine using non-linear optimization if (optimize) point = triangulateNonlinear // - (poses, sharedCal, measurements, point); + (poses, sharedCal, measurements, point, model); #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION // verify that the triangulated point lies in front of all cameras @@ -265,7 +289,8 @@ template Point3 triangulatePoint3( const CameraSet& cameras, const typename CAMERA::MeasurementVector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { size_t m = cameras.size(); assert(measurements.size() == m); @@ -274,16 +299,12 @@ Point3 triangulatePoint3( throw(TriangulationUnderconstrainedException()); // construct projection matrices from poses & calibration - std::vector> projection_matrices; - for(const CAMERA& camera: cameras) - projection_matrices.push_back( - CameraProjectionMatrix(camera.calibration())( - camera.pose())); + auto projection_matrices = projectionMatricesFromCameras(cameras); Point3 point = triangulateDLT(projection_matrices, measurements, rank_tol); // The n refine using non-linear optimization if (optimize) - point = triangulateNonlinear(cameras, measurements, point); + point = triangulateNonlinear(cameras, measurements, point, model); #ifdef GTSAM_THROW_CHEIRALITY_EXCEPTION // verify that the triangulated point lies in front of all cameras @@ -302,9 +323,10 @@ template Point3 triangulatePoint3( const CameraSet >& cameras, const Point2Vector& measurements, double rank_tol = 1e-9, - bool optimize = false) { + bool optimize = false, + const SharedNoiseModel& model = nullptr) { return triangulatePoint3 > // - (cameras, measurements, rank_tol, optimize); + (cameras, measurements, rank_tol, optimize, model); } struct GTSAM_EXPORT TriangulationParameters { @@ -326,20 +348,25 @@ struct GTSAM_EXPORT TriangulationParameters { */ double dynamicOutlierRejectionThreshold; + SharedNoiseModel noiseModel; ///< used in the nonlinear triangulation + /** * Constructor * @param rankTol tolerance used to check if point triangulation is degenerate * @param enableEPI if true refine triangulation with embedded LM iterations * @param landmarkDistanceThreshold flag as degenerate if point further than this * @param dynamicOutlierRejectionThreshold or if average error larger than this + * @param noiseModel noise model to use during nonlinear triangulation * */ TriangulationParameters(const double _rankTolerance = 1.0, const bool _enableEPI = false, double _landmarkDistanceThreshold = -1, - double _dynamicOutlierRejectionThreshold = -1) : + double _dynamicOutlierRejectionThreshold = -1, + const SharedNoiseModel& _noiseModel = nullptr) : rankTolerance(_rankTolerance), enableEPI(_enableEPI), // landmarkDistanceThreshold(_landmarkDistanceThreshold), // - dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold) { + dynamicOutlierRejectionThreshold(_dynamicOutlierRejectionThreshold), + noiseModel(_noiseModel){ } // stream to output @@ -351,6 +378,7 @@ struct GTSAM_EXPORT TriangulationParameters { << std::endl; os << "dynamicOutlierRejectionThreshold = " << p.dynamicOutlierRejectionThreshold << std::endl; + os << "noise model" << std::endl; return os; } @@ -453,8 +481,9 @@ TriangulationResult triangulateSafe(const CameraSet& cameras, else // We triangulate the 3D position of the landmark try { - Point3 point = triangulatePoint3(cameras, measured, - params.rankTolerance, params.enableEPI); + Point3 point = + triangulatePoint3(cameras, measured, params.rankTolerance, + params.enableEPI, params.noiseModel); // Check landmark distance and re-projection errors to avoid outliers size_t i = 0; @@ -474,8 +503,8 @@ TriangulationResult triangulateSafe(const CameraSet& cameras, #endif // Check reprojection error if (params.dynamicOutlierRejectionThreshold > 0) { - const Point2& zi = measured.at(i); - Point2 reprojectionError(camera.project(point) - zi); + const typename CAMERA::Measurement& zi = measured.at(i); + Point2 reprojectionError = camera.reprojectionError(point, zi); maxReprojError = std::max(maxReprojError, reprojectionError.norm()); } i += 1; @@ -503,6 +532,6 @@ using CameraSetCal3Bundler = CameraSet>; using CameraSetCal3_S2 = CameraSet>; using CameraSetCal3Fisheye = CameraSet>; using CameraSetCal3Unified = CameraSet>; - +using CameraSetSpherical = CameraSet; } // \namespace gtsam diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 67c3278a3..d4e959c3d 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -39,9 +39,6 @@ class KeyList { void remove(size_t key); void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastSet @@ -67,9 +64,6 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a vector @@ -91,9 +85,6 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; - - // enable pickling in python - void pickle() const; }; // Actually a FastMap @@ -165,6 +156,7 @@ gtsam::Values allPose2s(gtsam::Values& values); Matrix extractPose2(const gtsam::Values& values); gtsam::Values allPose3s(gtsam::Values& values); Matrix extractPose3(const gtsam::Values& values); +Matrix extractVectors(const gtsam::Values& values, char c); void perturbPoint2(gtsam::Values& values, double sigma, int seed = 42u); void perturbPose2(gtsam::Values& values, double sigmaT, double sigmaR, int seed = 42u); diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index a73762258..afde5498d 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -10,46 +10,76 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include #include +#include #include #include +#include namespace gtsam { /* ************************************************************************* */ template -void BayesNet::print( - const std::string& s, const KeyFormatter& formatter) const { +void BayesNet::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } /* ************************************************************************* */ template -void BayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; +void BayesNet::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.digraphPreamble(&os); + // Create nodes for each variable in the graph + for (Key key : this->keys()) { + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); + } + os << "\n"; + + // Reverse order as typically Bayes nets stored in reverse topological sort. for (auto conditional : boost::adaptors::reverse(*this)) { - typename CONDITIONAL::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename CONDITIONAL::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; + auto frontals = conditional->frontals(); + const Key me = frontals.front(); + auto parents = conditional->parents(); + for (const Key& p : parents) + os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n"; } - of << "}"; + os << "}"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string BayesNet::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::stringstream ss; + dot(ss, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +template +void BayesNet::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter, writer); of.close(); } +/* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 938278d5a..219864c54 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -10,67 +10,79 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include - #include +#include +#include + namespace gtsam { - /** - * A BayesNet is a tree of conditionals, stored in elimination order. - * - * todo: how to handle Bayes nets with an optimize function? Currently using global functions. - * \nosubgrouping - */ - template - class BayesNet : public FactorGraph { +/** + * A BayesNet is a tree of conditionals, stored in elimination order. + * @addtogroup inference + */ +template +class BayesNet : public FactorGraph { + private: + typedef FactorGraph Base; - private: + public: + typedef typename boost::shared_ptr + sharedConditional; ///< A shared pointer to a conditional - typedef FactorGraph Base; + protected: + /// @name Standard Constructors + /// @{ - public: - typedef typename boost::shared_ptr sharedConditional; ///< A shared pointer to a conditional + /** Default constructor as an empty BayesNet */ + BayesNet() {} - protected: - /// @name Standard Constructors - /// @{ + /** Construct from iterator over conditionals */ + template + BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} - /** Default constructor as an empty BayesNet */ - BayesNet() {}; + /// @} - /** Construct from iterator over conditionals */ - template - BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + public: + /// @name Testable + /// @{ - /// @} + /** print out graph */ + void print( + const std::string& s = "BayesNet", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - public: - /// @name Testable - /// @{ + /// @} - /** print out graph */ - void print( - const std::string& s = "BayesNet", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /// @name Graph Display + /// @{ - /// @} + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// @name Standard Interface - /// @{ + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - void saveGraph(const std::string& s, - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - }; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; -} + /// @} +}; + +} // namespace gtsam #include diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 5b53a5719..9b937fefb 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -63,20 +63,40 @@ namespace gtsam { } /* ************************************************************************* */ - template - void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { - if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); - std::ofstream of(s.c_str()); - of<< "digraph G{\n"; - for(const sharedClique& root: roots_) - saveGraph(of, root, keyFormatter); - of<<"}"; + template + void BayesTree::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + if (roots_.empty()) + throw std::invalid_argument( + "the root of Bayes tree has not been initialized!"); + os << "digraph G{\n"; + for (const sharedClique& root : roots_) dot(os, root, keyFormatter); + os << "}"; + std::flush(os); + } + + /* ************************************************************************* */ + template + std::string BayesTree::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } /* ************************************************************************* */ - template - void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { + template + void BayesTree::dot(std::ostream& s, sharedClique clique, + const KeyFormatter& indexFormatter, + int parentnum) const { static int num = 0; bool first = true; std::stringstream out; @@ -107,7 +127,7 @@ namespace gtsam { for (sharedClique c : clique->children) { num++; - saveGraph(s, c, indexFormatter, parentnum); + dot(s, c, indexFormatter, parentnum); } } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 7199da0ad..a32b3ce22 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -19,6 +19,8 @@ #pragma once +#include + #include #include #include @@ -141,7 +143,7 @@ namespace gtsam { const Nodes& nodes() const { return nodes_; } /** Access node by variable */ - const sharedNode operator[](Key j) const { return nodes_.at(j); } + sharedClique operator[](Key j) const { return nodes_.at(j); } /** return root cliques */ const Roots& roots() const { return roots_; } @@ -180,13 +182,20 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /** - * Read only with side effects - */ + /// @name Graph Display + /// @{ - /** saves the Tree to a text file in GraphViz format */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ @@ -234,8 +243,8 @@ namespace gtsam { protected: /** private helper method for saving the Tree to a text file in GraphViz format */ - void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, - int parentnum = 0) const; + void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, + int parentnum = 0) const; /** Gather data on a single clique */ void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const; @@ -247,7 +256,7 @@ namespace gtsam { void fillNodesIndex(const sharedClique& subtree); // Friend JunctionTree because it directly fills roots and nodes index. - template friend class EliminatableClusterTree; + template friend class EliminatableClusterTree; private: /** Serialization function */ diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 295122879..7594da78d 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -25,15 +25,12 @@ namespace gtsam { /** - * TODO: Update comments. The following comments are out of date!!! - * - * Base class for conditional densities, templated on KEY type. This class - * provides storage for the keys involved in a conditional, and iterators and + * Base class for conditional densities. This class iterators and * access to the frontal and separator keys. * * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * to the associated factor type and shared_ptr type of the derived class. See - * IndexConditional and GaussianConditional for examples. + * SymbolicConditional and GaussianConditional for examples. * \nosubgrouping */ template diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp new file mode 100644 index 000000000..ad5330575 --- /dev/null +++ b/gtsam/inference/DotWriter.cpp @@ -0,0 +1,129 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.cpp + * @brief Graphviz formatting for factor graphs. + * @author Frank Dellaert + * @date December, 2021 + */ + +#include + +#include +#include + +#include + +using namespace std; + +namespace gtsam { + +void DotWriter::graphPreamble(ostream* os) const { + *os << "graph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::digraphPreamble(ostream* os) const { + *os << "digraph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) const { + // Label the node with the label from the KeyFormatter + *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) + << "\""; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + if (boxes.count(key)) { + *os << ", shape=box"; + } + *os << "];\n"; +} + +void DotWriter::DrawFactor(size_t i, const boost::optional& position, + ostream* os) { + *os << " factor" << i << "[label=\"\", shape=point"; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +static void ConnectVariables(Key key1, Key key2, + const KeyFormatter& keyFormatter, ostream* os) { + *os << " var" << keyFormatter(key1) << "--" + << "var" << keyFormatter(key2) << ";\n"; +} + +static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, + size_t i, ostream* os) { + *os << " var" << keyFormatter(key) << "--" + << "factor" << i << ";\n"; +} + +/// Return variable position or none +boost::optional DotWriter::variablePos(Key key) const { + boost::optional result = boost::none; + + // Check position hint + Symbol symbol(key); + auto hint = positionHints.find(symbol.chr()); + if (hint != positionHints.end()) + result.reset(Vector2(symbol.index(), hint->second)); + + // Override with explicit position, if given. + auto pos = variablePositions.find(key); + if (pos != variablePositions.end()) + result.reset(pos->second); + + return result; +} + +void DotWriter::processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) const { + if (plotFactorPoints) { + if (binaryEdges && keys.size() == 2) { + ConnectVariables(keys[0], keys[1], keyFormatter, os); + } else { + // Create dot for the factor. + if (!position && factorPositions.count(i)) + DrawFactor(i, factorPositions.at(i), os); + else + DrawFactor(i, position, os); + + // Make factor-variable connections + if (connectKeysToFactor) { + for (Key key : keys) { + ConnectVariableFactor(key, keyFormatter, i, os); + } + } + } + } else { + // just connect variables in a clique + for (Key key1 : keys) { + for (Key key2 : keys) { + if (key2 > key1) { + ConnectVariables(key1, key2, keyFormatter, os); + } + } + } + } +} + +} // namespace gtsam diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h new file mode 100644 index 000000000..23302ee60 --- /dev/null +++ b/gtsam/inference/DotWriter.h @@ -0,0 +1,100 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.h + * @brief Graphviz formatter + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * @brief DotWriter is a helper class for writing graphviz .dot files. + * @addtogroup inference + */ +struct GTSAM_EXPORT DotWriter { + double figureWidthInches; ///< The figure width on paper in inches + double figureHeightInches; ///< The figure height on paper in inches + bool plotFactorPoints; ///< Plots each factor as a dot between the variables + bool connectKeysToFactor; ///< Draw a line from each key within a factor to + ///< the dot of the factor + bool binaryEdges; ///< just use non-dotted edges for binary factors + + /** + * Variable positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map variablePositions; + + /** + * The position hints allow one to use symbol character and index to specify + * position. Unless variable positions are specified, if a hint is present for + * a given symbol, it will be used to calculate the positions as (index,hint). + */ + std::map positionHints; + + /** A set of keys that will be displayed as a box */ + std::set boxes; + + /** + * Factor positions can be optionally specified and will be included in the + * dot file with a "!' sign, so "neato" can use it to render them. + */ + std::map factorPositions; + + explicit DotWriter(double figureWidthInches = 5, + double figureHeightInches = 5, + bool plotFactorPoints = true, + bool connectKeysToFactor = true, bool binaryEdges = false) + : figureWidthInches(figureWidthInches), + figureHeightInches(figureHeightInches), + plotFactorPoints(plotFactorPoints), + connectKeysToFactor(connectKeysToFactor), + binaryEdges(binaryEdges) {} + + /// Write out preamble for graph, including size. + void graphPreamble(std::ostream* os) const; + + /// Write out preamble for digraph, including size. + void digraphPreamble(std::ostream* os) const; + + /// Create a variable dot fragment. + void drawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os) const; + + /// Create factor dot. + static void DrawFactor(size_t i, const boost::optional& position, + std::ostream* os); + + /// Return variable position or none + boost::optional variablePos(Key key) const; + + /// Draw a single factor, specified by its index i and its variable keys. + void processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os) const; +}; + +} // namespace gtsam diff --git a/gtsam/inference/EliminateableFactorGraph-inst.h b/gtsam/inference/EliminateableFactorGraph-inst.h index 4157336d1..35e7505c9 100644 --- a/gtsam/inference/EliminateableFactorGraph-inst.h +++ b/gtsam/inference/EliminateableFactorGraph-inst.h @@ -78,29 +78,31 @@ namespace gtsam { } /* ************************************************************************* */ - template - boost::shared_ptr::BayesTreeType> - EliminateableFactorGraph::eliminateMultifrontal( - OptionalOrderingType orderingType, const Eliminate& function, - OptionalVariableIndex variableIndex) const - { - if(!variableIndex) { - // If no VariableIndex provided, compute one and call this function again IMPORTANT: we check - // for no variable index first so that it's always computed if we need to call COLAMD because - // no Ordering is provided. When removing optional from VariableIndex, create VariableIndex - // before creating ordering. + template + boost::shared_ptr< + typename EliminateableFactorGraph::BayesTreeType> + EliminateableFactorGraph::eliminateMultifrontal( + OptionalOrderingType orderingType, const Eliminate& function, + OptionalVariableIndex variableIndex) const { + if (!variableIndex) { + // If no VariableIndex provided, compute one and call this function again + // IMPORTANT: we check for no variable index first so that it's always + // computed if we need to call COLAMD because no Ordering is provided. + // When removing optional from VariableIndex, create VariableIndex before + // creating ordering. VariableIndex computedVariableIndex(asDerived()); - return eliminateMultifrontal(function, computedVariableIndex, orderingType); - } - else { - // Compute an ordering and call this function again. We are guaranteed to have a - // VariableIndex already here because we computed one if needed in the previous 'if' block. + return eliminateMultifrontal(orderingType, function, + computedVariableIndex); + } else { + // Compute an ordering and call this function again. We are guaranteed to + // have a VariableIndex already here because we computed one if needed in + // the previous 'if' block. if (orderingType == Ordering::METIS) { Ordering computedOrdering = Ordering::Metis(asDerived()); - return eliminateMultifrontal(computedOrdering, function, variableIndex, orderingType); + return eliminateMultifrontal(computedOrdering, function, variableIndex); } else { Ordering computedOrdering = Ordering::Colamd(*variableIndex); - return eliminateMultifrontal(computedOrdering, function, variableIndex, orderingType); + return eliminateMultifrontal(computedOrdering, function, variableIndex); } } } @@ -273,7 +275,7 @@ namespace gtsam { else { // No ordering was provided for the unmarginalized variables, so order them with COLAMD. - return factorGraph->eliminateSequential(function); + return factorGraph->eliminateSequential(Ordering::COLAMD, function); } } } @@ -340,7 +342,7 @@ namespace gtsam { else { // No ordering was provided for the unmarginalized variables, so order them with COLAMD. - return factorGraph->eliminateMultifrontal(function); + return factorGraph->eliminateMultifrontal(Ordering::COLAMD, function); } } } diff --git a/gtsam/inference/EliminateableFactorGraph.h b/gtsam/inference/EliminateableFactorGraph.h index edc4883e7..c904d2f7f 100644 --- a/gtsam/inference/EliminateableFactorGraph.h +++ b/gtsam/inference/EliminateableFactorGraph.h @@ -288,8 +288,9 @@ namespace gtsam { FactorGraphType& asDerived() { return static_cast(*this); } public: - /** \deprecated ordering and orderingType shouldn't both be specified */ - boost::shared_ptr eliminateSequential( + #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated ordering and orderingType shouldn't both be specified */ + boost::shared_ptr GTSAM_DEPRECATED eliminateSequential( const Ordering& ordering, const Eliminate& function, OptionalVariableIndex variableIndex, @@ -297,16 +298,16 @@ namespace gtsam { return eliminateSequential(ordering, function, variableIndex); } - /** \deprecated orderingType specified first for consistency */ - boost::shared_ptr eliminateSequential( + /** @deprecated orderingType specified first for consistency */ + boost::shared_ptr GTSAM_DEPRECATED eliminateSequential( const Eliminate& function, OptionalVariableIndex variableIndex = boost::none, OptionalOrderingType orderingType = boost::none) const { return eliminateSequential(orderingType, function, variableIndex); } - /** \deprecated ordering and orderingType shouldn't both be specified */ - boost::shared_ptr eliminateMultifrontal( + /** @deprecated ordering and orderingType shouldn't both be specified */ + boost::shared_ptr GTSAM_DEPRECATED eliminateMultifrontal( const Ordering& ordering, const Eliminate& function, OptionalVariableIndex variableIndex, @@ -314,16 +315,16 @@ namespace gtsam { return eliminateMultifrontal(ordering, function, variableIndex); } - /** \deprecated orderingType specified first for consistency */ - boost::shared_ptr eliminateMultifrontal( + /** @deprecated orderingType specified first for consistency */ + boost::shared_ptr GTSAM_DEPRECATED eliminateMultifrontal( const Eliminate& function, OptionalVariableIndex variableIndex = boost::none, OptionalOrderingType orderingType = boost::none) const { return eliminateMultifrontal(orderingType, function, variableIndex); } - /** \deprecated */ - boost::shared_ptr marginalMultifrontalBayesNet( + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED marginalMultifrontalBayesNet( boost::variant variables, boost::none_t, const Eliminate& function = EliminationTraitsType::DefaultEliminate, @@ -331,14 +332,15 @@ namespace gtsam { return marginalMultifrontalBayesNet(variables, function, variableIndex); } - /** \deprecated */ - boost::shared_ptr marginalMultifrontalBayesTree( + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED marginalMultifrontalBayesTree( boost::variant variables, boost::none_t, const Eliminate& function = EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex = boost::none) const { return marginalMultifrontalBayesTree(variables, function, variableIndex); } + #endif }; } diff --git a/gtsam/inference/Factor.h b/gtsam/inference/Factor.h index 6ea81030a..27b85ef67 100644 --- a/gtsam/inference/Factor.h +++ b/gtsam/inference/Factor.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include #include @@ -111,6 +112,9 @@ typedef FastSet FactorIndexSet; /// @name Standard Interface /// @{ + /// Whether the factor is empty (involves zero variables). + bool empty() const { return keys_.empty(); } + /// First key Key front() const { return keys_.front(); } @@ -149,13 +153,11 @@ typedef FastSet FactorIndexSet; const std::string& s = "Factor", const KeyFormatter& formatter = DefaultKeyFormatter) const; - protected: /// check equality bool equals(const This& other, double tol = 1e-9) const; /// @} - public: /// @name Advanced Interface /// @{ diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 166ae41f4..a2ae07101 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -26,6 +26,7 @@ #include #include #include // for cout :-( +#include #include #include @@ -125,4 +126,50 @@ FactorIndices FactorGraph::add_factors(const CONTAINER& factors, return newFactorIndices; } +/* ************************************************************************* */ +template +void FactorGraph::dot(std::ostream& os, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.graphPreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : keys()) { + auto position = writer.variablePos(key); + writer.drawVariable(key, keyFormatter, position, &os); + } + os << "\n"; + + // Create factors and variable connections + for (size_t i = 0; i < size(); ++i) { + const auto& factor = at(i); + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os); + } + } + + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string FactorGraph::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::stringstream ss; + dot(ss, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +template +void FactorGraph::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter, writer); + of.close(); +} + } // namespace gtsam diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index e337e3249..afea63da8 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -22,9 +22,10 @@ #pragma once +#include +#include #include #include -#include #include // for Eigen::aligned_allocator @@ -36,6 +37,7 @@ #include #include #include +#include namespace gtsam { /// Define collection type: @@ -126,6 +128,11 @@ class FactorGraph { /** Collection of factors */ FastVector factors_; + /// Check exact equality of the factor pointers. Useful for derived ==. + bool isEqual(const FactorGraph& other) const { + return factors_ == other.factors_; + } + /// @name Standard Constructors /// @{ @@ -288,11 +295,11 @@ class FactorGraph { /// @name Testable /// @{ - /// print out graph + /// Print out graph to std::cout, with optional key formatter. virtual void print(const std::string& s = "FactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const; - /** Check equality */ + /// Check equality up to tolerance. bool equals(const This& fg, double tol = 1e-9) const; /// @} @@ -371,6 +378,24 @@ class FactorGraph { return factors_.erase(first, last); } + /// @} + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/inference/MetisIndex-inl.h b/gtsam/inference/MetisIndex-inl.h index eb9670254..646523372 100644 --- a/gtsam/inference/MetisIndex-inl.h +++ b/gtsam/inference/MetisIndex-inl.h @@ -23,8 +23,8 @@ namespace gtsam { /* ************************************************************************* */ -template -void MetisIndex::augment(const FactorGraph& factors) { +template +void MetisIndex::augment(const FACTORGRAPH& factors) { std::map > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first std::map >::iterator iAdjMapIt; std::set keySet; diff --git a/gtsam/inference/MetisIndex.h b/gtsam/inference/MetisIndex.h index 7ec435caa..7431bff4c 100644 --- a/gtsam/inference/MetisIndex.h +++ b/gtsam/inference/MetisIndex.h @@ -62,8 +62,8 @@ public: nKeys_(0) { } - template - MetisIndex(const FG& factorGraph) : + template + MetisIndex(const FACTORGRAPH& factorGraph) : nKeys_(0) { augment(factorGraph); } @@ -78,8 +78,8 @@ public: * Augment the variable index with new factors. This can be used when * solving problems incrementally. */ - template - void augment(const FactorGraph& factors); + template + void augment(const FACTORGRAPH& factors); const std::vector& xadj() const { return xadj_; diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i new file mode 100644 index 000000000..5a661d5cf --- /dev/null +++ b/gtsam/inference/inference.i @@ -0,0 +1,168 @@ +//************************************************************************* +// inference +//************************************************************************* + +namespace gtsam { + +#include + +// Default keyformatter +void PrintKeyList( + const gtsam::KeyList& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +void PrintKeyVector( + const gtsam::KeyVector& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); +void PrintKeySet( + const gtsam::KeySet& keys, const string& s = "", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); + +#include +class Symbol { + Symbol(); + Symbol(char c, uint64_t j); + Symbol(size_t key); + + size_t key() const; + void print(const string& s = "") const; + bool equals(const gtsam::Symbol& expected, double tol) const; + + char chr() const; + uint64_t index() const; + string string() const; +}; + +size_t symbol(char chr, size_t index); +char symbolChr(size_t key); +size_t symbolIndex(size_t key); + +namespace symbol_shorthand { +size_t A(size_t j); +size_t B(size_t j); +size_t C(size_t j); +size_t D(size_t j); +size_t E(size_t j); +size_t F(size_t j); +size_t G(size_t j); +size_t H(size_t j); +size_t I(size_t j); +size_t J(size_t j); +size_t K(size_t j); +size_t L(size_t j); +size_t M(size_t j); +size_t N(size_t j); +size_t O(size_t j); +size_t P(size_t j); +size_t Q(size_t j); +size_t R(size_t j); +size_t S(size_t j); +size_t T(size_t j); +size_t U(size_t j); +size_t V(size_t j); +size_t W(size_t j); +size_t X(size_t j); +size_t Y(size_t j); +size_t Z(size_t j); +} // namespace symbol_shorthand + +#include +class LabeledSymbol { + LabeledSymbol(size_t full_key); + LabeledSymbol(const gtsam::LabeledSymbol& key); + LabeledSymbol(unsigned char valType, unsigned char label, size_t j); + + size_t key() const; + unsigned char label() const; + unsigned char chr() const; + size_t index() const; + + gtsam::LabeledSymbol upper() const; + gtsam::LabeledSymbol lower() const; + gtsam::LabeledSymbol newChr(unsigned char c) const; + gtsam::LabeledSymbol newLabel(unsigned char label) const; + + void print(string s = "") const; +}; + +size_t mrsymbol(unsigned char c, unsigned char label, size_t j); +unsigned char mrsymbolChr(size_t key); +unsigned char mrsymbolLabel(size_t key); +size_t mrsymbolIndex(size_t key); + +#include +class Ordering { + /// Type of ordering to use + enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM }; + + // Standard Constructors and Named Constructors + Ordering(); + Ordering(const gtsam::Ordering& other); + + template + static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph); + + // Testable + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::Ordering& ord, double tol) const; + + // Standard interface + size_t size() const; + size_t at(size_t key) const; + void push_back(size_t key); + + // enabling serialization functionality + void serialize() const; +}; + +#include +class DotWriter { + DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, + bool plotFactorPoints = true, bool connectKeysToFactor = true, + bool binaryEdges = true); + + double figureWidthInches; + double figureHeightInches; + bool plotFactorPoints; + bool connectKeysToFactor; + bool binaryEdges; + + std::map variablePositions; + std::map positionHints; + std::set boxes; + std::map factorPositions; +}; + +#include + +// Headers for overloaded methods below, break hierarchy :-/ +#include +#include +#include + +class VariableIndex { + // Standard Constructors and Named Constructors + VariableIndex(); + // TODO: Templetize constructor when wrap supports it + // template + // VariableIndex(const T& factorGraph, size_t nVariables); + // VariableIndex(const T& factorGraph); + VariableIndex(const gtsam::SymbolicFactorGraph& sfg); + VariableIndex(const gtsam::GaussianFactorGraph& gfg); + VariableIndex(const gtsam::NonlinearFactorGraph& fg); + VariableIndex(const gtsam::VariableIndex& other); + + // Testable + bool equals(const gtsam::VariableIndex& other, double tol) const; + void print(string s = "VariableIndex: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + // Standard interface + size_t size() const; + size_t nFactors() const; + size_t nEntries() const; +}; + +} // namespace gtsam diff --git a/gtsam/linear/Errors.cpp b/gtsam/linear/Errors.cpp index 3fe2f3307..41c6c3d09 100644 --- a/gtsam/linear/Errors.cpp +++ b/gtsam/linear/Errors.cpp @@ -110,11 +110,10 @@ double dot(const Errors& a, const Errors& b) { } /* ************************************************************************* */ -template<> -void axpy(double alpha, const Errors& x, Errors& y) { +void axpy(double alpha, const Errors& x, Errors& y) { Errors::const_iterator it = x.begin(); for(Vector& yi: y) - axpy(alpha,*(it++),yi); + yi += alpha * (*(it++)); } /* ************************************************************************* */ diff --git a/gtsam/linear/Errors.h b/gtsam/linear/Errors.h index eb844e04d..f6e147084 100644 --- a/gtsam/linear/Errors.h +++ b/gtsam/linear/Errors.h @@ -65,8 +65,7 @@ namespace gtsam { /** * BLAS level 2 style */ - template <> - GTSAM_EXPORT void axpy(double alpha, const Errors& x, Errors& y); + GTSAM_EXPORT void axpy(double alpha, const Errors& x, Errors& y); /** print with optional string */ GTSAM_EXPORT void print(const Errors& a, const std::string& s = "Error"); diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index 1e790d0f1..8fd4f2c26 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -205,23 +205,5 @@ namespace gtsam { } /* ************************************************************************* */ - void GaussianBayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional : boost::adaptors::reverse(*this)) { - typename GaussianConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename GaussianConditional::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; - } - - of << "}"; - of.close(); - } - - /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index e55a89bcd..6d906d65e 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -21,17 +21,22 @@ #pragma once #include +#include #include #include +#include namespace gtsam { - /** A Bayes net made from linear-Gaussian densities */ - class GTSAM_EXPORT GaussianBayesNet: public FactorGraph + /** + * GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals. + * @addtogroup linear + */ + class GTSAM_EXPORT GaussianBayesNet: public BayesNet { public: - typedef FactorGraph Base; + typedef BayesNet Base; typedef GaussianBayesNet This; typedef GaussianConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +49,21 @@ namespace gtsam { GaussianBayesNet() {} /** Construct from iterator over conditionals */ - template - GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit GaussianBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - GaussianBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit GaussianBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~GaussianBayesNet() {} @@ -66,6 +76,13 @@ namespace gtsam { /** Check equality */ bool equals(const This& bn, double tol = 1e-9) const; + /// print graph + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + /// @} /// @name Standard Interface @@ -180,23 +197,6 @@ namespace gtsam { */ VectorValues backSubstituteTranspose(const VectorValues& gx) const; - /// print graph - void print( - const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); - } - - /** - * @brief Save the GaussianBayesNet as an image. Requires `dot` to be - * installed. - * - * @param s The name of the figure. - * @param keyFormatter Formatter to use for styling keys in the graph. - */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const; - /// @} private: diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 9297d6461..d87c39eea 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -193,12 +193,15 @@ namespace gtsam { } /* ************************************************************************* */ - void GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const { +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + void GTSAM_DEPRECATED + GaussianConditional::scaleFrontalsBySigma(VectorValues& gy) const { DenseIndex vectorPosition = 0; for (const_iterator frontal = beginFrontals(); frontal != endFrontals(); ++frontal) { gy[*frontal].array() *= model_->sigmas().segment(vectorPosition, getDim(frontal)).array(); vectorPosition += getDim(frontal); } } +#endif } // namespace gtsam diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 0ea597f99..d93f65b42 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -125,12 +125,11 @@ namespace gtsam { /** Performs transpose backsubstition in place on values */ void solveTransposeInPlace(VectorValues& gy) const; +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /** Scale the values in \c gy according to the sigmas for the frontal variables in this * conditional. */ - void scaleFrontalsBySigma(VectorValues& gy) const; - - // FIXME: deprecated flag doesn't appear to exist? - // __declspec(deprecated) void scaleFrontalsBySigma(VectorValues& gy) const; + void GTSAM_DEPRECATED scaleFrontalsBySigma(VectorValues& gy) const; +#endif private: /** Serialization function */ diff --git a/gtsam/linear/GaussianFactor.h b/gtsam/linear/GaussianFactor.h index 334722868..672f5aa0d 100644 --- a/gtsam/linear/GaussianFactor.h +++ b/gtsam/linear/GaussianFactor.h @@ -117,9 +117,6 @@ namespace gtsam { /** Clone a factor (make a deep copy) */ virtual GaussianFactor::shared_ptr clone() const = 0; - /** Test whether the factor is empty */ - virtual bool empty() const = 0; - /** * Construct the corresponding anti-factor to negate information * stored stored in this factor. diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index 24c4b9a0d..72eb107d0 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -19,7 +19,6 @@ */ #include -#include #include #include #include @@ -290,10 +289,11 @@ namespace gtsam { return blocks; } - /* ************************************************************************* */ + /* ************************************************************************ */ VectorValues GaussianFactorGraph::optimize(const Eliminate& function) const { gttic(GaussianFactorGraph_optimize); - return BaseEliminateable::eliminateMultifrontal(function)->optimize(); + return BaseEliminateable::eliminateMultifrontal(Ordering::COLAMD, function) + ->optimize(); } /* ************************************************************************* */ @@ -379,7 +379,7 @@ namespace gtsam { gttic(Compute_minimizing_step_size); // Compute minimizing step size - double step = -gradientSqNorm / dot(Rg, Rg); + double step = -gradientSqNorm / gtsam::dot(Rg, Rg); gttoc(Compute_minimizing_step_size); gttic(Compute_point); @@ -503,13 +503,6 @@ namespace gtsam { return e; } - /* ************************************************************************* */ - /** \deprecated */ - VectorValues GaussianFactorGraph::optimize(boost::none_t, - const Eliminate& function) const { - return optimize(function); - } - /* ************************************************************************* */ void GaussianFactorGraph::printErrors( const VectorValues& values, const std::string& str, diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index d41374854..0d5057aa8 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -21,12 +21,13 @@ #pragma once -#include #include +#include +#include // Included here instead of fw-declared so we can use Errors::iterator #include -#include #include -#include // Included here instead of fw-declared so we can use Errors::iterator +#include +#include namespace gtsam { @@ -98,6 +99,12 @@ namespace gtsam { /// @} + /// Check exact equality. + friend bool operator==(const GaussianFactorGraph& lhs, + const GaussianFactorGraph& rhs) { + return lhs.isEqual(rhs); + } + /** Add a factor by value - makes a copy */ void add(const GaussianFactor& factor) { push_back(factor.clone()); } @@ -153,7 +160,8 @@ namespace gtsam { /** Unnormalized probability. O(n) */ double probPrime(const VectorValues& c) const { - return exp(-0.5 * error(c)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(c)); } /** @@ -395,9 +403,14 @@ namespace gtsam { public: - /** \deprecated */ - VectorValues optimize(boost::none_t, - const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated */ + VectorValues GTSAM_DEPRECATED + optimize(boost::none_t, const Eliminate& function = + EliminationTraitsType::DefaultEliminate) const { + return optimize(function); + } +#endif }; @@ -407,7 +420,7 @@ namespace gtsam { */ GTSAM_EXPORT bool hasConstraints(const GaussianFactorGraph& factors); - /****** Linear Algebra Opeations ******/ + /****** Linear Algebra Operations ******/ ///* matrix-vector operations */ //GTSAM_EXPORT void residual(const GaussianFactorGraph& fg, const VectorValues &x, VectorValues &r); diff --git a/gtsam/linear/HessianFactor.h b/gtsam/linear/HessianFactor.h index 0f4c993fe..7020d6edd 100644 --- a/gtsam/linear/HessianFactor.h +++ b/gtsam/linear/HessianFactor.h @@ -221,9 +221,6 @@ namespace gtsam { */ GaussianFactor::shared_ptr negate() const override; - /** Check if the factor is empty. TODO: How should this be defined? */ - bool empty() const override { return size() == 0 /*|| rows() == 0*/; } - /** Return the constant term \f$ f \f$ as described above * @return The constant term \f$ f \f$ */ diff --git a/gtsam/linear/JacobianFactor.h b/gtsam/linear/JacobianFactor.h index 4d4480d32..ddf614910 100644 --- a/gtsam/linear/JacobianFactor.h +++ b/gtsam/linear/JacobianFactor.h @@ -260,9 +260,6 @@ namespace gtsam { */ GaussianFactor::shared_ptr negate() const override; - /** Check if the factor is empty. TODO: How should this be defined? */ - bool empty() const override { return size() == 0 /*|| rows() == 0*/; } - /** is noise model constrained ? */ bool isConstrained() const { return model_ && model_->isConstrained(); diff --git a/gtsam/linear/NoiseModel.h b/gtsam/linear/NoiseModel.h index 2fb54d329..5c379beb8 100644 --- a/gtsam/linear/NoiseModel.h +++ b/gtsam/linear/NoiseModel.h @@ -177,17 +177,16 @@ namespace gtsam { return *sqrt_information_; } - protected: - - /** protected constructor takes square root information matrix */ - Gaussian(size_t dim = 1, const boost::optional& sqrt_information = boost::none) : - Base(dim), sqrt_information_(sqrt_information) { - } public: typedef boost::shared_ptr shared_ptr; + /** constructor takes square root information matrix */ + Gaussian(size_t dim = 1, + const boost::optional& sqrt_information = boost::none) + : Base(dim), sqrt_information_(sqrt_information) {} + ~Gaussian() override {} /** @@ -290,13 +289,13 @@ namespace gtsam { Vector sigmas_, invsigmas_, precisions_; protected: - /** protected constructor - no initializations */ - Diagonal(); /** constructor to allow for disabling initialization of invsigmas */ Diagonal(const Vector& sigmas); public: + /** constructor - no initializations, for serialization */ + Diagonal(); typedef boost::shared_ptr shared_ptr; @@ -387,14 +386,6 @@ namespace gtsam { // Sigmas are contained in the base class Vector mu_; ///< Penalty function weight - needs to be large enough to dominate soft constraints - /** - * protected constructor takes sigmas. - * prevents any inf values - * from appearing in invsigmas or precisions. - * mu set to large default value (1000.0) - */ - Constrained(const Vector& sigmas = Z_1x1); - /** * Constructor that prevents any inf values * from appearing in invsigmas or precisions. @@ -406,6 +397,14 @@ namespace gtsam { typedef boost::shared_ptr shared_ptr; + /** + * protected constructor takes sigmas. + * prevents any inf values + * from appearing in invsigmas or precisions. + * mu set to large default value (1000.0) + */ + Constrained(const Vector& sigmas = Z_1x1); + ~Constrained() override {} /// true if a constrained noise mode, saves slow/clumsy dynamic casting @@ -461,6 +460,11 @@ namespace gtsam { return MixedVariances(precisions.array().inverse()); } + /** + * The squaredMahalanobisDistance function for a constrained noisemodel, + * for non-constrained versions, uses sigmas, otherwise + * uses the penalty function with mu + */ double squaredMahalanobisDistance(const Vector& v) const override; /** Fully constrained variations */ @@ -531,11 +535,11 @@ namespace gtsam { Isotropic(size_t dim, double sigma) : Diagonal(Vector::Constant(dim, sigma)),sigma_(sigma),invsigma_(1.0/sigma) {} + public: + /* dummy constructor to allow for serialization */ Isotropic() : Diagonal(Vector1::Constant(1.0)),sigma_(1.0),invsigma_(1.0) {} - public: - ~Isotropic() override {} typedef boost::shared_ptr shared_ptr; @@ -592,14 +596,13 @@ namespace gtsam { * Unit: i.i.d. unit-variance noise on all m dimensions. */ class GTSAM_EXPORT Unit : public Isotropic { - protected: - - Unit(size_t dim=1): Isotropic(dim,1.0) {} - public: typedef boost::shared_ptr shared_ptr; + /** constructor for serialization */ + Unit(size_t dim=1): Isotropic(dim,1.0) {} + ~Unit() override {} /** @@ -682,19 +685,19 @@ namespace gtsam { /// Return the contained noise model const NoiseModel::shared_ptr& noise() const { return noise_; } - // TODO: functions below are dummy but necessary for the noiseModel::Base + // Functions below are dummy but necessary for the noiseModel::Base inline Vector whiten(const Vector& v) const override { Vector r = v; this->WhitenSystem(r); return r; } inline Matrix Whiten(const Matrix& A) const override { Vector b; Matrix B=A; this->WhitenSystem(B,b); return B; } inline Vector unwhiten(const Vector& /*v*/) const override { throw std::invalid_argument("unwhiten is not currently supported for robust noise models."); } - + /// Compute loss from the m-estimator using the Mahalanobis distance. double loss(const double squared_distance) const override { return robust_->loss(std::sqrt(squared_distance)); } - // TODO: these are really robust iterated re-weighting support functions + // These are really robust iterated re-weighting support functions virtual void WhitenSystem(Vector& b) const; void WhitenSystem(std::vector& A, Vector& b) const override; void WhitenSystem(Matrix& A, Vector& b) const override; @@ -705,7 +708,6 @@ namespace gtsam { return noise_->unweightedWhiten(v); } double weight(const Vector& v) const override { - // Todo(mikebosse): make the robust weight function input a vector. return robust_->weight(v.norm()); } @@ -728,8 +730,8 @@ namespace gtsam { } // namespace noiseModel - /** Note, deliberately not in noiseModel namespace. - * Deprecated. Only for compatibility with previous version. + /** + * Aliases. Deliberately not in noiseModel namespace. */ typedef noiseModel::Base::shared_ptr SharedNoiseModel; typedef noiseModel::Gaussian::shared_ptr SharedGaussian; diff --git a/gtsam/linear/SubgraphBuilder.cpp b/gtsam/linear/SubgraphBuilder.cpp index 1919d38be..18e19cd20 100644 --- a/gtsam/linear/SubgraphBuilder.cpp +++ b/gtsam/linear/SubgraphBuilder.cpp @@ -446,30 +446,29 @@ SubgraphBuilder::Weights SubgraphBuilder::weights( } /*****************************************************************************/ -GaussianFactorGraph::shared_ptr buildFactorSubgraph( - const GaussianFactorGraph &gfg, const Subgraph &subgraph, - const bool clone) { - auto subgraphFactors = boost::make_shared(); - subgraphFactors->reserve(subgraph.size()); +GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg, + const Subgraph &subgraph, + const bool clone) { + GaussianFactorGraph subgraphFactors; + subgraphFactors.reserve(subgraph.size()); for (const auto &e : subgraph) { const auto factor = gfg[e.index]; - subgraphFactors->push_back(clone ? factor->clone() : factor); + subgraphFactors.push_back(clone ? factor->clone() : factor); } return subgraphFactors; } /**************************************************************************************************/ -std::pair // -splitFactorGraph(const GaussianFactorGraph &factorGraph, - const Subgraph &subgraph) { +std::pair splitFactorGraph( + const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) { // Get the subgraph by calling cheaper method auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false); // Now, copy all factors then set subGraph factors to zero - auto remaining = boost::make_shared(factorGraph); + GaussianFactorGraph remaining = factorGraph; for (const auto &e : subgraph) { - remaining->remove(e.index); + remaining.remove(e.index); } return std::make_pair(subgraphFactors, remaining); diff --git a/gtsam/linear/SubgraphBuilder.h b/gtsam/linear/SubgraphBuilder.h index 84a477a5e..a900c7531 100644 --- a/gtsam/linear/SubgraphBuilder.h +++ b/gtsam/linear/SubgraphBuilder.h @@ -172,12 +172,13 @@ class GTSAM_EXPORT SubgraphBuilder { }; /** Select the factors in a factor graph according to the subgraph. */ -boost::shared_ptr buildFactorSubgraph( - const GaussianFactorGraph &gfg, const Subgraph &subgraph, const bool clone); +GaussianFactorGraph buildFactorSubgraph(const GaussianFactorGraph &gfg, + const Subgraph &subgraph, + const bool clone); /** Split the graph into a subgraph and the remaining edges. * Note that the remaining factorgraph has null factors. */ -std::pair, boost::shared_ptr > -splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph); +std::pair splitFactorGraph( + const GaussianFactorGraph &factorGraph, const Subgraph &subgraph); } // namespace gtsam diff --git a/gtsam/linear/SubgraphPreconditioner.cpp b/gtsam/linear/SubgraphPreconditioner.cpp index fdcb4f7ac..6689cdbed 100644 --- a/gtsam/linear/SubgraphPreconditioner.cpp +++ b/gtsam/linear/SubgraphPreconditioner.cpp @@ -77,16 +77,16 @@ static void setSubvector(const Vector &src, const KeyInfo &keyInfo, /* ************************************************************************* */ // Convert any non-Jacobian factors to Jacobians (e.g. Hessian -> Jacobian with // Cholesky) -static GaussianFactorGraph::shared_ptr convertToJacobianFactors( +static GaussianFactorGraph convertToJacobianFactors( const GaussianFactorGraph &gfg) { - auto result = boost::make_shared(); + GaussianFactorGraph result; for (const auto &factor : gfg) if (factor) { auto jf = boost::dynamic_pointer_cast(factor); if (!jf) { jf = boost::make_shared(*factor); } - result->push_back(jf); + result.push_back(jf); } return result; } @@ -96,42 +96,42 @@ SubgraphPreconditioner::SubgraphPreconditioner(const SubgraphPreconditionerParam parameters_(p) {} /* ************************************************************************* */ -SubgraphPreconditioner::SubgraphPreconditioner(const sharedFG& Ab2, - const sharedBayesNet& Rc1, const sharedValues& xbar, const SubgraphPreconditionerParameters &p) : - Ab2_(convertToJacobianFactors(*Ab2)), Rc1_(Rc1), xbar_(xbar), - b2bar_(new Errors(-Ab2_->gaussianErrors(*xbar))), parameters_(p) { +SubgraphPreconditioner::SubgraphPreconditioner(const GaussianFactorGraph& Ab2, + const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p) : + Ab2_(convertToJacobianFactors(Ab2)), Rc1_(Rc1), xbar_(xbar), + b2bar_(-Ab2_.gaussianErrors(xbar)), parameters_(p) { } /* ************************************************************************* */ // x = xbar + inv(R1)*y VectorValues SubgraphPreconditioner::x(const VectorValues& y) const { - return *xbar_ + Rc1_->backSubstitute(y); + return xbar_ + Rc1_.backSubstitute(y); } /* ************************************************************************* */ double SubgraphPreconditioner::error(const VectorValues& y) const { Errors e(y); VectorValues x = this->x(y); - Errors e2 = Ab2()->gaussianErrors(x); + Errors e2 = Ab2_.gaussianErrors(x); return 0.5 * (dot(e, e) + dot(e2,e2)); } /* ************************************************************************* */ // gradient is y + inv(R1')*A2'*(A2*inv(R1)*y-b2bar), VectorValues SubgraphPreconditioner::gradient(const VectorValues &y) const { - VectorValues x = Rc1()->backSubstitute(y); /* inv(R1)*y */ - Errors e = (*Ab2() * x - *b2bar()); /* (A2*inv(R1)*y-b2bar) */ + VectorValues x = Rc1_.backSubstitute(y); /* inv(R1)*y */ + Errors e = Ab2_ * x - b2bar_; /* (A2*inv(R1)*y-b2bar) */ VectorValues v = VectorValues::Zero(x); - Ab2()->transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */ - return y + Rc1()->backSubstituteTranspose(v); + Ab2_.transposeMultiplyAdd(1.0, e, v); /* A2'*(A2*inv(R1)*y-b2bar) */ + return y + Rc1_.backSubstituteTranspose(v); } /* ************************************************************************* */ // Apply operator A, A*y = [I;A2*inv(R1)]*y = [y; A2*inv(R1)*y] -Errors SubgraphPreconditioner::operator*(const VectorValues& y) const { +Errors SubgraphPreconditioner::operator*(const VectorValues &y) const { Errors e(y); - VectorValues x = Rc1()->backSubstitute(y); /* x=inv(R1)*y */ - Errors e2 = *Ab2() * x; /* A2*x */ + VectorValues x = Rc1_.backSubstitute(y); /* x=inv(R1)*y */ + Errors e2 = Ab2_ * x; /* A2*x */ e.splice(e.end(), e2); return e; } @@ -147,8 +147,8 @@ void SubgraphPreconditioner::multiplyInPlace(const VectorValues& y, Errors& e) c } // Add A2 contribution - VectorValues x = Rc1()->backSubstitute(y); // x=inv(R1)*y - Ab2()->multiplyInPlace(x, ei); // use iterator version + VectorValues x = Rc1_.backSubstitute(y); // x=inv(R1)*y + Ab2_.multiplyInPlace(x, ei); // use iterator version } /* ************************************************************************* */ @@ -173,7 +173,7 @@ void SubgraphPreconditioner::transposeMultiplyAdd Errors::const_iterator it = e.begin(); for(auto& key_value: y) { const Vector& ei = *it; - axpy(alpha, ei, key_value.second); + key_value.second += alpha * ei; ++it; } transposeMultiplyAdd2(alpha, it, e.end(), y); @@ -190,14 +190,14 @@ void SubgraphPreconditioner::transposeMultiplyAdd2 (double alpha, while (it != end) e2.push_back(*(it++)); VectorValues x = VectorValues::Zero(y); // x = 0 - Ab2_->transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2 - axpy(alpha, Rc1_->backSubstituteTranspose(x), y); // y += alpha*inv(R1')*x + Ab2_.transposeMultiplyAdd(1.0,e2,x); // x += A2'*e2 + y += alpha * Rc1_.backSubstituteTranspose(x); // y += alpha*inv(R1')*x } /* ************************************************************************* */ void SubgraphPreconditioner::print(const std::string& s) const { cout << s << endl; - Ab2_->print(); + Ab2_.print(); } /*****************************************************************************/ @@ -205,7 +205,7 @@ void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const { assert(x.size() == y.size()); /* back substitute */ - for (const auto &cg : boost::adaptors::reverse(*Rc1_)) { + for (const auto &cg : boost::adaptors::reverse(Rc1_)) { /* collect a subvector of x that consists of the parents of cg (S) */ const KeyVector parentKeys(cg->beginParents(), cg->endParents()); const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); @@ -228,7 +228,7 @@ void SubgraphPreconditioner::transposeSolve(const Vector &y, Vector &x) const { std::copy(y.data(), y.data() + y.rows(), x.data()); /* in place back substitute */ - for (const auto &cg : *Rc1_) { + for (const auto &cg : Rc1_) { const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals()); const Vector rhsFrontal = getSubvector(x, keyInfo_, frontalKeys); const Vector solFrontal = @@ -261,10 +261,10 @@ void SubgraphPreconditioner::build(const GaussianFactorGraph &gfg, const KeyInfo keyInfo_ = keyInfo; /* build factor subgraph */ - GaussianFactorGraph::shared_ptr gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true); + auto gfg_subgraph = buildFactorSubgraph(gfg, subgraph, true); /* factorize and cache BayesNet */ - Rc1_ = gfg_subgraph->eliminateSequential(); + Rc1_ = *gfg_subgraph.eliminateSequential(); } /*****************************************************************************/ diff --git a/gtsam/linear/SubgraphPreconditioner.h b/gtsam/linear/SubgraphPreconditioner.h index 681c12e40..81c8968b1 100644 --- a/gtsam/linear/SubgraphPreconditioner.h +++ b/gtsam/linear/SubgraphPreconditioner.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -53,16 +55,12 @@ namespace gtsam { public: typedef boost::shared_ptr shared_ptr; - typedef boost::shared_ptr sharedBayesNet; - typedef boost::shared_ptr sharedFG; - typedef boost::shared_ptr sharedValues; - typedef boost::shared_ptr sharedErrors; private: - sharedFG Ab2_; - sharedBayesNet Rc1_; - sharedValues xbar_; ///< A1 \ b1 - sharedErrors b2bar_; ///< A2*xbar - b2 + GaussianFactorGraph Ab2_; + GaussianBayesNet Rc1_; + VectorValues xbar_; ///< A1 \ b1 + Errors b2bar_; ///< A2*xbar - b2 KeyInfo keyInfo_; SubgraphPreconditionerParameters parameters_; @@ -77,7 +75,7 @@ namespace gtsam { * @param Rc1: the Bayes Net R1*x=c1 * @param xbar: the solution to R1*x=c1 */ - SubgraphPreconditioner(const sharedFG& Ab2, const sharedBayesNet& Rc1, const sharedValues& xbar, + SubgraphPreconditioner(const GaussianFactorGraph& Ab2, const GaussianBayesNet& Rc1, const VectorValues& xbar, const SubgraphPreconditionerParameters &p = SubgraphPreconditionerParameters()); ~SubgraphPreconditioner() override {} @@ -86,13 +84,13 @@ namespace gtsam { void print(const std::string& s = "SubgraphPreconditioner") const; /** Access Ab2 */ - const sharedFG& Ab2() const { return Ab2_; } + const GaussianFactorGraph& Ab2() const { return Ab2_; } /** Access Rc1 */ - const sharedBayesNet& Rc1() const { return Rc1_; } + const GaussianBayesNet& Rc1() const { return Rc1_; } /** Access b2bar */ - const sharedErrors b2bar() const { return b2bar_; } + const Errors b2bar() const { return b2bar_; } /** * Add zero-mean i.i.d. Gaussian prior terms to each variable @@ -104,8 +102,7 @@ namespace gtsam { /* A zero VectorValues with the structure of xbar */ VectorValues zero() const { - assert(xbar_); - return VectorValues::Zero(*xbar_); + return VectorValues::Zero(xbar_); } /** diff --git a/gtsam/linear/SubgraphSolver.cpp b/gtsam/linear/SubgraphSolver.cpp index f49f9a135..0156c717e 100644 --- a/gtsam/linear/SubgraphSolver.cpp +++ b/gtsam/linear/SubgraphSolver.cpp @@ -34,24 +34,24 @@ namespace gtsam { SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab, const Parameters ¶meters, const Ordering& ordering) : parameters_(parameters) { - GaussianFactorGraph::shared_ptr Ab1,Ab2; + GaussianFactorGraph Ab1, Ab2; std::tie(Ab1, Ab2) = splitGraph(Ab); if (parameters_.verbosity()) - cout << "Split A into (A1) " << Ab1->size() << " and (A2) " << Ab2->size() + cout << "Split A into (A1) " << Ab1.size() << " and (A2) " << Ab2.size() << " factors" << endl; - auto Rc1 = Ab1->eliminateSequential(ordering, EliminateQR); - auto xbar = boost::make_shared(Rc1->optimize()); + auto Rc1 = *Ab1.eliminateSequential(ordering, EliminateQR); + auto xbar = Rc1.optimize(); pc_ = boost::make_shared(Ab2, Rc1, xbar); } /**************************************************************************************************/ // Taking eliminated tree [R1|c] and constraint graph [A2|b2] -SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1, - const GaussianFactorGraph::shared_ptr &Ab2, +SubgraphSolver::SubgraphSolver(const GaussianBayesNet &Rc1, + const GaussianFactorGraph &Ab2, const Parameters ¶meters) : parameters_(parameters) { - auto xbar = boost::make_shared(Rc1->optimize()); + auto xbar = Rc1.optimize(); pc_ = boost::make_shared(Ab2, Rc1, xbar); } @@ -59,10 +59,10 @@ SubgraphSolver::SubgraphSolver(const GaussianBayesNet::shared_ptr &Rc1, // Taking subgraphs [A1|b1] and [A2|b2] // delegate up SubgraphSolver::SubgraphSolver(const GaussianFactorGraph &Ab1, - const GaussianFactorGraph::shared_ptr &Ab2, + const GaussianFactorGraph &Ab2, const Parameters ¶meters, const Ordering &ordering) - : SubgraphSolver(Ab1.eliminateSequential(ordering, EliminateQR), Ab2, + : SubgraphSolver(*Ab1.eliminateSequential(ordering, EliminateQR), Ab2, parameters) {} /**************************************************************************************************/ @@ -78,7 +78,7 @@ VectorValues SubgraphSolver::optimize(const GaussianFactorGraph &gfg, return VectorValues(); } /**************************************************************************************************/ -pair // +pair // SubgraphSolver::splitGraph(const GaussianFactorGraph &factorGraph) { /* identify the subgraph structure */ diff --git a/gtsam/linear/SubgraphSolver.h b/gtsam/linear/SubgraphSolver.h index a41738321..0598b3321 100644 --- a/gtsam/linear/SubgraphSolver.h +++ b/gtsam/linear/SubgraphSolver.h @@ -99,15 +99,13 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver { * eliminate Ab1. We take Ab1 as a const reference, as it will be transformed * into Rc1, but take Ab2 as a shared pointer as we need to keep it around. */ - SubgraphSolver(const GaussianFactorGraph &Ab1, - const boost::shared_ptr &Ab2, + SubgraphSolver(const GaussianFactorGraph &Ab1, const GaussianFactorGraph &Ab2, const Parameters ¶meters, const Ordering &ordering); /** * The same as above, but we assume A1 was solved by caller. * We take two shared pointers as we keep both around. */ - SubgraphSolver(const boost::shared_ptr &Rc1, - const boost::shared_ptr &Ab2, + SubgraphSolver(const GaussianBayesNet &Rc1, const GaussianFactorGraph &Ab2, const Parameters ¶meters); /// Destructor @@ -131,9 +129,8 @@ class GTSAM_EXPORT SubgraphSolver : public IterativeSolver { /// @{ /// Split graph using Kruskal algorithm, treating binary factors as edges. - std::pair < boost::shared_ptr, - boost::shared_ptr > splitGraph( - const GaussianFactorGraph &gfg); + std::pair splitGraph( + const GaussianFactorGraph &gfg); /// @} }; diff --git a/gtsam/linear/iterative-inl.h b/gtsam/linear/iterative-inl.h index 58ef7d733..906ee80fd 100644 --- a/gtsam/linear/iterative-inl.h +++ b/gtsam/linear/iterative-inl.h @@ -72,7 +72,7 @@ namespace gtsam { double takeOptimalStep(V& x) { // TODO: can we use gamma instead of dot(d,g) ????? Answer not trivial double alpha = -dot(d, g) / dot(Ad, Ad); // calculate optimal step-size - axpy(alpha, d, x); // // do step in new search direction, x += alpha*d + x += alpha * d; // do step in new search direction, x += alpha*d return alpha; } @@ -106,7 +106,7 @@ namespace gtsam { double beta = new_gamma / gamma; // d = g + d*beta; d *= beta; - axpy(1.0, g, d); + d += 1.0 * g; } gamma = new_gamma; diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index c74161f26..b079c3dd1 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -255,9 +255,6 @@ class VectorValues { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -302,6 +299,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor { void print(string s = "", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void printKeys(string s) const; + gtsam::KeyVector& keys() const; bool equals(const gtsam::GaussianFactor& lf, double tol) const; size_t size() const; Vector unweighted_error(const gtsam::VectorValues& c) const; @@ -328,9 +326,6 @@ virtual class JacobianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -363,9 +358,6 @@ virtual class HessianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -445,46 +437,53 @@ class GaussianFactorGraph { pair hessian() const; pair hessian(const gtsam::Ordering& ordering) const; + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include virtual class GaussianConditional : gtsam::JacobianFactor { - //Constructors - GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); + // Constructors + GaussianConditional(size_t key, Vector d, Matrix R, + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - const gtsam::noiseModel::Diagonal* sigmas); + const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas); + size_t name2, Matrix T, + const gtsam::noiseModel::Diagonal* sigmas); - //Constructors with no noise model + // Constructors with no noise model GaussianConditional(size_t key, Vector d, Matrix R); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); - GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, - size_t name2, Matrix T); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); + GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, + size_t name2, Matrix T); - //Standard Interface - void print(string s = "GaussianConditional", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::GaussianConditional& cg, double tol) const; + // Standard Interface + void print(string s = "GaussianConditional", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::GaussianConditional& cg, double tol) const; + gtsam::Key firstFrontalKey() const; + + // Advanced Interface + gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; + gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, + const gtsam::VectorValues& rhs) const; + void solveTransposeInPlace(gtsam::VectorValues& gy) const; + Matrix R() const; + Matrix S() const; + Vector d() const; - // Advanced Interface - gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; - gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents, - const gtsam::VectorValues& rhs) const; - void solveTransposeInPlace(gtsam::VectorValues& gy) const; - void scaleFrontalsBySigma(gtsam::VectorValues& gy) const; - Matrix R() const; - Matrix S() const; - Vector d() const; - - // enabling serialization functionality - void serialize() const; + // enabling serialization functionality + void serialize() const; }; #include @@ -514,9 +513,9 @@ virtual class GaussianBayesNet { size_t size() const; // FactorGraph derived interface - // size_t size() const; gtsam::GaussianConditional* at(size_t idx) const; gtsam::KeySet keys() const; + gtsam::KeyVector keyVector() const; bool exists(size_t idx) const; void saveGraph(const string& s) const; @@ -536,6 +535,14 @@ virtual class GaussianBayesNet { double logDeterminant() const; gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -636,7 +643,7 @@ virtual class SubgraphSolverParameters : gtsam::ConjugateGradientParameters { virtual class SubgraphSolver { SubgraphSolver(const gtsam::GaussianFactorGraph &A, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); - SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph* Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); + SubgraphSolver(const gtsam::GaussianFactorGraph &Ab1, const gtsam::GaussianFactorGraph& Ab2, const gtsam::SubgraphSolverParameters ¶meters, const gtsam::Ordering& ordering); gtsam::VectorValues optimize() const; }; diff --git a/gtsam/linear/tests/testErrors.cpp b/gtsam/linear/tests/testErrors.cpp index 74eef9a2c..f11fb90b9 100644 --- a/gtsam/linear/tests/testErrors.cpp +++ b/gtsam/linear/tests/testErrors.cpp @@ -32,7 +32,7 @@ TEST( Errors, arithmetic ) e += Vector2(1.0,2.0), Vector3(3.0,4.0,5.0); DOUBLES_EQUAL(1+4+9+16+25,dot(e,e),1e-9); - axpy(2.0,e,e); + axpy(2.0, e, e); Errors expected; expected += Vector2(3.0,6.0), Vector3(9.0,12.0,15.0); CHECK(assert_equal(expected,e)); diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index 00a338e54..f62da15dd 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) { } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +TEST(GaussianBayesNet, Dot) { + GaussianBayesNet fragment; + DotWriter writer; + writer.variablePositions.emplace(_x_, Vector2(10, 20)); + writer.variablePositions.emplace(_y_, Vector2(50, 20)); + + auto position = writer.variablePos(_x_); + CHECK(position); + EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5)); + + string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var11[label=\"11\", pos=\"10,20!\"];\n" + " var22[label=\"22\", pos=\"50,20!\"];\n" + "\n" + " var22->var11\n" + "}"); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ diff --git a/gtsam/linear/tests/testGaussianFactorGraph.cpp b/gtsam/linear/tests/testGaussianFactorGraph.cpp index bb07a36aa..41464a110 100644 --- a/gtsam/linear/tests/testGaussianFactorGraph.cpp +++ b/gtsam/linear/tests/testGaussianFactorGraph.cpp @@ -426,6 +426,7 @@ TEST(GaussianFactorGraph, hessianDiagonal) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ TEST(GaussianFactorGraph, DenseSolve) { GaussianFactorGraph fg = createSimpleGaussianFactorGraph(); VectorValues expected = fg.optimize(); @@ -433,6 +434,28 @@ TEST(GaussianFactorGraph, DenseSolve) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianFactorGraph, ProbPrime) { + GaussianFactorGraph gfg; + gfg.emplace_shared(1, I_1x1, Z_1x1, + noiseModel::Isotropic::Sigma(1, 1.0)); + + VectorValues values; + values.insert(1, I_1x1); + + // We are testing the normal distribution PDF where info matrix Σ = 1, + // mean mu = 0 and x = 1. + // Therefore factor squared error: y = 0.5 * (Σ*x - mu)^2 = + // 0.5 * (1.0 - 0)^2 = 0.5 + // NOTE the 0.5 constant is a part of the factor error. + EXPECT_DOUBLES_EQUAL(0.5, gfg.error(values), 1e-12); + + // The gaussian PDF value is: exp^(-0.5 * (Σ*x - mu)^2) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-0.5 * (1.0)^2) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/linear/tests/testNoiseModel.cpp b/gtsam/linear/tests/testNoiseModel.cpp index 42d68a603..b974b6cd5 100644 --- a/gtsam/linear/tests/testNoiseModel.cpp +++ b/gtsam/linear/tests/testNoiseModel.cpp @@ -662,25 +662,14 @@ TEST(NoiseModel, robustNoiseL2WithDeadZone) { double dead_zone_size = 1.0; SharedNoiseModel robust = noiseModel::Robust::Create( - noiseModel::mEstimator::L2WithDeadZone::Create(dead_zone_size), - Unit::Create(3)); - -/* - * TODO(mike): There is currently a bug in GTSAM, where none of the mEstimator classes - * implement a loss function, and GTSAM calls the weight function to evaluate the - * total penalty, rather than calling the loss function. The weight function should be - * used during iteratively reweighted least squares optimization, but should not be used to - * evaluate the total penalty. The long-term solution is for all mEstimators to implement - * both a weight and a loss function, and for GTSAM to call the loss function when - * evaluating the total penalty. This bug causes the test below to fail, so I'm leaving it - * commented out until the underlying bug in GTSAM is fixed. - * - * for (int i = 0; i < 5; i++) { - * Vector3 error = Vector3(i, 0, 0); - * DOUBLES_EQUAL(0.5*max(0,i-1)*max(0,i-1), robust->distance(error), 1e-8); - * } - */ + noiseModel::mEstimator::L2WithDeadZone::Create(dead_zone_size), + Unit::Create(3)); + for (int i = 0; i < 5; i++) { + Vector3 error = Vector3(i, 0, 0); + DOUBLES_EQUAL(std::fmax(0, i - dead_zone_size) * i, + robust->squaredMahalanobisDistance(error), 1e-8); + } } TEST(NoiseModel, lossFunctionAtZero) @@ -707,9 +696,9 @@ TEST(NoiseModel, lossFunctionAtZero) auto dcs = mEstimator::DCS::Create(k); DOUBLES_EQUAL(dcs->loss(0), 0, 1e-8); DOUBLES_EQUAL(dcs->weight(0), 1, 1e-8); - // auto lsdz = mEstimator::L2WithDeadZone::Create(k); - // DOUBLES_EQUAL(lsdz->loss(0), 0, 1e-8); - // DOUBLES_EQUAL(lsdz->weight(0), 1, 1e-8); + auto lsdz = mEstimator::L2WithDeadZone::Create(k); + DOUBLES_EQUAL(lsdz->loss(0), 0, 1e-8); + DOUBLES_EQUAL(lsdz->weight(0), 0, 1e-8); } diff --git a/gtsam/navigation/AHRSFactor.cpp b/gtsam/navigation/AHRSFactor.cpp index 4604a55dd..f4db42d0f 100644 --- a/gtsam/navigation/AHRSFactor.cpp +++ b/gtsam/navigation/AHRSFactor.cpp @@ -168,13 +168,12 @@ Vector AHRSFactor::evaluateError(const Rot3& Ri, const Rot3& Rj, } //------------------------------------------------------------------------------ -Rot3 AHRSFactor::Predict( - const Rot3& rot_i, const Vector3& bias, - const PreintegratedAhrsMeasurements preintegratedMeasurements) { - const Vector3 biascorrectedOmega = preintegratedMeasurements.predict(bias); +Rot3 AHRSFactor::Predict(const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim) { + const Vector3 biascorrectedOmega = pim.predict(bias); // Coriolis term - const Vector3 coriolis = preintegratedMeasurements.integrateCoriolis(rot_i); + const Vector3 coriolis = pim.integrateCoriolis(rot_i); const Vector3 correctedOmega = biascorrectedOmega - coriolis; const Rot3 correctedDeltaRij = Rot3::Expmap(correctedOmega); @@ -184,27 +183,26 @@ Rot3 AHRSFactor::Predict( //------------------------------------------------------------------------------ AHRSFactor::AHRSFactor(Key rot_i, Key rot_j, Key bias, - const PreintegratedMeasurements& pim, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, const boost::optional& body_P_sensor) - : Base(noiseModel::Gaussian::Covariance(pim.preintMeasCov_), rot_i, rot_j, bias), + : Base(noiseModel::Gaussian::Covariance(pim.preintMeasCov_), rot_i, rot_j, + bias), _PIM_(pim) { - boost::shared_ptr p = - boost::make_shared(pim.p()); + auto p = boost::make_shared(pim.p()); p->body_P_sensor = body_P_sensor; _PIM_.p_ = p; } //------------------------------------------------------------------------------ Rot3 AHRSFactor::predict(const Rot3& rot_i, const Vector3& bias, - const PreintegratedMeasurements pim, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, const boost::optional& body_P_sensor) { - boost::shared_ptr p = - boost::make_shared(pim.p()); + auto p = boost::make_shared(pim.p()); p->omegaCoriolis = omegaCoriolis; p->body_P_sensor = body_P_sensor; - PreintegratedMeasurements newPim = pim; + PreintegratedAhrsMeasurements newPim = pim; newPim.p_ = p; return Predict(rot_i, bias, newPim); } diff --git a/gtsam/navigation/AHRSFactor.h b/gtsam/navigation/AHRSFactor.h index 1ab2d7cdc..10c33d101 100644 --- a/gtsam/navigation/AHRSFactor.h +++ b/gtsam/navigation/AHRSFactor.h @@ -104,11 +104,10 @@ class GTSAM_EXPORT PreintegratedAhrsMeasurements : public PreintegratedRotation static Vector DeltaAngles(const Vector& msr_gyro_t, const double msr_dt, const Vector3& delta_angles); - /// @deprecated constructor + /// @deprecated constructor, but used in tests. PreintegratedAhrsMeasurements(const Vector3& biasHat, const Matrix3& measuredOmegaCovariance) - : PreintegratedRotation(boost::make_shared()), - biasHat_(biasHat) { + : PreintegratedRotation(boost::make_shared()), biasHat_(biasHat) { p_->gyroscopeCovariance = measuredOmegaCovariance; resetIntegration(); } @@ -182,24 +181,26 @@ public: /// predicted states from IMU /// TODO(frank): relationship with PIM predict ?? - static Rot3 Predict( - const Rot3& rot_i, const Vector3& bias, - const PreintegratedAhrsMeasurements preintegratedMeasurements); + static Rot3 Predict(const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim); + /// @deprecated constructor, but used in tests. + AHRSFactor(Key rot_i, Key rot_j, Key bias, + const PreintegratedAhrsMeasurements& pim, + const Vector3& omegaCoriolis, + const boost::optional& body_P_sensor = boost::none); + + /// @deprecated static function, but used in tests. + static Rot3 predict( + const Rot3& rot_i, const Vector3& bias, + const PreintegratedAhrsMeasurements& pim, const Vector3& omegaCoriolis, + const boost::optional& body_P_sensor = boost::none); + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @deprecated name typedef PreintegratedAhrsMeasurements PreintegratedMeasurements; - /// @deprecated constructor - AHRSFactor(Key rot_i, Key rot_j, Key bias, - const PreintegratedMeasurements& preintegratedMeasurements, - const Vector3& omegaCoriolis, - const boost::optional& body_P_sensor = boost::none); - - /// @deprecated static function - static Rot3 predict(const Rot3& rot_i, const Vector3& bias, - const PreintegratedMeasurements preintegratedMeasurements, - const Vector3& omegaCoriolis, - const boost::optional& body_P_sensor = boost::none); +#endif private: diff --git a/gtsam/navigation/BarometricFactor.cpp b/gtsam/navigation/BarometricFactor.cpp new file mode 100644 index 000000000..2f0ff7436 --- /dev/null +++ b/gtsam/navigation/BarometricFactor.cpp @@ -0,0 +1,55 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.cpp + * @author Peter Milani + * @brief Implementation file for Barometric factor + * @date December 16, 2021 + **/ + +#include "BarometricFactor.h" + +using namespace std; + +namespace gtsam { + +//*************************************************************************** +void BarometricFactor::print(const string& s, + const KeyFormatter& keyFormatter) const { + cout << (s.empty() ? "" : s + " ") << "Barometric Factor on " + << keyFormatter(key1()) << "Barometric Bias on " + << keyFormatter(key2()) << "\n"; + + cout << " Baro measurement: " << nT_ << "\n"; + noiseModel_->print(" noise model: "); +} + +//*************************************************************************** +bool BarometricFactor::equals(const NonlinearFactor& expected, + double tol) const { + const This* e = dynamic_cast(&expected); + return e != nullptr && Base::equals(*e, tol) && + traits::Equals(nT_, e->nT_, tol); +} + +//*************************************************************************** +Vector BarometricFactor::evaluateError(const Pose3& p, const double& bias, + boost::optional H, + boost::optional H2) const { + Matrix tH; + Vector ret = (Vector(1) << (p.translation(tH).z() + bias - nT_)).finished(); + if (H) (*H) = tH.block<1, 6>(2, 0); + if (H2) (*H2) = (Matrix(1, 1) << 1.0).finished(); + return ret; +} + +} // namespace gtsam diff --git a/gtsam/navigation/BarometricFactor.h b/gtsam/navigation/BarometricFactor.h new file mode 100644 index 000000000..e7bf6f998 --- /dev/null +++ b/gtsam/navigation/BarometricFactor.h @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BarometricFactor.h + * @author Peter Milani + * @brief Header file for Barometric factor + * @date December 16, 2021 + **/ +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * Prior on height in a cartesian frame. + * Receive barometric pressure in kilopascals + * Model with a slowly moving bias to capture differences + * between the height and the standard atmosphere + * https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + * @addtogroup Navigation + */ +class GTSAM_EXPORT BarometricFactor : public NoiseModelFactor2 { + private: + typedef NoiseModelFactor2 Base; + + double nT_; ///< Height Measurement based on a standard atmosphere + + public: + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + + /// Typedef to this class + typedef BarometricFactor This; + + /** default constructor - only use for serialization */ + BarometricFactor() : nT_(0) {} + + ~BarometricFactor() override {} + + /** + * @brief Constructor from a measurement of pressure in KPa. + * @param key of the Pose3 variable that will be constrained + * @param key of the barometric bias that will be constrained + * @param baroIn measurement in KPa + * @param model Gaussian noise model 1 dimension + */ + BarometricFactor(Key key, Key baroKey, const double& baroIn, + const SharedNoiseModel& model) + : Base(model, key, baroKey), nT_(heightOut(baroIn)) {} + + /// @return a deep copy of this factor + gtsam::NonlinearFactor::shared_ptr clone() const override { + return boost::static_pointer_cast( + gtsam::NonlinearFactor::shared_ptr(new This(*this))); + } + + /// print + void print( + const std::string& s = "", + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + + /// equals + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override; + + /// vector of errors + Vector evaluateError( + const Pose3& p, const double& b, + boost::optional H = boost::none, + boost::optional H2 = boost::none) const override; + + inline const double& measurementIn() const { return nT_; } + + inline double heightOut(double n) const { + // From https://www.grc.nasa.gov/www/k-12/airplane/atmosmet.html + return (std::pow(n / 101.29, 1. / 5.256) * 288.08 - 273.1 - 15.04) / + -0.00649; + }; + + inline double baroOut(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); + }; + + private: + /// Serialization function + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "NoiseModelFactor1", + boost::serialization::base_object(*this)); + ar& BOOST_SERIALIZATION_NVP(nT_); + } +}; + +} // namespace gtsam diff --git a/gtsam/navigation/ImuBias.h b/gtsam/navigation/ImuBias.h index fad952232..9346a4a77 100644 --- a/gtsam/navigation/ImuBias.h +++ b/gtsam/navigation/ImuBias.h @@ -131,30 +131,30 @@ public: /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @name Deprecated /// @{ - ConstantBias inverse() { - return -(*this); - } - ConstantBias compose(const ConstantBias& q) { + ConstantBias GTSAM_DEPRECATED inverse() { return -(*this); } + ConstantBias GTSAM_DEPRECATED compose(const ConstantBias& q) { return (*this) + q; } - ConstantBias between(const ConstantBias& q) { + ConstantBias GTSAM_DEPRECATED between(const ConstantBias& q) { return q - (*this); } - Vector6 localCoordinates(const ConstantBias& q) { - return between(q).vector(); + Vector6 GTSAM_DEPRECATED localCoordinates(const ConstantBias& q) { + return (q - (*this)).vector(); } - ConstantBias retract(const Vector6& v) { - return compose(ConstantBias(v)); + ConstantBias GTSAM_DEPRECATED retract(const Vector6& v) { + return (*this) + ConstantBias(v); } - static Vector6 Logmap(const ConstantBias& p) { + static Vector6 GTSAM_DEPRECATED Logmap(const ConstantBias& p) { return p.vector(); } - static ConstantBias Expmap(const Vector6& v) { + static ConstantBias GTSAM_DEPRECATED Expmap(const Vector6& v) { return ConstantBias(v); } /// @} +#endif private: diff --git a/gtsam/navigation/navigation.i b/gtsam/navigation/navigation.i index 7a879c3ef..6ede1645f 100644 --- a/gtsam/navigation/navigation.i +++ b/gtsam/navigation/navigation.i @@ -18,29 +18,21 @@ class ConstantBias { // Group static gtsam::imuBias::ConstantBias identity(); - gtsam::imuBias::ConstantBias inverse() const; - gtsam::imuBias::ConstantBias compose(const gtsam::imuBias::ConstantBias& b) const; - gtsam::imuBias::ConstantBias between(const gtsam::imuBias::ConstantBias& b) const; // Operator Overloads gtsam::imuBias::ConstantBias operator-() const; gtsam::imuBias::ConstantBias operator+(const gtsam::imuBias::ConstantBias& b) const; gtsam::imuBias::ConstantBias operator-(const gtsam::imuBias::ConstantBias& b) const; - // Manifold - gtsam::imuBias::ConstantBias retract(Vector v) const; - Vector localCoordinates(const gtsam::imuBias::ConstantBias& b) const; - - // Lie Group - static gtsam::imuBias::ConstantBias Expmap(Vector v); - static Vector Logmap(const gtsam::imuBias::ConstantBias& b); - // Standard Interface Vector vector() const; Vector accelerometer() const; Vector gyroscope() const; Vector correctAccelerometer(Vector measurement) const; Vector correctGyroscope(Vector measurement) const; + + // enabling serialization functionality + void serialize() const; }; }///\namespace imuBias @@ -64,6 +56,9 @@ class NavState { gtsam::NavState retract(const Vector& x) const; Vector localCoordinates(const gtsam::NavState& g) const; + + // enabling serialization functionality + void serialize() const; }; #include @@ -106,6 +101,9 @@ virtual class PreintegrationParams : gtsam::PreintegratedRotationParams { Matrix getAccelerometerCovariance() const; Matrix getIntegrationCovariance() const; bool getUse2ndOrderCoriolis() const; + + // enabling serialization functionality + void serialize() const; }; #include @@ -135,6 +133,9 @@ class PreintegratedImuMeasurements { Vector biasHatVector() const; gtsam::NavState predict(const gtsam::NavState& state_i, const gtsam::imuBias::ConstantBias& bias) const; + + // enabling serialization functionality + void serialize() const; }; virtual class ImuFactor: gtsam::NonlinearFactor { diff --git a/gtsam/navigation/tests/testAHRSFactor.cpp b/gtsam/navigation/tests/testAHRSFactor.cpp index a4d06d01a..779f6abcc 100644 --- a/gtsam/navigation/tests/testAHRSFactor.cpp +++ b/gtsam/navigation/tests/testAHRSFactor.cpp @@ -54,11 +54,11 @@ Rot3 evaluateRotationError(const AHRSFactor& factor, const Rot3 rot_i, return Rot3::Expmap(factor.evaluateError(rot_i, rot_j, bias).tail(3)); } -AHRSFactor::PreintegratedMeasurements evaluatePreintegratedMeasurements( +PreintegratedAhrsMeasurements evaluatePreintegratedMeasurements( const Vector3& bias, const list& measuredOmegas, const list& deltaTs, const Vector3& initialRotationRate = Vector3::Zero()) { - AHRSFactor::PreintegratedMeasurements result(bias, I_3x3); + PreintegratedAhrsMeasurements result(bias, I_3x3); list::const_iterator itOmega = measuredOmegas.begin(); list::const_iterator itDeltaT = deltaTs.begin(); @@ -86,10 +86,10 @@ Rot3 evaluateRotation(const Vector3 measuredOmega, const Vector3 biasOmega, Vector3 evaluateLogRotation(const Vector3 thetahat, const Vector3 deltatheta) { return Rot3::Logmap(Rot3::Expmap(thetahat).compose(Rot3::Expmap(deltatheta))); } - } + //****************************************************************************** -TEST( AHRSFactor, PreintegratedMeasurements ) { +TEST( AHRSFactor, PreintegratedAhrsMeasurements ) { // Linearization point Vector3 bias(0,0,0); ///< Current estimate of angular rate bias @@ -102,7 +102,7 @@ TEST( AHRSFactor, PreintegratedMeasurements ) { double expectedDeltaT1(0.5); // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements actual1(bias, Z_3x3); + PreintegratedAhrsMeasurements actual1(bias, Z_3x3); actual1.integrateMeasurement(measuredOmega, deltaT); EXPECT(assert_equal(expectedDeltaR1, Rot3(actual1.deltaRij()), 1e-6)); @@ -113,7 +113,7 @@ TEST( AHRSFactor, PreintegratedMeasurements ) { double expectedDeltaT2(1); // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements actual2 = actual1; + PreintegratedAhrsMeasurements actual2 = actual1; actual2.integrateMeasurement(measuredOmega, deltaT); EXPECT(assert_equal(expectedDeltaR2, Rot3(actual2.deltaRij()), 1e-6)); @@ -159,7 +159,7 @@ TEST(AHRSFactor, Error) { Vector3 measuredOmega; measuredOmega << M_PI / 100, 0, 0; double deltaT = 1.0; - AHRSFactor::PreintegratedMeasurements pim(bias, Z_3x3); + PreintegratedAhrsMeasurements pim(bias, Z_3x3); pim.integrateMeasurement(measuredOmega, deltaT); // Create factor @@ -217,7 +217,7 @@ TEST(AHRSFactor, ErrorWithBiases) { measuredOmega << 0, 0, M_PI / 10.0 + 0.3; double deltaT = 1.0; - AHRSFactor::PreintegratedMeasurements pim(Vector3(0,0,0), + PreintegratedAhrsMeasurements pim(Vector3(0,0,0), Z_3x3); pim.integrateMeasurement(measuredOmega, deltaT); @@ -360,7 +360,7 @@ TEST( AHRSFactor, FirstOrderPreIntegratedMeasurements ) { } // Actual preintegrated values - AHRSFactor::PreintegratedMeasurements preintegrated = + PreintegratedAhrsMeasurements preintegrated = evaluatePreintegratedMeasurements(bias, measuredOmegas, deltaTs, Vector3(M_PI / 100.0, 0.0, 0.0)); @@ -397,7 +397,7 @@ TEST( AHRSFactor, ErrorWithBiasesAndSensorBodyDisplacement ) { const Pose3 body_P_sensor(Rot3::Expmap(Vector3(0, 0.10, 0.10)), Point3(1, 0, 0)); - AHRSFactor::PreintegratedMeasurements pim(Vector3::Zero(), kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(Vector3::Zero(), kMeasuredAccCovariance); pim.integrateMeasurement(measuredOmega, deltaT); @@ -439,7 +439,7 @@ TEST (AHRSFactor, predictTest) { Vector3 measuredOmega; measuredOmega << 0, 0, M_PI / 10.0; double deltaT = 0.2; - AHRSFactor::PreintegratedMeasurements pim(bias, kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(bias, kMeasuredAccCovariance); for (int i = 0; i < 1000; ++i) { pim.integrateMeasurement(measuredOmega, deltaT); } @@ -456,9 +456,9 @@ TEST (AHRSFactor, predictTest) { Rot3 actualRot = factor.predict(x, bias, pim, kZeroOmegaCoriolis); EXPECT(assert_equal(expectedRot, actualRot, 1e-6)); - // AHRSFactor::PreintegratedMeasurements::predict + // PreintegratedAhrsMeasurements::predict Matrix expectedH = numericalDerivative11( - std::bind(&AHRSFactor::PreintegratedMeasurements::predict, + std::bind(&PreintegratedAhrsMeasurements::predict, &pim, std::placeholders::_1, boost::none), bias); // Actual Jacobians @@ -478,7 +478,7 @@ TEST (AHRSFactor, graphTest) { // PreIntegrator Vector3 biasHat(0, 0, 0); - AHRSFactor::PreintegratedMeasurements pim(biasHat, kMeasuredAccCovariance); + PreintegratedAhrsMeasurements pim(biasHat, kMeasuredAccCovariance); // Pre-integrate measurements Vector3 measuredOmega(0, M_PI / 20, 0); diff --git a/gtsam/navigation/tests/testBarometricFactor.cpp b/gtsam/navigation/tests/testBarometricFactor.cpp new file mode 100644 index 000000000..47f4824c1 --- /dev/null +++ b/gtsam/navigation/tests/testBarometricFactor.cpp @@ -0,0 +1,129 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testBarometricFactor.cpp + * @brief Unit test for BarometricFactor + * @author Peter Milani + * @date 16 Dec, 2021 + */ + +#include +#include +#include +#include + +#include + +using namespace std::placeholders; +using namespace std; +using namespace gtsam; + +// ************************************************************************* +namespace example {} + +double metersToBaro(const double& meters) { + double temp = 15.04 - 0.00649 * meters; + return 101.29 * std::pow(((temp + 273.1) / 288.08), 5.256); +} + +// ************************************************************************* +TEST(BarometricFactor, Constructor) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + // Create a linearization point at zero error + Pose3 T(Rot3::RzRyRx(0., 0., 0.), Point3(0., 0., 10.)); + double baroBias = 0.; + Vector1 zero; + zero << 0.; + EXPECT(assert_equal(zero, factor.evaluateError(T, baroBias), 1e-5)); + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative + Matrix actualH, actualH2; + factor.evaluateError(T, baroBias, actualH, actualH2); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); +} + +// ************************************************************************* + +//*************************************************************************** +TEST(BarometricFactor, nonZero) { + using namespace example; + + // meters to barometric. + + double baroMeasurement = metersToBaro(10.); + + // Factor + Key key(1); + Key key2(2); + SharedNoiseModel model = noiseModel::Isotropic::Sigma(1, 0.25); + BarometricFactor factor(key, key2, baroMeasurement, model); + + Pose3 T(Rot3::RzRyRx(0.5, 1., 1.), Point3(20., 30., 1.)); + double baroBias = 5.; + + // Calculate numerical derivatives + Matrix expectedH = numericalDerivative21( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + Matrix expectedH2 = numericalDerivative22( + std::bind(&BarometricFactor::evaluateError, &factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none), + T, baroBias); + + // Use the factor to calculate the derivative and the error + Matrix actualH, actualH2; + Vector error = factor.evaluateError(T, baroBias, actualH, actualH2); + Vector actual = (Vector(1) << -4.0).finished(); + + // Verify we get the expected error + EXPECT(assert_equal(expectedH, actualH, 1e-8)); + EXPECT(assert_equal(expectedH2, actualH2, 1e-8)); + EXPECT(assert_equal(error, actual, 1e-8)); +} + +// ************************************************************************* +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +// ************************************************************************* diff --git a/gtsam/navigation/tests/testGPSFactor.cpp b/gtsam/navigation/tests/testGPSFactor.cpp index b784c0c94..c94e1d3d5 100644 --- a/gtsam/navigation/tests/testGPSFactor.cpp +++ b/gtsam/navigation/tests/testGPSFactor.cpp @@ -72,7 +72,7 @@ TEST( GPSFactor, Constructor ) { // Calculate numerical derivatives Matrix expectedH = numericalDerivative11( - std::bind(&GPSFactor::evaluateError, &factor, _1, boost::none), T); + std::bind(&GPSFactor::evaluateError, &factor, std::placeholders::_1, boost::none), T); // Use the factor to calculate the derivative Matrix actualH; @@ -101,7 +101,7 @@ TEST( GPSFactor2, Constructor ) { // Calculate numerical derivatives Matrix expectedH = numericalDerivative11( - std::bind(&GPSFactor2::evaluateError, &factor, _1, boost::none), T); + std::bind(&GPSFactor2::evaluateError, &factor, std::placeholders::_1, boost::none), T); // Use the factor to calculate the derivative Matrix actualH; diff --git a/gtsam/navigation/tests/testImuBias.cpp b/gtsam/navigation/tests/testImuBias.cpp index b486a4a98..81a1a2ceb 100644 --- a/gtsam/navigation/tests/testImuBias.cpp +++ b/gtsam/navigation/tests/testImuBias.cpp @@ -47,20 +47,19 @@ TEST(ImuBias, Constructor) { } /* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 TEST(ImuBias, inverse) { Bias biasActual = bias1.inverse(); Bias biasExpected = Bias(-biasAcc1, -biasGyro1); EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, compose) { Bias biasActual = bias2.compose(bias1); Bias biasExpected = Bias(biasAcc1 + biasAcc2, biasGyro1 + biasGyro2); EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, between) { // p.between(q) == q - p Bias biasActual = bias2.between(bias1); @@ -68,7 +67,6 @@ TEST(ImuBias, between) { EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, localCoordinates) { Vector deltaActual = Vector(bias2.localCoordinates(bias1)); Vector deltaExpected = @@ -76,7 +74,6 @@ TEST(ImuBias, localCoordinates) { EXPECT(assert_equal(deltaExpected, deltaActual)); } -/* ************************************************************************* */ TEST(ImuBias, retract) { Vector6 delta; delta << 0.1, 0.2, -0.3, 0.1, -0.1, 0.2; @@ -86,14 +83,12 @@ TEST(ImuBias, retract) { EXPECT(assert_equal(biasExpected, biasActual)); } -/* ************************************************************************* */ TEST(ImuBias, Logmap) { Vector deltaActual = bias2.Logmap(bias1); Vector deltaExpected = bias1.vector(); EXPECT(assert_equal(deltaExpected, deltaActual)); } -/* ************************************************************************* */ TEST(ImuBias, Expmap) { Vector6 delta; delta << 0.1, 0.2, -0.3, 0.1, -0.1, 0.2; @@ -101,6 +96,7 @@ TEST(ImuBias, Expmap) { Bias biasExpected = Bias(delta); EXPECT(assert_equal(biasExpected, biasActual)); } +#endif /* ************************************************************************* */ TEST(ImuBias, operatorSub) { diff --git a/gtsam/navigation/tests/testMagFactor.cpp b/gtsam/navigation/tests/testMagFactor.cpp index 5107b3b6b..85447facd 100644 --- a/gtsam/navigation/tests/testMagFactor.cpp +++ b/gtsam/navigation/tests/testMagFactor.cpp @@ -64,7 +64,7 @@ TEST( MagFactor, unrotate ) { Point3 expected(22735.5, 314.502, 44202.5); EXPECT( assert_equal(expected, MagFactor::unrotate(theta,nM,H),1e-1)); EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor::unrotate, _1, nM, none), theta), H, 1e-6)); + (std::bind(&MagFactor::unrotate, std::placeholders::_1, nM, none), theta), H, 1e-6)); } // ************************************************************************* @@ -76,35 +76,35 @@ TEST( MagFactor, Factors ) { MagFactor f(1, measured, s, dir, bias, model); EXPECT( assert_equal(Z_3x1,f.evaluateError(theta,H1),1e-5)); EXPECT( assert_equal((Matrix)numericalDerivative11 // - (std::bind(&MagFactor::evaluateError, &f, _1, none), theta), H1, 1e-7)); + (std::bind(&MagFactor::evaluateError, &f, std::placeholders::_1, none), theta), H1, 1e-7)); // MagFactor1 MagFactor1 f1(1, measured, s, dir, bias, model); EXPECT( assert_equal(Z_3x1,f1.evaluateError(nRb,H1),1e-5)); EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor1::evaluateError, &f1, _1, none), nRb), H1, 1e-7)); + (std::bind(&MagFactor1::evaluateError, &f1, std::placeholders::_1, none), nRb), H1, 1e-7)); // MagFactor2 MagFactor2 f2(1, 2, measured, nRb, model); EXPECT( assert_equal(Z_3x1,f2.evaluateError(scaled,bias,H1,H2),1e-5)); EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor2::evaluateError, &f2, _1, bias, none, none), scaled),// + (std::bind(&MagFactor2::evaluateError, &f2, std::placeholders::_1, bias, none, none), scaled),// H1, 1e-7)); EXPECT( assert_equal(numericalDerivative11 // - (std::bind(&MagFactor2::evaluateError, &f2, scaled, _1, none, none), bias),// + (std::bind(&MagFactor2::evaluateError, &f2, scaled, std::placeholders::_1, none, none), bias),// H2, 1e-7)); // MagFactor2 MagFactor3 f3(1, 2, 3, measured, nRb, model); EXPECT(assert_equal(Z_3x1,f3.evaluateError(s,dir,bias,H1,H2,H3),1e-5)); EXPECT(assert_equal((Matrix)numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, _1, dir, bias, none, none, none), s),// + (std::bind(&MagFactor3::evaluateError, &f3, std::placeholders::_1, dir, bias, none, none, none), s),// H1, 1e-7)); EXPECT(assert_equal(numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, s, _1, bias, none, none, none), dir),// + (std::bind(&MagFactor3::evaluateError, &f3, s, std::placeholders::_1, bias, none, none, none), dir),// H2, 1e-7)); EXPECT(assert_equal(numericalDerivative11 // - (std::bind(&MagFactor3::evaluateError, &f3, s, dir, _1, none, none, none), bias),// + (std::bind(&MagFactor3::evaluateError, &f3, s, dir, std::placeholders::_1, none, none, none), bias),// H3, 1e-7)); } diff --git a/gtsam/nonlinear/DoglegOptimizerImpl.cpp b/gtsam/nonlinear/DoglegOptimizerImpl.cpp index c319f26e6..7e9db6b64 100644 --- a/gtsam/nonlinear/DoglegOptimizerImpl.cpp +++ b/gtsam/nonlinear/DoglegOptimizerImpl.cpp @@ -78,7 +78,8 @@ VectorValues DoglegOptimizerImpl::ComputeBlend(double delta, const VectorValues& // Compute blended point if(verbose) cout << "In blend region with fraction " << tau << " of Newton's method point" << endl; - VectorValues blend = (1. - tau) * x_u; axpy(tau, x_n, blend); + VectorValues blend = (1. - tau) * x_u; + blend += tau * x_n; return blend; } diff --git a/gtsam/nonlinear/ExpressionFactor.h b/gtsam/nonlinear/ExpressionFactor.h index b55d643aa..11bf873e7 100644 --- a/gtsam/nonlinear/ExpressionFactor.h +++ b/gtsam/nonlinear/ExpressionFactor.h @@ -295,17 +295,17 @@ struct traits> // ExpressionFactorN -#if defined(GTSAM_ALLOW_DEPRECATED_SINCE_V41) +#if defined(GTSAM_ALLOW_DEPRECATED_SINCE_V42) /** * Binary specialization of ExpressionFactor meant as a base class for binary * factors. Enforces an 'expression' method with two keys, and provides * 'evaluateError'. Derived class (a binary factor!) needs to call 'initialize'. * * \sa ExpressionFactorN - * \deprecated Prefer the more general ExpressionFactorN<>. + * @deprecated Prefer the more general ExpressionFactorN<>. */ template -class ExpressionFactor2 : public ExpressionFactorN { +class GTSAM_DEPRECATED ExpressionFactor2 : public ExpressionFactorN { public: /// Destructor ~ExpressionFactor2() override {} diff --git a/gtsam/nonlinear/ExtendedKalmanFilter.h b/gtsam/nonlinear/ExtendedKalmanFilter.h index 77bb1ca6c..df27d16ff 100644 --- a/gtsam/nonlinear/ExtendedKalmanFilter.h +++ b/gtsam/nonlinear/ExtendedKalmanFilter.h @@ -51,9 +51,11 @@ class ExtendedKalmanFilter { typedef boost::shared_ptr > shared_ptr; typedef VALUE T; +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 //@deprecated: any NoiseModelFactor will do, as long as they have the right keys typedef NoiseModelFactor2 MotionFactor; typedef NoiseModelFactor1 MeasurementFactor; +#endif protected: T x_; // linearization point diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp new file mode 100644 index 000000000..ca3466b6a --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -0,0 +1,145 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.cpp + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +// TODO(frank): nonlinear should not depend on geometry: +#include +#include + +#include + +namespace gtsam { + +Vector2 GraphvizFormatting::findBounds(const Values& values, + const KeySet& keys) const { + Vector2 min; + min.x() = std::numeric_limits::infinity(); + min.y() = std::numeric_limits::infinity(); + for (const Key& key : keys) { + if (values.exists(key)) { + boost::optional xy = extractPosition(values.at(key)); + if (xy) { + if (xy->x() < min.x()) min.x() = xy->x(); + if (xy->y() < min.y()) min.y() = xy->y(); + } + } + } + return min; +} + +boost::optional GraphvizFormatting::extractPosition( + const Value& value) const { + Vector3 t; + if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + if (p->dim() == 2) { + const Eigen::Ref p_2d(p->value()); + t << p_2d.x(), p_2d.y(), 0; + } else if (p->dim() == 3) { + const Eigen::Ref p_3d(p->value()); + t = p_3d; + } else { + return boost::none; + } + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value().translation(); + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value(); + } else { + return boost::none; + } + double x, y; + switch (paperHorizontalAxis) { + case X: + x = t.x(); + break; + case Y: + x = t.y(); + break; + case Z: + x = t.z(); + break; + case NEGX: + x = -t.x(); + break; + case NEGY: + x = -t.y(); + break; + case NEGZ: + x = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + switch (paperVerticalAxis) { + case X: + y = t.x(); + break; + case Y: + y = t.y(); + break; + case Z: + y = t.z(); + break; + case NEGX: + y = -t.x(); + break; + case NEGY: + y = -t.y(); + break; + case NEGZ: + y = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + return Vector2(x, y); +} + +boost::optional GraphvizFormatting::variablePos(const Values& values, + const Vector2& min, + Key key) const { + if (!values.exists(key)) return DotWriter::variablePos(key); + boost::optional xy = extractPosition(values.at(key)); + if (xy) { + xy->x() = scale * (xy->x() - min.x()); + xy->y() = scale * (xy->y() - min.y()); + } + return xy; +} + +boost::optional GraphvizFormatting::factorPos(const Vector2& min, + size_t i) const { + if (factorPositions.size() == 0) return boost::none; + auto it = factorPositions.find(i); + if (it == factorPositions.end()) return boost::none; + auto pos = it->second; + return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y())); +} + +} // namespace gtsam diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h new file mode 100644 index 000000000..03cdb3469 --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -0,0 +1,66 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.h + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include + +namespace gtsam { + +class Values; +class Value; + +/** + * Formatting options and functions for saving a NonlinearFactorGraph instance + * in GraphViz format. + */ +struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { + /// World axes to be assigned to paper axes + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis + double scale; ///< Scale all positions to reduce / increase density + bool mergeSimilarFactors; ///< Merge multiple factors that have the same + ///< connectivity + + /// Default constructor sets up robot coordinates. Paper horizontal is robot + /// Y, paper vertical is robot X. Default figure size of 5x5 in. + GraphvizFormatting() + : paperHorizontalAxis(Y), + paperVerticalAxis(X), + scale(1), + mergeSimilarFactors(false) {} + + // Find bounds + Vector2 findBounds(const Values& values, const KeySet& keys) const; + + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional extractPosition(const Value& value) const; + + /// Return affinely transformed variable position if it exists. + boost::optional variablePos(const Values& values, const Vector2& min, + Key key) const; + + /// Return affinely transformed factor position if it exists. + boost::optional factorPos(const Vector2& min, size_t i) const; +}; + +} // namespace gtsam diff --git a/gtsam/nonlinear/LinearContainerFactor.h b/gtsam/nonlinear/LinearContainerFactor.h index 8c5b34f01..efc095775 100644 --- a/gtsam/nonlinear/LinearContainerFactor.h +++ b/gtsam/nonlinear/LinearContainerFactor.h @@ -29,9 +29,6 @@ protected: GaussianFactor::shared_ptr factor_; boost::optional linearizationPoint_; - /** Default constructor - necessary for serialization */ - LinearContainerFactor() {} - /** direct copy constructor */ GTSAM_EXPORT LinearContainerFactor(const GaussianFactor::shared_ptr& factor, const boost::optional& linearizationPoint); @@ -43,6 +40,9 @@ public: typedef boost::shared_ptr shared_ptr; + /** Default constructor - necessary for serialization */ + LinearContainerFactor() {} + /** Primary constructor: store a linear factor with optional linearization point */ GTSAM_EXPORT LinearContainerFactor(const JacobianFactor& factor, const Values& linearizationPoint = Values()); diff --git a/gtsam/nonlinear/Marginals.cpp b/gtsam/nonlinear/Marginals.cpp index c29a79623..41212ed76 100644 --- a/gtsam/nonlinear/Marginals.cpp +++ b/gtsam/nonlinear/Marginals.cpp @@ -80,11 +80,15 @@ Marginals::Marginals(const GaussianFactorGraph& graph, const VectorValues& solut /* ************************************************************************* */ void Marginals::computeBayesTree() { + // The default ordering to use. + const Ordering::OrderingType defaultOrderingType = Ordering::COLAMD; // Compute BayesTree - if(factorization_ == CHOLESKY) - bayesTree_ = *graph_.eliminateMultifrontal(EliminatePreferCholesky); - else if(factorization_ == QR) - bayesTree_ = *graph_.eliminateMultifrontal(EliminateQR); + if (factorization_ == CHOLESKY) + bayesTree_ = *graph_.eliminateMultifrontal(defaultOrderingType, + EliminatePreferCholesky); + else if (factorization_ == QR) + bayesTree_ = + *graph_.eliminateMultifrontal(defaultOrderingType, EliminateQR); } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/Marginals.h b/gtsam/nonlinear/Marginals.h index 9935bafdd..028545d01 100644 --- a/gtsam/nonlinear/Marginals.h +++ b/gtsam/nonlinear/Marginals.h @@ -131,17 +131,19 @@ protected: void computeBayesTree(const Ordering& ordering); public: - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization, +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization, + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} - /** \deprecated argument order changed due to removing boost::optional */ - Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization, + /** @deprecated argument order changed due to removing boost::optional */ + GTSAM_DEPRECATED Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization, const Ordering& ordering) : Marginals(graph, solution, ordering, factorization) {} +#endif }; diff --git a/gtsam/nonlinear/NonlinearFactor.cpp b/gtsam/nonlinear/NonlinearFactor.cpp index 8b8d2da6c..3d572e970 100644 --- a/gtsam/nonlinear/NonlinearFactor.cpp +++ b/gtsam/nonlinear/NonlinearFactor.cpp @@ -114,7 +114,7 @@ double NoiseModelFactor::weight(const Values& c) const { if (noiseModel_) { const Vector b = unwhitenedError(c); check(noiseModel_, b.size()); - return 0.5 * noiseModel_->weight(b); + return noiseModel_->weight(b); } else return 1.0; diff --git a/gtsam/nonlinear/NonlinearFactor.h b/gtsam/nonlinear/NonlinearFactor.h index 3a59a4db3..d7061215e 100644 --- a/gtsam/nonlinear/NonlinearFactor.h +++ b/gtsam/nonlinear/NonlinearFactor.h @@ -402,7 +402,7 @@ ALIAS_X(X6, 5, 5 < sizeof...(VALUES)); * objects in non-linear manifolds (Lie groups). */ template -class NoiseModelFactorN +class GTSAM_EXPORT NoiseModelFactorN : public NoiseModelFactor, public detail::AliasX, // using X = VALUE1 public detail::AliasX1, // using X1 = VALUE1 diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 8e4cf277c..dfa54f26f 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -33,9 +33,10 @@ # include #endif +#include #include #include -#include +#include using namespace std; @@ -46,7 +47,8 @@ template class FactorGraph; /* ************************************************************************* */ double NonlinearFactorGraph::probPrime(const Values& values) const { - return exp(-0.5 * error(values)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(values)); } /* ************************************************************************* */ @@ -55,9 +57,14 @@ void NonlinearFactorGraph::print(const std::string& str, const KeyFormatter& key for (size_t i = 0; i < factors_.size(); i++) { stringstream ss; ss << "Factor " << i << ": "; - if (factors_[i] != nullptr) factors_[i]->print(ss.str(), keyFormatter); - cout << endl; + if (factors_[i] != nullptr) { + factors_[i]->print(ss.str(), keyFormatter); + cout << "\n"; + } else { + cout << ss.str() << "nullptr\n"; + } } + std::cout.flush(); } /* ************************************************************************* */ @@ -81,8 +88,9 @@ void NonlinearFactorGraph::printErrors(const Values& values, const std::string& factor->print(ss.str(), keyFormatter); cout << "error = " << errorValue << "\n"; } - cout << endl; // only one "endl" at end might be faster, \n for each factor + cout << "\n"; } + std::cout.flush(); } /* ************************************************************************* */ @@ -91,89 +99,25 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, - const GraphvizFormatting& formatting, - const KeyFormatter& keyFormatter) const -{ - stm << "graph {\n"; - stm << " size=\"" << formatting.figureWidthInches << "," << - formatting.figureHeightInches << "\";\n\n"; +void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + writer.graphPreamble(&os); + // Find bounds (imperative) KeySet keys = this->keys(); - - // Local utility function to extract x and y coordinates - struct { boost::optional operator()( - const Value& value, const GraphvizFormatting& graphvizFormatting) - { - Vector3 t; - if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value().translation(); - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value(); - } else { - return boost::none; - } - double x, y; - switch (graphvizFormatting.paperHorizontalAxis) { - case GraphvizFormatting::X: x = t.x(); break; - case GraphvizFormatting::Y: x = t.y(); break; - case GraphvizFormatting::Z: x = t.z(); break; - case GraphvizFormatting::NEGX: x = -t.x(); break; - case GraphvizFormatting::NEGY: x = -t.y(); break; - case GraphvizFormatting::NEGZ: x = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - switch (graphvizFormatting.paperVerticalAxis) { - case GraphvizFormatting::X: y = t.x(); break; - case GraphvizFormatting::Y: y = t.y(); break; - case GraphvizFormatting::Z: y = t.z(); break; - case GraphvizFormatting::NEGX: y = -t.x(); break; - case GraphvizFormatting::NEGY: y = -t.y(); break; - case GraphvizFormatting::NEGZ: y = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - return Point2(x,y); - }} getXY; - - // Find bounds - double minX = numeric_limits::infinity(), maxX = -numeric_limits::infinity(); - double minY = numeric_limits::infinity(), maxY = -numeric_limits::infinity(); - for (const Key& key : keys) { - if (values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) { - if(xy->x() < minX) - minX = xy->x(); - if(xy->x() > maxX) - maxX = xy->x(); - if(xy->y() < minY) - minY = xy->y(); - if(xy->y() > maxY) - maxY = xy->y(); - } - } - } + Vector2 min = writer.findBounds(values, keys); // Create nodes for each variable in the graph - for(Key key: keys){ - // Label the node with the label from the KeyFormatter - stm << " var" << key << "[label=\"" << keyFormatter(key) << "\""; - if(values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) - stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\""; - } - stm << "];\n"; + for (Key key : keys) { + auto position = writer.variablePos(values, min, key); + writer.drawVariable(key, keyFormatter, position, &os); } - stm << "\n"; + os << "\n"; - if (formatting.mergeSimilarFactors) { + if (writer.mergeSimilarFactors) { // Remove duplicate factors - std::set structure; + std::set structure; for (const sharedFactor& factor : factors_) { if (factor) { KeyVector factorKeys = factor->keys(); @@ -184,86 +128,41 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, // Create factors and variable connections size_t i = 0; - for(const KeyVector& factorKeys: structure){ - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = formatting.factorPositions.find(i); - if(pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << "," - << formatting.scale*(pos->second.y() - minY) << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - for(Key key: factorKeys) { - stm << " var" << key << "--" << "factor" << i << ";\n"; - } - - ++ i; + for (const KeyVector& factorKeys : structure) { + writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os); } } else { // Create factors and variable connections - for(size_t i = 0; i < size(); ++i) { + for (size_t i = 0; i < size(); ++i) { const NonlinearFactor::shared_ptr& factor = at(i); - // If null pointer, move on to the next - if (!factor) { - continue; - } - - if (formatting.plotFactorPoints) { - const KeyVector& keys = factor->keys(); - if (formatting.binaryEdges && keys.size() == 2) { - stm << " var" << keys[0] << "--" - << "var" << keys[1] << ";\n"; - } else { - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = - formatting.factorPositions.find(i); - if (pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX) - << "," << formatting.scale * (pos->second.y() - minY) - << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - if (formatting.connectKeysToFactor && factor) { - for (Key key : *factor) { - stm << " var" << key << "--" - << "factor" << i << ";\n"; - } - } - } - } else { - Key k; - bool firstTime = true; - for (Key key : *this->at(i)) { - if (firstTime) { - k = key; - firstTime = false; - continue; - } - stm << " var" << key << "--" - << "var" << k << ";\n"; - k = key; - } + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, keyFormatter, + writer.factorPos(min, i), &os); } } } - stm << "}\n"; + os << "}\n"; + std::flush(os); } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph( - const std::string& file, const Values& values, - const GraphvizFormatting& graphvizFormatting, - const KeyFormatter& keyFormatter) const { - std::ofstream of(file); - saveGraph(of, values, graphvizFormatting, keyFormatter); +std::string NonlinearFactorGraph::dot(const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::stringstream ss; + dot(ss, values, keyFormatter, writer); + return ss.str(); +} + +/* ************************************************************************* */ +void NonlinearFactorGraph::saveGraph(const std::string& filename, + const Values& values, + const KeyFormatter& keyFormatter, + const GraphvizFormatting& writer) const { + std::ofstream of(filename); + dot(of, values, keyFormatter, writer); of.close(); } diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 4d321f8ab..3237d7c1e 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -42,38 +43,14 @@ namespace gtsam { class ExpressionFactor; /** - * Formatting options when saving in GraphViz format using - * NonlinearFactorGraph::saveGraph. - */ - struct GTSAM_EXPORT GraphvizFormatting { - enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis - double figureWidthInches; ///< The figure width on paper in inches - double figureHeightInches; ///< The figure height on paper in inches - double scale; ///< Scale all positions to reduce / increase density - bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity - bool plotFactorPoints; ///< Plots each factor as a dot between the variables - bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor - bool binaryEdges; ///< just use non-dotted edges for binary factors - std::map factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions. - /// Default constructor sets up robot coordinates. Paper horizontal is robot Y, - /// paper vertical is robot X. Default figure size of 5x5 in. - GraphvizFormatting() : - paperHorizontalAxis(Y), paperVerticalAxis(X), - figureWidthInches(5), figureHeightInches(5), scale(1), - mergeSimilarFactors(false), plotFactorPoints(true), - connectKeysToFactor(true), binaryEdges(true) {} - }; - - - /** - * A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, - * which derive from NonlinearFactor. The values structures are typically (in SAM) more general - * than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds. - * Linearizing the non-linear factor graph creates a linear factor graph on the - * tangent vector space at the linearization point. Because the tangent space is a true - * vector space, the config type will be an VectorValues in that linearized factor graph. + * A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors, + * which derive from NonlinearFactor. The values structures are typically (in + * SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects + * in non-linear manifolds. Linearizing the non-linear factor graph creates a + * linear factor graph on the tangent vector space at the linearization point. + * Because the tangent space is a true vector space, the config type will be + * an VectorValues in that linearized factor graph. + * @addtogroup nonlinear */ class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph { @@ -83,6 +60,9 @@ namespace gtsam { typedef NonlinearFactorGraph This; typedef boost::shared_ptr shared_ptr; + /// @name Standard Constructors + /// @{ + /** Default constructor */ NonlinearFactorGraph() {} @@ -101,6 +81,10 @@ namespace gtsam { /// Destructor virtual ~NonlinearFactorGraph() {} + /// @} + /// @name Testable + /// @{ + /** print */ void print( const std::string& str = "NonlinearFactorGraph: ", @@ -115,22 +99,11 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /// Write the graph in GraphViz format for visualization - void saveGraph(std::ostream& stm, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} + /// @name Standard Interface + /// @{ - /** - * Write the graph in GraphViz format to file for visualization. - * - * This is a wrapper friendly version since wrapped languages don't have - * access to C++ streams. - */ - void saveGraph(const std::string& file, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ + /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ double error(const Values& values) const; /** Unnormalized probability. O(n) */ @@ -246,7 +219,32 @@ namespace gtsam { emplace_shared>(key, prior, covariance); } - private: + /// @} + /// @name Graph Display + /// @{ + + using FactorGraph::dot; + using FactorGraph::saveGraph; + + /// Output to graphviz format, stream version, with Values/extra options. + void dot(std::ostream& os, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + + /// Output to graphviz format string, with Values/extra options. + std::string dot( + const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + + /// output to file with graphviz format, with Values/extra options. + void saveGraph( + const std::string& filename, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; + /// @} + + private: /** * Linearize from Scatter rather than from Ordering. Made private because @@ -265,16 +263,36 @@ namespace gtsam { public: - /** \deprecated */ - boost::shared_ptr linearizeToHessianFactor( +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated + /// @{ + /** @deprecated */ + boost::shared_ptr GTSAM_DEPRECATED linearizeToHessianFactor( const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return linearizeToHessianFactor(values, dampen);} - /** \deprecated */ - Values updateCholesky(const Values& values, boost::none_t, + /** @deprecated */ + Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return updateCholesky(values, dampen);} + /** @deprecated */ + void GTSAM_DEPRECATED saveGraph( + std::ostream& os, const Values& values = Values(), + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + dot(os, values, keyFormatter, graphvizFormatting); + } + /** @deprecated */ + void GTSAM_DEPRECATED + saveGraph(const std::string& filename, const Values& values, + const GraphvizFormatting& graphvizFormatting, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + saveGraph(filename, values, keyFormatter, graphvizFormatting); + } + /// @} +#endif + }; /// traits diff --git a/gtsam/nonlinear/NonlinearOptimizer.cpp b/gtsam/nonlinear/NonlinearOptimizer.cpp index 0d7e9e17f..3ce6db4af 100644 --- a/gtsam/nonlinear/NonlinearOptimizer.cpp +++ b/gtsam/nonlinear/NonlinearOptimizer.cpp @@ -147,11 +147,13 @@ VectorValues NonlinearOptimizer::solve(const GaussianFactorGraph& gfg, } else if (params.isSequential()) { // Sequential QR or Cholesky (decided by params.getEliminationFunction()) if (params.ordering) - delta = gfg.eliminateSequential(*params.ordering, params.getEliminationFunction(), - boost::none, params.orderingType)->optimize(); + delta = gfg.eliminateSequential(*params.ordering, + params.getEliminationFunction()) + ->optimize(); else - delta = gfg.eliminateSequential(params.getEliminationFunction(), boost::none, - params.orderingType)->optimize(); + delta = gfg.eliminateSequential(params.orderingType, + params.getEliminationFunction()) + ->optimize(); } else if (params.isIterative()) { // Conjugate Gradient -> needs params.iterativeParams if (!params.iterativeParams) diff --git a/gtsam/nonlinear/Values-inl.h b/gtsam/nonlinear/Values-inl.h index 8ebdcab17..dfcb7e174 100644 --- a/gtsam/nonlinear/Values-inl.h +++ b/gtsam/nonlinear/Values-inl.h @@ -391,4 +391,10 @@ namespace gtsam { update(j, static_cast(GenericValue(val))); } + // insert_or_assign with templated value + template + void Values::insert_or_assign(Key j, const ValueType& val) { + insert_or_assign(j, static_cast(GenericValue(val))); + } + } diff --git a/gtsam/nonlinear/Values.cpp b/gtsam/nonlinear/Values.cpp index ebc9c51f6..adadc99c0 100644 --- a/gtsam/nonlinear/Values.cpp +++ b/gtsam/nonlinear/Values.cpp @@ -171,6 +171,25 @@ namespace gtsam { } } + /* ************************************************************************ */ + void Values::insert_or_assign(Key j, const Value& val) { + if (this->exists(j)) { + // If key already exists, perform an update. + this->update(j, val); + } else { + // If key does not exist, perform an insert. + this->insert(j, val); + } + } + + /* ************************************************************************ */ + void Values::insert_or_assign(const Values& values) { + for (const_iterator key_value = values.begin(); key_value != values.end(); + ++key_value) { + this->insert_or_assign(key_value->key, key_value->value); + } + } + /* ************************************************************************* */ void Values::erase(Key j) { KeyValueMap::iterator item = values_.find(j); diff --git a/gtsam/nonlinear/Values.h b/gtsam/nonlinear/Values.h index 33e9e7d82..cfe6347b5 100644 --- a/gtsam/nonlinear/Values.h +++ b/gtsam/nonlinear/Values.h @@ -24,6 +24,7 @@ #pragma once +#include #include #include #include @@ -62,17 +63,18 @@ namespace gtsam { class GTSAM_EXPORT Values { private: - // Internally we store a boost ptr_map, with a ValueCloneAllocator (defined - // below) to clone and deallocate the Value objects, and a boost - // fast_pool_allocator to allocate map nodes. In this way, all memory is - // allocated in a boost memory pool. + // below) to clone and deallocate the Value objects, and our compile-flag- + // dependent FastDefaultAllocator to allocate map nodes. In this way, the + // user defines the allocation details (i.e. optimize for memory pool/arenas + // concurrency). + typedef internal::FastDefaultAllocator>::type KeyValuePtrPairAllocator; typedef boost::ptr_map< Key, Value, std::less, ValueCloneAllocator, - boost::fast_pool_allocator > > KeyValueMap; + KeyValuePtrPairAllocator > KeyValueMap; // The member to store the values, see just above KeyValueMap values_; @@ -283,6 +285,19 @@ namespace gtsam { /** update the current available values without adding new ones */ void update(const Values& values); + /// If key j exists, update value, else perform an insert. + void insert_or_assign(Key j, const Value& val); + + /** + * Update a set of variables. + * If any variable key doe not exist, then perform an insert. + */ + void insert_or_assign(const Values& values); + + /// Templated version to insert_or_assign a variable with the given j. + template + void insert_or_assign(Key j, const ValueType& val); + /** Remove a variable from the config, throws KeyDoesNotExist if j is not present */ void erase(Key j); diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index ecf63094d..eedf421bc 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -23,117 +23,19 @@ namespace gtsam { #include #include #include -#include #include #include -class Symbol { - Symbol(); - Symbol(char c, uint64_t j); - Symbol(size_t key); +#include +class GraphvizFormatting : gtsam::DotWriter { + GraphvizFormatting(); - size_t key() const; - void print(const string& s = "") const; - bool equals(const gtsam::Symbol& expected, double tol) const; + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + Axis paperHorizontalAxis; + Axis paperVerticalAxis; - char chr() const; - uint64_t index() const; - string string() const; -}; - -size_t symbol(char chr, size_t index); -char symbolChr(size_t key); -size_t symbolIndex(size_t key); - -namespace symbol_shorthand { -size_t A(size_t j); -size_t B(size_t j); -size_t C(size_t j); -size_t D(size_t j); -size_t E(size_t j); -size_t F(size_t j); -size_t G(size_t j); -size_t H(size_t j); -size_t I(size_t j); -size_t J(size_t j); -size_t K(size_t j); -size_t L(size_t j); -size_t M(size_t j); -size_t N(size_t j); -size_t O(size_t j); -size_t P(size_t j); -size_t Q(size_t j); -size_t R(size_t j); -size_t S(size_t j); -size_t T(size_t j); -size_t U(size_t j); -size_t V(size_t j); -size_t W(size_t j); -size_t X(size_t j); -size_t Y(size_t j); -size_t Z(size_t j); -} // namespace symbol_shorthand - -// Default keyformatter -void PrintKeyList( - const gtsam::KeyList& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void PrintKeyVector( - const gtsam::KeyVector& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void PrintKeySet( - const gtsam::KeySet& keys, const string& s = "", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); - -#include -class LabeledSymbol { - LabeledSymbol(size_t full_key); - LabeledSymbol(const gtsam::LabeledSymbol& key); - LabeledSymbol(unsigned char valType, unsigned char label, size_t j); - - size_t key() const; - unsigned char label() const; - unsigned char chr() const; - size_t index() const; - - gtsam::LabeledSymbol upper() const; - gtsam::LabeledSymbol lower() const; - gtsam::LabeledSymbol newChr(unsigned char c) const; - gtsam::LabeledSymbol newLabel(unsigned char label) const; - - void print(string s = "") const; -}; - -size_t mrsymbol(unsigned char c, unsigned char label, size_t j); -unsigned char mrsymbolChr(size_t key); -unsigned char mrsymbolLabel(size_t key); -size_t mrsymbolIndex(size_t key); - -#include -class Ordering { - // Standard Constructors and Named Constructors - Ordering(); - Ordering(const gtsam::Ordering& other); - - template - static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph); - - // Testable - void print(string s = "", const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - bool equals(const gtsam::Ordering& ord, double tol) const; - - // Standard interface - size_t size() const; - size_t at(size_t key) const; - void push_back(size_t key); - - // enabling serialization functionality - void serialize() const; - - // enable pickling in python - void pickle() const; + double scale; + bool mergeSimilarFactors; }; #include @@ -193,13 +95,17 @@ class NonlinearFactorGraph { gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const; gtsam::NonlinearFactorGraph clone() const; + string dot( + const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const GraphvizFormatting& formatting = GraphvizFormatting()); + void saveGraph( + const string& s, const gtsam::Values& values, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const GraphvizFormatting& formatting = GraphvizFormatting()) const; + // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; - - void saveGraph(const string& s) const; }; #include @@ -275,6 +181,7 @@ class Values { void insert(const gtsam::Values& values); void update(const gtsam::Values& values); + void insert_or_assign(const gtsam::Values& values); void erase(size_t j); void swap(gtsam::Values& values); @@ -289,9 +196,6 @@ class Values { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // New in 4.0, we have to specialize every insert/update/at to generate // wrappers Instead of the old: void insert(size_t j, const gtsam::Value& // value); void update(size_t j, const gtsam::Value& val); gtsam::Value @@ -351,6 +255,32 @@ class Values { void update(size_t j, Matrix matrix); void update(size_t j, double c); + void insert_or_assign(size_t j, const gtsam::Point2& point2); + void insert_or_assign(size_t j, const gtsam::Point3& point3); + void insert_or_assign(size_t j, const gtsam::Rot2& rot2); + void insert_or_assign(size_t j, const gtsam::Pose2& pose2); + void insert_or_assign(size_t j, const gtsam::SO3& R); + void insert_or_assign(size_t j, const gtsam::SO4& Q); + void insert_or_assign(size_t j, const gtsam::SOn& P); + void insert_or_assign(size_t j, const gtsam::Rot3& rot3); + void insert_or_assign(size_t j, const gtsam::Pose3& pose3); + void insert_or_assign(size_t j, const gtsam::Unit3& unit3); + void insert_or_assign(size_t j, const gtsam::Cal3_S2& cal3_s2); + void insert_or_assign(size_t j, const gtsam::Cal3DS2& cal3ds2); + void insert_or_assign(size_t j, const gtsam::Cal3Bundler& cal3bundler); + void insert_or_assign(size_t j, const gtsam::Cal3Fisheye& cal3fisheye); + void insert_or_assign(size_t j, const gtsam::Cal3Unified& cal3unified); + void insert_or_assign(size_t j, const gtsam::EssentialMatrix& essential_matrix); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::PinholeCamera& camera); + void insert_or_assign(size_t j, const gtsam::imuBias::ConstantBias& constant_bias); + void insert_or_assign(size_t j, const gtsam::NavState& nav_state); + void insert_or_assign(size_t j, Vector vector); + void insert_or_assign(size_t j, Matrix matrix); + void insert_or_assign(size_t j, double c); + template , gtsam::PinholeCamera, Vector, Matrix}> VALUE calculateEstimate(size_t key) const; - gtsam::Values calculateBestEstimate() const; Matrix marginalCovariance(size_t key) const; + gtsam::Values calculateBestEstimate() const; gtsam::VectorValues getDelta() const; + double error(const gtsam::VectorValues& x) const; gtsam::NonlinearFactorGraph getFactorsUnsafe() const; gtsam::VariableIndex getVariableIndex() const; + const gtsam::KeySet& getFixedVariables() const; gtsam::ISAM2Params params() const; + + void printStats() const; + gtsam::VectorValues gradientAtZero() const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; #include @@ -814,9 +760,6 @@ virtual class PriorFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include diff --git a/gtsam/geometry/tests/testUtilities.cpp b/gtsam/nonlinear/tests/testUtilities.cpp similarity index 68% rename from gtsam/geometry/tests/testUtilities.cpp rename to gtsam/nonlinear/tests/testUtilities.cpp index 25ac3acc8..55a7fdb13 100644 --- a/gtsam/geometry/tests/testUtilities.cpp +++ b/gtsam/nonlinear/tests/testUtilities.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include using namespace gtsam; @@ -55,6 +54,26 @@ TEST(Utilities, ExtractPoint3) { EXPECT_LONGS_EQUAL(2, all_points.rows()); } +/* ************************************************************************* */ +TEST(Utilities, ExtractVector) { + // Test normal case with 3 vectors and 1 non-vector (ignore non-vector) + auto values = Values(); + values.insert(X(0), (Vector(4) << 1, 2, 3, 4).finished()); + values.insert(X(2), (Vector(4) << 13, 14, 15, 16).finished()); + values.insert(X(1), (Vector(4) << 6, 7, 8, 9).finished()); + values.insert(X(3), Pose3()); + auto actual = utilities::extractVectors(values, 'x'); + auto expected = + (Matrix(3, 4) << 1, 2, 3, 4, 6, 7, 8, 9, 13, 14, 15, 16).finished(); + EXPECT(assert_equal(expected, actual)); + + // Check that mis-sized vectors fail + values.insert(X(4), (Vector(2) << 1, 2).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); + values.update(X(4), (Vector(6) << 1, 2, 3, 4, 5, 6).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); +} + /* ************************************************************************* */ int main() { srand(time(nullptr)); diff --git a/gtsam/nonlinear/tests/testValues.cpp b/gtsam/nonlinear/tests/testValues.cpp index b894f4816..bed2a8af9 100644 --- a/gtsam/nonlinear/tests/testValues.cpp +++ b/gtsam/nonlinear/tests/testValues.cpp @@ -172,6 +172,22 @@ TEST( Values, update_element ) CHECK(assert_equal((Vector)v2, cfg.at(key1))); } +TEST(Values, InsertOrAssign) { + Values values; + Key X(0); + double x = 1; + + CHECK(values.size() == 0); + // This should perform an insert. + values.insert_or_assign(X, x); + EXPECT(assert_equal(values.at(X), x)); + + // This should perform an update. + double y = 2; + values.insert_or_assign(X, y); + EXPECT(assert_equal(values.at(X), y)); +} + /* ************************************************************************* */ TEST(Values, basic_functions) { diff --git a/gtsam/nonlinear/utilities.h b/gtsam/nonlinear/utilities.h index fdc1da2c4..d2b38d374 100644 --- a/gtsam/nonlinear/utilities.h +++ b/gtsam/nonlinear/utilities.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -162,6 +163,34 @@ Matrix extractPose3(const Values& values) { return result; } +/// Extract all Vector values with a given symbol character into an mxn matrix, +/// where m is the number of symbols that match the character and n is the +/// dimension of the variables. If not all variables have dimension n, then a +/// runtime error will be thrown. The order of returned values are sorted by +/// the symbol. +/// For example, calling extractVector(values, 'x'), where values contains 200 +/// variables x1, x2, ..., x200 of type Vector each 5-dimensional, will return a +/// 200x5 matrix with row i containing xi. +Matrix extractVectors(const Values& values, char c) { + Values::ConstFiltered vectors = + values.filter(Symbol::ChrTest(c)); + if (vectors.size() == 0) { + return Matrix(); + } + auto dim = vectors.begin()->value.size(); + Matrix result(vectors.size(), dim); + Eigen::Index rowi = 0; + for (const auto& kv : vectors) { + if (kv.value.size() != dim) { + throw std::runtime_error( + "Tried to extract different-sized vectors into a single matrix"); + } + result.row(rowi) = kv.value; + ++rowi; + } + return result; +} + /// Perturb all Point2 values using normally distributed noise void perturbPoint2(Values& values, double sigma, int32_t seed = 42u) { noiseModel::Isotropic::shared_ptr model = diff --git a/gtsam/sfm/MFAS.h b/gtsam/sfm/MFAS.h index decfbed0f..151b318ad 100644 --- a/gtsam/sfm/MFAS.h +++ b/gtsam/sfm/MFAS.h @@ -48,7 +48,7 @@ namespace gtsam { unit translations in a projection direction. @addtogroup SFM */ -class MFAS { +class GTSAM_EXPORT MFAS { public: // used to represent edges between two nodes in the graph. When used in // translation averaging for global SfM diff --git a/gtsam/slam/EssentialMatrixFactor.h b/gtsam/slam/EssentialMatrixFactor.h index 787efac51..5997ad224 100644 --- a/gtsam/slam/EssentialMatrixFactor.h +++ b/gtsam/slam/EssentialMatrixFactor.h @@ -1,7 +1,20 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2014, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + /* - * @file EssentialMatrixFactor.cpp + * @file EssentialMatrixFactor.h * @brief EssentialMatrixFactor class * @author Frank Dellaert + * @author Ayush Baid + * @author Akshay Krishnan * @date December 17, 2013 */ diff --git a/gtsam/slam/ReadMe.md b/gtsam/slam/README.md similarity index 100% rename from gtsam/slam/ReadMe.md rename to gtsam/slam/README.md diff --git a/gtsam/slam/RegularImplicitSchurFactor.h b/gtsam/slam/RegularImplicitSchurFactor.h index b4a341719..340f84018 100644 --- a/gtsam/slam/RegularImplicitSchurFactor.h +++ b/gtsam/slam/RegularImplicitSchurFactor.h @@ -260,10 +260,6 @@ public: "RegularImplicitSchurFactor::clone non implemented"); } - bool empty() const override { - return false; - } - GaussianFactor::shared_ptr negate() const override { return boost::make_shared >(keys_, FBlocks_, PointCovariance_, E_, b_); diff --git a/gtsam/slam/SmartFactorBase.h b/gtsam/slam/SmartFactorBase.h index ddf56b289..209c1196d 100644 --- a/gtsam/slam/SmartFactorBase.h +++ b/gtsam/slam/SmartFactorBase.h @@ -47,7 +47,7 @@ namespace gtsam { * @tparam CAMERA should behave like a PinholeCamera. */ template -class SmartFactorBase: public NonlinearFactor { +class GTSAM_EXPORT SmartFactorBase: public NonlinearFactor { private: typedef NonlinearFactor Base; diff --git a/gtsam/slam/SmartProjectionPoseFactor.h b/gtsam/slam/SmartProjectionPoseFactor.h index c7b1d5424..3cd69c46f 100644 --- a/gtsam/slam/SmartProjectionPoseFactor.h +++ b/gtsam/slam/SmartProjectionPoseFactor.h @@ -41,11 +41,10 @@ namespace gtsam { * If the calibration should be optimized, as well, use SmartProjectionFactor instead! * @addtogroup SLAM */ -template -class SmartProjectionPoseFactor: public SmartProjectionFactor< - PinholePose > { - -private: +template +class GTSAM_EXPORT SmartProjectionPoseFactor + : public SmartProjectionFactor > { + private: typedef PinholePose Camera; typedef SmartProjectionFactor Base; typedef SmartProjectionPoseFactor This; @@ -156,7 +155,6 @@ public: ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_); } - }; // end of class declaration diff --git a/gtsam/slam/SmartProjectionRigFactor.h b/gtsam/slam/SmartProjectionRigFactor.h index 8d6918b3e..149c12928 100644 --- a/gtsam/slam/SmartProjectionRigFactor.h +++ b/gtsam/slam/SmartProjectionRigFactor.h @@ -54,6 +54,8 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { typedef SmartProjectionFactor Base; typedef SmartProjectionRigFactor This; typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; static const int DimPose = 6; ///< Pose3 dimension static const int ZDim = 2; ///< Measurement dimension @@ -118,7 +120,7 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { * @param cameraId ID of the camera in the rig taking the measurement (default * 0) */ - void add(const Point2& measured, const Key& poseKey, + void add(const MEASUREMENT& measured, const Key& poseKey, const size_t& cameraId = 0) { // store measurement and key this->measured_.push_back(measured); @@ -144,7 +146,7 @@ class SmartProjectionRigFactor : public SmartProjectionFactor { * @param cameraIds IDs of the cameras in the rig taking each measurement * (same order as the measurements) */ - void add(const Point2Vector& measurements, const KeyVector& poseKeys, + void add(const MEASUREMENTS& measurements, const KeyVector& poseKeys, const FastVector& cameraIds = FastVector()) { if (poseKeys.size() != measurements.size() || (poseKeys.size() != cameraIds.size() && cameraIds.size() != 0)) { diff --git a/gtsam/slam/TriangulationFactor.h b/gtsam/slam/TriangulationFactor.h index f12053d29..40e9538e2 100644 --- a/gtsam/slam/TriangulationFactor.h +++ b/gtsam/slam/TriangulationFactor.h @@ -33,18 +33,18 @@ class TriangulationFactor: public NoiseModelFactor1 { public: /// CAMERA type - typedef CAMERA Camera; + using Camera = CAMERA; protected: /// shorthand for base class type - typedef NoiseModelFactor1 Base; + using Base = NoiseModelFactor1; /// shorthand for this class - typedef TriangulationFactor This; + using This = TriangulationFactor; /// shorthand for measurement type, e.g. Point2 or StereoPoint2 - typedef typename CAMERA::Measurement Measurement; + using Measurement = typename CAMERA::Measurement; // Keep a copy of measurement and calibration for I/O const CAMERA camera_; ///< CAMERA in which this landmark was seen @@ -55,9 +55,10 @@ protected: const bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) public: + EIGEN_MAKE_ALIGNED_OPERATOR_NEW /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; + using shared_ptr = boost::shared_ptr; /// Default constructor TriangulationFactor() : @@ -129,7 +130,7 @@ public: << std::endl; if (throwCheirality_) throw e; - return Eigen::Matrix::dimension,1>::Constant(2.0 * camera_.calibration().fx()); + return camera_.defaultErrorWhenTriangulatingBehindCamera(); } } diff --git a/gtsam/slam/dataset.cpp b/gtsam/slam/dataset.cpp index c8a8b15c5..0684063de 100644 --- a/gtsam/slam/dataset.cpp +++ b/gtsam/slam/dataset.cpp @@ -392,7 +392,7 @@ parseMeasurements(const std::string &filename, size_t maxIndex) { ParseMeasurement parse{model ? createSampler(model) : nullptr, maxIndex, true, NoiseFormatAUTO, - KernelFunctionTypeNONE}; + KernelFunctionTypeNONE, nullptr}; return parseToVector>(filename, parse); } @@ -1304,14 +1304,14 @@ parse3DFactors(const std::string &filename, return parseFactors(filename, model, maxIndex); } -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 -std::map parse3DPoses(const std::string &filename, - size_t maxIndex) { +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +std::map GTSAM_DEPRECATED +parse3DPoses(const std::string &filename, size_t maxIndex) { return parseVariables(filename, maxIndex); } -std::map parse3DLandmarks(const std::string &filename, - size_t maxIndex) { +std::map GTSAM_DEPRECATED +parse3DLandmarks(const std::string &filename, size_t maxIndex) { return parseVariables(filename, maxIndex); } #endif diff --git a/gtsam/slam/dataset.h b/gtsam/slam/dataset.h index ec5d6dce9..db5d7d76a 100644 --- a/gtsam/slam/dataset.h +++ b/gtsam/slam/dataset.h @@ -172,10 +172,6 @@ GTSAM_EXPORT GraphAndValues load2D(const std::string& filename, false, bool smart = true, NoiseFormat noiseFormat = NoiseFormatAUTO, // KernelFunctionType kernelFunctionType = KernelFunctionTypeNONE); -/// @deprecated load2D now allows for arbitrary models and wrapping a robust kernel -GTSAM_EXPORT GraphAndValues load2D_robust(const std::string& filename, - const noiseModel::Base::shared_ptr& model, size_t maxIndex = 0); - /** save 2d graph */ GTSAM_EXPORT void save2D(const NonlinearFactorGraph& graph, const Values& config, const noiseModel::Diagonal::shared_ptr model, @@ -504,17 +500,21 @@ parse3DFactors(const std::string &filename, size_t maxIndex = 0); using BinaryMeasurementsUnit3 = std::vector>; -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 -inline boost::optional parseVertex(std::istream &is, - const std::string &tag) { + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 +inline boost::optional GTSAM_DEPRECATED +parseVertex(std::istream& is, const std::string& tag) { return parseVertexPose(is, tag); } -GTSAM_EXPORT std::map parse3DPoses(const std::string &filename, - size_t maxIndex = 0); +GTSAM_EXPORT std::map GTSAM_DEPRECATED +parse3DPoses(const std::string& filename, size_t maxIndex = 0); -GTSAM_EXPORT std::map -parse3DLandmarks(const std::string &filename, size_t maxIndex = 0); +GTSAM_EXPORT std::map GTSAM_DEPRECATED +parse3DLandmarks(const std::string& filename, size_t maxIndex = 0); +GTSAM_EXPORT GraphAndValues GTSAM_DEPRECATED +load2D_robust(const std::string& filename, + const noiseModel::Base::shared_ptr& model, size_t maxIndex = 0); #endif } // namespace gtsam diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index 60000dbab..e044dd2c1 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -11,7 +11,7 @@ namespace gtsam { // ###### #include -template virtual class BetweenFactor : gtsam::NoiseModelFactor { @@ -21,9 +21,6 @@ virtual class BetweenFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; - - // enable pickling in python - void pickle() const; }; #include @@ -168,6 +165,10 @@ template virtual class PoseTranslationPrior : gtsam::NoiseModelFactor { PoseTranslationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Translation measured() const; + + // enabling serialization functionality + void serialize() const; }; typedef gtsam::PoseTranslationPrior PoseTranslationPrior2D; @@ -178,6 +179,7 @@ template virtual class PoseRotationPrior : gtsam::NoiseModelFactor { PoseRotationPrior(size_t key, const POSE& pose_z, const gtsam::noiseModel::Base* noiseModel); + POSE::Rotation measured() const; }; typedef gtsam::PoseRotationPrior PoseRotationPrior2D; @@ -188,6 +190,21 @@ virtual class EssentialMatrixFactor : gtsam::NoiseModelFactor { EssentialMatrixFactor(size_t key, const gtsam::Point2& pA, const gtsam::Point2& pB, const gtsam::noiseModel::Base* noiseModel); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixFactor& other, double tol) const; + Vector evaluateError(const gtsam::EssentialMatrix& E) const; +}; + +#include +virtual class EssentialMatrixConstraint : gtsam::NoiseModelFactor { + EssentialMatrixConstraint(size_t key1, size_t key2, const gtsam::EssentialMatrix &measuredE, + const gtsam::noiseModel::Base *model); + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::EssentialMatrixConstraint& other, double tol) const; + Vector evaluateError(const gtsam::Pose3& p1, const gtsam::Pose3& p2) const; + const gtsam::EssentialMatrix& measured() const; }; #include @@ -211,9 +228,6 @@ class SfmTrack { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::SfmTrack& expected, double tol) const; }; @@ -230,9 +244,6 @@ class SfmData { // enabling serialization functionality void serialize() const; - // enable pickling in python - void pickle() const; - // enabling function to compare objects bool equals(const gtsam::SfmData& expected, double tol) const; }; @@ -253,8 +264,6 @@ pair load2D( pair load2D( string filename, gtsam::noiseModel::Diagonal* model); pair load2D(string filename); -pair load2D_robust( - string filename, gtsam::noiseModel::Base* model, int maxIndex); void save2D(const gtsam::NonlinearFactorGraph& graph, const gtsam::Values& config, gtsam::noiseModel::Diagonal* model, string filename); @@ -314,6 +323,8 @@ virtual class KarcherMeanFactor : gtsam::NonlinearFactor { KarcherMeanFactor(const gtsam::KeyVector& keys); }; +gtsam::Rot3 FindKarcherMean(const gtsam::Rot3Vector& rotations); + #include gtsam::noiseModel::Isotropic* ConvertNoiseModel(gtsam::noiseModel::Base* model, size_t d); diff --git a/gtsam/slam/tests/smartFactorScenarios.h b/gtsam/slam/tests/smartFactorScenarios.h index b17ffdac6..66be08c67 100644 --- a/gtsam/slam/tests/smartFactorScenarios.h +++ b/gtsam/slam/tests/smartFactorScenarios.h @@ -17,11 +17,13 @@ */ #pragma once -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + #include "../SmartProjectionRigFactor.h" using namespace std; @@ -44,7 +46,7 @@ Pose3 pose_above = level_pose * Pose3(Rot3(), Point3(0, -1, 0)); // Create a noise unit2 for the pixel error static SharedNoiseModel unit2(noiseModel::Unit::Create(2)); -static double fov = 60; // degrees +static double fov = 60; // degrees static size_t w = 640, h = 480; /* ************************************************************************* */ @@ -63,7 +65,7 @@ Camera cam2(pose_right, K2); Camera cam3(pose_above, K2); typedef GeneralSFMFactor SFMFactor; SmartProjectionParams params; -} +} // namespace vanilla /* ************************************************************************* */ // default Cal3_S2 poses @@ -78,7 +80,7 @@ Camera level_camera_right(pose_right, sharedK); Camera cam1(level_pose, sharedK); Camera cam2(pose_right, sharedK); Camera cam3(pose_above, sharedK); -} +} // namespace vanillaPose /* ************************************************************************* */ // default Cal3_S2 poses @@ -93,7 +95,7 @@ Camera level_camera_right(pose_right, sharedK2); Camera cam1(level_pose, sharedK2); Camera cam2(pose_right, sharedK2); Camera cam3(pose_above, sharedK2); -} +} // namespace vanillaPose2 /* *************************************************************************/ // Cal3Bundler cameras @@ -111,7 +113,8 @@ Camera cam1(level_pose, K); Camera cam2(pose_right, K); Camera cam3(pose_above, K); typedef GeneralSFMFactor SFMFactor; -} +} // namespace bundler + /* *************************************************************************/ // Cal3Bundler poses namespace bundlerPose { @@ -119,35 +122,50 @@ typedef PinholePose Camera; typedef CameraSet Cameras; typedef SmartProjectionPoseFactor SmartFactor; typedef SmartProjectionRigFactor SmartRigFactor; -static boost::shared_ptr sharedBundlerK( - new Cal3Bundler(500, 1e-3, 1e-3, 1000, 2000)); +static boost::shared_ptr sharedBundlerK(new Cal3Bundler(500, 1e-3, + 1e-3, 1000, + 2000)); Camera level_camera(level_pose, sharedBundlerK); Camera level_camera_right(pose_right, sharedBundlerK); Camera cam1(level_pose, sharedBundlerK); Camera cam2(pose_right, sharedBundlerK); Camera cam3(pose_above, sharedBundlerK); -} +} // namespace bundlerPose + +/* ************************************************************************* */ +// sphericalCamera +namespace sphericalCamera { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionRigFactor SmartFactorP; +static EmptyCal::shared_ptr emptyK(new EmptyCal()); +Camera level_camera(level_pose); +Camera level_camera_right(pose_right); +Camera cam1(level_pose); +Camera cam2(pose_right); +Camera cam3(pose_above); +} // namespace sphericalCamera /* *************************************************************************/ -template +template CAMERA perturbCameraPose(const CAMERA& camera) { - Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), - Point3(0.5, 0.1, 0.3)); + Pose3 noise_pose = + Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 10), Point3(0.5, 0.1, 0.3)); Pose3 cameraPose = camera.pose(); Pose3 perturbedCameraPose = cameraPose.compose(noise_pose); return CAMERA(perturbedCameraPose, camera.calibration()); } -template -void projectToMultipleCameras(const CAMERA& cam1, const CAMERA& cam2, - const CAMERA& cam3, Point3 landmark, typename CAMERA::MeasurementVector& measurements_cam) { - Point2 cam1_uv1 = cam1.project(landmark); - Point2 cam2_uv1 = cam2.project(landmark); - Point2 cam3_uv1 = cam3.project(landmark); +template +void projectToMultipleCameras( + const CAMERA& cam1, const CAMERA& cam2, const CAMERA& cam3, Point3 landmark, + typename CAMERA::MeasurementVector& measurements_cam) { + typename CAMERA::Measurement cam1_uv1 = cam1.project(landmark); + typename CAMERA::Measurement cam2_uv1 = cam2.project(landmark); + typename CAMERA::Measurement cam3_uv1 = cam3.project(landmark); measurements_cam.push_back(cam1_uv1); measurements_cam.push_back(cam2_uv1); measurements_cam.push_back(cam3_uv1); } /* ************************************************************************* */ - diff --git a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp index 080239b35..2faac24d1 100644 --- a/gtsam/slam/tests/testEssentialMatrixConstraint.cpp +++ b/gtsam/slam/tests/testEssentialMatrixConstraint.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file testEssentialMatrixConstraint.cpp + * @file TestEssentialMatrixConstraint.cpp * @brief Unit tests for EssentialMatrixConstraint Class * @author Frank Dellaert * @author Pablo Alcantarilla diff --git a/gtsam/slam/tests/testSmartProjectionRigFactor.cpp b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp index b8150a1aa..b4876b27e 100644 --- a/gtsam/slam/tests/testSmartProjectionRigFactor.cpp +++ b/gtsam/slam/tests/testSmartProjectionRigFactor.cpp @@ -55,8 +55,6 @@ Key cameraId3 = 2; static Point2 measurement1(323.0, 240.0); LevenbergMarquardtParams lmParams; -// Make more verbose like so (in tests): -// params.verbosityLM = LevenbergMarquardtParams::SUMMARY; /* ************************************************************************* */ // default Cal3_S2 poses with rolling shutter effect @@ -1187,10 +1185,9 @@ TEST(SmartProjectionRigFactor, optimization_3poses_measurementsFromSamePose) { // this factor is slightly slower (but comparable) to original // SmartProjectionPoseFactor //-Total: 0 CPU (0 times, 0 wall, 0.17 children, min: 0 max: 0) -//| -SmartRigFactor LINEARIZE: 0.06 CPU -//(10000 times, 0.061226 wall, 0.06 children, min: 0 max: 0) -//| -SmartPoseFactor LINEARIZE: 0.06 CPU -//(10000 times, 0.073037 wall, 0.06 children, min: 0 max: 0) +//| -SmartRigFactor LINEARIZE: 0.05 CPU (10000 times, 0.057952 wall, 0.05 +// children, min: 0 max: 0) | -SmartPoseFactor LINEARIZE: 0.05 CPU (10000 +// times, 0.069647 wall, 0.05 children, min: 0 max: 0) /* *************************************************************************/ TEST(SmartProjectionRigFactor, timing) { using namespace vanillaRig; @@ -1249,6 +1246,355 @@ TEST(SmartProjectionRigFactor, timing) { } #endif +/* *************************************************************************/ +TEST(SmartProjectionFactorP, optimization_3poses_sphericalCamera) { + using namespace sphericalCamera; + Camera::MeasurementVector measurements_lmk1, measurements_lmk2, + measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + KeyVector keys; + keys.push_back(x1); + keys.push_back(x2); + keys.push_back(x3); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartFactorP::shared_ptr smartFactor1( + new SmartFactorP(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, keys); + + SmartFactorP::shared_ptr smartFactor2( + new SmartFactorP(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, keys); + + SmartFactorP::shared_ptr smartFactor3( + new SmartFactorP(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, keys); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 10, 0., -M_PI / 100), + Point3(0.2, 0.2, 0.2)); // note: larger noise! + + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + + DOUBLES_EQUAL(0.94148963675515274, graph.error(values), 1e-9); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + + EXPECT(assert_equal(pose_above, result.at(x3), 1e-5)); +} + +#ifndef DISABLE_TIMING +#include +// using spherical camera is slightly slower (but comparable) to +// PinholePose +//| -SmartFactorP spherical LINEARIZE: 0.01 CPU (1000 times, 0.008178 wall, +// 0.01 children, min: 0 max: 0) | -SmartFactorP pinhole LINEARIZE: 0.01 CPU +//(1000 times, 0.005717 wall, 0.01 children, min: 0 max: 0) +/* *************************************************************************/ +TEST(SmartProjectionFactorP, timing_sphericalCamera) { + // create common data + Rot3 R = Rot3::identity(); + Pose3 pose1 = Pose3(R, Point3(0, 0, 0)); + Pose3 pose2 = Pose3(R, Point3(1, 0, 0)); + Pose3 body_P_sensorId = Pose3::identity(); + Point3 landmark1(0, 0, 10); + + // create spherical data + EmptyCal::shared_ptr emptyK; + SphericalCamera cam1_sphere(pose1, emptyK), cam2_sphere(pose2, emptyK); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1_sphere; + measurements_lmk1_sphere.push_back(cam1_sphere.project(landmark1)); + measurements_lmk1_sphere.push_back(cam2_sphere.project(landmark1)); + + // create Cal3_S2 data + static Cal3_S2::shared_ptr sharedKSimple(new Cal3_S2(100, 100, 0, 0, 0)); + PinholePose cam1(pose1, sharedKSimple), cam2(pose2, sharedKSimple); + // Project 2 landmarks into 2 cameras + std::vector measurements_lmk1; + measurements_lmk1.push_back(cam1.project(landmark1)); + measurements_lmk1.push_back(cam2.project(landmark1)); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + + size_t nrTests = 1000; + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(body_P_sensorId, emptyK)); + + SmartProjectionRigFactor::shared_ptr smartFactorP( + new SmartProjectionRigFactor(model, cameraRig, + params)); + smartFactorP->add(measurements_lmk1_sphere[0], x1); + smartFactorP->add(measurements_lmk1_sphere[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_spherical_LINEARIZE); + smartFactorP->linearize(values); + gttoc_(SmartFactorP_spherical_LINEARIZE); + } + + for (size_t i = 0; i < nrTests; i++) { + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(body_P_sensorId, sharedKSimple)); + + SmartProjectionRigFactor>::shared_ptr smartFactorP2( + new SmartProjectionRigFactor>(model, cameraRig, + params)); + smartFactorP2->add(measurements_lmk1[0], x1); + smartFactorP2->add(measurements_lmk1[1], x1); + + Values values; + values.insert(x1, pose1); + values.insert(x2, pose2); + gttic_(SmartFactorP_pinhole_LINEARIZE); + smartFactorP2->linearize(values); + gttoc_(SmartFactorP_pinhole_LINEARIZE); + } + tictoc_print_(); +} +#endif + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_rankTol) { + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + // triangulate from a stereo with 10cm baseline, assuming standard calibration + { // default rankTol = 1 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + + Camera cam1(poseA, sharedK); + Camera cam2(poseB, sharedK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), sharedK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // default rankTol = 1 or 0.1 gives a degenerate point, which is + // undesirable for a point 5m away and 10cm baseline + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // valid triangulation + } + // triangulate from a stereo with 10cm baseline, assuming canonical + // calibration + { // smaller rankTol = 0.01 gives a valid point (compare with calibrated and + // spherical cameras below) + using namespace vanillaPose; // pinhole with Cal3_S2 calibration + static Cal3_S2::shared_ptr canonicalK( + new Cal3_S2(1.0, 1.0, 0.0, 0.0, 0.0)); // canonical camera + + Camera cam1(poseA, canonicalK); + Camera cam2(poseB, canonicalK); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + boost::shared_ptr>> cameraRig( + new CameraSet>()); // single camera in the rig + cameraRig->push_back(PinholePose(Pose3::identity(), canonicalK)); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + +/* *************************************************************************/ +TEST(SmartProjectionFactorP, 2poses_sphericalCamera_rankTol) { + typedef SphericalCamera Camera; + typedef SmartProjectionRigFactor SmartRigFactor; + EmptyCal::shared_ptr emptyK(new EmptyCal()); + Pose3 poseA = Pose3( + Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, 0.0, 0.0)); // with z pointing along x axis of global frame + Pose3 poseB = Pose3(Rot3::Ypr(-M_PI / 2, 0., -M_PI / 2), + Point3(0.0, -0.1, 0.0)); // 10cm to the right of poseA + Point3 landmarkL = Point3(5.0, 0.0, 0.0); // 5m in front of poseA + + Camera cam1(poseA); + Camera cam2(poseB); + + boost::shared_ptr> cameraRig( + new CameraSet()); // single camera in the rig + cameraRig->push_back(SphericalCamera(Pose3::identity(), emptyK)); + + // TRIANGULATION TEST WITH DEFAULT RANK TOL + { // rankTol = 1 or 0.1 gives a degenerate point, which is undesirable for a + // point 5m away and 10cm baseline + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.1); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.degenerate()); // not enough parallax + } + // SAME TEST WITH SMALLER RANK TOL + { // rankTol = 0.01 gives a valid point + // By playing with this test, we can show we can triangulate also with a + // baseline of 5cm (even for points far away, >100m), but the test fails + // when the baseline becomes 1cm. This suggests using rankTol = 0.01 and + // setting a reasonable max landmark distance to obtain best results. + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with rig factors + params.setRankTolerance(0.01); + + SmartRigFactor::shared_ptr smartFactor1( + new SmartRigFactor(model, cameraRig, params)); + smartFactor1->add(cam1.project(landmarkL), x1); + smartFactor1->add(cam2.project(landmarkL), x2); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + + Values groundTruth; + groundTruth.insert(x1, poseA); + groundTruth.insert(x2, poseB); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // get point + TriangulationResult point = smartFactor1->point(); + EXPECT(point.valid()); // valid triangulation + EXPECT(assert_equal(landmarkL, *point, 1e-7)); + } +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/symbolic/SymbolicBayesNet.cpp b/gtsam/symbolic/SymbolicBayesNet.cpp index 5bc20ad12..f7113b23a 100644 --- a/gtsam/symbolic/SymbolicBayesNet.cpp +++ b/gtsam/symbolic/SymbolicBayesNet.cpp @@ -16,41 +16,16 @@ * @author Richard Roberts */ -#include -#include #include - -#include -#include +#include namespace gtsam { - // Instantiate base class - template class FactorGraph; - - /* ************************************************************************* */ - bool SymbolicBayesNet::equals(const This& bn, double tol) const - { - return Base::equals(bn, tol); - } - - /* ************************************************************************* */ - void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const - { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional: boost::adaptors::reverse(*this)) { - SymbolicConditional::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - SymbolicConditional::Parents parents = conditional->parents(); - for(Key p: parents) - of << p << "->" << me << std::endl; - } - - of << "}"; - of.close(); - } - +// Instantiate base class +template class FactorGraph; +/* ************************************************************************* */ +bool SymbolicBayesNet::equals(const This& bn, double tol) const { + return Base::equals(bn, tol); } +} // namespace gtsam diff --git a/gtsam/symbolic/SymbolicBayesNet.h b/gtsam/symbolic/SymbolicBayesNet.h index 464af060b..2f66b80e2 100644 --- a/gtsam/symbolic/SymbolicBayesNet.h +++ b/gtsam/symbolic/SymbolicBayesNet.h @@ -19,19 +19,19 @@ #pragma once #include +#include #include #include namespace gtsam { - /** Symbolic Bayes Net - * \nosubgrouping + /** + * A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals. + * @addtogroup symbolic */ - class SymbolicBayesNet : public FactorGraph { - - public: - - typedef FactorGraph Base; + class SymbolicBayesNet : public BayesNet { + public: + typedef BayesNet Base; typedef SymbolicBayesNet This; typedef SymbolicConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -44,16 +44,21 @@ namespace gtsam { SymbolicBayesNet() {} /** Construct from iterator over conditionals */ - template - SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + template + SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template - explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} + template + explicit SymbolicBayesNet(const CONTAINER& conditionals) { + push_back(conditionals); + } - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template - SymbolicBayesNet(const FactorGraph& graph) : Base(graph) {} + /** Implicit copy/downcast constructor to override explicit template + * container constructor */ + template + explicit SymbolicBayesNet(const FactorGraph& graph) + : Base(graph) {} /// Destructor virtual ~SymbolicBayesNet() {} @@ -75,13 +80,6 @@ namespace gtsam { /// @} - /// @name Standard Interface - /// @{ - - GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /// @} - private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/symbolic/SymbolicFactor.h b/gtsam/symbolic/SymbolicFactor.h index 2a488a4da..767998d22 100644 --- a/gtsam/symbolic/SymbolicFactor.h +++ b/gtsam/symbolic/SymbolicFactor.h @@ -144,9 +144,6 @@ namespace gtsam { /// @name Standard Interface /// @{ - /** Whether the factor is empty (involves zero variables). */ - bool empty() const { return keys_.empty(); } - /** Eliminate the variables in \c keys, in the order specified in \c keys, returning a * conditional and marginal. */ std::pair, boost::shared_ptr > diff --git a/gtsam/symbolic/symbolic.i b/gtsam/symbolic/symbolic.i index 4e7cca68a..1f1d4b48f 100644 --- a/gtsam/symbolic/symbolic.i +++ b/gtsam/symbolic/symbolic.i @@ -3,11 +3,6 @@ //************************************************************************* namespace gtsam { -#include -#include - -// ################### - #include virtual class SymbolicFactor { // Standard Constructors and Named Constructors @@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph { const gtsam::KeyVector& key_vector, const gtsam::Ordering& marginalizedVariableOrdering); gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor { bool equals(const gtsam::SymbolicConditional& other, double tol) const; // Standard interface + gtsam::Key firstFrontalKey() const; size_t nrFrontals() const; size_t nrParents() const; }; @@ -125,6 +129,14 @@ class SymbolicBayesNet { gtsam::SymbolicConditional* back() const; void push_back(gtsam::SymbolicConditional* conditional); void push_back(const gtsam::SymbolicBayesNet& bayesNet); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; }; #include @@ -173,29 +185,4 @@ class SymbolicBayesTreeClique { void deleteCachedShortcuts(); }; -#include -class VariableIndex { - // Standard Constructors and Named Constructors - VariableIndex(); - // TODO: Templetize constructor when wrap supports it - // template - // VariableIndex(const T& factorGraph, size_t nVariables); - // VariableIndex(const T& factorGraph); - VariableIndex(const gtsam::SymbolicFactorGraph& sfg); - VariableIndex(const gtsam::GaussianFactorGraph& gfg); - VariableIndex(const gtsam::NonlinearFactorGraph& fg); - VariableIndex(const gtsam::VariableIndex& other); - - // Testable - bool equals(const gtsam::VariableIndex& other, double tol) const; - void print(string s = "VariableIndex: ", - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - - // Standard interface - size_t size() const; - size_t nFactors() const; - size_t nEntries() const; -}; - } // namespace gtsam diff --git a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp index a92d66f68..2e13be10e 100644 --- a/gtsam/symbolic/tests/testSymbolicBayesNet.cpp +++ b/gtsam/symbolic/tests/testSymbolicBayesNet.cpp @@ -15,13 +15,16 @@ * @author Frank Dellaert */ -#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include using namespace std; using namespace gtsam; @@ -30,7 +33,6 @@ static const Key _L_ = 0; static const Key _A_ = 1; static const Key _B_ = 2; static const Key _C_ = 3; -static const Key _D_ = 4; static SymbolicConditional::shared_ptr B(new SymbolicConditional(_B_)), @@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine ) } /* ************************************************************************* */ -TEST(SymbolicBayesNet, saveGraph) { +TEST(SymbolicBayesNet, Dot) { + using symbol_shorthand::A; + using symbol_shorthand::X; SymbolicBayesNet bn; - bn += SymbolicConditional(_A_, _B_); - KeyVector keys {_B_, _C_, _D_}; - bn += SymbolicConditional::FromKeys(keys,2); - bn += SymbolicConditional(_D_); + bn += SymbolicConditional(X(3), X(2), A(2)); + bn += SymbolicConditional(X(2), X(1), A(1)); + bn += SymbolicConditional(X(1)); - bn.saveGraph("SymbolicBayesNet.dot"); + DotWriter writer; + writer.positionHints.emplace('a', 2); + writer.positionHints.emplace('x', 1); + writer.boxes.emplace(A(1)); + writer.boxes.emplace(A(2)); + + auto position = writer.variablePos(A(1)); + CHECK(position); + EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5)); + + string actual = bn.dot(DefaultKeyFormatter, writer); + bn.saveGraph("bn.dot", DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n" + " vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n" + " varx1[label=\"x1\", pos=\"1,1!\"];\n" + " varx2[label=\"x2\", pos=\"2,1!\"];\n" + " varx3[label=\"x3\", pos=\"3,1!\"];\n" + "\n" + " varx1->varx2\n" + " vara1->varx2\n" + " varx2->varx3\n" + " vara2->varx3\n" + "}"); } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 9e124954f..bff524bc2 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -5,107 +5,109 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include + #include namespace gtsam { - /* ************************************************************************* */ - AllDiff::AllDiff(const DiscreteKeys& dkeys) : - Constraint(dkeys.indices()) { - for(const DiscreteKey& dkey: dkeys) - cardinalities_.insert(dkey); - } +/* ************************************************************************* */ +AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { + for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); +} - /* ************************************************************************* */ - void AllDiff::print(const std::string& s, - const KeyFormatter& formatter) const { - std::cout << s << "AllDiff on "; - for (Key dkey: keys_) - std::cout << formatter(dkey) << " "; - std::cout << std::endl; - } +/* ************************************************************************* */ +void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { + std::cout << s << "AllDiff on "; + for (Key dkey : keys_) std::cout << formatter(dkey) << " "; + std::cout << std::endl; +} - /* ************************************************************************* */ - double AllDiff::operator()(const Values& values) const { - std::set < size_t > taken; // record values taken by keys - for(Key dkey: keys_) { - size_t value = values.at(dkey); // get the value for that key - if (taken.count(value)) return 0.0;// check if value alreday taken - taken.insert(value);// if not, record it as taken and keep checking +/* ************************************************************************* */ +double AllDiff::operator()(const DiscreteValues& values) const { + std::set taken; // record values taken by keys + for (Key dkey : keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0; // check if value alreday taken + taken.insert(value); // if not, record it as taken and keep checking + } + return 1.0; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff + // constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); + converted = converted * binary12.toDecisionTreeFactor(); } - return 1.0; + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { + Domain& Dj = domains->at(j); + + // Though strictly not part of allDiff, we check for + // a value in domains->at(j) that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + boost::optional maybeChanged = Dj.checkAllDiff(keys_, *domains); + if (maybeChanged) { + Dj = *maybeChanged; + return true; } - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { - // We will do this by converting the allDif into many BinaryAllDiff constraints - DecisionTreeFactor converted; - size_t nrKeys = keys_.size(); - for (size_t i1 = 0; i1 < nrKeys; i1++) - for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { - BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); - converted = converted * binary12.toDecisionTreeFactor(); - } - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool AllDiff::ensureArcConsistency(size_t j, std::vector& domains) const { - // Though strictly not part of allDiff, we check for - // a value in domains[j] that does not occur in any other connected domain. - // If found, we make this a singleton... - // TODO: make a new constraint where this really is true - Domain& Dj = domains[j]; - if (Dj.checkAllDiff(keys_, domains)) return true; - - // Check all other domains for singletons and erase corresponding values - // This is the same as arc-consistency on the equivalent binary constraints - bool changed = false; - for(Key k: keys_) - if (k != j) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) { // check if singleton - size_t value = Dk.firstValue(); - if (Dj.contains(value)) { - Dj.erase(value); // erase value if true - changed = true; - } + // Check all other domains for singletons and erase corresponding values. + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + for (Key k : keys_) + if (k != j) { + const Domain& Dk = domains->at(k); + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; } } - return changed; - } + } + return changed; +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const { - DiscreteKeys newKeys; - // loop over keys and add them only if they do not appear in values - for(Key k: keys_) - if (values.find(k) == values.end()) { - newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); - } - return boost::make_shared(newKeys); - } +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + for (Key k : keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); + } + return boost::make_shared(newKeys); +} - /* ************************************************************************* */ - Constraint::shared_ptr AllDiff::partiallyApply( - const std::vector& domains) const { - DiscreteFactor::Values known; - for(Key k: keys_) { - const Domain& Dk = domains[k]; - if (Dk.isSingleton()) - known[k] = Dk.firstValue(); - } - return partiallyApply(known); +/* ************************************************************************* */ +Constraint::shared_ptr AllDiff::partiallyApply( + const Domains& domains) const { + DiscreteValues known; + for (Key k : keys_) { + const Domain& Dk = domains.at(k); + if (Dk.isSingleton()) known[k] = Dk.firstValue(); } + return partiallyApply(known); +} - /* ************************************************************************* */ -} // namespace gtsam +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index 80e700b29..9496fc1a6 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -7,71 +7,66 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * General AllDiff constraint - * Returns 1 if values for all keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Key and an Key. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. +/** + * General AllDiff constraint. + * Returns 1 if values for all keys are different, 0 otherwise. + */ +class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { + std::map cardinalities_; + + DiscreteKey discreteKey(size_t i) const { + Key j = keys_[i]; + return DiscreteKey(j, cardinalities_.at(j)); + } + + public: + /// Construct from keys. + AllDiff(const DiscreteKeys& dkeys); + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const AllDiff& f(static_cast(other)); + return cardinalities_.size() == f.cardinalities_.size() && + std::equal(cardinalities_.begin(), cardinalities_.end(), + f.cardinalities_.begin()); + } + } + + /// Calculate value = expensive ! + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree, can be *very* expensive ! + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - std::map cardinalities_; + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override; - DiscreteKey discreteKey(size_t i) const { - Key j = keys_[i]; - return DiscreteKey(j,cardinalities_.at(j)); - } + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains&) const override; +}; - public: - - /// Constructor - AllDiff(const DiscreteKeys& dkeys); - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const AllDiff& f(static_cast(other)); - return cardinalities_.size() == f.cardinalities_.size() - && std::equal(cardinalities_.begin(), cardinalities_.end(), - f.cardinalities_.begin()); - } - } - - /// Calculate value = expensive ! - double operator()(const Values& values) const override; - - /// Convert into a decisiontree, can be *very* expensive ! - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * Arc-consistency involves creating binaryAllDiff constraints - * In which case the combinatorial hyper-arc explosion disappears. - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply(const std::vector&) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index bbb60e2f1..b207acb9d 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -7,94 +7,90 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { - /** - * Binary AllDiff constraint - * Returns 1 if values for two keys are different, 0 otherwise - * DiscreteFactors are all awkward in that they have to store two types of keys: - * for each variable we have a Index and an Index. In this factor, we - * keep the Indices locally, and the Indices are stored in IndexFactor. - */ - class BinaryAllDiff: public Constraint { +/** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise. + */ +class BinaryAllDiff : public Constraint { + size_t cardinality0_, cardinality1_; /// cardinality - size_t cardinality0_, cardinality1_; /// cardinality + public: + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) + : Constraint(key1.first, key2.first), + cardinality0_(key1.second), + cardinality1_(key2.second) {} - public: + // print + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " + << formatter(keys_[1]) << std::endl; + } - /// Constructor - BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : - Constraint(key1.first, key2.first), - cardinality0_(key1.second), cardinality1_(key2.second) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override { - std::cout << s << "BinaryAllDiff on " << formatter(keys_[0]) << " and " - << formatter(keys_[1]) << std::endl; - } - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const BinaryAllDiff& f(static_cast(other)); - return (cardinality0_==f.cardinality0_) && (cardinality1_==f.cardinality1_); - } - } - - /// Calculate value - double operator()(const Values& values) const override { - return (double) (values.at(keys_[0]) != values.at(keys_[1])); - } - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - DiscreteKeys keys; - keys.push_back(DiscreteKey(keys_[0],cardinality0_)); - keys.push_back(DiscreteKey(keys_[1],cardinality1_)); - std::vector table; - for (size_t i1 = 0; i1 < cardinality0_; i1++) - for (size_t i2 = 0; i2 < cardinality1_; i2++) - table.push_back(i1 != i2); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - /// - bool ensureArcConsistency(size_t j, std::vector& domains) const override { -// throw std::runtime_error( -// "BinaryAllDiff::ensureArcConsistency not implemented"); + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) return false; + else { + const BinaryAllDiff& f(static_cast(other)); + return (cardinality0_ == f.cardinality0_) && + (cardinality1_ == f.cardinality1_); } + } - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } + /// Calculate value + double operator()(const DiscreteValues& values) const override { + return (double)(values.at(keys_[0]) != values.at(keys_[1])); + } - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector&) const override { - throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); - } - }; + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0], cardinality0_)); + keys.push_back(DiscreteKey(keys_[1], cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } -} // namespace gtsam + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; + } + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. + */ + bool ensureArcConsistency(Key j, Domains* domains) const override { + throw std::runtime_error( + "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains&) const override { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } +}; + +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index 525abd098..08143c469 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -5,99 +5,84 @@ * @author Frank Dellaert */ -#include -#include #include +#include +#include +#include using namespace std; namespace gtsam { - /// Find the best total assignment - can be expensive - CSP::sharedValues CSP::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); - sharedValues mpe = chordal->optimize(); - return mpe; - } +bool CSP::runArcConsistency(const VariableIndex& index, + Domains* domains) const { + bool changed = false; - /// Find the best total assignment - can be expensive - CSP::sharedValues CSP::optimalAssignment(const Ordering& ordering) const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); - sharedValues mpe = chordal->optimize(); - return mpe; - } + // iterate over all variables in the index + for (auto entry : index) { + // Get the variable's key and associated factors: + const Key key = entry.first; + const FactorIndices& factors = entry.second; - void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const { - // Create VariableIndex - VariableIndex index(*this); - // index.print(); + // If this domain is already a singleton, we do nothing. + if (domains->at(key).isSingleton()) continue; - size_t n = index.size(); - - // Initialize domains - std::vector < Domain > domains; - for (size_t j = 0; j < n; j++) - domains.push_back(Domain(DiscreteKey(j,cardinality))); - - // Create array of flags indicating a domain changed or not - std::vector changed(n); - - // iterate nrIterations over entire grid - for (size_t it = 0; it < nrIterations; it++) { - bool anyChange = false; - // iterate over all cells - for (size_t v = 0; v < n; v++) { - // keep track of which domains changed - changed[v] = false; - // loop over all factors/constraints for variable v - const FactorIndices& factors = index[v]; - for(size_t f: factors) { - // if not already a singleton - if (!domains[v].isSingleton()) { - // get the constraint and call its ensureArcConsistency method - Constraint::shared_ptr constraint = boost::dynamic_pointer_cast((*this)[f]); - if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - changed[v] = constraint->ensureArcConsistency(v,domains) || changed[v]; - } - } // f - if (changed[v]) anyChange = true; - } // v - if (!anyChange) break; - // TODO: Sudoku specific hack - if (print) { - if (cardinality == 9 && n == 81) { - for (size_t i = 0, v = 0; i < (size_t)std::sqrt((double)n); i++) { - for (size_t j = 0; j < (size_t)std::sqrt((double)n); j++, v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // i - cout << endl; - } // j - } else { - for (size_t v = 0; v < n; v++) { - if (changed[v]) cout << "*"; - domains[v].print(); - cout << "\t"; - } // v - } - cout << endl; - } // print - } // it - -#ifndef INPROGRESS - // Now create new problem with all singleton variables removed - // We do this by adding simplifying all factors using parial application - // TODO: create a new ordering as we go, to ensure a connected graph - // KeyOrdering ordering; - // vector dkeys; - for(const DiscreteFactor::shared_ptr& f: factors_) { - Constraint::shared_ptr constraint = boost::dynamic_pointer_cast(f); - if (!constraint) throw runtime_error("CSP:runArcConsistency: non-constraint factor"); - Constraint::shared_ptr reduced = constraint->partiallyApply(domains); - if (print) reduced->print(); + // Otherwise, loop over all factors/constraints for variable with given key. + for (size_t f : factors) { + // If this factor is a constraint, call its ensureArcConsistency method: + auto constraint = boost::dynamic_pointer_cast((*this)[f]); + if (constraint) { + changed = constraint->ensureArcConsistency(key, domains) || changed; + } } -#endif } -} // gtsam + return changed; +} +// TODO(dellaert): This is AC1, which is inefficient as any change will cause +// the algorithm to revisit *all* variables again. Implement AC3. +Domains CSP::runArcConsistency(size_t cardinality, size_t maxIterations) const { + // Create VariableIndex + VariableIndex index(*this); + + // Initialize domains + Domains domains; + for (auto entry : index) { + const Key key = entry.first; + domains.emplace(key, DiscreteKey(key, cardinality)); + } + + // Iterate until convergence or not a single domain changed. + for (size_t it = 0; it < maxIterations; it++) { + bool changed = runArcConsistency(index, &domains); + if (!changed) break; + } + return domains; +} + +CSP CSP::partiallyApply(const Domains& domains) const { + // Create new problem with all singleton variables removed + // We do this by adding simplifying all factors using partial application. + // TODO: create a new ordering as we go, to ensure a connected graph + // KeyOrdering ordering; + // vector dkeys; + CSP new_csp; + + // Add tightened domains as new factors: + for (auto key_domain : domains) { + new_csp.emplace_shared(key_domain.second); + } + + // Reduce all existing factors: + for (const DiscreteFactor::shared_ptr& f : factors_) { + auto constraint = boost::dynamic_pointer_cast(f); + if (!constraint) + throw runtime_error("CSP:runArcConsistency: non-constraint factor"); + Constraint::shared_ptr reduced = constraint->partiallyApply(domains); + if (reduced->size() > 1) { + new_csp.push_back(reduced); + } + } + return new_csp; +} +} // namespace gtsam diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index 9e843f667..40853bed6 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -7,84 +7,70 @@ #pragma once +#include #include #include -#include namespace gtsam { - /** - * Constraint Satisfaction Problem class - * A specialization of a DiscreteFactorGraph. - * It knows about CSP-specific constraints and algorithms +/** + * Constraint Satisfaction Problem class + * A specialization of a DiscreteFactorGraph. + * It knows about CSP-specific constraints and algorithms + */ +class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { + public: + using Values = DiscreteValues; ///< backwards compatibility + + /// Add a unary constraint, allowing only a single value + void addSingleValue(const DiscreteKey& dkey, size_t value) { + emplace_shared(dkey, value); + } + + /// Add a binary AllDiff constraint + void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { + emplace_shared(key1, key2); + } + + /// Add a general AllDiff constraint + void addAllDiff(const DiscreteKeys& dkeys) { emplace_shared(dkeys); } + + // /** return product of all factors as a single factor */ + // DecisionTreeFactor product() const { + // DecisionTreeFactor result; + // for(const sharedFactor& factor: *this) + // if (factor) result = (*factor) * result; + // return result; + // } + + // /* + // * Perform loopy belief propagation + // * True belief propagation would check for each value in domain + // * whether any satisfying separator assignment can be found. + // * This corresponds to hyper-arc consistency in CSP speak. + // * This can be done by creating a mini-factor graph and search. + // * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels + // deep. + // * It will be very expensive to exclude values that way. + // */ + // void applyBeliefPropagation(size_t maxIterations = 10) const; + + /* + * Apply arc-consistency ~ Approximate loopy belief propagation + * We need to give the domains to a constraint, and it returns + * a domain whose values don't conflict in the arc-consistency way. + * TODO: should get cardinality from DiscreteKeys */ - class GTSAM_UNSTABLE_EXPORT CSP: public DiscreteFactorGraph { - public: + Domains runArcConsistency(size_t cardinality, + size_t maxIterations = 10) const; - /** A map from keys to values */ - typedef KeyVector Indices; - typedef Assignment Values; - typedef boost::shared_ptr sharedValues; + /// Run arc consistency for all variables, return true if any domain changed. + bool runArcConsistency(const VariableIndex& index, Domains* domains) const; - public: - -// /// Constructor -// CSP() { -// } - - /// Add a unary constraint, allowing only a single value - void addSingleValue(const DiscreteKey& dkey, size_t value) { - boost::shared_ptr factor(new SingleValue(dkey, value)); - push_back(factor); - } - - /// Add a binary AllDiff constraint - void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { - boost::shared_ptr factor( - new BinaryAllDiff(key1, key2)); - push_back(factor); - } - - /// Add a general AllDiff constraint - void addAllDiff(const DiscreteKeys& dkeys) { - boost::shared_ptr factor(new AllDiff(dkeys)); - push_back(factor); - } - -// /** return product of all factors as a single factor */ -// DecisionTreeFactor product() const { -// DecisionTreeFactor result; -// for(const sharedFactor& factor: *this) -// if (factor) result = (*factor) * result; -// return result; -// } - - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment() const; - - /// Find the best total assignment - can be expensive - sharedValues optimalAssignment(const Ordering& ordering) const; - -// /* -// * Perform loopy belief propagation -// * True belief propagation would check for each value in domain -// * whether any satisfying separator assignment can be found. -// * This corresponds to hyper-arc consistency in CSP speak. -// * This can be done by creating a mini-factor graph and search. -// * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. -// * It will be very expensive to exclude values that way. -// */ -// void applyBeliefPropagation(size_t nrIterations = 10) const; - - /* - * Apply arc-consistency ~ Approximate loopy belief propagation - * We need to give the domains to a constraint, and it returns - * a domain whose values don't conflict in the arc-consistency way. - * TODO: should get cardinality from Indices - */ - void runArcConsistency(size_t cardinality, size_t nrIterations = 10, - bool print = false) const; - }; // CSP - -} // gtsam + /* + * Create a new CSP, applying the given Domain constraints. + */ + CSP partiallyApply(const Domains& domains) const; +}; // CSP +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index c3a26de68..4ee7b85eb 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -17,77 +17,88 @@ #pragma once -#include #include +#include +#include + #include +#include +#include namespace gtsam { - class Domain; +class Domain; +using Domains = std::map; - /** - * Base class for discrete probabilistic factors - * The most general one is the derived DecisionTreeFactor +/** + * Base class for constraint factors + * Derived classes include SingleValue, BinaryAllDiff, and AllDiff. + */ +class GTSAM_EXPORT Constraint : public DiscreteFactor { + public: + typedef boost::shared_ptr shared_ptr; + + protected: + /// Construct unary constraint factor. + Constraint(Key j) : DiscreteFactor(boost::assign::cref_list_of<1>(j)) {} + + /// Construct binary constraint factor. + Constraint(Key j1, Key j2) + : DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) {} + + /// Construct n-way constraint factor. + Constraint(const KeyVector& js) : DiscreteFactor(js) {} + + /// construct from container + template + Constraint(KeyIterator beginKey, KeyIterator endKey) + : DiscreteFactor(beginKey, endKey) {} + + public: + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + Constraint(); + + /// Virtual destructor + ~Constraint() override {} + + /// @} + /// @name Standard Interface + /// @{ + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class Constraint : public DiscreteFactor { + virtual bool ensureArcConsistency(Key j, Domains* domains) const = 0; - public: + /// Partially apply known values + virtual shared_ptr partiallyApply(const DiscreteValues&) const = 0; - typedef boost::shared_ptr shared_ptr; + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const Domains&) const = 0; + /// @} + /// @name Wrapper support + /// @{ - protected: + /// Render as markdown table. + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { + return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); + } - /// Construct n-way factor - Constraint(const KeyVector& js) : - DiscreteFactor(js) { - } + /// Render as html table. + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { + return (boost::format("

Constraint on %1% variables

") % (size())).str(); + } - /// Construct unary factor - Constraint(Key j) : - DiscreteFactor(boost::assign::cref_list_of<1>(j)) { - } - - /// Construct binary factor - Constraint(Key j1, Key j2) : - DiscreteFactor(boost::assign::cref_list_of<2>(j1)(j2)) { - } - - /// construct from container - template - Constraint(KeyIterator beginKey, KeyIterator endKey) : - DiscreteFactor(beginKey, endKey) { - } - - public: - - /// @name Standard Constructors - /// @{ - - /// Default constructor for I/O - Constraint(); - - /// Virtual destructor - ~Constraint() override {} - - /// @} - /// @name Standard Interface - /// @{ - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - virtual bool ensureArcConsistency(size_t j, std::vector& domains) const = 0; - - /// Partially apply known values - virtual shared_ptr partiallyApply(const Values&) const = 0; - - - /// Partially apply known values, domain version - virtual shared_ptr partiallyApply(const std::vector&) const = 0; - /// @} - }; + /// @} +}; // DiscreteFactor -}// namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index 740ef067c..7acc10cb4 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -5,92 +5,94 @@ * @author Frank Dellaert */ -#include -#include #include -#include +#include +#include +#include +#include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void Domain::print(const string& s, - const KeyFormatter& formatter) const { -// cout << s << ": Domain on " << formatter(keys_[0]) << " (j=" << -// formatter(keys_[0]) << ") with values"; -// for (size_t v: values_) cout << " " << v; -// cout << endl; - for (size_t v: values_) cout << v; - } - - /* ************************************************************************* */ - double Domain::operator()(const Values& values) const { - return contains(values.at(keys_[0])); - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; ++i1) - table.push_back(contains(i1)); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool Domain::ensureArcConsistency(size_t j, vector& domains) const { - if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); - Domain& D = domains[j]; - for(size_t value: values_) - if (!D.contains(value)) throw runtime_error("Unsatisfiable"); - D = *this; - return true; - } - - /* ************************************************************************* */ - bool Domain::checkAllDiff(const KeyVector keys, vector& domains) { - Key j = keys_[0]; - // for all values in this domain - for(size_t value: values_) { - // for all connected domains - for(Key k: keys) - // if any domain contains the value we cannot make this domain singleton - if (k!=j && domains[k].contains(value)) - goto found; - values_.clear(); - values_.insert(value); - return true; // we changed it - found:; - } - return false; // we did not change it - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && !contains(it->second)) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (*this); - } - - /* ************************************************************************* */ - Constraint::shared_ptr Domain::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( - "Domain::partiallyApply: unsatisfiable"); - return boost::make_shared < Domain > (Dk); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void Domain::print(const string& s, const KeyFormatter& formatter) const { + cout << s << ": Domain on " << formatter(key()) << " (j=" << formatter(key()) + << ") with values"; + for (size_t v : values_) cout << " " << v; + cout << endl; +} + +/* ************************************************************************* */ +string Domain::base1Str() const { + stringstream ss; + for (size_t v : values_) ss << v + 1; + return ss.str(); +} + +/* ************************************************************************* */ +double Domain::operator()(const DiscreteValues& values) const { + return contains(values.at(key())); +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(key(), cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool Domain::ensureArcConsistency(Key j, Domains* domains) const { + if (j != key()) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains->at(j); + for (size_t value : values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; +} + +/* ************************************************************************* */ +boost::optional Domain::checkAllDiff(const KeyVector keys, + const Domains& domains) const { + Key j = key(); + // for all values in this domain + for (const size_t value : values_) { + // for all connected domains + for (const Key k : keys) + // if any domain contains the value we cannot make this domain singleton + if (k != j && domains.at(k).contains(value)) goto found; + // Otherwise: return a singleton: + return Domain(this->discreteKey(), value); + found:; + } + return boost::none; // we did not change it +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(key()); + if (it != values.end() && !contains(it->second)) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(*this); +} + +/* ************************************************************************* */ +Constraint::shared_ptr Domain::partiallyApply(const Domains& domains) const { + const Domain& Dk = domains.at(key()); + if (Dk.isSingleton() && !contains(*Dk.begin())) + throw runtime_error("Domain::partiallyApply: unsatisfiable"); + return boost::make_shared(Dk); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 5acc5a08f..1047101c5 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -7,111 +7,107 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * Domain restriction constraint +/** + * The Domain class represents a constraint that restricts the possible values a + * particular variable, with given key, can take on. + */ +class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { + size_t cardinality_; /// Cardinality + std::set values_; /// allowed values + + public: + typedef boost::shared_ptr shared_ptr; + + // Constructor on Discrete Key initializes an "all-allowed" domain + Domain(const DiscreteKey& dkey) + : Constraint(dkey.first), cardinality_(dkey.second) { + for (size_t v = 0; v < cardinality_; v++) values_.insert(v); + } + + // Constructor on Discrete Key with single allowed value + // Consider SingleValue constraint + Domain(const DiscreteKey& dkey, size_t v) + : Constraint(dkey.first), cardinality_(dkey.second) { + values_.insert(v); + } + + /// The one key + Key key() const { return keys_[0]; } + + // The associated discrete key + DiscreteKey discreteKey() const { return DiscreteKey(key(), cardinality_); } + + /// Insert a value, non const :-( + void insert(size_t value) { values_.insert(value); } + + /// Erase a value, non const :-( + void erase(size_t value) { values_.erase(value); } + + size_t nrValues() const { return values_.size(); } + + bool isSingleton() const { return nrValues() == 1; } + + size_t firstValue() const { return *values_.begin(); } + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const Domain& f(static_cast(other)); + return (cardinality_ == f.cardinality_) && (values_ == f.values_); + } + } + + // Return concise string representation, mostly to debug arc consistency. + // Converts from base 0 to base1. + std::string base1Str() const; + + // Check whether domain cotains a specific value. + bool contains(size_t value) const { return values_.count(value) > 0; } + + /// Calculate value + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency by checking every possible value of domain j. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be + * checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT Domain: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - size_t cardinality_; /// Cardinality - std::set values_; /// allowed values + /** + * Check for a value in domain that does not occur in any other connected + * domain. If found, return a a new singleton domain... + * Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + * @param keys other domains + */ + boost::optional checkAllDiff(const KeyVector keys, + const Domains& domains) const; - public: + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; - typedef boost::shared_ptr shared_ptr; + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply(const Domains& domains) const override; +}; - // Constructor on Discrete Key initializes an "all-allowed" domain - Domain(const DiscreteKey& dkey) : - Constraint(dkey.first), cardinality_(dkey.second) { - for (size_t v = 0; v < cardinality_; v++) - values_.insert(v); - } - - // Constructor on Discrete Key with single allowed value - // Consider SingleValue constraint - Domain(const DiscreteKey& dkey, size_t v) : - Constraint(dkey.first), cardinality_(dkey.second) { - values_.insert(v); - } - - /// Constructor - Domain(const Domain& other) : - Constraint(other.keys_[0]), values_(other.values_) { - } - - /// insert a value, non const :-( - void insert(size_t value) { - values_.insert(value); - } - - /// erase a value, non const :-( - void erase(size_t value) { - values_.erase(value); - } - - size_t nrValues() const { - return values_.size(); - } - - bool isSingleton() const { - return nrValues() == 1; - } - - size_t firstValue() const { - return *values_.begin(); - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const Domain& f(static_cast(other)); - return (cardinality_==f.cardinality_) && (values_==f.values_); - } - } - - bool contains(size_t value) const { - return values_.count(value)>0; - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /** - * Check for a value in domain that does not occur in any other connected domain. - * If found, we make this a singleton... Called in AllDiff::ensureArcConsistency - * @param keys connected domains through alldiff - */ - bool checkAllDiff(const KeyVector keys, std::vector& domains); - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index 3273778c4..b86df6c29 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -5,298 +5,268 @@ * @author Frank Dellaert */ -#include -#include #include #include +#include +#include #include - +#include #include #include -#include namespace gtsam { - using namespace std; +using namespace std; - Scheduler::Scheduler(size_t maxNrStudents, const string& filename): - maxNrStudents_(maxNrStudents) - { - typedef boost::tokenizer > Tokenizer; +Scheduler::Scheduler(size_t maxNrStudents, const string& filename) + : maxNrStudents_(maxNrStudents) { + typedef boost::tokenizer > Tokenizer; - // open file - ifstream is(filename.c_str()); - if (!is) { - cerr << "Scheduler: could not open file " << filename << endl; - throw runtime_error("Scheduler: could not open file " + filename); - } - - string line; // buffer - - // process first line with faculty - if (getline(is, line, '\r')) { - Tokenizer tok(line); - Tokenizer::iterator it = tok.begin(); - for (++it; it != tok.end(); ++it) - addFaculty(*it); - } - - // for all remaining lines - size_t count = 0; - while (getline(is, line, '\r')) { - if (count++ > 100) throw runtime_error("reached 100 lines, exiting"); - Tokenizer tok(line); - Tokenizer::iterator it = tok.begin(); - addSlot(*it++); // add slot - // add availability - for (; it != tok.end(); ++it) - available_ += (it->empty()) ? "0 " : "1 "; - available_ += '\n'; - } - } // constructor - - /** addStudent has to be called after adding slots and faculty */ - void Scheduler::addStudent(const string& studentName, - const string& area1, const string& area2, - const string& area3, const string& advisor) { - assert(nrStudents() area) const { - return area ? students_[s].keys_[*area] : students_[s].key_; + // open file + ifstream is(filename.c_str()); + if (!is) { + cerr << "Scheduler: could not open file " << filename << endl; + throw runtime_error("Scheduler: could not open file " + filename); } - const string& Scheduler::studentName(size_t i) const { - assert(i 100) throw runtime_error("reached 100 lines, exiting"); + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + addSlot(*it++); // add slot + // add availability + for (; it != tok.end(); ++it) available_ += (it->empty()) ? "0 " : "1 "; + available_ += '\n'; + } +} // constructor + +/** addStudent has to be called after adding slots and faculty */ +void Scheduler::addStudent(const string& studentName, const string& area1, + const string& area2, const string& area3, + const string& advisor) { + assert(nrStudents() < maxNrStudents_); + assert(facultyInArea_.count(area1)); + assert(facultyInArea_.count(area2)); + assert(facultyInArea_.count(area3)); + size_t advisorIndex = facultyIndex_[advisor]; + Student student(nrFaculty(), advisorIndex); + student.name_ = studentName; + // We fix the ordering by assigning a higher index to the student + // and numbering the areas lower + Key j = 3 * maxNrStudents_ + nrStudents(); + student.key_ = DiscreteKey(j, nrTimeSlots()); + Key base = 3 * nrStudents(); + student.keys_[0] = DiscreteKey(base + 0, nrFaculty()); + student.keys_[1] = DiscreteKey(base + 1, nrFaculty()); + student.keys_[2] = DiscreteKey(base + 2, nrFaculty()); + student.areaName_[0] = area1; + student.areaName_[1] = area2; + student.areaName_[2] = area3; + students_.push_back(student); +} + +/** get key for student and area, 0 is time slot itself */ +const DiscreteKey& Scheduler::key(size_t s, + boost::optional area) const { + return area ? students_[s].keys_[*area] : students_[s].key_; +} + +const string& Scheduler::studentName(size_t i) const { + assert(i < nrStudents()); + return students_[i].name_; +} + +const DiscreteKey& Scheduler::studentKey(size_t i) const { + assert(i < nrStudents()); + return students_[i].key_; +} + +const string& Scheduler::studentArea(size_t i, size_t area) const { + assert(i < nrStudents()); + return students_[i].areaName_[area]; +} + +/** Add student-specific constraints to the graph */ +void Scheduler::addStudentSpecificConstraints(size_t i, + boost::optional slot) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + assert(i < nrStudents()); + const Student& s = students_[i]; + + if (!slot && !slotsAvailable_.empty()) { + if (debug) cout << "Adding availability of slots" << endl; + assert(slotsAvailable_.size() == s.key_.second); + CSP::add(s.key_, slotsAvailable_); } - const string& Scheduler::studentArea(size_t i, size_t area) const { - assert(i slot) { - bool debug = ISDEBUG("Scheduler::buildGraph"); + if (debug) cout << "Area constraints " << areaName << endl; + assert(facultyInArea_[areaName].size() == areaKey.second); + CSP::add(areaKey, facultyInArea_[areaName]); - assert(i p(dummy & areaKey, + available_); // available_ is Doodle string + auto q = p.choose(dummyIndex, *slot); + CSP::add(areaKey, q); } else { - if (debug) cout << "Mutex for Students" << endl; - for (size_t i1 = 0; i1 < nrStudents(); i1++) { - // if mutexBound=1, we only mutex with next student - size_t bound = min((i1 + 1 + mutexBound), nrStudents()); - for (size_t i2 = i1 + 1; i2 < bound; i2++) { - addAllDiff(studentKey(i1), studentKey(i2)); - } + DiscreteKeys keys {s.key_, areaKey}; + CSP::add(keys, available_); // available_ is Doodle string + } + } + + // add mutex + if (debug) cout << "Mutex for faculty" << endl; + addAllDiff(s.keys_[0] & s.keys_[1] & s.keys_[2]); +} + +/** Main routine that builds factor graph */ +void Scheduler::buildGraph(size_t mutexBound) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + if (debug) cout << "Adding student-specific constraints" << endl; + for (size_t i = 0; i < nrStudents(); i++) addStudentSpecificConstraints(i); + + // special constraint for MN + if (studentName(0) == "Michael N") + CSP::add(studentKey(0), "0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"); + + if (!mutexBound) { + DiscreteKeys dkeys; + for (const Student& s : students_) dkeys.push_back(s.key_); + addAllDiff(dkeys); + } else { + if (debug) cout << "Mutex for Students" << endl; + for (size_t i1 = 0; i1 < nrStudents(); i1++) { + // if mutexBound=1, we only mutex with next student + size_t bound = min((i1 + 1 + mutexBound), nrStudents()); + for (size_t i2 = i1 + 1; i2 < bound; i2++) { + addAllDiff(studentKey(i1), studentKey(i2)); } } - } // buildGraph - - /** print */ - void Scheduler::print(const string& s, const KeyFormatter& formatter) const { - cout << s << " Faculty:" << endl; - for(const string& name: facultyName_) - cout << name << '\n'; - cout << endl; - - cout << s << " Slots:\n"; - size_t i = 0; - for(const string& name: slotName_) - cout << i++ << " " << name << endl; - cout << endl; - - cout << "Availability:\n" << available_ << '\n'; - - cout << s << " Area constraints:\n"; - for(const FacultyInArea::value_type& it: facultyInArea_) - { - cout << setw(12) << it.first << ": "; - for(double v: it.second) - cout << v << " "; - cout << '\n'; - } - cout << endl; - - cout << s << " Students:\n"; - for (const Student& student: students_) - student.print(); - cout << endl; - - CSP::print(s + " Factor graph"); - cout << endl; - } // print - - /** Print readable form of assignment */ - void Scheduler::printAssignment(sharedValues assignment) const { - // Not intended to be general! Assumes very particular ordering ! - cout << endl; - for (size_t s = 0; s < nrStudents(); s++) { - Key j = 3*maxNrStudents_ + s; - size_t slot = assignment->at(j); - cout << studentName(s) << " slot: " << slotName_[slot] << endl; - Key base = 3*s; - for (size_t area = 0; area < 3; area++) { - size_t faculty = assignment->at(base+area); - cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty] - << endl; - } - cout << endl; - } } +} // buildGraph - /** Special print for single-student case */ - void Scheduler::printSpecial(sharedValues assignment) const { - Values::const_iterator it = assignment->begin(); - for (size_t area = 0; area < 3; area++, it++) { - size_t f = it->second; - cout << setw(12) << studentArea(0,area) << ": " << facultyName_[f] << endl; +/** print */ +void Scheduler::print(const string& s, const KeyFormatter& formatter) const { + cout << s << " Faculty:" << endl; + for (const string& name : facultyName_) cout << name << '\n'; + cout << endl; + + cout << s << " Slots:\n"; + size_t i = 0; + for (const string& name : slotName_) cout << i++ << " " << name << endl; + cout << endl; + + cout << "Availability:\n" << available_ << '\n'; + + cout << s << " Area constraints:\n"; + for (const FacultyInArea::value_type& it : facultyInArea_) { + cout << setw(12) << it.first << ": "; + for (double v : it.second) cout << v << " "; + cout << '\n'; + } + cout << endl; + + cout << s << " Students:\n"; + for (const Student& student : students_) student.print(); + cout << endl; + + CSP::print(s + " Factor graph"); + cout << endl; +} // print + +/** Print readable form of assignment */ +void Scheduler::printAssignment(const DiscreteValues& assignment) const { + // Not intended to be general! Assumes very particular ordering ! + cout << endl; + for (size_t s = 0; s < nrStudents(); s++) { + Key j = 3 * maxNrStudents_ + s; + size_t slot = assignment.at(j); + cout << studentName(s) << " slot: " << slotName_[slot] << endl; + Key base = 3 * s; + for (size_t area = 0; area < 3; area++) { + size_t faculty = assignment.at(base + area); + cout << setw(12) << studentArea(s, area) << ": " << facultyName_[faculty] + << endl; } cout << endl; } +} - /** Accumulate faculty stats */ - void Scheduler::accumulateStats(sharedValues assignment, vector< - size_t>& stats) const { - for (size_t s = 0; s < nrStudents(); s++) { - Key base = 3*s; - for (size_t area = 0; area < 3; area++) { - size_t f = assignment->at(base+area); - assert(fsecond; + cout << setw(12) << studentArea(0, area) << ": " << facultyName_[f] << endl; } + cout << endl; +} - /** Eliminate, return a Bayes net */ - DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { - gttic(my_eliminate); - // TODO: fix this!! - size_t maxKey = keys().size(); - Ordering defaultKeyOrdering; - for (size_t i = 0; ieliminateSequential(defaultKeyOrdering); - gttoc(my_eliminate); - return chordal; - } +/** Accumulate faculty stats */ +void Scheduler::accumulateStats(const DiscreteValues& assignment, + vector& stats) const { + for (size_t s = 0; s < nrStudents(); s++) { + Key base = 3 * s; + for (size_t area = 0; area < 3; area++) { + size_t f = assignment.at(base + area); + assert(f < stats.size()); + stats[f]++; + } // area + } // s +} - /** Find the best total assignment - can be expensive */ - Scheduler::sharedValues Scheduler::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = eliminate(); +/** Eliminate, return a Bayes net */ +DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { + gttic(my_eliminate); + // TODO: fix this!! + size_t maxKey = keys().size(); + Ordering defaultKeyOrdering; + for (size_t i = 0; i < maxKey; ++i) defaultKeyOrdering += Key(i); + DiscreteBayesNet::shared_ptr chordal = + this->eliminateSequential(defaultKeyOrdering); + gttoc(my_eliminate); + return chordal; +} - if (ISDEBUG("Scheduler::optimalAssignment")) { - DiscreteBayesNet::const_iterator it = chordal->end()-1; - const Student & student = students_.front(); - cout << endl; - (*it)->print(student.name_); - } - - gttic(my_optimize); - sharedValues mpe = chordal->optimize(); - gttoc(my_optimize); - return mpe; - } - - /** find the assignment of students to slots with most possible committees */ - Scheduler::sharedValues Scheduler::bestSchedule() const { - sharedValues best; - throw runtime_error("bestSchedule not implemented"); - return best; - } - - /** find the corresponding most desirable committee assignment */ - Scheduler::sharedValues Scheduler::bestAssignment( - sharedValues bestSchedule) const { - sharedValues best; - throw runtime_error("bestAssignment not implemented"); - return best; - } - -} // gtsam +/** find the assignment of students to slots with most possible committees */ +DiscreteValues Scheduler::bestSchedule() const { + DiscreteValues best; + throw runtime_error("bestSchedule not implemented"); + return best; +} +/** find the corresponding most desirable committee assignment */ +DiscreteValues Scheduler::bestAssignment(const DiscreteValues& bestSchedule) const { + DiscreteValues best; + throw runtime_error("bestAssignment not implemented"); + return best; +} +} // namespace gtsam diff --git a/gtsam_unstable/discrete/Scheduler.h b/gtsam_unstable/discrete/Scheduler.h index 6faf9956f..8d269e81a 100644 --- a/gtsam_unstable/discrete/Scheduler.h +++ b/gtsam_unstable/discrete/Scheduler.h @@ -8,168 +8,151 @@ #pragma once #include +#include namespace gtsam { +/** + * Scheduler class + * Creates one variable for each student, and three variables for each + * of the student's areas, for a total of 4*nrStudents variables. + * The "student" variable will determine when the student takes the qual. + * The "area" variables determine which faculty are on his/her committee. + */ +class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { + private: + /** Internal data structure for students */ + struct Student { + std::string name_; + DiscreteKey key_; // key for student + std::vector keys_; // key for areas + std::vector areaName_; + std::vector advisor_; + Student(size_t nrFaculty, size_t advisorIndex) + : keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { + advisor_[advisorIndex] = 0.0; + } + void print() const { + using std::cout; + cout << name_ << ": "; + for (size_t area = 0; area < 3; area++) cout << areaName_[area] << " "; + cout << std::endl; + } + }; + + /** Maximum number of students */ + size_t maxNrStudents_; + + /** discrete keys, indexed by student and area index */ + std::vector students_; + + /** faculty identifiers */ + std::map facultyIndex_; + std::vector facultyName_, slotName_, areaName_; + + /** area constraints */ + typedef std::map > FacultyInArea; + FacultyInArea facultyInArea_; + + /** nrTimeSlots * nrFaculty availability constraints */ + std::string available_; + + /** which slots are good */ + std::vector slotsAvailable_; + + public: /** - * Scheduler class - * Creates one variable for each student, and three variables for each - * of the student's areas, for a total of 4*nrStudents variables. - * The "student" variable will determine when the student takes the qual. - * The "area" variables determine which faculty are on his/her committee. + * Constructor + * We need to know the number of students in advance for ordering keys. + * then add faculty, slots, areas, availability, students, in that order */ - class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { + Scheduler(size_t maxNrStudents) : maxNrStudents_(maxNrStudents) {} - private: + /// Destructor + virtual ~Scheduler() {} - /** Internal data structure for students */ - struct Student { - std::string name_; - DiscreteKey key_; // key for student - std::vector keys_; // key for areas - std::vector areaName_; - std::vector advisor_; - Student(size_t nrFaculty, size_t advisorIndex) : - keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { - advisor_[advisorIndex] = 0.0; - } - void print() const { - using std::cout; - cout << name_ << ": "; - for (size_t area = 0; area < 3; area++) - cout << areaName_[area] << " "; - cout << std::endl; - } - }; + void addFaculty(const std::string& facultyName) { + facultyIndex_[facultyName] = nrFaculty(); + facultyName_.push_back(facultyName); + } - /** Maximum number of students */ - size_t maxNrStudents_; + size_t nrFaculty() const { return facultyName_.size(); } - /** discrete keys, indexed by student and area index */ - std::vector students_; + /** boolean std::string of nrTimeSlots * nrFaculty */ + void setAvailability(const std::string& available) { available_ = available; } - /** faculty identifiers */ - std::map facultyIndex_; - std::vector facultyName_, slotName_, areaName_; + void addSlot(const std::string& slotName) { slotName_.push_back(slotName); } - /** area constraints */ - typedef std::map > FacultyInArea; - FacultyInArea facultyInArea_; + size_t nrTimeSlots() const { return slotName_.size(); } - /** nrTimeSlots * nrFaculty availability constraints */ - std::string available_; + const std::string& slotName(size_t s) const { return slotName_[s]; } - /** which slots are good */ - std::vector slotsAvailable_; + /** slots available, boolean */ + void setSlotsAvailable(const std::vector& slotsAvailable) { + slotsAvailable_ = slotsAvailable; + } - public: + void addArea(const std::string& facultyName, const std::string& areaName) { + areaName_.push_back(areaName); + std::vector& table = + facultyInArea_[areaName]; // will create if needed + if (table.empty()) table.resize(nrFaculty(), 0); + table[facultyIndex_[facultyName]] = 1; + } - /** - * Constructor - * We need to know the number of students in advance for ordering keys. - * then add faculty, slots, areas, availability, students, in that order - */ - Scheduler(size_t maxNrStudents) : maxNrStudents_(maxNrStudents) {} + /** + * Constructor that reads in faculty, slots, availibility. + * Still need to add areas and students after this + */ + Scheduler(size_t maxNrStudents, const std::string& filename); - /// Destructor - virtual ~Scheduler() {} + /** get key for student and area, 0 is time slot itself */ + const DiscreteKey& key(size_t s, + boost::optional area = boost::none) const; - void addFaculty(const std::string& facultyName) { - facultyIndex_[facultyName] = nrFaculty(); - facultyName_.push_back(facultyName); - } + /** addStudent has to be called after adding slots and faculty */ + void addStudent(const std::string& studentName, const std::string& area1, + const std::string& area2, const std::string& area3, + const std::string& advisor); - size_t nrFaculty() const { - return facultyName_.size(); - } + /// current number of students + size_t nrStudents() const { return students_.size(); } - /** boolean std::string of nrTimeSlots * nrFaculty */ - void setAvailability(const std::string& available) { - available_ = available; - } + const std::string& studentName(size_t i) const; + const DiscreteKey& studentKey(size_t i) const; + const std::string& studentArea(size_t i, size_t area) const; - void addSlot(const std::string& slotName) { - slotName_.push_back(slotName); - } + /** Add student-specific constraints to the graph */ + void addStudentSpecificConstraints( + size_t i, boost::optional slot = boost::none); - size_t nrTimeSlots() const { - return slotName_.size(); - } + /** Main routine that builds factor graph */ + void buildGraph(size_t mutexBound = 7); - const std::string& slotName(size_t s) const { - return slotName_[s]; - } + /** print */ + void print( + const std::string& s = "Scheduler", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** slots available, boolean */ - void setSlotsAvailable(const std::vector& slotsAvailable) { - slotsAvailable_ = slotsAvailable; - } + /** Print readable form of assignment */ + void printAssignment(const DiscreteValues& assignment) const; - void addArea(const std::string& facultyName, const std::string& areaName) { - areaName_.push_back(areaName); - std::vector& table = facultyInArea_[areaName]; // will create if needed - if (table.empty()) table.resize(nrFaculty(), 0); - table[facultyIndex_[facultyName]] = 1; - } + /** Special print for single-student case */ + void printSpecial(const DiscreteValues& assignment) const; - /** - * Constructor that reads in faculty, slots, availibility. - * Still need to add areas and students after this - */ - Scheduler(size_t maxNrStudents, const std::string& filename); + /** Accumulate faculty stats */ + void accumulateStats(const DiscreteValues& assignment, + std::vector& stats) const; - /** get key for student and area, 0 is time slot itself */ - const DiscreteKey& key(size_t s, boost::optional area = boost::none) const; + /** Eliminate, return a Bayes net */ + DiscreteBayesNet::shared_ptr eliminate() const; - /** addStudent has to be called after adding slots and faculty */ - void addStudent(const std::string& studentName, const std::string& area1, - const std::string& area2, const std::string& area3, - const std::string& advisor); + /** find the assignment of students to slots with most possible committees */ + DiscreteValues bestSchedule() const; - /// current number of students - size_t nrStudents() const { - return students_.size(); - } - - const std::string& studentName(size_t i) const; - const DiscreteKey& studentKey(size_t i) const; - const std::string& studentArea(size_t i, size_t area) const; - - /** Add student-specific constraints to the graph */ - void addStudentSpecificConstraints(size_t i, boost::optional slot = boost::none); - - /** Main routine that builds factor graph */ - void buildGraph(size_t mutexBound = 7); - - /** print */ - void print( - const std::string& s = "Scheduler", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /** Print readable form of assignment */ - void printAssignment(sharedValues assignment) const; - - /** Special print for single-student case */ - void printSpecial(sharedValues assignment) const; - - /** Accumulate faculty stats */ - void accumulateStats(sharedValues assignment, - std::vector& stats) const; - - /** Eliminate, return a Bayes net */ - DiscreteBayesNet::shared_ptr eliminate() const; - - /** Find the best total assignment - can be expensive */ - sharedValues optimalAssignment() const; - - /** find the assignment of students to slots with most possible committees */ - sharedValues bestSchedule() const; - - /** find the corresponding most desirable committee assignment */ - sharedValues bestAssignment(sharedValues bestSchedule) const; - - }; // Scheduler - -} // gtsam + /** find the corresponding most desirable committee assignment */ + DiscreteValues bestAssignment(const DiscreteValues& bestSchedule) const; +}; // Scheduler +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6324f14cd..6dd81a7dc 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -5,75 +5,73 @@ * @author Frank Dellaert */ -#include -#include -#include #include +#include +#include +#include + #include namespace gtsam { - using namespace std; - - /* ************************************************************************* */ - void SingleValue::print(const string& s, - const KeyFormatter& formatter) const { - cout << s << "SingleValue on " << "j=" << formatter(keys_[0]) - << " with value " << value_ << endl; - } - - /* ************************************************************************* */ - double SingleValue::operator()(const Values& values) const { - return (double) (values.at(keys_[0]) == value_); - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { - DiscreteKeys keys; - keys += DiscreteKey(keys_[0],cardinality_); - vector table; - for (size_t i1 = 0; i1 < cardinality_; i1++) - table.push_back(i1 == value_); - DecisionTreeFactor converted(keys, table); - return converted; - } - - /* ************************************************************************* */ - DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { - // TODO: can we do this more efficiently? - return toDecisionTreeFactor() * f; - } - - /* ************************************************************************* */ - bool SingleValue::ensureArcConsistency(size_t j, - vector& domains) const { - if (j != keys_[0]) throw invalid_argument( - "SingleValue check on wrong domain"); - Domain& D = domains[j]; - if (D.isSingleton()) { - if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); - return false; - } - D = Domain(discreteKey(),value_); - return true; - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply(const Values& values) const { - Values::const_iterator it = values.find(keys_[0]); - if (it != values.end() && it->second != value_) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_); - } - - /* ************************************************************************* */ - Constraint::shared_ptr SingleValue::partiallyApply( - const vector& domains) const { - const Domain& Dk = domains[keys_[0]]; - if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( - "SingleValue::partiallyApply: unsatisfiable"); - return boost::make_shared < SingleValue > (discreteKey(), value_); - } +using namespace std; /* ************************************************************************* */ -} // namespace gtsam +void SingleValue::print(const string& s, const KeyFormatter& formatter) const { + cout << s << "SingleValue on " + << "j=" << formatter(keys_[0]) << " with value " << value_ << endl; +} + +/* ************************************************************************* */ +double SingleValue::operator()(const DiscreteValues& values) const { + return (double)(values.at(keys_[0]) == value_); +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::toDecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0], cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; +} + +/* ************************************************************************* */ +DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************* */ +bool SingleValue::ensureArcConsistency(Key j, Domains* domains) const { + if (j != keys_[0]) + throw invalid_argument("SingleValue check on wrong domain"); + Domain& D = domains->at(j); + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(), value_); + return true; +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply(const DiscreteValues& values) const { + DiscreteValues::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(keys_[0], cardinality_, value_); +} + +/* ************************************************************************* */ +Constraint::shared_ptr SingleValue::partiallyApply( + const Domains& domains) const { + const Domain& Dk = domains.at(keys_[0]); + if (Dk.isSingleton() && !Dk.contains(value_)) + throw runtime_error("SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared(discreteKey(), value_); +} + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index c4d2addec..3b2d6e80b 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -7,76 +7,71 @@ #pragma once -#include #include +#include namespace gtsam { - /** - * SingleValue constraint +/** + * SingleValue constraint: ensures a variable takes on a certain value. + * This could of course also be implemented by changing its `Domain`. + */ +class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { + size_t cardinality_; /// < Number of values + size_t value_; ///< allowed value + + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0], cardinality_); + } + + public: + typedef boost::shared_ptr shared_ptr; + + /// Construct from key, cardinality, and given value. + SingleValue(Key key, size_t n, size_t value) + : Constraint(key), cardinality_(n), value_(value) {} + + /// Construct from DiscreteKey and given value. + SingleValue(const DiscreteKey& dkey, size_t value) + : Constraint(dkey.first), cardinality_(dkey.second), value_(value) {} + + // print + void print(const std::string& s = "", const KeyFormatter& formatter = + DefaultKeyFormatter) const override; + + /// equals + bool equals(const DiscreteFactor& other, double tol) const override { + if (!dynamic_cast(&other)) + return false; + else { + const SingleValue& f(static_cast(other)); + return (cardinality_ == f.cardinality_) && (value_ == f.value_); + } + } + + /// Calculate value + double operator()(const DiscreteValues& values) const override; + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Multiply into a decisiontree + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + + /* + * Ensure Arc-consistency: just sets domain[j] to {value_}. + * @param j domain to be checked + * @param (in/out) domains all domains, but only domains->at(j) will be checked. + * @return true if domains->at(j) was changed, false otherwise. */ - class GTSAM_UNSTABLE_EXPORT SingleValue: public Constraint { + bool ensureArcConsistency(Key j, Domains* domains) const override; - /// Number of values - size_t cardinality_; + /// Partially apply known values + Constraint::shared_ptr partiallyApply(const DiscreteValues& values) const override; - /// allowed value - size_t value_; + /// Partially apply known values, domain version + Constraint::shared_ptr partiallyApply( + const Domains& domains) const override; +}; - DiscreteKey discreteKey() const { - return DiscreteKey(keys_[0],cardinality_); - } - - public: - - typedef boost::shared_ptr shared_ptr; - - /// Constructor - SingleValue(Key key, size_t n, size_t value) : - Constraint(key), cardinality_(n), value_(value) { - } - - /// Constructor - SingleValue(const DiscreteKey& dkey, size_t value) : - Constraint(dkey.first), cardinality_(dkey.second), value_(value) { - } - - // print - void print(const std::string& s = "", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - /// equals - bool equals(const DiscreteFactor& other, double tol) const override { - if(!dynamic_cast(&other)) - return false; - else { - const SingleValue& f(static_cast(other)); - return (cardinality_==f.cardinality_) && (value_==f.value_); - } - } - - /// Calculate value - double operator()(const Values& values) const override; - - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; - - /// Multiply into a decisiontree - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - - /* - * Ensure Arc-consistency - * @param j domain to be checked - * @param domains all other domains - */ - bool ensureArcConsistency(size_t j, std::vector& domains) const override; - - /// Partially apply known values - Constraint::shared_ptr partiallyApply(const Values& values) const override; - - /// Partially apply known values, domain version - Constraint::shared_ptr partiallyApply( - const std::vector& domains) const override; - }; - -} // namespace gtsam +} // namespace gtsam diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index e9f63b2d8..487edc97a 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -115,14 +115,14 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference // SETDEBUG("timing-verbose", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true); gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(6 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -225,7 +225,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < 7; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); @@ -319,11 +319,11 @@ void accomodateStudent() { // GTSAM_PRINT(*chordal); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(0); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) @@ -331,7 +331,7 @@ void accomodateStudent() { // sample schedules for (size_t n = 0; n < 10; n++) { - Scheduler::sharedValues sample0 = chordal->sample(); + auto sample0 = chordal->sample(); scheduler.printAssignment(sample0); } } diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 1fc4a1459..830d59ba7 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference @@ -129,7 +129,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = chordal->sample(); + auto assignment = chordal->sample(); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -143,7 +143,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -234,7 +234,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 500; n++) { vector stats(19, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 95b64f289..b24f9bf0a 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -139,7 +139,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference @@ -153,7 +153,7 @@ void runLargeExample() { tictoc_finishedIteration(); tictoc_print(); for (size_t i=0;i<100;i++) { - DiscreteFactor::sharedValues assignment = sample(*chordal); + auto assignment = sample(*chordal); vector stats(scheduler.nrFaculty()); scheduler.accumulateStats(assignment, stats); size_t max = *max_element(stats.begin(), stats.end()); @@ -167,7 +167,7 @@ void runLargeExample() { } #else gttic(large); - DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - Scheduler::Values values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; double count = (*root)(values); @@ -259,7 +259,7 @@ void sampleSolutions() { // now, sample schedules for (size_t n = 0; n < 10000; n++) { vector stats(nrFaculty, 0); - vector samples; + vector samples; for (size_t i = 0; i < NRSTUDENTS; i++) { samples.push_back(samplers[i]->sample()); schedulers[i].accumulateStats(samples[i], stats); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 3dd493b1b..fb386b255 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -7,59 +7,119 @@ #include #include + #include using boost::assign::insert; #include -#include + #include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( BinaryAllDif, allInOne) -{ - // Create keys and ordering +TEST(CSP, SingleValue) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check that a single value is equal to a decision stump with only one "1": + SingleValue singleValue(AZ, 2); + DecisionTreeFactor f1(AZ, "0 0 1"); + EXPECT(assert_equal(f1, singleValue.toDecisionTreeFactor())); + + // Create domains + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // Ensure arc-consistency: just wipes out values in AZ domain: + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + LONGS_EQUAL(3, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(3, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, BinaryAllDif) { + // Create keys for Idaho, Arizona, and Utah, allowing 2 colors for each: size_t nrColors = 2; -// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors); - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Check construction and conversion BinaryAllDiff c1(ID, UT); DecisionTreeFactor f1(ID & UT, "0 1 1 0"); - EXPECT(assert_equal(f1,c1.toDecisionTreeFactor())); + EXPECT(assert_equal(f1, c1.toDecisionTreeFactor())); // Check construction and conversion BinaryAllDiff c2(UT, AZ); DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); - EXPECT(assert_equal(f2,c2.toDecisionTreeFactor())); + EXPECT(assert_equal(f2, c2.toDecisionTreeFactor())); - DecisionTreeFactor f3 = f1*f2; - EXPECT(assert_equal(f3,c1*f2)); - EXPECT(assert_equal(f3,c2*f1)); + // Check multiplication of factors with constraint: + DecisionTreeFactor f3 = f1 * f2; + EXPECT(assert_equal(f3, c1 * f2)); + EXPECT(assert_equal(f3, c2 * f1)); } /* ************************************************************************* */ -TEST_UNSAFE( CSP, allInOne) -{ - // Create keys and ordering +TEST(CSP, AllDiff) { + // Create keys for Idaho, Arizona, and Utah, allowing two colors for each: + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); + + // Check construction and conversion + vector dkeys{ID, UT, AZ}; + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); + // GTSAM_PRINT(actual); + actual.dot("actual"); + DecisionTreeFactor f2( + ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2, actual)); + + // Create domains. + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + // First constrict AZ domain: + SingleValue singleValue(AZ, 2); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + + // Arc-consistency + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); +} + +/* ************************************************************************* */ +TEST(CSP, allInOne) { + // Create keys for Idaho, Arizona, and Utah, allowing 3 colors for each: size_t nrColors = 2; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); // Create the CSP CSP csp; - csp.addAllDiff(ID,UT); - csp.addAllDiff(UT,AZ); + csp.addAllDiff(ID, UT); + csp.addAllDiff(UT, AZ); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 0; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 0; @@ -69,68 +129,62 @@ TEST_UNSAFE( CSP, allInOne) DecisionTreeFactor product = csp.product(); // product.dot("product"); DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); - EXPECT(assert_equal(expectedProduct,product)); + EXPECT(assert_equal(expectedProduct, product)); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimize(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); } /* ************************************************************************* */ -TEST_UNSAFE( CSP, WesternUS) -{ - // Create keys +TEST(CSP, WesternUS) { + // Create keys for all states in Western US, with 4 color possibilities. size_t nrColors = 4; - DiscreteKey - // Create ordering according to example in ND-CSP.lyx - WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), - ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), - MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); + DiscreteKey WA(0, nrColors), OR(3, nrColors), CA(1, nrColors), + NV(2, nrColors), ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); // Create the CSP CSP csp; - csp.addAllDiff(WA,ID); - csp.addAllDiff(WA,OR); - csp.addAllDiff(OR,ID); - csp.addAllDiff(OR,CA); - csp.addAllDiff(OR,NV); - csp.addAllDiff(CA,NV); - csp.addAllDiff(CA,AZ); - csp.addAllDiff(ID,MT); - csp.addAllDiff(ID,WY); - csp.addAllDiff(ID,UT); - csp.addAllDiff(ID,NV); - csp.addAllDiff(NV,UT); - csp.addAllDiff(NV,AZ); - csp.addAllDiff(UT,WY); - csp.addAllDiff(UT,CO); - csp.addAllDiff(UT,NM); - csp.addAllDiff(UT,AZ); - csp.addAllDiff(AZ,CO); - csp.addAllDiff(AZ,NM); - csp.addAllDiff(MT,WY); - csp.addAllDiff(WY,CO); - csp.addAllDiff(CO,NM); + csp.addAllDiff(WA, ID); + csp.addAllDiff(WA, OR); + csp.addAllDiff(OR, ID); + csp.addAllDiff(OR, CA); + csp.addAllDiff(OR, NV); + csp.addAllDiff(CA, NV); + csp.addAllDiff(CA, AZ); + csp.addAllDiff(ID, MT); + csp.addAllDiff(ID, WY); + csp.addAllDiff(ID, UT); + csp.addAllDiff(ID, NV); + csp.addAllDiff(NV, UT); + csp.addAllDiff(NV, AZ); + csp.addAllDiff(UT, WY); + csp.addAllDiff(UT, CO); + csp.addAllDiff(UT, NM); + csp.addAllDiff(UT, AZ); + csp.addAllDiff(AZ, CO); + csp.addAllDiff(AZ, NM); + csp.addAllDiff(MT, WY); + csp.addAllDiff(WY, CO); + csp.addAllDiff(CO, NM); - // Solve + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0); + + // Create ordering according to example in ND-CSP.lyx Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7),Key(8),Key(9),Key(10); - CSP::sharedValues mpe = csp.optimalAssignment(ordering); - // GTSAM_PRINT(*mpe); - CSP::Values expected; - insert(expected) - (WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) - (MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) - (ID.first,2)(UT.first,1)(AZ.first,0); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), + Key(8), Key(9), Key(10); - // TODO: Fix me! mpe result seems to be right. (See the printing) - // It has the same prob as the expected solution. - // Is mpe another solution, or the expected solution is unique??? - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + // Solve using that ordering: + auto actualMPE = csp.optimize(ordering); + + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); // Write out the dual graph for hmetis #ifdef DUAL @@ -142,85 +196,74 @@ TEST_UNSAFE( CSP, WesternUS) } /* ************************************************************************* */ -TEST_UNSAFE( CSP, AllDiff) -{ - // Create keys and ordering +TEST(CSP, ArcConsistency) { + // Create keys for Idaho, Arizona, and Utah, allowing three colors for each: size_t nrColors = 3; - DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + DiscreteKey ID(0, nrColors), AZ(1, nrColors), UT(2, nrColors); - // Create the CSP + // Create the CSP using just one all-diff constraint, plus constrain Arizona. CSP csp; - vector dkeys; - dkeys += ID,UT,AZ; + vector dkeys{ID, UT, AZ}; csp.addAllDiff(dkeys); - csp.addSingleValue(AZ,2); -// GTSAM_PRINT(csp); - - // Check construction and conversion - SingleValue s(AZ,2); - DecisionTreeFactor f1(AZ,"0 0 1"); - EXPECT(assert_equal(f1,s.toDecisionTreeFactor())); - - // Check construction and conversion - AllDiff alldiff(dkeys); - DecisionTreeFactor actual = alldiff.toDecisionTreeFactor(); -// GTSAM_PRINT(actual); -// actual.dot("actual"); - DecisionTreeFactor f2(ID & AZ & UT, - "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); - EXPECT(assert_equal(f2,actual)); + csp.addSingleValue(AZ, 2); + // GTSAM_PRINT(csp); // Check an invalid combination, with ID==UT==AZ all same color - DiscreteFactor::Values invalid; + DiscreteValues invalid; invalid[ID.first] = 0; invalid[UT.first] = 1; invalid[AZ.first] = 0; EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); // Check a valid combination - DiscreteFactor::Values valid; + DiscreteValues valid; valid[ID.first] = 0; valid[UT.first] = 1; valid[AZ.first] = 2; EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Solve - CSP::sharedValues mpe = csp.optimalAssignment(); - CSP::Values expected; + auto mpe = csp.optimize(); + DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); - EXPECT(assert_equal(expected,*mpe)); - EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + EXPECT(assert_equal(expected, mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); - // Arc-consistency - vector domains; - domains += Domain(ID), Domain(AZ), Domain(UT); - SingleValue singleValue(AZ,2); - EXPECT(singleValue.ensureArcConsistency(1,domains)); - EXPECT(alldiff.ensureArcConsistency(0,domains)); - EXPECT(!alldiff.ensureArcConsistency(1,domains)); - EXPECT(alldiff.ensureArcConsistency(2,domains)); - LONGS_EQUAL(2,domains[0].nrValues()); - LONGS_EQUAL(1,domains[1].nrValues()); - LONGS_EQUAL(2,domains[2].nrValues()); + // ensure arc-consistency, i.e., narrow domains... + Domains domains; + domains.emplace(0, Domain(ID)); + domains.emplace(1, Domain(AZ)); + domains.emplace(2, Domain(UT)); + + SingleValue singleValue(AZ, 2); + AllDiff alldiff(dkeys); + EXPECT(singleValue.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(0, &domains)); + EXPECT(!alldiff.ensureArcConsistency(1, &domains)); + EXPECT(alldiff.ensureArcConsistency(2, &domains)); + LONGS_EQUAL(2, domains.at(0).nrValues()); + LONGS_EQUAL(1, domains.at(1).nrValues()); + LONGS_EQUAL(2, domains.at(2).nrValues()); // Parial application, version 1 - DiscreteFactor::Values known; + DiscreteValues known; known[AZ.first] = 2; DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); - EXPECT(assert_equal(f3,reduced1->toDecisionTreeFactor())); + EXPECT(assert_equal(f3, reduced1->toDecisionTreeFactor())); DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); DecisionTreeFactor f4(AZ, "0 0 1"); - EXPECT(assert_equal(f4,reduced2->toDecisionTreeFactor())); + EXPECT(assert_equal(f4, reduced2->toDecisionTreeFactor())); // Parial application, version 2 DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); - EXPECT(assert_equal(f3,reduced3->toDecisionTreeFactor())); + EXPECT(assert_equal(f3, reduced3->toDecisionTreeFactor())); DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); - EXPECT(assert_equal(f4,reduced4->toDecisionTreeFactor())); + EXPECT(assert_equal(f4, reduced4->toDecisionTreeFactor())); // full arc-consistency test csp.runArcConsistency(nrColors); + // GTSAM_PRINT(csp); } /* ************************************************************************* */ @@ -229,4 +272,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index 9929938d5..eac0d834e 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -5,14 +5,16 @@ * @date Oct 11, 2013 */ -#include +#include #include #include -#include -#include +#include +#include + #include -#include +#include #include +#include using namespace std; using namespace boost; @@ -23,11 +25,12 @@ using namespace gtsam; * Loopy belief solver for graphs with only binary and unary factors */ class LoopyBelief { - /** Star graph struct for each node, containing * - the star graph itself - * - the product of original unary factors so we don't have to recompute it later, and - * - the factor indices of the corrected belief factors of the neighboring nodes + * - the product of original unary factors so we don't have to recompute it + * later, and + * - the factor indices of the corrected belief factors of the neighboring + * nodes */ typedef std::map CorrectedBeliefIndices; struct StarGraph { @@ -36,41 +39,41 @@ class LoopyBelief { DecisionTreeFactor::shared_ptr unary; VariableIndex varIndex_; StarGraph(const DiscreteFactorGraph::shared_ptr& _star, - const CorrectedBeliefIndices& _beliefIndices, - const DecisionTreeFactor::shared_ptr& _unary) : - star(_star), correctedBeliefIndices(_beliefIndices), unary(_unary), varIndex_( - *_star) { - } + const CorrectedBeliefIndices& _beliefIndices, + const DecisionTreeFactor::shared_ptr& _unary) + : star(_star), + correctedBeliefIndices(_beliefIndices), + unary(_unary), + varIndex_(*_star) {} void print(const std::string& s = "") const { cout << s << ":" << endl; star->print("Star graph: "); - for(Key key: correctedBeliefIndices | boost::adaptors::map_keys) { + for (Key key : correctedBeliefIndices | boost::adaptors::map_keys) { cout << "Belief factor index for " << key << ": " - << correctedBeliefIndices.at(key) << endl; + << correctedBeliefIndices.at(key) << endl; } - if (unary) - unary->print("Unary: "); + if (unary) unary->print("Unary: "); } }; typedef std::map StarGraphs; - StarGraphs starGraphs_; ///< star graph at each variable + StarGraphs starGraphs_; ///< star graph at each variable -public: + public: /** Constructor - * Need all discrete keys to access node's cardinality for creating belief factors + * Need all discrete keys to access node's cardinality for creating belief + * factors * TODO: so troublesome!! */ LoopyBelief(const DiscreteFactorGraph& graph, - const std::map& allDiscreteKeys) : - starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) { - } + const std::map& allDiscreteKeys) + : starGraphs_(buildStarGraphs(graph, allDiscreteKeys)) {} /// print void print(const std::string& s = "") const { cout << s << ":" << endl; - for(Key key: starGraphs_ | boost::adaptors::map_keys) { + for (Key key : starGraphs_ | boost::adaptors::map_keys) { starGraphs_.at(key).print((boost::format("Node %d:") % key).str()); } } @@ -79,12 +82,13 @@ public: DiscreteFactorGraph::shared_ptr iterate( const std::map& allDiscreteKeys) { static const bool debug = false; - static DiscreteConditional::shared_ptr dummyCond; // unused by-product of elimination + static DiscreteConditional::shared_ptr + dummyCond; // unused by-product of elimination DiscreteFactorGraph::shared_ptr beliefs(new DiscreteFactorGraph()); std::map > allMessages; // Eliminate each star graph - for(Key key: starGraphs_ | boost::adaptors::map_keys) { -// cout << "***** Node " << key << "*****" << endl; + for (Key key : starGraphs_ | boost::adaptors::map_keys) { + // cout << "***** Node " << key << "*****" << endl; // initialize belief to the unary factor from the original graph DecisionTreeFactor::shared_ptr beliefAtKey; @@ -92,15 +96,16 @@ public: std::map messages; // eliminate each neighbor in this star graph one by one - for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { + for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices | + boost::adaptors::map_keys) { DiscreteFactorGraph subGraph; - for(size_t factor: starGraphs_.at(key).varIndex_[neighbor]) { + for (size_t factor : starGraphs_.at(key).varIndex_[neighbor]) { subGraph.push_back(starGraphs_.at(key).star->at(factor)); } if (debug) subGraph.print("------- Subgraph:"); DiscreteFactor::shared_ptr message; - boost::tie(dummyCond, message) = EliminateDiscrete(subGraph, - Ordering(list_of(neighbor))); + boost::tie(dummyCond, message) = + EliminateDiscrete(subGraph, Ordering(list_of(neighbor))); // store the new factor into messages messages.insert(make_pair(neighbor, message)); if (debug) message->print("------- Message: "); @@ -108,14 +113,12 @@ public: // Belief is the product of all messages and the unary factor // Incorporate new the factor to belief if (!beliefAtKey) - beliefAtKey = boost::dynamic_pointer_cast( - message); - else beliefAtKey = - boost::make_shared( - (*beliefAtKey) - * (*boost::dynamic_pointer_cast( - message))); + boost::dynamic_pointer_cast(message); + else + beliefAtKey = boost::make_shared( + (*beliefAtKey) * + (*boost::dynamic_pointer_cast(message))); } if (starGraphs_.at(key).unary) beliefAtKey = boost::make_shared( @@ -124,7 +127,7 @@ public: // normalize belief double sum = 0.0; for (size_t v = 0; v < allDiscreteKeys.at(key).second; ++v) { - DiscreteFactor::Values val; + DiscreteValues val; val[key] = v; sum += (*beliefAtKey)(val); } @@ -133,7 +136,8 @@ public: sumFactorTable = (boost::format("%s %f") % sumFactorTable % sum).str(); DecisionTreeFactor sumFactor(allDiscreteKeys.at(key), sumFactorTable); if (debug) sumFactor.print("denomFactor: "); - beliefAtKey = boost::make_shared((*beliefAtKey) / sumFactor); + beliefAtKey = + boost::make_shared((*beliefAtKey) / sumFactor); if (debug) beliefAtKey->print("New belief at key normalized: "); beliefs->push_back(beliefAtKey); allMessages[key] = messages; @@ -141,17 +145,20 @@ public: // Update corrected beliefs VariableIndex beliefFactors(*beliefs); - for(Key key: starGraphs_ | boost::adaptors::map_keys) { + for (Key key : starGraphs_ | boost::adaptors::map_keys) { std::map messages = allMessages[key]; - for(Key neighbor: starGraphs_.at(key).correctedBeliefIndices | boost::adaptors::map_keys) { - DecisionTreeFactor correctedBelief = (*boost::dynamic_pointer_cast< - DecisionTreeFactor>(beliefs->at(beliefFactors[key].front()))) - / (*boost::dynamic_pointer_cast( + for (Key neighbor : starGraphs_.at(key).correctedBeliefIndices | + boost::adaptors::map_keys) { + DecisionTreeFactor correctedBelief = + (*boost::dynamic_pointer_cast( + beliefs->at(beliefFactors[key].front()))) / + (*boost::dynamic_pointer_cast( messages.at(neighbor))); if (debug) correctedBelief.print("correctedBelief: "); - size_t beliefIndex = starGraphs_.at(neighbor).correctedBeliefIndices.at( - key); - starGraphs_.at(neighbor).star->replace(beliefIndex, + size_t beliefIndex = + starGraphs_.at(neighbor).correctedBeliefIndices.at(key); + starGraphs_.at(neighbor).star->replace( + beliefIndex, boost::make_shared(correctedBelief)); } } @@ -161,21 +168,22 @@ public: return beliefs; } -private: + private: /** * Build star graphs for each node. */ - StarGraphs buildStarGraphs(const DiscreteFactorGraph& graph, + StarGraphs buildStarGraphs( + const DiscreteFactorGraph& graph, const std::map& allDiscreteKeys) const { StarGraphs starGraphs; - VariableIndex varIndex(graph); ///< access to all factors of each node - for(Key key: varIndex | boost::adaptors::map_keys) { + VariableIndex varIndex(graph); ///< access to all factors of each node + for (Key key : varIndex | boost::adaptors::map_keys) { // initialize to multiply with other unary factors later DecisionTreeFactor::shared_ptr prodOfUnaries; // collect all factors involving this key in the original graph DiscreteFactorGraph::shared_ptr star(new DiscreteFactorGraph()); - for(size_t factorIndex: varIndex[key]) { + for (size_t factorIndex : varIndex[key]) { star->push_back(graph.at(factorIndex)); // accumulate unary factors @@ -185,9 +193,9 @@ private: graph.at(factorIndex)); else prodOfUnaries = boost::make_shared( - *prodOfUnaries - * (*boost::dynamic_pointer_cast( - graph.at(factorIndex)))); + *prodOfUnaries * + (*boost::dynamic_pointer_cast( + graph.at(factorIndex)))); } } @@ -196,7 +204,7 @@ private: KeySet neighbors = star->keys(); neighbors.erase(key); CorrectedBeliefIndices correctedBeliefIndices; - for(Key neighbor: neighbors) { + for (Key neighbor : neighbors) { // TODO: default table for keys with more than 2 values? string initialBelief; for (size_t v = 0; v < allDiscreteKeys.at(neighbor).second - 1; ++v) { @@ -207,9 +215,8 @@ private: DecisionTreeFactor(allDiscreteKeys.at(neighbor), initialBelief)); correctedBeliefIndices.insert(make_pair(neighbor, star->size() - 1)); } - starGraphs.insert( - make_pair(key, - StarGraph(star, correctedBeliefIndices, prodOfUnaries))); + starGraphs.insert(make_pair( + key, StarGraph(star, correctedBeliefIndices, prodOfUnaries))); } return starGraphs; } @@ -249,7 +256,6 @@ TEST_UNSAFE(LoopyBelief, construction) { DiscreteFactorGraph::shared_ptr beliefs = solver.iterate(allKeys); beliefs->print(); } - } /* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index 3f6c6a1e0..086057a46 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -5,14 +5,13 @@ */ //#define ENABLE_TIMING -#include +#include #include #include +#include -#include - -#include #include +#include #include using namespace boost::assign; @@ -22,7 +21,6 @@ using namespace gtsam; /* ************************************************************************* */ // Create the expected graph of constraints DiscreteFactorGraph createExpected() { - // Start building size_t nrFaculty = 4, nrTimeSlots = 3; @@ -47,27 +45,27 @@ DiscreteFactorGraph createExpected() { string available = "1 1 1 0 1 1 1 1 0 1 1 1"; // Akansel - expected.add(A1, faculty_in_A); // Area 1 - expected.add(A1, "1 1 1 0"); // Advisor + expected.add(A1, faculty_in_A); // Area 1 + expected.add(A1, "1 1 1 0"); // Advisor expected.add(A & A1, available); - expected.add(A2, faculty_in_M); // Area 2 - expected.add(A2, "1 1 1 0"); // Advisor + expected.add(A2, faculty_in_M); // Area 2 + expected.add(A2, "1 1 1 0"); // Advisor expected.add(A & A2, available); - expected.add(A3, faculty_in_P); // Area 3 - expected.add(A3, "1 1 1 0"); // Advisor + expected.add(A3, faculty_in_P); // Area 3 + expected.add(A3, "1 1 1 0"); // Advisor expected.add(A & A3, available); // Mutual exclusion for faculty expected.addAllDiff(A1 & A2 & A3); // Jake - expected.add(J1, faculty_in_H); // Area 1 - expected.add(J1, "1 0 1 1"); // Advisor + expected.add(J1, faculty_in_H); // Area 1 + expected.add(J1, "1 0 1 1"); // Advisor expected.add(J & J1, available); - expected.add(J2, faculty_in_C); // Area 2 - expected.add(J2, "1 0 1 1"); // Advisor + expected.add(J2, faculty_in_C); // Area 2 + expected.add(J2, "1 0 1 1"); // Advisor expected.add(J & J2, available); - expected.add(J3, faculty_in_A); // Area 3 - expected.add(J3, "1 0 1 1"); // Advisor + expected.add(J3, faculty_in_A); // Area 3 + expected.add(J3, "1 0 1 1"); // Advisor expected.add(J & J3, available); // Mutual exclusion for faculty expected.addAllDiff(J1 & J2 & J3); @@ -79,8 +77,7 @@ DiscreteFactorGraph createExpected() { } /* ************************************************************************* */ -TEST( schedulingExample, test) -{ +TEST(schedulingExample, test) { Scheduler s(2); // add faculty @@ -121,33 +118,32 @@ TEST( schedulingExample, test) // Do brute force product and output that to file DecisionTreeFactor product = s.product(); - //product.dot("scheduling", false); + // product.dot("scheduling", false); // Do exact inference gttic(small); - DiscreteFactor::sharedValues MPE = s.optimalAssignment(); + auto MPE = s.optimize(); gttoc(small); // print MPE, commented out as unit tests don't print -// s.printAssignment(MPE); + // s.printAssignment(MPE); // Commented out as does not work yet // s.runArcConsistency(8,10,true); // find the assignment of students to slots with most possible committees // Commented out as not implemented yet -// sharedValues bestSchedule = s.bestSchedule(); -// GTSAM_PRINT(*bestSchedule); + // auto bestSchedule = s.bestSchedule(); + // GTSAM_PRINT(bestSchedule); // find the corresponding most desirable committee assignment // Commented out as not implemented yet -// sharedValues bestAssignment = s.bestAssignment(bestSchedule); -// GTSAM_PRINT(*bestAssignment); + // auto bestAssignment = s.bestAssignment(bestSchedule); + // GTSAM_PRINT(bestAssignment); } /* ************************************************************************* */ -TEST( schedulingExample, smallFromFile) -{ +TEST(schedulingExample, smallFromFile) { string path(TOPSRCDIR "/gtsam_unstable/discrete/examples/"); Scheduler s(2, path + "small.csv"); @@ -179,4 +175,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index e2115e8bc..8b2858169 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -5,74 +5,69 @@ * @author Frank Dellaert */ -#include #include +#include +#include + #include using boost::assign::insert; +#include + #include #include -#include using namespace std; using namespace gtsam; #define PRINT false -class Sudoku: public CSP { +/// A class that encodes Sudoku's as a CSP problem +class Sudoku : public CSP { + size_t n_; ///< Side of Sudoku, e.g. 4 or 9 - /// sudoku size - size_t n_; - - /// discrete keys - typedef std::pair IJ; + /// Mapping from base i,j coordinates to discrete keys: + using IJ = std::pair; std::map dkeys_; -public: - + public: /// return DiscreteKey for cell(i,j) const DiscreteKey& dkey(size_t i, size_t j) const { return dkeys_.at(IJ(i, j)); } /// return Key for cell(i,j) - Key key(size_t i, size_t j) const { - return dkey(i, j).first; - } + Key key(size_t i, size_t j) const { return dkey(i, j).first; } /// Constructor - Sudoku(size_t n, ...) : - n_(n) { + Sudoku(size_t n, ...) : n_(n) { // Create variables, ordering, and unary constraints va_list ap; va_start(ap, n); - Key k=0; for (size_t i = 0; i < n; ++i) { - for (size_t j = 0; j < n; ++j, ++k) { + for (size_t j = 0; j < n; ++j) { // create the key IJ ij(i, j); - dkeys_[ij] = DiscreteKey(k, n); + Symbol key('1' + i, j + 1); + dkeys_[ij] = DiscreteKey(key, n); // get the unary constraint, if any int value = va_arg(ap, int); - // cout << value << " "; if (value != 0) addSingleValue(dkeys_[ij], value - 1); } - //cout << endl; + // cout << endl; } va_end(ap); // add row constraints for (size_t i = 0; i < n; i++) { DiscreteKeys dkeys; - for (size_t j = 0; j < n; j++) - dkeys += dkey(i, j); + for (size_t j = 0; j < n; j++) dkeys += dkey(i, j); addAllDiff(dkeys); } // add col constraints for (size_t j = 0; j < n; j++) { DiscreteKeys dkeys; - for (size_t i = 0; i < n; i++) - dkeys += dkey(i, j); + for (size_t i = 0; i < n; i++) dkeys += dkey(i, j); addAllDiff(dkeys); } @@ -84,8 +79,7 @@ public: // Box I,J DiscreteKeys dkeys; for (size_t i = i0; i < i0 + N; i++) - for (size_t j = j0; j < j0 + N; j++) - dkeys += dkey(i, j); + for (size_t j = j0; j < j0 + N; j++) dkeys += dkey(i, j); addAllDiff(dkeys); j0 += N; } @@ -94,120 +88,171 @@ public: } /// Print readable form of assignment - void printAssignment(DiscreteFactor::sharedValues assignment) const { + void printAssignment(const DiscreteValues& assignment) const { for (size_t i = 0; i < n_; i++) { for (size_t j = 0; j < n_; j++) { Key k = key(i, j); - cout << 1 + assignment->at(k) << " "; + cout << 1 + assignment.at(k) << " "; } cout << endl; } } /// solve and print solution - void printSolution() { - DiscreteFactor::sharedValues MPE = optimalAssignment(); + void printSolution() const { + auto MPE = optimize(); printAssignment(MPE); } + // Print domain + void printDomains(const Domains& domains) { + for (size_t i = 0; i < n_; i++) { + for (size_t j = 0; j < n_; j++) { + Key k = key(i, j); + cout << domains.at(k).base1Str(); + cout << "\t"; + } // i + cout << endl; + } // j + } }; /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, small) -{ - Sudoku csp(4, - 1,0, 0,4, - 0,0, 0,0, - - 4,0, 2,0, - 0,1, 0,0); - - // Do BP - csp.runArcConsistency(4,10,PRINT); +TEST(Sudoku, small) { + Sudoku csp(4, // + 1, 0, 0, 4, // + 0, 0, 0, 0, // + 4, 0, 2, 0, // + 0, 1, 0, 0); // optimize and check - CSP::sharedValues solution = csp.optimalAssignment(); - CSP::Values expected; - insert(expected) - (csp.key(0,0), 0)(csp.key(0,1), 1)(csp.key(0,2), 2)(csp.key(0,3), 3) - (csp.key(1,0), 2)(csp.key(1,1), 3)(csp.key(1,2), 0)(csp.key(1,3), 1) - (csp.key(2,0), 3)(csp.key(2,1), 2)(csp.key(2,2), 1)(csp.key(2,3), 0) - (csp.key(3,0), 1)(csp.key(3,1), 0)(csp.key(3,2), 3)(csp.key(3,3), 2); - EXPECT(assert_equal(expected,*solution)); - //csp.printAssignment(solution); + auto solution = csp.optimize(); + DiscreteValues expected; + insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( + csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( + csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)( + csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)( + csp.key(3, 3), 2); + EXPECT(assert_equal(expected, solution)); + // csp.printAssignment(solution); + + // Do BP (AC1) + auto domains = csp.runArcConsistency(4, 3); + // csp.printDomains(domains); + Domain domain44 = domains.at(Symbol('4', 4)); + EXPECT_LONGS_EQUAL(1, domain44.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Should only be 16 new Domains + EXPECT_LONGS_EQUAL(16, new_csp.size()); + + // Check that solution + auto new_solution = new_csp.optimize(); + // csp.printAssignment(new_solution); + EXPECT(assert_equal(expected, new_solution)); } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, easy) -{ - Sudoku sudoku(9, - 0,0,5, 0,9,0, 0,0,1, - 0,0,0, 0,0,2, 0,7,3, - 7,6,0, 0,0,8, 2,0,0, +TEST(Sudoku, easy) { + Sudoku csp(9, // + 0, 0, 5, 0, 9, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 2, 0, 7, 3, // + 7, 6, 0, 0, 0, 8, 2, 0, 0, // - 0,1,2, 0,0,9, 0,0,4, - 0,0,0, 2,0,3, 0,0,0, - 3,0,0, 1,0,0, 9,6,0, + 0, 1, 2, 0, 0, 9, 0, 0, 4, // + 0, 0, 0, 2, 0, 3, 0, 0, 0, // + 3, 0, 0, 1, 0, 0, 9, 6, 0, // - 0,0,1, 9,0,0, 0,5,8, - 9,7,0, 5,0,0, 0,0,0, - 5,0,0, 0,3,0, 7,0,0); + 0, 0, 1, 9, 0, 0, 0, 5, 8, // + 9, 7, 0, 5, 0, 0, 0, 0, 0, // + 5, 0, 0, 0, 3, 0, 7, 0, 0); - // Do BP - sudoku.runArcConsistency(4,10,PRINT); + // csp.printSolution(); // don't do it - // sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 26 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 26, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, extreme) -{ - Sudoku sudoku(9, - 0,0,9, 7,4,8, 0,0,0, - 7,0,0, 0,0,0, 0,0,0, - 0,2,0, 1,0,9, 0,0,0, - - 0,0,7, 0,0,0, 2,4,0, - 0,6,4, 0,1,0, 5,9,0, - 0,9,8, 0,0,0, 3,0,0, - - 0,0,0, 8,0,3, 0,2,0, - 0,0,0, 0,0,0, 0,0,6, - 0,0,0, 2,7,5, 9,0,0); +TEST(Sudoku, extreme) { + Sudoku csp(9, // + 0, 0, 9, 7, 4, 8, 0, 0, 0, 7, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, // + 0, 1, 0, 9, 0, 0, 0, 0, 0, 7, // + 0, 0, 0, 2, 4, 0, 0, 6, 4, 0, // + 1, 0, 5, 9, 0, 0, 9, 8, 0, 0, // + 0, 3, 0, 0, 0, 0, 0, 8, 0, 3, // + 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0); // Do BP - sudoku.runArcConsistency(9,10,PRINT); + csp.runArcConsistency(9, 10); #ifdef METIS - VariableIndexOrdered index(sudoku); + VariableIndexOrdered index(csp); index.print("index"); ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt"); index.outputMetisFormat(os); #endif - //sudoku.printSolution(); // don't do it + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(2, domain99.nrValues()); + + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // 81 new Domains, and still 20 all-diff constraints + EXPECT_LONGS_EQUAL(81 + 20, new_csp.size()); + + // csp.printSolution(); // still don't do it ! :-( } /* ************************************************************************* */ -TEST_UNSAFE( Sudoku, AJC_3star_Feb8_2012) -{ - Sudoku sudoku(9, - 9,5,0, 0,0,6, 0,0,0, - 0,8,4, 0,7,0, 0,0,0, - 6,2,0, 5,0,0, 4,0,0, +TEST(Sudoku, AJC_3star_Feb8_2012) { + Sudoku csp(9, // + 9, 5, 0, 0, 0, 6, 0, 0, 0, // + 0, 8, 4, 0, 7, 0, 0, 0, 0, // + 6, 2, 0, 5, 0, 0, 4, 0, 0, // - 0,0,0, 2,9,0, 6,0,0, - 0,9,0, 0,0,0, 0,2,0, - 0,0,2, 0,6,3, 0,0,0, + 0, 0, 0, 2, 9, 0, 6, 0, 0, // + 0, 9, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 2, 0, 6, 3, 0, 0, 0, // - 0,0,9, 0,0,7, 0,6,8, - 0,0,0, 0,3,0, 2,9,0, - 0,0,0, 1,0,0, 0,3,7); + 0, 0, 9, 0, 0, 7, 0, 6, 8, // + 0, 0, 0, 0, 3, 0, 2, 9, 0, // + 0, 0, 0, 1, 0, 0, 0, 3, 7); - // Do BP - sudoku.runArcConsistency(9,10,PRINT); + // Do BP (AC1) + auto domains = csp.runArcConsistency(9, 10); + // csp.printDomains(domains); + Key key99 = Symbol('9', 9); + Domain domain99 = domains.at(key99); + EXPECT_LONGS_EQUAL(1, domain99.nrValues()); - //sudoku.printSolution(); // don't do it + // Test Creation of a new, simpler CSP + CSP new_csp = csp.partiallyApply(domains); + // Just the 81 new Domains + EXPECT_LONGS_EQUAL(81, new_csp.size()); + + // Check that solution + auto solution = new_csp.optimize(); + // csp.printAssignment(solution); + EXPECT_LONGS_EQUAL(6, solution.at(key99)); } /* ************************************************************************* */ @@ -216,4 +261,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam_unstable/slam/BetweenFactorEM.h b/gtsam_unstable/slam/BetweenFactorEM.h index 572935da3..9c19bae8c 100644 --- a/gtsam_unstable/slam/BetweenFactorEM.h +++ b/gtsam_unstable/slam/BetweenFactorEM.h @@ -421,4 +421,8 @@ private: }; // \class BetweenFactorEM +/// traits +template +struct traits > : public Testable > {}; + } // namespace gtsam diff --git a/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h b/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h index 0e2aebd7f..b053b13f8 100644 --- a/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h +++ b/gtsam_unstable/slam/EquivInertialNavFactor_GlobalVel_NoBias.h @@ -372,15 +372,15 @@ public: Matrix Z_3x3 = Z_3x3; Matrix I_3x3 = I_3x3; - Matrix H_pos_pos = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, _1, delta_vel_in_t0), delta_pos_in_t0); - Matrix H_pos_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, delta_pos_in_t0, _1), delta_vel_in_t0); + Matrix H_pos_pos = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, _1, delta_vel_in_t0), delta_pos_in_t0); + Matrix H_pos_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_pos, msr_dt, delta_pos_in_t0, _1), delta_vel_in_t0); Matrix H_pos_angles = Z_3x3; - Matrix H_vel_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, delta_angles, _1, flag_use_body_P_sensor, body_P_sensor), delta_vel_in_t0); - Matrix H_vel_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, _1, delta_vel_in_t0, flag_use_body_P_sensor, body_P_sensor), delta_angles); + Matrix H_vel_vel = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, delta_angles, _1, flag_use_body_P_sensor, body_P_sensor), delta_vel_in_t0); + Matrix H_vel_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_vel, msr_gyro_t, msr_acc_t, msr_dt, _1, delta_vel_in_t0, flag_use_body_P_sensor, body_P_sensor), delta_angles); Matrix H_vel_pos = Z_3x3; - Matrix H_angles_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_angles, msr_gyro_t, msr_dt, _1, flag_use_body_P_sensor, body_P_sensor), delta_angles); + Matrix H_angles_angles = numericalDerivative11(std::bind(&PreIntegrateIMUObservations_delta_angles, msr_gyro_t, msr_dt, _1, flag_use_body_P_sensor, body_P_sensor), delta_angles); Matrix H_angles_pos = Z_3x3; Matrix H_angles_vel = Z_3x3; diff --git a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h index 5264c8f4b..f81c18bfa 100644 --- a/gtsam_unstable/slam/LocalOrientedPlane3Factor.h +++ b/gtsam_unstable/slam/LocalOrientedPlane3Factor.h @@ -9,6 +9,8 @@ #include #include +#include + #include namespace gtsam { @@ -32,16 +34,16 @@ namespace gtsam { * a local linearisation point for the plane. The plane is representated and * optimized in x1 frame in the optimization. */ -class LocalOrientedPlane3Factor: public NoiseModelFactor3 { -protected: +class GTSAM_UNSTABLE_EXPORT LocalOrientedPlane3Factor + : public NoiseModelFactor3 { + protected: OrientedPlane3 measured_p_; typedef NoiseModelFactor3 Base; public: /// Constructor LocalOrientedPlane3Factor() {} - virtual ~LocalOrientedPlane3Factor() {} + ~LocalOrientedPlane3Factor() override {} /** Constructor with measured plane (a,b,c,d) coefficients * @param z measured plane (a,b,c,d) coefficients as 4D vector @@ -54,12 +56,12 @@ public: * Note: The anchorPoseKey can simply be chosen as the first pose a plane * is observed. */ - LocalOrientedPlane3Factor(const Vector4& z, const SharedGaussian& noiseModel, + LocalOrientedPlane3Factor(const Vector4& z, const SharedNoiseModel& noiseModel, Key poseKey, Key anchorPoseKey, Key landmarkKey) : Base(noiseModel, poseKey, anchorPoseKey, landmarkKey), measured_p_(z) {} LocalOrientedPlane3Factor(const OrientedPlane3& z, - const SharedGaussian& noiseModel, + const SharedNoiseModel& noiseModel, Key poseKey, Key anchorPoseKey, Key landmarkKey) : Base(noiseModel, poseKey, anchorPoseKey, landmarkKey), measured_p_(z) {} diff --git a/gtsam_unstable/slam/PoseToPointFactor.h b/gtsam_unstable/slam/PoseToPointFactor.h index ec7da22ef..cab48e506 100644 --- a/gtsam_unstable/slam/PoseToPointFactor.h +++ b/gtsam_unstable/slam/PoseToPointFactor.h @@ -1,11 +1,14 @@ /** * @file PoseToPointFactor.hpp - * @brief This factor can be used to track a 3D landmark over time by - *providing local measurements of its location. + * @brief This factor can be used to model relative position measurements + * from a (2D or 3D) pose to a landmark * @author David Wisth + * @author Luca Carlone **/ #pragma once +#include +#include #include #include #include @@ -17,12 +20,13 @@ namespace gtsam { * A class for a measurement between a pose and a point. * @addtogroup SLAM */ -class PoseToPointFactor : public NoiseModelFactor2 { +template +class PoseToPointFactor : public NoiseModelFactor2 { private: typedef PoseToPointFactor This; - typedef NoiseModelFactor2 Base; + typedef NoiseModelFactor2 Base; - Point3 measured_; /** the point measurement in local coordinates */ + POINT measured_; /** the point measurement in local coordinates */ public: // shorthand for a smart pointer to a factor @@ -32,7 +36,7 @@ class PoseToPointFactor : public NoiseModelFactor2 { PoseToPointFactor() {} /** Constructor */ - PoseToPointFactor(Key key1, Key key2, const Point3& measured, + PoseToPointFactor(Key key1, Key key2, const POINT& measured, const SharedNoiseModel& model) : Base(model, key1, key2), measured_(measured) {} @@ -41,8 +45,8 @@ class PoseToPointFactor : public NoiseModelFactor2 { /** implement functions needed for Testable */ /** print */ - virtual void print(const std::string& s, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const { + void print(const std::string& s, const KeyFormatter& keyFormatter = + DefaultKeyFormatter) const override { std::cout << s << "PoseToPointFactor(" << keyFormatter(this->key1()) << "," << keyFormatter(this->key2()) << ")\n" << " measured: " << measured_.transpose() << std::endl; @@ -50,30 +54,31 @@ class PoseToPointFactor : public NoiseModelFactor2 { } /** equals */ - virtual bool equals(const NonlinearFactor& expected, - double tol = 1e-9) const { + bool equals(const NonlinearFactor& expected, + double tol = 1e-9) const override { const This* e = dynamic_cast(&expected); return e != nullptr && Base::equals(*e, tol) && - traits::Equals(this->measured_, e->measured_, tol); + traits::Equals(this->measured_, e->measured_, tol); } /** implement functions needed to derive from Factor */ /** vector of errors - * @brief Error = wTwi.inverse()*wPwp - measured_ - * @param wTwi The pose of the sensor in world coordinates - * @param wPwp The estimated point location in world coordinates + * @brief Error = w_T_b.inverse()*w_P - measured_ + * @param w_T_b The pose of the body in world coordinates + * @param w_P The estimated point location in world coordinates * * Note: measured_ and the error are in local coordiantes. */ - Vector evaluateError(const Pose3& wTwi, const Point3& wPwp, - boost::optional H1 = boost::none, - boost::optional H2 = boost::none) const { - return wTwi.transformTo(wPwp, H1, H2) - measured_; + Vector evaluateError( + const POSE& w_T_b, const POINT& w_P, + boost::optional H1 = boost::none, + boost::optional H2 = boost::none) const override { + return w_T_b.transformTo(w_P, H1, H2) - measured_; } /** return the measured */ - const Point3& measured() const { return measured_; } + const POINT& measured() const { return measured_; } private: /** Serialization function */ diff --git a/gtsam_unstable/slam/ProjectionFactorPPPC.h b/gtsam_unstable/slam/ProjectionFactorPPPC.h index fbc11503c..18ee13b9a 100644 --- a/gtsam_unstable/slam/ProjectionFactorPPPC.h +++ b/gtsam_unstable/slam/ProjectionFactorPPPC.h @@ -18,9 +18,11 @@ #pragma once -#include -#include #include +#include +#include +#include + #include namespace gtsam { @@ -30,60 +32,50 @@ namespace gtsam { * estimates the body pose, body-camera transform, 3D landmark, and calibration. * @addtogroup SLAM */ - template - class ProjectionFactorPPPC: public NoiseModelFactor4 { - protected: +template +class GTSAM_UNSTABLE_EXPORT ProjectionFactorPPPC + : public NoiseModelFactor4 { + protected: + Point2 measured_; ///< 2D measurement - Point2 measured_; ///< 2D measurement + // verbosity handling for Cheirality Exceptions + bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) + bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) - // verbosity handling for Cheirality Exceptions - bool throwCheirality_; ///< If true, rethrows Cheirality exceptions (default: false) - bool verboseCheirality_; ///< If true, prints text for Cheirality exceptions (default: false) + public: + /// shorthand for base class type + typedef NoiseModelFactor4 Base; - public: + /// shorthand for this class + typedef ProjectionFactorPPPC This; - /// shorthand for base class type - typedef NoiseModelFactor4 Base; + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; - /// shorthand for this class - typedef ProjectionFactorPPPC This; - - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - - /// Default constructor + /// Default constructor ProjectionFactorPPPC() : measured_(0.0, 0.0), throwCheirality_(false), verboseCheirality_(false) { } - /** - * Constructor - * TODO: Mark argument order standard (keys, measurement, parameters) - * @param measured is the 2 dimensional location of point in image (the measurement) - * @param model is the standard deviation - * @param poseKey is the index of the camera - * @param pointKey is the index of the landmark - * @param K shared pointer to the constant calibration - */ - ProjectionFactorPPPC(const Point2& measured, const SharedNoiseModel& model, - Key poseKey, Key transformKey, Key pointKey, Key calibKey) : - Base(model, poseKey, transformKey, pointKey, calibKey), measured_(measured), - throwCheirality_(false), verboseCheirality_(false) {} /** * Constructor with exception-handling flags * TODO: Mark argument order standard (keys, measurement, parameters) - * @param measured is the 2 dimensional location of point in image (the measurement) + * @param measured is the 2 dimensional location of point in image (the + * measurement) * @param model is the standard deviation * @param poseKey is the index of the camera + * @param transformKey is the index of the extrinsic calibration * @param pointKey is the index of the landmark - * @param K shared pointer to the constant calibration - * @param throwCheirality determines whether Cheirality exceptions are rethrown - * @param verboseCheirality determines whether exceptions are printed for Cheirality + * @param calibKey is the index of the intrinsic calibration + * @param throwCheirality determines whether Cheirality exceptions are + * rethrown + * @param verboseCheirality determines whether exceptions are printed for + * Cheirality */ ProjectionFactorPPPC(const Point2& measured, const SharedNoiseModel& model, Key poseKey, Key transformKey, Key pointKey, Key calibKey, - bool throwCheirality, bool verboseCheirality) : + bool throwCheirality = false, bool verboseCheirality = false) : Base(model, poseKey, transformKey, pointKey, calibKey), measured_(measured), throwCheirality_(throwCheirality), verboseCheirality_(verboseCheirality) {} @@ -123,8 +115,8 @@ namespace gtsam { try { if(H1 || H2 || H3 || H4) { Matrix H0, H02; - PinholeCamera camera(pose.compose(transform, H0, H02), K); - Point2 reprojectionError(camera.project(point, H1, H3, H4) - measured_); + const PinholeCamera camera(pose.compose(transform, H0, H02), K); + const Point2 reprojectionError(camera.project(point, H1, H3, H4) - measured_); *H2 = *H1 * H02; *H1 = *H1 * H0; return reprojectionError; @@ -168,7 +160,7 @@ namespace gtsam { ar & BOOST_SERIALIZATION_NVP(throwCheirality_); ar & BOOST_SERIALIZATION_NVP(verboseCheirality_); } - }; +}; /// traits template diff --git a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h index c92653c13..2aeaa4824 100644 --- a/gtsam_unstable/slam/ProjectionFactorRollingShutter.h +++ b/gtsam_unstable/slam/ProjectionFactorRollingShutter.h @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -40,7 +41,7 @@ namespace gtsam { * @addtogroup SLAM */ -class ProjectionFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT ProjectionFactorRollingShutter : public NoiseModelFactor3 { protected: // Keep a copy of measurement and calibration for I/O diff --git a/gtsam_unstable/slam/ReadMe.md b/gtsam_unstable/slam/README.md similarity index 100% rename from gtsam_unstable/slam/ReadMe.md rename to gtsam_unstable/slam/README.md diff --git a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h index 23203be67..ff84fcd16 100644 --- a/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h +++ b/gtsam_unstable/slam/SmartProjectionPoseFactorRollingShutter.h @@ -20,6 +20,7 @@ #include #include +#include namespace gtsam { /** @@ -41,12 +42,14 @@ namespace gtsam { * @addtogroup SLAM */ template -class SmartProjectionPoseFactorRollingShutter +class GTSAM_UNSTABLE_EXPORT SmartProjectionPoseFactorRollingShutter : public SmartProjectionFactor { private: typedef SmartProjectionFactor Base; typedef SmartProjectionPoseFactorRollingShutter This; typedef typename CAMERA::CalibrationType CALIBRATION; + typedef typename CAMERA::Measurement MEASUREMENT; + typedef typename CAMERA::MeasurementVector MEASUREMENTS; protected: /// The keys of the pose of the body (with respect to an external world @@ -68,12 +71,6 @@ class SmartProjectionPoseFactorRollingShutter public: EIGEN_MAKE_ALIGNED_OPERATOR_NEW - typedef CAMERA Camera; - typedef CameraSet Cameras; - - /// shorthand for a smart pointer to a factor - typedef boost::shared_ptr shared_ptr; - static const int DimBlock = 12; ///< size of the variable stacking 2 poses from which the observation ///< pose is interpolated @@ -84,6 +81,12 @@ class SmartProjectionPoseFactorRollingShutter typedef std::vector> FBlocks; // vector of F blocks + typedef CAMERA Camera; + typedef CameraSet Cameras; + + /// shorthand for a smart pointer to a factor + typedef boost::shared_ptr shared_ptr; + /// Default constructor, only for serialization SmartProjectionPoseFactorRollingShutter() {} @@ -125,7 +128,7 @@ class SmartProjectionPoseFactorRollingShutter * interpolated pose is the same as world_P_body_key1 * @param cameraId ID of the camera taking the measurement (default 0) */ - void add(const Point2& measured, const Key& world_P_body_key1, + void add(const MEASUREMENT& measured, const Key& world_P_body_key1, const Key& world_P_body_key2, const double& alpha, const size_t& cameraId = 0) { // store measurements in base class @@ -164,7 +167,7 @@ class SmartProjectionPoseFactorRollingShutter * @param cameraIds IDs of the cameras taking each measurement (same order as * the measurements) */ - void add(const Point2Vector& measurements, + void add(const MEASUREMENTS& measurements, const std::vector>& world_P_body_key_pairs, const std::vector& alphas, const FastVector& cameraIds = FastVector()) { @@ -330,12 +333,13 @@ class SmartProjectionPoseFactorRollingShutter const typename Base::Camera& camera_i = (*cameraRig_)[cameraIds_[i]]; auto body_P_cam = camera_i.pose(); auto w_P_cam = w_P_body.compose(body_P_cam, dPoseCam_dInterpPose); - PinholeCamera camera(w_P_cam, camera_i.calibration()); + typename Base::Camera camera( + w_P_cam, make_shared( + camera_i.calibration())); // get jacobians and error vector for current measurement - Point2 reprojectionError_i = - Point2(camera.project(*this->result_, dProject_dPoseCam, Ei) - - this->measured_.at(i)); + Point2 reprojectionError_i = camera.reprojectionError( + *this->result_, this->measured_.at(i), dProject_dPoseCam, Ei); Eigen::Matrix J; // 2 x 12 J.block(0, 0, ZDim, 6) = dProject_dPoseCam * dPoseCam_dInterpPose * @@ -403,7 +407,7 @@ class SmartProjectionPoseFactorRollingShutter for (size_t i = 0; i < Fs.size(); i++) Fs[i] = this->noiseModel_->Whiten(Fs[i]); - Matrix3 P = Base::Cameras::PointCov(E, lambda, diagonalDamping); + Matrix3 P = Cameras::PointCov(E, lambda, diagonalDamping); // Collect all the key pairs: these are the keys that correspond to the // blocks in Fs (on which we apply the Schur Complement) diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactor.h b/gtsam_unstable/slam/SmartStereoProjectionFactor.h index 88e112998..5cdfb2ab7 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactor.h @@ -20,18 +20,18 @@ #pragma once -#include -#include - -#include #include #include -#include +#include #include +#include +#include +#include #include +#include -#include #include +#include #include namespace gtsam { @@ -49,8 +49,9 @@ typedef SmartProjectionParams SmartStereoProjectionParams; * If you'd like to store poses in values instead of cameras, use * SmartStereoProjectionPoseFactor instead */ -class SmartStereoProjectionFactor: public SmartFactorBase { -private: +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionFactor + : public SmartFactorBase { + private: typedef SmartFactorBase Base; diff --git a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h index ce6df15cb..e20241a0e 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h +++ b/gtsam_unstable/slam/SmartStereoProjectionFactorPP.h @@ -40,7 +40,8 @@ namespace gtsam { * are Pose3 variables). * @addtogroup SLAM */ -class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionFactorPP + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; @@ -294,7 +295,6 @@ class SmartStereoProjectionFactorPP : public SmartStereoProjectionFactor { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_NVP(K_all_); } - }; // end of class declaration diff --git a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h index 2a8180ac5..a46000a68 100644 --- a/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h +++ b/gtsam_unstable/slam/SmartStereoProjectionPoseFactor.h @@ -43,7 +43,8 @@ namespace gtsam { * This factor requires that values contains the involved poses (Pose3). * @addtogroup SLAM */ -class SmartStereoProjectionPoseFactor : public SmartStereoProjectionFactor { +class GTSAM_UNSTABLE_EXPORT SmartStereoProjectionPoseFactor + : public SmartStereoProjectionFactor { protected: /// shared pointer to calibration object (one for each camera) std::vector> K_all_; diff --git a/gtsam_unstable/slam/serialization.cpp b/gtsam_unstable/slam/serialization.cpp index 88a94fd51..d87ca6f2d 100644 --- a/gtsam_unstable/slam/serialization.cpp +++ b/gtsam_unstable/slam/serialization.cpp @@ -5,8 +5,6 @@ * @author Alex Cunningham */ -#include -#include #include #include @@ -31,8 +29,6 @@ using namespace gtsam; // Creating as many permutations of factors as possible -typedef PriorFactor PriorFactorLieVector; -typedef PriorFactor PriorFactorLieMatrix; typedef PriorFactor PriorFactorPoint2; typedef PriorFactor PriorFactorStereoPoint2; typedef PriorFactor PriorFactorPoint3; @@ -46,8 +42,6 @@ typedef PriorFactor PriorFactorCalibratedCamera; typedef PriorFactor PriorFactorPinholeCameraCal3_S2; typedef PriorFactor PriorFactorStereoCamera; -typedef BetweenFactor BetweenFactorLieVector; -typedef BetweenFactor BetweenFactorLieMatrix; typedef BetweenFactor BetweenFactorPoint2; typedef BetweenFactor BetweenFactorPoint3; typedef BetweenFactor BetweenFactorRot2; @@ -55,8 +49,6 @@ typedef BetweenFactor BetweenFactorRot3; typedef BetweenFactor BetweenFactorPose2; typedef BetweenFactor BetweenFactorPose3; -typedef NonlinearEquality NonlinearEqualityLieVector; -typedef NonlinearEquality NonlinearEqualityLieMatrix; typedef NonlinearEquality NonlinearEqualityPoint2; typedef NonlinearEquality NonlinearEqualityStereoPoint2; typedef NonlinearEquality NonlinearEqualityPoint3; @@ -112,8 +104,6 @@ BOOST_CLASS_EXPORT_GUID(gtsam::SharedDiagonal, "gtsam_SharedDiagonal"); /* Create GUIDs for geometry */ /* ************************************************************************* */ -GTSAM_VALUE_EXPORT(gtsam::LieVector); -GTSAM_VALUE_EXPORT(gtsam::LieMatrix); GTSAM_VALUE_EXPORT(gtsam::Point2); GTSAM_VALUE_EXPORT(gtsam::StereoPoint2); GTSAM_VALUE_EXPORT(gtsam::Point3); @@ -133,8 +123,6 @@ GTSAM_VALUE_EXPORT(gtsam::StereoCamera); BOOST_CLASS_EXPORT_GUID(gtsam::JacobianFactor, "gtsam::JacobianFactor"); BOOST_CLASS_EXPORT_GUID(gtsam::HessianFactor , "gtsam::HessianFactor"); -BOOST_CLASS_EXPORT_GUID(PriorFactorLieVector, "gtsam::PriorFactorLieVector"); -BOOST_CLASS_EXPORT_GUID(PriorFactorLieMatrix, "gtsam::PriorFactorLieMatrix"); BOOST_CLASS_EXPORT_GUID(PriorFactorPoint2, "gtsam::PriorFactorPoint2"); BOOST_CLASS_EXPORT_GUID(PriorFactorStereoPoint2, "gtsam::PriorFactorStereoPoint2"); BOOST_CLASS_EXPORT_GUID(PriorFactorPoint3, "gtsam::PriorFactorPoint3"); @@ -147,8 +135,6 @@ BOOST_CLASS_EXPORT_GUID(PriorFactorCal3DS2, "gtsam::PriorFactorCal3DS2"); BOOST_CLASS_EXPORT_GUID(PriorFactorCalibratedCamera, "gtsam::PriorFactorCalibratedCamera"); BOOST_CLASS_EXPORT_GUID(PriorFactorStereoCamera, "gtsam::PriorFactorStereoCamera"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorLieVector, "gtsam::BetweenFactorLieVector"); -BOOST_CLASS_EXPORT_GUID(BetweenFactorLieMatrix, "gtsam::BetweenFactorLieMatrix"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint2, "gtsam::BetweenFactorPoint2"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPoint3, "gtsam::BetweenFactorPoint3"); BOOST_CLASS_EXPORT_GUID(BetweenFactorRot2, "gtsam::BetweenFactorRot2"); @@ -156,8 +142,6 @@ BOOST_CLASS_EXPORT_GUID(BetweenFactorRot3, "gtsam::BetweenFactorRot3"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPose2, "gtsam::BetweenFactorPose2"); BOOST_CLASS_EXPORT_GUID(BetweenFactorPose3, "gtsam::BetweenFactorPose3"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityLieVector, "gtsam::NonlinearEqualityLieVector"); -BOOST_CLASS_EXPORT_GUID(NonlinearEqualityLieMatrix, "gtsam::NonlinearEqualityLieMatrix"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint2, "gtsam::NonlinearEqualityPoint2"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityStereoPoint2, "gtsam::NonlinearEqualityStereoPoint2"); BOOST_CLASS_EXPORT_GUID(NonlinearEqualityPoint3, "gtsam::NonlinearEqualityPoint3"); @@ -189,7 +173,7 @@ BOOST_CLASS_EXPORT_GUID(GeneralSFMFactor2Cal3_S2, "gtsam::GeneralSFMFactor2Cal3_ BOOST_CLASS_EXPORT_GUID(GenericStereoFactor3D, "gtsam::GenericStereoFactor3D"); -#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V41 +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 typedef PriorFactor PriorFactorSimpleCamera; typedef NonlinearEquality NonlinearEqualitySimpleCamera; diff --git a/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp b/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp index 4d6e1912a..f43ae293e 100644 --- a/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp +++ b/gtsam_unstable/slam/tests/testBetweenFactorEM.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -21,26 +22,24 @@ using namespace gtsam; // Disabled this test because it is currently failing - remove the lines "#if 0" and "#endif" below // to reenable the test. -#if 0 +// #if 0 /* ************************************************************************* */ -LieVector predictionError(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactorEM& factor){ +Vector predictionError(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactorEM& factor){ gtsam::Values values; values.insert(key1, p1); values.insert(key2, p2); - // LieVector err = factor.whitenedError(values); - // return err; - return LieVector::Expmap(factor.whitenedError(values)); + return factor.whitenedError(values); } /* ************************************************************************* */ -LieVector predictionError_standard(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactor& factor){ +Vector predictionError_standard(const Pose2& p1, const Pose2& p2, const gtsam::Key& key1, const gtsam::Key& key2, const BetweenFactor& factor){ gtsam::Values values; values.insert(key1, p1); values.insert(key2, p2); - // LieVector err = factor.whitenedError(values); + // Vector err = factor.whitenedError(values); // return err; - return LieVector::Expmap(factor.whitenedError(values)); + return factor.whitenedError(values); } /* ************************************************************************* */ @@ -99,8 +98,8 @@ TEST( BetweenFactorEM, EvaluateError) Vector actual_err_wh = f.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); - Vector actual_err_wh_outlier = (Vector(3) << actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); + Vector3 actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector3 actual_err_wh_outlier = Vector3(actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); // cout << "Inlier test. norm of actual_err_wh_inlier, actual_err_wh_outlier: "< h_EM(key1, key2, rel_pose_msr, model_inlier, model_outlier, prior_inlier, prior_outlier); actual_err_wh = h_EM.whitenedError(values); - actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); BetweenFactor h(key1, key2, rel_pose_msr, model_inlier ); Vector actual_err_wh_stnd = h.whitenedError(values); @@ -178,7 +177,7 @@ TEST (BetweenFactorEM, jacobian ) { // compare to standard between factor BetweenFactor h(key1, key2, rel_pose_msr, model_inlier ); Vector actual_err_wh_stnd = h.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); // CHECK( assert_equal(actual_err_wh_stnd, actual_err_wh_inlier, 1e-8)); std::vector H_actual_stnd_unwh(2); (void)h.unwhitenedError(values, H_actual_stnd_unwh); @@ -190,12 +189,13 @@ TEST (BetweenFactorEM, jacobian ) { // CHECK( assert_equal(H2_actual_stnd, H2_actual, 1e-8)); double stepsize = 1.0e-9; - Matrix H1_expected = gtsam::numericalDerivative11(std::bind(&predictionError, _1, p2, key1, key2, f), p1, stepsize); - Matrix H2_expected = gtsam::numericalDerivative11(std::bind(&predictionError, p1, _1, key1, key2, f), p2, stepsize); + using std::placeholders::_1; + Matrix H1_expected = gtsam::numericalDerivative11(std::bind(&predictionError, _1, p2, key1, key2, f), p1, stepsize); + Matrix H2_expected = gtsam::numericalDerivative11(std::bind(&predictionError, p1, _1, key1, key2, f), p2, stepsize); // try to check numerical derivatives of a standard between factor - Matrix H1_expected_stnd = gtsam::numericalDerivative11(std::bind(&predictionError_standard, _1, p2, key1, key2, h), p1, stepsize); + Matrix H1_expected_stnd = gtsam::numericalDerivative11(std::bind(&predictionError_standard, _1, p2, key1, key2, h), p1, stepsize); // CHECK( assert_equal(H1_expected_stnd, H1_actual_stnd, 1e-5)); // // @@ -240,8 +240,8 @@ TEST( BetweenFactorEM, CaseStudy) Vector actual_err_unw = f.unwhitenedError(values); Vector actual_err_wh = f.whitenedError(values); - Vector actual_err_wh_inlier = (Vector(3) << actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); - Vector actual_err_wh_outlier = (Vector(3) << actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); + Vector3 actual_err_wh_inlier = Vector3(actual_err_wh[0], actual_err_wh[1], actual_err_wh[2]); + Vector3 actual_err_wh_outlier = Vector3(actual_err_wh[3], actual_err_wh[4], actual_err_wh[5]); if (debug){ cout << "p_inlier_outler: "<print("model_inlier:"); - model_outlier->print("model_outlier:"); - model_inlier_new->print("model_inlier_new:"); - model_outlier_new->print("model_outlier_new:"); + // model_inlier->print("model_inlier:"); + // model_outlier->print("model_outlier:"); + // model_inlier_new->print("model_inlier_new:"); + // model_outlier_new->print("model_outlier_new:"); } -#endif +// #endif /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} diff --git a/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp b/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp index 8692cf584..ed4092c60 100644 --- a/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp +++ b/gtsam_unstable/slam/tests/testGaussMarkov1stOrderFactor.cpp @@ -16,22 +16,23 @@ * @date Jan 17, 2012 */ -#include -#include -#include -#include #include -#include +#include +#include +#include +#include +#include using namespace std::placeholders; using namespace std; using namespace gtsam; //! Factors -typedef GaussMarkov1stOrderFactor GaussMarkovFactor; +typedef GaussMarkov1stOrderFactor GaussMarkovFactor; /* ************************************************************************* */ -LieVector predictionError(const LieVector& v1, const LieVector& v2, const GaussMarkovFactor factor) { +Vector predictionError(const Vector& v1, const Vector& v2, + const GaussMarkovFactor factor) { return factor.evaluateError(v1, v2); } @@ -58,29 +59,29 @@ TEST( GaussMarkovFactor, error ) Key x1(1); Key x2(2); double delta_t = 0.10; - Vector tau = Vector3(100.0, 150.0, 10.0); + Vector3 tau(100.0, 150.0, 10.0); SharedGaussian model = noiseModel::Isotropic::Sigma(3, 1.0); - LieVector v1 = LieVector(Vector3(10.0, 12.0, 13.0)); - LieVector v2 = LieVector(Vector3(10.0, 15.0, 14.0)); + Vector3 v1(10.0, 12.0, 13.0); + Vector3 v2(10.0, 15.0, 14.0); // Create two nodes linPoint.insert(x1, v1); linPoint.insert(x2, v2); GaussMarkovFactor factor(x1, x2, delta_t, tau, model); - Vector Err1( factor.evaluateError(v1, v2) ); + Vector3 error1 = factor.evaluateError(v1, v2); // Manually calculate the error - Vector alpha(tau.size()); - Vector alpha_v1(tau.size()); + Vector3 alpha(tau.size()); + Vector3 alpha_v1(tau.size()); for(int i=0; i +#include +#include + +using namespace gtsam; +using namespace gtsam::noiseModel; + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(0.0, 0.0); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = Vector2(0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_2D) { + Pose2 pose = Pose2::identity(); + Point2 point(1.0, 2.0); + Point2 noise(-1.0, 0.5); + Point2 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(2, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_2D) { + // Measurement + gtsam::Point2 l_meas(1, 2); + + // Linearisation point + gtsam::Point2 p_t(-5, 12); + gtsam::Rot2 p_R(1.5 * M_PI); + Pose2 p(p_R, p_t); + + gtsam::Point2 l(3, 0); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector2(0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +// Verify zero error when there is no noise +TEST(PoseToPointFactor, errorNoiseless_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(0.0, 0.0, 0.0); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = Vector3(0.0, 0.0, 0.0); + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Verify expected error in test scenario +TEST(PoseToPointFactor, errorNoise_3D) { + Pose3 pose = Pose3::identity(); + Point3 point(1.0, 2.0, 3.0); + Point3 noise(-1.0, 0.5, 0.3); + Point3 measured = point + noise; + + Key pose_key(1); + Key point_key(2); + PoseToPointFactor factor(pose_key, point_key, measured, + Isotropic::Sigma(3, 0.05)); + Vector expectedError = -noise; + Vector actualError = factor.evaluateError(pose, point); + EXPECT(assert_equal(expectedError, actualError, 1E-5)); +} + +/* ************************************************************************* */ +// Check Jacobians are correct +TEST(PoseToPointFactor, jacobian_3D) { + // Measurement + gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); + + // Linearisation point + gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); + gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); + Pose3 p(p_R, p_t); + + gtsam::Point3 l = gtsam::Point3(3, 0, 5); + + // Factor + Key pose_key(1); + Key point_key(2); + SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); + PoseToPointFactor factor(pose_key, point_key, l_meas, noise); + + // Calculate numerical derivatives + auto f = std::bind(&PoseToPointFactor::evaluateError, factor, + std::placeholders::_1, std::placeholders::_2, boost::none, + boost::none); + Matrix numerical_H1 = numericalDerivative21(f, p, l); + Matrix numerical_H2 = numericalDerivative22(f, p, l); + + // Use the factor to calculate the derivative + Matrix actual_H1; + Matrix actual_H2; + factor.evaluateError(p, l, actual_H1, actual_H2); + + // Verify we get the expected error + EXPECT(assert_equal(numerical_H1, actual_H1, 1e-8)); + EXPECT(assert_equal(numerical_H2, actual_H2, 1e-8)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testPoseToPointFactor.h b/gtsam_unstable/slam/tests/testPoseToPointFactor.h deleted file mode 100644 index e0e5c4581..000000000 --- a/gtsam_unstable/slam/tests/testPoseToPointFactor.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * @file testPoseToPointFactor.cpp - * @brief - * @author David Wisth - * @date June 20, 2020 - */ - -#include -#include -#include - -using namespace gtsam; -using namespace gtsam::noiseModel; - -/// Verify zero error when there is no noise -TEST(PoseToPointFactor, errorNoiseless) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(0.0, 0.0, 0.0); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = Vector3(0.0, 0.0, 0.0); - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Verify expected error in test scenario -TEST(PoseToPointFactor, errorNoise) { - Pose3 pose = Pose3::identity(); - Point3 point(1.0, 2.0, 3.0); - Point3 noise(-1.0, 0.5, 0.3); - Point3 measured = t + noise; - - Key pose_key(1); - Key point_key(2); - PoseToPointFactor factor(pose_key, point_key, measured, - Isotropic::Sigma(3, 0.05)); - Vector expectedError = noise; - Vector actualError = factor.evaluateError(pose, point); - EXPECT(assert_equal(expectedError, actualError, 1E-5)); -} - -/// Check Jacobians are correct -TEST(PoseToPointFactor, jacobian) { - // Measurement - gtsam::Point3 l_meas = gtsam::Point3(1, 2, 3); - - // Linearisation point - gtsam::Point3 p_t = gtsam::Point3(-5, 12, 2); - gtsam::Rot3 p_R = gtsam::Rot3::RzRyRx(1.5 * M_PI, -0.3 * M_PI, 0.4 * M_PI); - Pose3 p(p_R, p_t); - - gtsam::Point3 l = gtsam::Point3(3, 0, 5); - - // Factor - Key pose_key(1); - Key point_key(2); - SharedGaussian noise = noiseModel::Diagonal::Sigmas(Vector3(0.1, 0.1, 0.1)); - PoseToPointFactor factor(pose_key, point_key, l_meas, noise); - - // Calculate numerical derivatives - auto f = std::bind(&PoseToPointFactor::evaluateError, factor, _1, _2, - boost::none, boost::none); - Matrix numerical_H1 = numericalDerivative21(f, p, l); - Matrix numerical_H2 = numericalDerivative22(f, p, l); - - // Use the factor to calculate the derivative - Matrix actual_H1; - Matrix actual_H2; - factor.evaluateError(p, l, actual_H1, actual_H2); - - // Verify we get the expected error - EXPECT_TRUE(assert_equal(numerical_H1, actual_H1, 1e-8)); - EXPECT_TRUE(assert_equal(numerical_H2, actual_H2, 1e-8)); -} - -/* ************************************************************************* */ -int main() { - TestResult tr; - return TestRegistry::runAllTests(tr); -} -/* ************************************************************************* */ diff --git a/gtsam_unstable/slam/tests/testSerialization.cpp b/gtsam_unstable/slam/tests/testSerialization.cpp index 792fd1133..e9157317e 100644 --- a/gtsam_unstable/slam/tests/testSerialization.cpp +++ b/gtsam_unstable/slam/tests/testSerialization.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include diff --git a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp index c17ad7e1c..b5962d777 100644 --- a/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp +++ b/gtsam_unstable/slam/tests/testSmartProjectionPoseFactorRollingShutter.cpp @@ -1317,10 +1317,10 @@ TEST(SmartProjectionPoseFactorRollingShutter, #ifndef DISABLE_TIMING #include //-Total: 0 CPU (0 times, 0 wall, 0.21 children, min: 0 max: 0) -//| -SF RS LINEARIZE: 0.09 CPU -// (10000 times, 0.124106 wall, 0.09 children, min: 0 max: 0) -//| -RS LINEARIZE: 0.09 CPU -// (10000 times, 0.068719 wall, 0.09 children, min: 0 max: 0) +//| -SF RS LINEARIZE: 0.14 CPU +//(10000 times, 0.131202 wall, 0.14 children, min: 0 max: 0) +//| -RS LINEARIZE: 0.06 CPU +//(10000 times, 0.066951 wall, 0.06 children, min: 0 max: 0) /* *************************************************************************/ TEST(SmartProjectionPoseFactorRollingShutter, timing) { using namespace vanillaPose; @@ -1384,6 +1384,105 @@ TEST(SmartProjectionPoseFactorRollingShutter, timing) { } #endif +#include +/* ************************************************************************* */ +// spherical Camera with rolling shutter effect +namespace sphericalCameraRS { +typedef SphericalCamera Camera; +typedef CameraSet Cameras; +typedef SmartProjectionPoseFactorRollingShutter SmartFactorRS_spherical; +Pose3 interp_pose1 = interpolate(level_pose, pose_right, interp_factor1); +Pose3 interp_pose2 = interpolate(pose_right, pose_above, interp_factor2); +Pose3 interp_pose3 = interpolate(pose_above, level_pose, interp_factor3); +static EmptyCal::shared_ptr emptyK(new EmptyCal()); +Camera cam1(interp_pose1, emptyK); +Camera cam2(interp_pose2, emptyK); +Camera cam3(interp_pose3, emptyK); +} // namespace sphericalCameraRS + +/* *************************************************************************/ +TEST(SmartProjectionPoseFactorRollingShutter, + optimization_3poses_sphericalCameras) { + using namespace sphericalCameraRS; + std::vector measurements_lmk1, measurements_lmk2, measurements_lmk3; + + // Project three landmarks into three cameras + projectToMultipleCameras(cam1, cam2, cam3, landmark1, + measurements_lmk1); + projectToMultipleCameras(cam1, cam2, cam3, landmark2, + measurements_lmk2); + projectToMultipleCameras(cam1, cam2, cam3, landmark3, + measurements_lmk3); + + // create inputs + std::vector> key_pairs; + key_pairs.push_back(std::make_pair(x1, x2)); + key_pairs.push_back(std::make_pair(x2, x3)); + key_pairs.push_back(std::make_pair(x3, x1)); + + std::vector interp_factors; + interp_factors.push_back(interp_factor1); + interp_factors.push_back(interp_factor2); + interp_factors.push_back(interp_factor3); + + SmartProjectionParams params( + gtsam::HESSIAN, + gtsam::ZERO_ON_DEGENERACY); // only config that works with RS factors + params.setRankTolerance(0.1); + + boost::shared_ptr cameraRig(new Cameras()); + cameraRig->push_back(Camera(Pose3::identity(), emptyK)); + + SmartFactorRS_spherical::shared_ptr smartFactor1( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor1->add(measurements_lmk1, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor2( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor2->add(measurements_lmk2, key_pairs, interp_factors); + + SmartFactorRS_spherical::shared_ptr smartFactor3( + new SmartFactorRS_spherical(model, cameraRig, params)); + smartFactor3->add(measurements_lmk3, key_pairs, interp_factors); + + const SharedDiagonal noisePrior = noiseModel::Isotropic::Sigma(6, 0.10); + + NonlinearFactorGraph graph; + graph.push_back(smartFactor1); + graph.push_back(smartFactor2); + graph.push_back(smartFactor3); + graph.addPrior(x1, level_pose, noisePrior); + graph.addPrior(x2, pose_right, noisePrior); + + Values groundTruth; + groundTruth.insert(x1, level_pose); + groundTruth.insert(x2, pose_right); + groundTruth.insert(x3, pose_above); + DOUBLES_EQUAL(0, graph.error(groundTruth), 1e-9); + + // Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI/10, 0., -M_PI/10), + // Point3(0.5,0.1,0.3)); // noise from regular projection factor test below + Pose3 noise_pose = Pose3(Rot3::Ypr(-M_PI / 100, 0., -M_PI / 100), + Point3(0.1, 0.1, 0.1)); // smaller noise + Values values; + values.insert(x1, level_pose); + values.insert(x2, pose_right); + // initialize third pose with some noise, we expect it to move back to + // original pose_above + values.insert(x3, pose_above * noise_pose); + EXPECT( // check that the pose is actually noisy + assert_equal(Pose3(Rot3(0, -0.0314107591, 0.99950656, -0.99950656, + -0.0313952598, -0.000986635786, 0.0314107591, + -0.999013364, -0.0313952598), + Point3(0.1, -0.1, 1.9)), + values.at(x3))); + + Values result; + LevenbergMarquardtOptimizer optimizer(graph, values, lmParams); + result = optimizer.optimize(); + EXPECT(assert_equal(pose_above, result.at(x3), 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/matlab/+gtsam/Contents.m b/matlab/+gtsam/Contents.m index fb6d3081e..77536e5c9 100644 --- a/matlab/+gtsam/Contents.m +++ b/matlab/+gtsam/Contents.m @@ -49,9 +49,6 @@ % Ordering - class Ordering, see Doxygen page for details % Value - class Value, see Doxygen page for details % Values - class Values, see Doxygen page for details -% LieScalar - class LieScalar, see Doxygen page for details -% LieVector - class LieVector, see Doxygen page for details -% LieMatrix - class LieMatrix, see Doxygen page for details % NonlinearFactor - class NonlinearFactor, see Doxygen page for details % NonlinearFactorGraph - class NonlinearFactorGraph, see Doxygen page for details % @@ -101,9 +98,6 @@ % BearingFactor2D - class BearingFactor2D, see Doxygen page for details % BearingFactor3D - class BearingFactor3D, see Doxygen page for details % BearingRangeFactor2D - class BearingRangeFactor2D, see Doxygen page for details -% BetweenFactorLieMatrix - class BetweenFactorLieMatrix, see Doxygen page for details -% BetweenFactorLieScalar - class BetweenFactorLieScalar, see Doxygen page for details -% BetweenFactorLieVector - class BetweenFactorLieVector, see Doxygen page for details % BetweenFactorPoint2 - class BetweenFactorPoint2, see Doxygen page for details % BetweenFactorPoint3 - class BetweenFactorPoint3, see Doxygen page for details % BetweenFactorPose2 - class BetweenFactorPose2, see Doxygen page for details @@ -116,9 +110,6 @@ % GenericStereoFactor3D - class GenericStereoFactor3D, see Doxygen page for details % NonlinearEqualityCal3_S2 - class NonlinearEqualityCal3_S2, see Doxygen page for details % NonlinearEqualityCalibratedCamera - class NonlinearEqualityCalibratedCamera, see Doxygen page for details -% NonlinearEqualityLieMatrix - class NonlinearEqualityLieMatrix, see Doxygen page for details -% NonlinearEqualityLieScalar - class NonlinearEqualityLieScalar, see Doxygen page for details -% NonlinearEqualityLieVector - class NonlinearEqualityLieVector, see Doxygen page for details % NonlinearEqualityPoint2 - class NonlinearEqualityPoint2, see Doxygen page for details % NonlinearEqualityPoint3 - class NonlinearEqualityPoint3, see Doxygen page for details % NonlinearEqualityPose2 - class NonlinearEqualityPose2, see Doxygen page for details @@ -129,9 +120,6 @@ % NonlinearEqualityStereoPoint2 - class NonlinearEqualityStereoPoint2, see Doxygen page for details % PriorFactorCal3_S2 - class PriorFactorCal3_S2, see Doxygen page for details % PriorFactorCalibratedCamera - class PriorFactorCalibratedCamera, see Doxygen page for details -% PriorFactorLieMatrix - class PriorFactorLieMatrix, see Doxygen page for details -% PriorFactorLieScalar - class PriorFactorLieScalar, see Doxygen page for details -% PriorFactorLieVector - class PriorFactorLieVector, see Doxygen page for details % PriorFactorPoint2 - class PriorFactorPoint2, see Doxygen page for details % PriorFactorPoint3 - class PriorFactorPoint3, see Doxygen page for details % PriorFactorPose2 - class PriorFactorPose2, see Doxygen page for details diff --git a/matlab/+gtsam/VisualISAMInitialize.m b/matlab/+gtsam/VisualISAMInitialize.m index 29f8b3b46..9b834e3e1 100644 --- a/matlab/+gtsam/VisualISAMInitialize.m +++ b/matlab/+gtsam/VisualISAMInitialize.m @@ -12,11 +12,11 @@ end isam = ISAM2(params); %% Set Noise parameters -noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); +noiseModels.pose = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]', true); %noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.001 0.001 0.001 0.1 0.1 0.1]'); -noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]'); -noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1); -noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0); +noiseModels.odometry = noiseModel.Diagonal.Sigmas([0.05 0.05 0.05 0.2 0.2 0.2]', true); +noiseModels.point = noiseModel.Isotropic.Sigma(3, 0.1, true); +noiseModels.measurement = noiseModel.Isotropic.Sigma(2, 1.0, true); %% Add constraints/priors % TODO: should not be from ground truth! diff --git a/matlab/CMakeLists.txt b/matlab/CMakeLists.txt index 28e7cce6e..749ad870a 100644 --- a/matlab/CMakeLists.txt +++ b/matlab/CMakeLists.txt @@ -64,8 +64,21 @@ set(ignore gtsam::Point3 gtsam::CustomFactor) +set(interface_files + ${GTSAM_SOURCE_DIR}/gtsam/gtsam.i + ${GTSAM_SOURCE_DIR}/gtsam/base/base.i + ${GTSAM_SOURCE_DIR}/gtsam/basis/basis.i + ${GTSAM_SOURCE_DIR}/gtsam/geometry/geometry.i + ${GTSAM_SOURCE_DIR}/gtsam/linear/linear.i + ${GTSAM_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i + ${GTSAM_SOURCE_DIR}/gtsam/symbolic/symbolic.i + ${GTSAM_SOURCE_DIR}/gtsam/sam/sam.i + ${GTSAM_SOURCE_DIR}/gtsam/slam/slam.i + ${GTSAM_SOURCE_DIR}/gtsam/sfm/sfm.i + ${GTSAM_SOURCE_DIR}/gtsam/navigation/navigation.i +) # Wrap -matlab_wrap(${GTSAM_SOURCE_DIR}/gtsam/gtsam.i "${GTSAM_ADDITIONAL_LIBRARIES}" +matlab_wrap("${interface_files}" "gtsam" "${GTSAM_ADDITIONAL_LIBRARIES}" "" "${mexFlags}" "${ignore}") # Wrap version for gtsam_unstable diff --git a/matlab/gtsam_tests/testUtilities.m b/matlab/gtsam_tests/testUtilities.m index da8dec789..2bfe81a83 100644 --- a/matlab/gtsam_tests/testUtilities.m +++ b/matlab/gtsam_tests/testUtilities.m @@ -45,3 +45,12 @@ CHECK('KeySet', isa(actual,'gtsam.KeySet')); CHECK('size==3', actual.size==3); CHECK('actual.count(x1)', actual.count(x1)); +% test extractVectors +values = Values(); +values.insert(symbol('x', 0), (1:6)'); +values.insert(symbol('x', 1), (7:12)'); +values.insert(symbol('x', 2), (13:18)'); +values.insert(symbol('x', 7), Pose3()); +actual = utilities.extractVectors(values, 'x'); +expected = reshape(1:18, 6, 3)'; +CHECK('extractVectors', all(actual == expected, 'all')); diff --git a/matlab/unstable_examples/+imuSimulator/IMUComparison.m b/matlab/unstable_examples/+imuSimulator/IMUComparison.m index 871f023ef..ccc975d84 100644 --- a/matlab/unstable_examples/+imuSimulator/IMUComparison.m +++ b/matlab/unstable_examples/+imuSimulator/IMUComparison.m @@ -51,13 +51,13 @@ isam = gtsam.ISAM2(isamParams); initialValues = Values; initialValues.insert(symbol('x',0), currentPoseGlobal); -initialValues.insert(symbol('v',0), LieVector(currentVelocityGlobal)); +initialValues.insert(symbol('v',0), currentVelocityGlobal); initialValues.insert(symbol('b',0), imuBias.ConstantBias([0;0;0],[0;0;0])); initialFactors = NonlinearFactorGraph; initialFactors.add(PriorFactorPose3(symbol('x',0), ... currentPoseGlobal, noiseModel.Isotropic.Sigma(6, 1.0))); -initialFactors.add(PriorFactorLieVector(symbol('v',0), ... - LieVector(currentVelocityGlobal), noiseModel.Isotropic.Sigma(3, 1.0))); +initialFactors.add(PriorFactorVector(symbol('v',0), ... + currentVelocityGlobal, noiseModel.Isotropic.Sigma(3, 1.0))); initialFactors.add(PriorFactorConstantBias(symbol('b',0), ... imuBias.ConstantBias([0;0;0],[0;0;0]), noiseModel.Isotropic.Sigma(6, 1.0))); @@ -96,7 +96,7 @@ for t = times initialVel = isam.calculateEstimate(symbol('v',lastSummaryIndex)); else initialPose = Pose3; - initialVel = LieVector(velocity); + initialVel = velocity; end initialValues.insert(symbol('x',lastSummaryIndex+1), initialPose); initialValues.insert(symbol('v',lastSummaryIndex+1), initialVel); diff --git a/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m b/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m index 450697de0..6adc8e9dc 100644 --- a/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m +++ b/matlab/unstable_examples/+imuSimulator/IMUComparison_with_cov.m @@ -43,15 +43,15 @@ sigma_init_b = 1.0; initialValues = Values; initialValues.insert(symbol('x',0), currentPoseGlobal); -initialValues.insert(symbol('v',0), LieVector(currentVelocityGlobal)); +initialValues.insert(symbol('v',0), currentVelocityGlobal); initialValues.insert(symbol('b',0), imuBias.ConstantBias([0;0;0],[0;0;0])); initialFactors = NonlinearFactorGraph; % Prior on initial pose initialFactors.add(PriorFactorPose3(symbol('x',0), ... currentPoseGlobal, noiseModel.Isotropic.Sigma(6, sigma_init_x))); % Prior on initial velocity -initialFactors.add(PriorFactorLieVector(symbol('v',0), ... - LieVector(currentVelocityGlobal), noiseModel.Isotropic.Sigma(3, sigma_init_v))); +initialFactors.add(PriorFactorVector(symbol('v',0), ... + currentVelocityGlobal, noiseModel.Isotropic.Sigma(3, sigma_init_v))); % Prior on initial bias initialFactors.add(PriorFactorConstantBias(symbol('b',0), ... imuBias.ConstantBias([0;0;0],[0;0;0]), noiseModel.Isotropic.Sigma(6, sigma_init_b))); @@ -91,7 +91,7 @@ for t = times initialVel = isam.calculateEstimate(symbol('v',lastSummaryIndex)); else initialPose = Pose3; - initialVel = LieVector(velocity); + initialVel = velocity; end initialValues.insert(symbol('x',lastSummaryIndex+1), initialPose); initialValues.insert(symbol('v',lastSummaryIndex+1), initialVel); diff --git a/matlab/unstable_examples/+imuSimulator/coriolisExample.m b/matlab/unstable_examples/+imuSimulator/coriolisExample.m index ee4deb433..61dc78d96 100644 --- a/matlab/unstable_examples/+imuSimulator/coriolisExample.m +++ b/matlab/unstable_examples/+imuSimulator/coriolisExample.m @@ -175,9 +175,9 @@ for i = 1:length(times) % known initial conditions currentPoseEstimate = currentPoseFixedGT; if navFrameRotating == 1 - currentVelocityEstimate = LieVector(currentVelocityRotatingGT); + currentVelocityEstimate = currentVelocityRotatingGT; else - currentVelocityEstimate = LieVector(currentVelocityFixedGT); + currentVelocityEstimate = currentVelocityFixedGT; end % Set Priors @@ -186,7 +186,7 @@ for i = 1:length(times) newValues.insert(currentBiasKey, zeroBias); % Initial values, same for IMU types 1 and 2 newFactors.add(PriorFactorPose3(currentPoseKey, currentPoseEstimate, sigma_init_x)); - newFactors.add(PriorFactorLieVector(currentVelKey, currentVelocityEstimate, sigma_init_v)); + newFactors.add(PriorFactorVector(currentVelKey, currentVelocityEstimate, sigma_init_v)); newFactors.add(PriorFactorConstantBias(currentBiasKey, zeroBias, sigma_init_b)); % Store data diff --git a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m index 07f146dcb..037065ac5 100644 --- a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m +++ b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateFactorGraph.m @@ -27,7 +27,7 @@ for i=0:length(measurements) if options.includeIMUFactors == 1 currentVelKey = symbol('v', 0); currentVel = values.atPoint3(currentVelKey); - graph.add(PriorFactorLieVector(currentVelKey, LieVector(currentVel), noiseModels.noiseVel)); + graph.add(PriorFactorVector(currentVelKey, currentVel, noiseModels.noiseVel)); currentBiasKey = symbol('b', 0); currentBias = values.atPoint3(currentBiasKey); diff --git a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m index 3d8a9b5d2..5fb6589d6 100644 --- a/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m +++ b/matlab/unstable_examples/+imuSimulator/covarianceAnalysisCreateTrajectory.m @@ -82,7 +82,7 @@ if options.useRealData == 1 end % Add Values: velocity and bias - values.insert(currentVelKey, LieVector(currentVel)); + values.insert(currentVelKey, currentVel); values.insert(currentBiasKey, metadata.imu.zeroBias); end diff --git a/matlab/unstable_examples/FlightCameraTransformIMU.m b/matlab/unstable_examples/FlightCameraTransformIMU.m index d2f2bc34d..aeac2e243 100644 --- a/matlab/unstable_examples/FlightCameraTransformIMU.m +++ b/matlab/unstable_examples/FlightCameraTransformIMU.m @@ -167,7 +167,7 @@ for i=1:size(trajectory)-1 %% priors on first two poses if i < 3 - % fg.add(PriorFactorLieVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); + % fg.add(PriorFactorVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); fg.add(PriorFactorConstantBias(currentBiasKey, currentBias, sigma_init_b)); end diff --git a/matlab/unstable_examples/IMUKittiExampleVO.m b/matlab/unstable_examples/IMUKittiExampleVO.m index 6434e750a..f35d36512 100644 --- a/matlab/unstable_examples/IMUKittiExampleVO.m +++ b/matlab/unstable_examples/IMUKittiExampleVO.m @@ -46,7 +46,7 @@ clear logposes relposes %% Get initial conditions for the estimated trajectory currentPoseGlobal = Pose3; -currentVelocityGlobal = LieVector([0;0;0]); % the vehicle is stationary at the beginning +currentVelocityGlobal = [0;0;0]; % the vehicle is stationary at the beginning currentBias = imuBias.ConstantBias(zeros(3,1), zeros(3,1)); sigma_init_x = noiseModel.Isotropic.Sigmas([ 1.0; 1.0; 0.01; 0.01; 0.01; 0.01 ]); sigma_init_v = noiseModel.Isotropic.Sigma(3, 1000.0); @@ -88,7 +88,7 @@ for measurementIndex = 1:length(timestamps) newValues.insert(currentVelKey, currentVelocityGlobal); newValues.insert(currentBiasKey, currentBias); newFactors.add(PriorFactorPose3(currentPoseKey, currentPoseGlobal, sigma_init_x)); - newFactors.add(PriorFactorLieVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); + newFactors.add(PriorFactorVector(currentVelKey, currentVelocityGlobal, sigma_init_v)); newFactors.add(PriorFactorConstantBias(currentBiasKey, currentBias, sigma_init_b)); else t_previous = timestamps(measurementIndex-1, 1); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e2444a51a..85ddc7b6d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -33,8 +33,6 @@ add_custom_target(gtsam_unstable_header DEPENDS "${PROJECT_SOURCE_DIR}/gtsam_uns set(ignore gtsam::Point2 gtsam::Point3 - gtsam::LieVector - gtsam::LieMatrix gtsam::ISAM2ThresholdMapValue gtsam::FactorIndices gtsam::FactorIndexSet @@ -47,13 +45,17 @@ set(ignore gtsam::Point3Pairs gtsam::Pose3Pairs gtsam::Pose3Vector + gtsam::Rot3Vector gtsam::KeyVector gtsam::BinaryMeasurementsUnit3 + gtsam::DiscreteKey gtsam::KeyPairDoubleMap) set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i + ${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i + ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/nonlinear/nonlinear.i @@ -114,8 +116,6 @@ if(GTSAM_UNSTABLE_BUILD_PYTHON) set(ignore gtsam::Point2 gtsam::Point3 - gtsam::LieVector - gtsam::LieMatrix gtsam::ISAM2ThresholdMapValue gtsam::FactorIndices gtsam::FactorIndexSet @@ -183,5 +183,5 @@ add_custom_target( ${CMAKE_COMMAND} -E env # add package to python path so no need to install "PYTHONPATH=${GTSAM_PYTHON_BUILD_DIRECTORY}/$ENV{PYTHONPATH}" ${PYTHON_EXECUTABLE} -m unittest discover -v -s . - DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} + DEPENDS ${GTSAM_PYTHON_DEPENDENCIES} ${GTSAM_PYTHON_TEST_FILES} WORKING_DIRECTORY "${GTSAM_PYTHON_BUILD_DIRECTORY}/gtsam/tests") diff --git a/python/gtsam/notebooks/DiscreteBayesTree.ipynb b/python/gtsam/notebooks/DiscreteBayesTree.ipynb new file mode 100644 index 000000000..066c31d6a --- /dev/null +++ b/python/gtsam/notebooks/DiscreteBayesTree.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Discrete Bayes Tree\n", + "\n", + "An example of building a Bayes net, then eliminating it into a Bayes tree. Mirrors the code in `testDiscreteBayesTree.cpp` .\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesTree, DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " #TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n8\n\n8\n\n\n\n0\n\n0\n\n\n\n8->0\n\n\n\n\n\n1\n\n1\n\n\n\n8->1\n\n\n\n\n\n12\n\n12\n\n\n\n12->8\n\n\n\n\n\n12->0\n\n\n\n\n\n12->1\n\n\n\n\n\n9\n\n9\n\n\n\n12->9\n\n\n\n\n\n2\n\n2\n\n\n\n12->2\n\n\n\n\n\n3\n\n3\n\n\n\n12->3\n\n\n\n\n\n9->2\n\n\n\n\n\n9->3\n\n\n\n\n\n10\n\n10\n\n\n\n4\n\n4\n\n\n\n10->4\n\n\n\n\n\n5\n\n5\n\n\n\n10->5\n\n\n\n\n\n13\n\n13\n\n\n\n13->10\n\n\n\n\n\n13->4\n\n\n\n\n\n13->5\n\n\n\n\n\n11\n\n11\n\n\n\n13->11\n\n\n\n\n\n6\n\n6\n\n\n\n13->6\n\n\n\n\n\n7\n\n7\n\n\n\n13->7\n\n\n\n\n\n11->6\n\n\n\n\n\n11->7\n\n\n\n\n\n14\n\n14\n\n\n\n14->8\n\n\n\n\n\n14->12\n\n\n\n\n\n14->9\n\n\n\n\n\n14->10\n\n\n\n\n\n14->13\n\n\n\n\n\n14->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c615b0>" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define DiscreteKey pairs.\n", + "keys = [(j, 2) for j in range(15)]\n", + "\n", + "# Create thin-tree Bayesnet.\n", + "bayesNet = DiscreteBayesNet()\n", + "\n", + "\n", + "bayesNet.add(keys[0], P(keys[8], keys[12]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[1], P(keys[8], keys[12]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[2], P(keys[9], keys[12]), \"1/4 8/2 2/3 4/1\")\n", + "bayesNet.add(keys[3], P(keys[9], keys[12]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[4], P(keys[10], keys[13]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[5], P(keys[10], keys[13]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[6], P(keys[11], keys[13]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[7], P(keys[11], keys[13]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[8], P(keys[12], keys[14]), \"T 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[9], P(keys[12], keys[14]), \"4/1 2/3 F 1/4\")\n", + "bayesNet.add(keys[10], P(keys[13], keys[14]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[11], P(keys[13], keys[14]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[12], P(keys[14]), \"3/1 3/1\")\n", + "bayesNet.add(keys[13], P(keys[14]), \"1/3 3/1\")\n", + "\n", + "bayesNet.add(keys[14], P(), \"1/3\")\n", + "\n", + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 0}\n", + "DiscreteValues{0: 0, 1: 1, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n", + "DiscreteValues{0: 1, 1: 0, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1, 9: 0, 10: 1, 11: 1, 12: 0, 13: 1, 14: 0}\n", + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 1}\n", + "DiscreteValues{0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n" + ] + } + ], + "source": [ + "# Sample Bayes net (needs conditionals added in elimination order!)\n", + "for i in range(5):\n", + " print(bayesNet.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar0\n\n0\n\n\n\nfactor0\n\n\n\n\nvar0--factor0\n\n\n\n\nvar1\n\n1\n\n\n\nfactor1\n\n\n\n\nvar1--factor1\n\n\n\n\nvar2\n\n2\n\n\n\nfactor2\n\n\n\n\nvar2--factor2\n\n\n\n\nvar3\n\n3\n\n\n\nfactor3\n\n\n\n\nvar3--factor3\n\n\n\n\nvar4\n\n4\n\n\n\nfactor4\n\n\n\n\nvar4--factor4\n\n\n\n\nvar5\n\n5\n\n\n\nfactor5\n\n\n\n\nvar5--factor5\n\n\n\n\nvar6\n\n6\n\n\n\nfactor6\n\n\n\n\nvar6--factor6\n\n\n\n\nvar7\n\n7\n\n\n\nfactor7\n\n\n\n\nvar7--factor7\n\n\n\n\nvar8\n\n8\n\n\n\nvar8--factor0\n\n\n\n\nvar8--factor1\n\n\n\n\nfactor8\n\n\n\n\nvar8--factor8\n\n\n\n\nvar9\n\n9\n\n\n\nvar9--factor2\n\n\n\n\nvar9--factor3\n\n\n\n\nfactor9\n\n\n\n\nvar9--factor9\n\n\n\n\nvar10\n\n10\n\n\n\nvar10--factor4\n\n\n\n\nvar10--factor5\n\n\n\n\nfactor10\n\n\n\n\nvar10--factor10\n\n\n\n\nvar11\n\n11\n\n\n\nvar11--factor6\n\n\n\n\nvar11--factor7\n\n\n\n\nfactor11\n\n\n\n\nvar11--factor11\n\n\n\n\nvar12\n\n12\n\n\n\nvar14\n\n14\n\n\n\nvar12--var14\n\n\n\n\nvar12--factor0\n\n\n\n\nvar12--factor1\n\n\n\n\nvar12--factor2\n\n\n\n\nvar12--factor3\n\n\n\n\nvar12--factor8\n\n\n\n\nvar12--factor9\n\n\n\n\nvar13\n\n13\n\n\n\nvar13--var14\n\n\n\n\nvar13--factor4\n\n\n\n\nvar13--factor5\n\n\n\n\nvar13--factor6\n\n\n\n\nvar13--factor7\n\n\n\n\nvar13--factor10\n\n\n\n\nvar13--factor11\n\n\n\n\nvar14--factor8\n\n\n\n\nvar14--factor9\n\n\n\n\nvar14--factor10\n\n\n\n\nvar14--factor11\n\n\n\n\nfactor14\n\n\n\n\nvar14--factor14\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61f10>" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\n8,12,14\n\n\n\n1\n\n0 : 8,12\n\n\n\n0->1\n\n\n\n\n\n2\n\n1 : 8,12\n\n\n\n0->2\n\n\n\n\n\n3\n\n9 : 12,14\n\n\n\n0->3\n\n\n\n\n\n6\n\n10,13 : 14\n\n\n\n0->6\n\n\n\n\n\n4\n\n2 : 9,12\n\n\n\n3->4\n\n\n\n\n\n5\n\n3 : 9,12\n\n\n\n3->5\n\n\n\n\n\n7\n\n4 : 10,13\n\n\n\n6->7\n\n\n\n\n\n8\n\n5 : 10,13\n\n\n\n6->8\n\n\n\n\n\n9\n\n11 : 13,14\n\n\n\n6->9\n\n\n\n\n\n10\n\n6 : 11,13\n\n\n\n9->10\n\n\n\n\n\n11\n\n7 : 11,13\n\n\n\n9->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61b50>" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "for j in range(15): ordering.push_back(j)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "show(bayesTree)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/notebooks/DiscreteSwitching.ipynb b/python/gtsam/notebooks/DiscreteSwitching.ipynb new file mode 100644 index 000000000..6872e78c8 --- /dev/null +++ b/python/gtsam/notebooks/DiscreteSwitching.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A Discrete Switching System\n", + "\n", + "A la MHS, but all discrete.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n", + "from gtsam.symbol_shorthand import M\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " # TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nrStates = 3\n", + "K = 5\n", + "\n", + "bayesNet = DiscreteBayesNet()\n", + "for k in range(1, K):\n", + " key = S(k), nrStates\n", + " key_plus = S(k+1), nrStates\n", + " mode = M(k), 2\n", + " bayesNet.add(key_plus, P(mode, key), \"9/1/0 1/8/1 0/1/9 1/9/0 0/1/9 9/0/1\")\n", + "\n", + "bayesNet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "# First eliminate \"continuous\" states in time order\n", + "for k in range(1, K+1):\n", + " ordering.push_back(S(k))\n", + "for k in range(1, K):\n", + " ordering.push_back(M(k))\n", + "print(ordering)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "bayesTree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesTree)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h new file mode 100644 index 000000000..608508c32 --- /dev/null +++ b/python/gtsam/preamble/discrete.h @@ -0,0 +1,16 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include + +PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/preamble/inference.h b/python/gtsam/preamble/inference.h new file mode 100644 index 000000000..320e0ac71 --- /dev/null +++ b/python/gtsam/preamble/inference.h @@ -0,0 +1,15 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include + diff --git a/python/gtsam/preamble/slam.h b/python/gtsam/preamble/slam.h index 34dbb4b7a..f7bf5863c 100644 --- a/python/gtsam/preamble/slam.h +++ b/python/gtsam/preamble/slam.h @@ -15,3 +15,4 @@ PYBIND11_MAKE_OPAQUE( std::vector > >); PYBIND11_MAKE_OPAQUE( std::vector > >); +PYBIND11_MAKE_OPAQUE(gtsam::Rot3Vector); diff --git a/python/gtsam/specializations/discrete.h b/python/gtsam/specializations/discrete.h new file mode 100644 index 000000000..458a2ea4c --- /dev/null +++ b/python/gtsam/specializations/discrete.h @@ -0,0 +1,17 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + +// Seems this is not a good idea with inherited stl +//py::bind_vector>(m_, "DiscreteKeys"); + +py::bind_map(m_, "DiscreteValues"); diff --git a/python/gtsam/specializations/inference.h b/python/gtsam/specializations/inference.h new file mode 100644 index 000000000..22fe3beff --- /dev/null +++ b/python/gtsam/specializations/inference.h @@ -0,0 +1,13 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `py::bind_vector` and similar machinery gives the std container a Python-like + * interface, but without the `` copying mechanism. Combined + * with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python, + * and saves one copy operation. + */ + diff --git a/python/gtsam/specializations/slam.h b/python/gtsam/specializations/slam.h index 198485a72..6a439c370 100644 --- a/python/gtsam/specializations/slam.h +++ b/python/gtsam/specializations/slam.h @@ -12,8 +12,9 @@ */ py::bind_vector< - std::vector > > >( + std::vector>>>( m_, "BetweenFactorPose3s"); py::bind_vector< - std::vector > > >( + std::vector>>>( m_, "BetweenFactorPose2s"); +py::bind_vector(m_, "Rot3Vector"); diff --git a/python/gtsam/tests/testEssentialMatrixConstraint.py b/python/gtsam/tests/testEssentialMatrixConstraint.py new file mode 100644 index 000000000..8439ad2e9 --- /dev/null +++ b/python/gtsam/tests/testEssentialMatrixConstraint.py @@ -0,0 +1,47 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +visual_isam unit tests. +Author: Frank Dellaert & Pablo Alcantarilla +""" + +import unittest + +import gtsam +import numpy as np +from gtsam import (EssentialMatrix, EssentialMatrixConstraint, Point3, Pose3, + Rot3, Unit3, symbol) +from gtsam.utils.test_case import GtsamTestCase + + +class TestVisualISAMExample(GtsamTestCase): + def test_VisualISAMExample(self): + + # Create a factor + poseKey1 = symbol('x', 1) + poseKey2 = symbol('x', 2) + trueRotation = Rot3.RzRyRx(0.15, 0.15, -0.20) + trueTranslation = Point3(+0.5, -1.0, +1.0) + trueDirection = Unit3(trueTranslation) + E = EssentialMatrix(trueRotation, trueDirection) + model = gtsam.noiseModel.Isotropic.Sigma(5, 0.25) + factor = EssentialMatrixConstraint(poseKey1, poseKey2, E, model) + + # Create a linearization point at the zero-error point + pose1 = Pose3(Rot3.RzRyRx(0.00, -0.15, 0.30), Point3(-4.0, 7.0, -10.0)) + pose2 = Pose3( + Rot3.RzRyRx(0.179693265735950, 0.002945368776519, + 0.102274823253840), + Point3(-3.37493895, 6.14660244, -8.93650986)) + + expected = np.zeros((5, 1)) + actual = factor.evaluateError(pose1, pose2) + self.gtsamAssertEquals(actual, expected, 1e-8) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_Cal3Fisheye.py b/python/gtsam/tests/test_Cal3Fisheye.py index 298c6e57b..e54afc757 100644 --- a/python/gtsam/tests/test_Cal3Fisheye.py +++ b/python/gtsam/tests/test_Cal3Fisheye.py @@ -17,6 +17,15 @@ import gtsam from gtsam.utils.test_case import GtsamTestCase from gtsam.symbol_shorthand import K, L, P + +def ulp(ftype=np.float64): + """ + Unit in the last place of floating point datatypes + """ + f = np.finfo(ftype) + return f.tiny / ftype(1 << f.nmant) + + class TestCal3Fisheye(GtsamTestCase): @classmethod @@ -105,6 +114,71 @@ class TestCal3Fisheye(GtsamTestCase): score = graph.error(state) self.assertAlmostEqual(score, 0) + def test_jacobian_on_axis(self): + """Check of jacobian at optical axis""" + obj_point_on_axis = np.array([0, 0, 1]) + img_point = np.array([0, 0]) + f, z, H = self.evaluate_jacobian(obj_point_on_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + def test_jacobian_convergence(self): + """Test stability of jacobian close to optical axis""" + t = ulp(np.float64) + obj_point_close_to_axis = np.array([t, 0, 1]) + img_point = np.array([np.sqrt(t), 0]) + f, z, H = self.evaluate_jacobian(obj_point_close_to_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + # With a height of sqrt(ulp), this may cause an overflow + t = ulp(np.float64) + obj_point_close_to_axis = np.array([np.sqrt(t), 0, 1]) + img_point = np.array([np.sqrt(t), 0]) + f, z, H = self.evaluate_jacobian(obj_point_close_to_axis, img_point) + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 3*np.eye(2)) + + def test_scaling_factor(self): + """Check convergence of atan2(r, z)/r ~ 1/z for small r""" + r = ulp(np.float64) + s = np.arctan(r) / r + self.assertEqual(s, 1.0) + z = 1 + s = self.scaling_factor(r, z) + self.assertEqual(s, 1.0/z) + z = 2 + s = self.scaling_factor(r, z) + self.assertEqual(s, 1.0/z) + s = self.scaling_factor(2*r, z) + self.assertEqual(s, 1.0/z) + + @staticmethod + def scaling_factor(r, z): + """Projection factor theta/r for equidistant fisheye lens model""" + return np.arctan2(r, z) / r if r/z != 0 else 1.0/z + + @staticmethod + def evaluate_jacobian(obj_point, img_point): + """Evaluate jacobian at given object point""" + pose = gtsam.Pose3() + camera = gtsam.Cal3Fisheye() + state = gtsam.Values() + camera_key, pose_key, landmark_key = K(0), P(0), L(0) + state.insert_point3(landmark_key, obj_point) + state.insert_pose3(pose_key, pose) + g = gtsam.NonlinearFactorGraph() + noise_model = gtsam.noiseModel.Unit.Create(2) + factor = gtsam.GenericProjectionFactorCal3Fisheye(img_point, noise_model, pose_key, landmark_key, camera) + g.add(factor) + f = g.error(state) + gaussian_factor_graph = g.linearize(state) + H, z = gaussian_factor_graph.jacobian() + return f, z, H + @unittest.skip("triangulatePoint3 currently seems to require perspective projections.") def test_triangulation_skipped(self): """Estimate spatial point from image measurements""" diff --git a/python/gtsam/tests/test_Cal3Unified.py b/python/gtsam/tests/test_Cal3Unified.py index dab1ae446..bafbacfa4 100644 --- a/python/gtsam/tests/test_Cal3Unified.py +++ b/python/gtsam/tests/test_Cal3Unified.py @@ -117,6 +117,28 @@ class TestCal3Unified(GtsamTestCase): score = graph.error(state) self.assertAlmostEqual(score, 0) + def test_jacobian(self): + """Evaluate jacobian at optical axis""" + obj_point_on_axis = np.array([0, 0, 1]) + img_point = np.array([0.0, 0.0]) + pose = gtsam.Pose3() + camera = gtsam.Cal3Unified() + state = gtsam.Values() + camera_key, pose_key, landmark_key = K(0), P(0), L(0) + state.insert_cal3unified(camera_key, camera) + state.insert_point3(landmark_key, obj_point_on_axis) + state.insert_pose3(pose_key, pose) + g = gtsam.NonlinearFactorGraph() + noise_model = gtsam.noiseModel.Unit.Create(2) + factor = gtsam.GeneralSFMFactor2Cal3Unified(img_point, noise_model, pose_key, landmark_key, camera_key) + g.add(factor) + f = g.error(state) + gaussian_factor_graph = g.linearize(state) + H, z = gaussian_factor_graph.jacobian() + self.assertAlmostEqual(f, 0) + self.gtsamAssertEquals(z, np.zeros(2)) + self.gtsamAssertEquals(H @ H.T, 4*np.eye(2)) + @unittest.skip("triangulatePoint3 currently seems to require perspective projections.") def test_triangulation(self): """Estimate spatial point from image measurements""" diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py new file mode 100644 index 000000000..0499e7215 --- /dev/null +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -0,0 +1,98 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for DecisionTreeFactors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering +from gtsam.utils.test_case import GtsamTestCase + + +class TestDecisionTreeFactor(GtsamTestCase): + """Tests for DecisionTreeFactors.""" + + def setUp(self): + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") + + def test_enumerate(self): + actual = self.factor.enumerate() + _, values = zip(*actual) + self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscreteDistribution, i.e., Bayes Law! + prior = DiscreteDistribution(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + expected = \ + "|A|B|value|\n" \ + "|:-:|:-:|:-:|\n" \ + "|0|0|1|\n" \ + "|0|1|2|\n" \ + "|1|0|3|\n" \ + "|1|1|4|\n" \ + "|2|0|5|\n" \ + "|2|1|6|\n" + + def formatter(x: int): + return "A" if x == 12 else "B" + + actual = self.factor._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py new file mode 100644 index 000000000..3ae3b625c --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -0,0 +1,131 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes Nets. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, + DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) +from gtsam.utils.test_case import GtsamTestCase + +# Some keys: +Asia = (0, 2) +Smoking = (4, 2) +Tuberculosis = (3, 2) +LungCancer = (6, 2) + +Bronchitis = (7, 2) +Either = (5, 2) +XRay = (2, 2) +Dyspnea = (1, 2) + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_constructor(self): + """Test constructing a Bayes net.""" + + bayesNet = DiscreteBayesNet() + Parent, Child = (0, 2), (1, 2) + empty = DiscreteKeys() + prior = DiscreteConditional(Parent, empty, "6/4") + bayesNet.add(prior) + + parents = DiscreteKeys() + parents.push_back(Parent) + conditional = DiscreteConditional(Child, parents, "7/3 8/2") + bayesNet.add(conditional) + + # Check conversion to factor graph: + fg = DiscreteFactorGraph(bayesNet) + self.assertEqual(fg.size(), 2) + self.assertEqual(fg.at(1).size(), 2) + + def test_Asia(self): + """Test full Asia example.""" + + asia = DiscreteBayesNet() + asia.add(Asia, "99/1") + asia.add(Smoking, "50/50") + + asia.add(Tuberculosis, [Asia], "99/1 95/5") + asia.add(LungCancer, [Smoking], "99/1 90/10") + asia.add(Bronchitis, [Smoking], "70/30 40/60") + + asia.add(Either, [Tuberculosis, LungCancer], "F T T T") + + asia.add(XRay, [Either], "95/5 2/98") + asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") + + # Convert to factor graph + fg = DiscreteFactorGraph(asia) + + # Create solver and eliminate + ordering = Ordering() + for j in range(8): + ordering.push_back(j) + chordal = fg.eliminateSequential(ordering) + expected2 = DiscreteDistribution(Bronchitis, "11/9") + self.gtsamAssertEquals(chordal.at(7), expected2) + + # solve + actualMPE = fg.optimize() + expectedMPE = DiscreteValues() + for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: + expectedMPE[key[0]] = 0 + self.assertEqual(list(actualMPE.items()), + list(expectedMPE.items())) + + # Check value for MPE is the same + self.assertAlmostEqual(asia(actualMPE), fg(actualMPE)) + + # add evidence, we were in Asia and we have dyspnea + fg.add(Asia, "0 1") + fg.add(Dyspnea, "0 1") + + # solve again, now with evidence + actualMPE2 = fg.optimize() + expectedMPE2 = DiscreteValues() + for key in [XRay, Tuberculosis, Either, LungCancer]: + expectedMPE2[key[0]] = 0 + for key in [Asia, Dyspnea, Smoking, Bronchitis]: + expectedMPE2[key[0]] = 1 + self.assertEqual(list(actualMPE2.items()), + list(expectedMPE2.items())) + + # now sample from it + chordal2 = fg.eliminateSequential(ordering) + actualSample = chordal2.sample() + self.assertEqual(len(actualSample), 8) + + def test_fragment(self): + """Test sampling and optimizing for Asia fragment.""" + + # Create a reverse-topologically sorted fragment: + fragment = DiscreteBayesNet() + fragment.add(Either, [Tuberculosis, LungCancer], "F T T T") + fragment.add(Tuberculosis, [Asia], "99/1 95/5") + fragment.add(LungCancer, [Smoking], "99/1 90/10") + + # Create assignment with missing values: + given = DiscreteValues() + for key in [Asia, Smoking]: + given[key[0]] = 0 + + # Now sample from fragment: + actual = fragment.sample(given) + self.assertEqual(len(actual), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesTree.dot b/python/gtsam/tests/test_DiscreteBayesTree.dot new file mode 100644 index 000000000..d7cf7d9bc --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.dot @@ -0,0 +1,25 @@ +digraph G{ +0[label="8,12,14"]; +0->1 +1[label="0 : 8,12"]; +0->2 +2[label="1 : 8,12"]; +0->3 +3[label="9 : 12,14"]; +3->4 +4[label="2 : 9,12"]; +3->5 +5[label="3 : 9,12"]; +0->6 +6[label="10,13 : 14"]; +6->7 +7[label="4 : 10,13"]; +6->8 +8[label="5 : 10,13"]; +6->9 +9[label="11 : 13,14"]; +9->10 +10[label="6 : 11,13"]; +9->11 +11[label="7 : 11,13"]; +} \ No newline at end of file diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py new file mode 100644 index 000000000..b1ed4fe69 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -0,0 +1,79 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes trees. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, + DiscreteConditional, DiscreteFactorGraph, Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_elimination(self): + """Test Multifrontal elimination.""" + + # Define DiscreteKey pairs. + keys = [(j, 2) for j in range(15)] + + # Create thin-tree Bayesnet. + bayesNet = DiscreteBayesNet() + + bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") + bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") + bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[12], [keys[14]], "3/1 3/1") + bayesNet.add(keys[13], [keys[14]], "1/3 3/1") + + bayesNet.add(keys[14], "1/3") + + # Create a factor graph out of the Bayes net. + factorGraph = DiscreteFactorGraph(bayesNet) + + # Create a BayesTree out of the factor graph. + ordering = Ordering() + for j in range(15): + ordering.push_back(j) + bayesTree = factorGraph.eliminateMultifrontal(ordering) + + # Uncomment these for visualization: + # print(bayesTree) + # for key in range(15): + # bayesTree[key].printSignature() + # bayesTree.saveGraph("test_DiscreteBayesTree.dot") + + self.assertFalse(bayesTree.empty()) + self.assertEqual(12, bayesTree.size()) + + # The root is P( 8 12 14), we can retrieve it by key: + root = bayesTree[8] + self.assertIsInstance(root, DiscreteBayesTreeClique) + self.assertTrue(root.isRoot()) + self.assertIsInstance(root.conditional(), DiscreteConditional) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py new file mode 100644 index 000000000..241a5f0be --- /dev/null +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -0,0 +1,124 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Conditionals. +Author: Varun Agrawal +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + + +class TestDiscreteConditional(GtsamTestCase): + """Tests for Discrete Conditionals.""" + + def test_single_value_versions(self): + X = (0, 2) + Y = (1, 3) + conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") + + actual0 = conditional.likelihood(0) + expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") + self.gtsamAssertEquals(actual0, expected0, 1e-9) + + actual1 = conditional.likelihood(1) + expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") + self.gtsamAssertEquals(actual1, expected1, 1e-9) + + actual = conditional.sample(2) + self.assertIsInstance(actual, int) + + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A, "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + A = (2, 2) + B = (1, 2) + C = (0, 3) + parents = DiscreteKeys() + parents.push_back(B) + parents.push_back(C) + + conditional = DiscreteConditional(A, parents, + "0/1 1/3 1/1 3/1 0/1 1/0") + expected = " *P(A|B,C):*\n\n" \ + "|*B*|*C*|0|1|\n" \ + "|:-:|:-:|:-:|:-:|\n" \ + "|0|0|0|1|\n" \ + "|0|1|0.25|0.75|\n" \ + "|0|2|0.5|0.5|\n" \ + "|1|0|0.75|0.25|\n" \ + "|1|1|0|1|\n" \ + "|1|2|1|0|\n" + + def formatter(x: int): + names = ["C", "B", "A"] + return names[x] + + actual = conditional._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteDistribution.py b/python/gtsam/tests/test_DiscreteDistribution.py new file mode 100644 index 000000000..3986bf2df --- /dev/null +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -0,0 +1,69 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Priors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +import numpy as np +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution +from gtsam.utils.test_case import GtsamTestCase + +X = 0, 2 + + +class TestDiscreteDistribution(GtsamTestCase): + """Tests for Discrete Priors.""" + + def test_constructor(self): + """Test various constructors.""" + keys = DiscreteKeys() + keys.push_back(X) + f = DecisionTreeFactor(keys, "0.4 0.6") + expected = DiscreteDistribution(f) + + actual = DiscreteDistribution(X, "2/3") + self.gtsamAssertEquals(actual, expected) + + actual2 = DiscreteDistribution(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) + + def test_operator(self): + prior = DiscreteDistribution(X, "2/3") + self.assertAlmostEqual(prior(0), 0.4) + self.assertAlmostEqual(prior(1), 0.6) + + def test_pmf(self): + prior = DiscreteDistribution(X, "2/3") + expected = np.array([0.4, 0.6]) + np.testing.assert_allclose(expected, prior.pmf()) + + def test_sample(self): + prior = DiscreteDistribution(X, "2/3") + actual = prior.sample() + self.assertIsInstance(actual, int) + + def test_markdown(self): + """Test the _repr_markdown_ method.""" + + prior = DiscreteDistribution(X, "2/3") + expected = " *P(0):*\n\n" \ + "|0|value|\n" \ + "|:-:|:-:|\n" \ + "|0|0.4|\n" \ + "|1|0.6|\n" \ + + actual = prior._repr_markdown_() + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py new file mode 100644 index 000000000..ef85fc753 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -0,0 +1,160 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Factor Graphs. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering +from gtsam.utils.test_case import GtsamTestCase + +OrderingType = Ordering.OrderingType + + +class TestDiscreteFactorGraph(GtsamTestCase): + """Tests for Discrete Factor Graphs.""" + + def test_evaluation(self): + """Test constructing and evaluating a discrete factor graph.""" + + # Three keys + P1 = (0, 2) + P2 = (1, 2) + P3 = (2, 3) + + # Create the DiscreteFactorGraph + graph = DiscreteFactorGraph() + + # Add two unary factors (priors) + graph.add(P1, [0.9, 0.3]) + graph.add(P2, "0.9 0.6") + + # Add a binary factor + graph.add([P1, P2], "4 1 10 4") + + # Instantiate Values + assignment = DiscreteValues() + assignment[0] = 1 + assignment[1] = 1 + + # Check if graph evaluation works ( 0.3*0.6*4 ) + self.assertAlmostEqual(.72, graph(assignment)) + + # Create a new test with third node and adding unary and ternary factor + graph.add(P3, "0.9 0.2 0.5") + keys = DiscreteKeys() + keys.push_back(P1) + keys.push_back(P2) + keys.push_back(P3) + graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12") + + # Below assignment selects the 8th index in the ternary factor table + assignment[0] = 1 + assignment[1] = 0 + assignment[2] = 1 + + # Check if graph evaluation works (0.3*0.9*1*0.2*8) + self.assertAlmostEqual(4.32, graph(assignment)) + + # Below assignment selects the 3rd index in the ternary factor table + assignment[0] = 0 + assignment[1] = 1 + assignment[2] = 0 + + # Check if graph evaluation works (0.9*0.6*1*0.9*4) + self.assertAlmostEqual(1.944, graph(assignment)) + + # Check if graph product works + product = graph.product() + self.assertAlmostEqual(1.944, product(assignment)) + + def test_optimize(self): + """Test constructing and optizing a discrete factor graph.""" + + # Three keys + C = (0, 2) + B = (1, 2) + A = (2, 2) + + # A simple factor graph (A)-fAC-(C)-fBC-(B) + # with smoothness priors + graph = DiscreteFactorGraph() + graph.add([A, C], "3 1 1 3") + graph.add([C, B], "3 1 1 3") + + # Test optimization + expectedValues = DiscreteValues() + expectedValues[0] = 0 + expectedValues[1] = 0 + expectedValues[2] = 0 + actualValues = graph.optimize() + self.assertEqual(list(actualValues.items()), + list(expectedValues.items())) + + def test_MPE(self): + """Test maximum probable explanation (MPE): same as optimize.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use maxProduct + dag = graph.maxProduct(OrderingType.COLAMD) + actualMPE = dag.argmax() + self.assertEqual(list(actualMPE.items()), + list(mpe.items())) + + # All in one + actualMPE2 = graph.optimize() + self.assertEqual(list(actualMPE2.items()), + list(mpe.items())) + + def test_sumProduct(self): + """Test sumProduct.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use default sumProduct + bayesNet = graph.sumProduct() + mpeProbability = bayesNet(mpe) + self.assertAlmostEqual(mpeProbability, 0.36) # regression + + # Use sumProduct + for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, + OrderingType.CUSTOM]: + bayesNet = graph.sumProduct(ordering_type) + self.assertEqual(bayesNet(mpe), mpeProbability) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_GraphvizFormatting.py b/python/gtsam/tests/test_GraphvizFormatting.py new file mode 100644 index 000000000..5962366ef --- /dev/null +++ b/python/gtsam/tests/test_GraphvizFormatting.py @@ -0,0 +1,135 @@ +""" +See LICENSE for the license information + +Unit tests for Graphviz formatting of NonlinearFactorGraph. +Author: senselessDev (contact by mentioning on GitHub, e.g. in PR#1059) +""" + +# pylint: disable=no-member, invalid-name + +import unittest +import textwrap + +import numpy as np + +import gtsam +from gtsam.utils.test_case import GtsamTestCase + + +class TestGraphvizFormatting(GtsamTestCase): + """Tests for saving NonlinearFactorGraph to GraphViz format.""" + + def setUp(self): + self.graph = gtsam.NonlinearFactorGraph() + + odometry = gtsam.Pose2(2.0, 0.0, 0.0) + odometryNoise = gtsam.noiseModel.Diagonal.Sigmas( + np.array([0.2, 0.2, 0.1])) + self.graph.add(gtsam.BetweenFactorPose2(0, 1, odometry, odometryNoise)) + self.graph.add(gtsam.BetweenFactorPose2(1, 2, odometry, odometryNoise)) + + self.values = gtsam.Values() + self.values.insert_pose2(0, gtsam.Pose2(0., 0., 0.)) + self.values.insert_pose2(1, gtsam.Pose2(2., 0., 0.)) + self.values.insert_pose2(2, gtsam.Pose2(4., 0., 0.)) + + def test_default(self): + """Test with default GraphvizFormatting""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + self.assertEqual(self.graph.dot(self.values), + textwrap.dedent(expected_result)) + + def test_swapped_axes(self): + """Test with user-defined GraphvizFormatting swapping x and y""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="2,0!"]; + var2[label="2", pos="4,0!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X + graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y + self.assertEqual(self.graph.dot(self.values, + formatting=graphviz_formatting), + textwrap.dedent(expected_result)) + + def test_factor_points(self): + """Test with user-defined GraphvizFormatting without factor points""" + expected_result = """\ + graph { + size="5,5"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + var0--var1; + var1--var2; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.plotFactorPoints = False + + self.assertEqual(self.graph.dot(self.values, + formatting=graphviz_formatting), + textwrap.dedent(expected_result)) + + def test_width_height(self): + """Test with user-defined GraphvizFormatting for width and height""" + expected_result = """\ + graph { + size="20,10"; + + var0[label="0", pos="0,0!"]; + var1[label="1", pos="0,2!"]; + var2[label="2", pos="0,4!"]; + + factor0[label="", shape=point]; + var0--factor0; + var1--factor0; + factor1[label="", shape=point]; + var1--factor1; + var2--factor1; + } + """ + + graphviz_formatting = gtsam.GraphvizFormatting() + graphviz_formatting.figureWidthInches = 20 + graphviz_formatting.figureHeightInches = 10 + + self.assertEqual(self.graph.dot(self.values, + formatting=graphviz_formatting), + textwrap.dedent(expected_result)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_KarcherMeanFactor.py b/python/gtsam/tests/test_KarcherMeanFactor.py index a315a506c..f4ec64283 100644 --- a/python/gtsam/tests/test_KarcherMeanFactor.py +++ b/python/gtsam/tests/test_KarcherMeanFactor.py @@ -15,27 +15,15 @@ import unittest import gtsam import numpy as np +from gtsam import Rot3 from gtsam.utils.test_case import GtsamTestCase KEY = 0 MODEL = gtsam.noiseModel.Unit.Create(3) -def find_Karcher_mean_Rot3(rotations): - """Find the Karcher mean of given values.""" - # Cost function C(R) = \sum PriorFactor(R_i)::error(R) - # No closed form solution. - graph = gtsam.NonlinearFactorGraph() - for R in rotations: - graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) - initial = gtsam.Values() - initial.insert(KEY, gtsam.Rot3()) - result = gtsam.GaussNewtonOptimizer(graph, initial).optimize() - return result.atRot3(KEY) - - # Rot3 version -R = gtsam.Rot3.Expmap(np.array([0.1, 0, 0])) +R = Rot3.Expmap(np.array([0.1, 0, 0])) class TestKarcherMean(GtsamTestCase): @@ -43,11 +31,23 @@ class TestKarcherMean(GtsamTestCase): def test_find(self): # Check that optimizing for Karcher mean (which minimizes Between distance) # gets correct result. - rotations = {R, R.inverse()} - expected = gtsam.Rot3() - actual = find_Karcher_mean_Rot3(rotations) + rotations = gtsam.Rot3Vector([R, R.inverse()]) + expected = Rot3() + actual = gtsam.FindKarcherMean(rotations) self.gtsamAssertEquals(expected, actual) + def test_find_karcher_mean_identity(self): + """Averaging 3 identity rotations should yield the identity.""" + a1Rb1 = Rot3() + a2Rb2 = Rot3() + a3Rb3 = Rot3() + + aRb_list = gtsam.Rot3Vector([a1Rb1, a2Rb2, a3Rb3]) + aRb_expected = Rot3() + + aRb = gtsam.FindKarcherMean(aRb_list) + self.gtsamAssertEquals(aRb, aRb_expected) + def test_factor(self): """Check that the InnerConstraint factor leaves the mean unchanged.""" # Make a graph with two variables, one between, and one InnerConstraint @@ -66,11 +66,11 @@ class TestKarcherMean(GtsamTestCase): initial = gtsam.Values() initial.insert(1, R.inverse()) initial.insert(2, R) - expected = find_Karcher_mean_Rot3([R, R.inverse()]) + expected = Rot3() result = gtsam.GaussNewtonOptimizer(graph, initial).optimize() - actual = find_Karcher_mean_Rot3( - [result.atRot3(1), result.atRot3(2)]) + actual = gtsam.FindKarcherMean( + gtsam.Rot3Vector([result.atRot3(1), result.atRot3(2)])) self.gtsamAssertEquals(expected, actual) self.gtsamAssertEquals( R12, result.atRot3(1).between(result.atRot3(2))) diff --git a/python/gtsam/tests/test_Sim3.py b/python/gtsam/tests/test_Sim3.py index 001321e2c..c00a36435 100644 --- a/python/gtsam/tests/test_Sim3.py +++ b/python/gtsam/tests/test_Sim3.py @@ -10,6 +10,7 @@ Author: John Lambert """ # pylint: disable=no-name-in-module import unittest +from typing import List, Optional import numpy as np @@ -129,6 +130,587 @@ class TestSim3(GtsamTestCase): for aTi, bTi in zip(aTi_list, bTi_list): self.gtsamAssertEquals(aTi, aSb.transformFrom(bTi)) + def test_align_via_Sim3_to_poses_skydio32(self) -> None: + """Ensure scale estimate of Sim(3) object is non-negative. + + Comes from real data (from Skydio-32 Crane Mast sequence with a SIFT front-end). + """ + poses_gt = [ + Pose3( + Rot3( + [ + [0.696305769, -0.0106830792, -0.717665705], + [0.00546412488, 0.999939148, -0.00958346857], + [0.717724415, 0.00275160848, 0.696321772], + ] + ), + Point3(5.83077801, -0.94815149, 0.397751679), + ), + Pose3( + Rot3( + [ + [0.692272397, -0.00529704529, -0.721616549], + [0.00634689669, 0.999979075, -0.00125157022], + [0.721608079, -0.0037136016, 0.692291531], + ] + ), + Point3(5.03853323, -0.97547405, -0.348177392), + ), + Pose3( + Rot3( + [ + [0.945991981, -0.00633548292, -0.324128225], + [0.00450436485, 0.999969379, -0.00639931046], + [0.324158843, 0.00459370582, 0.945991552], + ] + ), + Point3(4.13186176, -0.956364218, -0.796029527), + ), + Pose3( + Rot3( + [ + [0.999553623, -0.00346470207, -0.0296740626], + [0.00346104216, 0.999993995, -0.00017469881], + [0.0296744897, 7.19175654e-05, 0.999559612], + ] + ), + Point3(3.1113898, -0.928583423, -0.90539337), + ), + Pose3( + Rot3( + [ + [0.967850252, -0.00144846042, 0.251522892], + [0.000254511591, 0.999988546, 0.00477934325], + [-0.251526934, -0.00456167299, 0.967839535], + ] + ), + Point3(2.10584013, -0.921303194, -0.809322971), + ), + Pose3( + Rot3( + [ + [0.969854065, 0.000629052774, 0.243685716], + [0.000387180179, 0.999991428, -0.00412234326], + [-0.243686221, 0.00409242166, 0.969845508], + ] + ), + Point3(1.0753788, -0.913035975, -0.616584091), + ), + Pose3( + Rot3( + [ + [0.998189342, 0.00110235337, 0.0601400045], + [-0.00110890447, 0.999999382, 7.55559042e-05], + [-0.060139884, -0.000142108649, 0.998189948], + ] + ), + Point3(0.029993558, -0.951495122, -0.425525143), + ), + Pose3( + Rot3( + [ + [0.999999996, -2.62868666e-05, -8.67178281e-05], + [2.62791334e-05, 0.999999996, -8.91767396e-05], + [8.67201719e-05, 8.91744604e-05, 0.999999992], + ] + ), + Point3(-0.973569417, -0.936340994, -0.253464928), + ), + Pose3( + Rot3( + [ + [0.99481227, -0.00153645011, 0.101716252], + [0.000916919443, 0.999980747, 0.00613725239], + [-0.101723724, -0.00601214847, 0.994794525], + ] + ), + Point3(-2.02071256, -0.955446292, -0.240707879), + ), + Pose3( + Rot3( + [ + [0.89795602, -0.00978591184, 0.43997636], + [0.00645921401, 0.999938116, 0.00905779513], + [-0.440037771, -0.00529159974, 0.89796366], + ] + ), + Point3(-2.94096695, -0.939974858, 0.0934225593), + ), + Pose3( + Rot3( + [ + [0.726299119, -0.00916784876, 0.687318077], + [0.00892018672, 0.999952563, 0.0039118575], + [-0.687321336, 0.00328981905, 0.726346444], + ] + ), + Point3(-3.72843416, -0.897889251, 0.685129502), + ), + Pose3( + Rot3( + [ + [0.506756029, -0.000331706105, 0.862089858], + [0.00613841257, 0.999975964, -0.00322354286], + [-0.862068067, 0.00692541035, 0.506745885], + ] + ), + Point3(-4.3909926, -0.890883291, 1.43029524), + ), + Pose3( + Rot3( + [ + [0.129316352, -0.00206958814, 0.991601896], + [0.00515932597, 0.999985691, 0.00141424797], + [-0.991590634, 0.00493310721, 0.129325179], + ] + ), + Point3(-4.58510846, -0.922534227, 2.36884523), + ), + Pose3( + Rot3( + [ + [0.599853194, -0.00890004681, -0.800060263], + [0.00313716318, 0.999956608, -0.00877161373], + [0.800103615, 0.00275175707, 0.599855085], + ] + ), + Point3(5.71559638, 0.486863076, 0.279141372), + ), + Pose3( + Rot3( + [ + [0.762552447, 0.000836438681, -0.646926069], + [0.00211337894, 0.999990607, 0.00378404105], + [0.646923157, -0.00425272942, 0.762543517], + ] + ), + Point3(5.00243443, 0.513321893, -0.466921769), + ), + Pose3( + Rot3( + [ + [0.930381645, -0.00340164355, -0.36657678], + [0.00425636616, 0.999989781, 0.00152338305], + [0.366567852, -0.00297761145, 0.930386617], + ] + ), + Point3(4.05404984, 0.493385291, -0.827904571), + ), + Pose3( + Rot3( + [ + [0.999996073, -0.00278379707, -0.000323508543], + [0.00278790921, 0.999905063, 0.0134941517], + [0.000285912831, -0.0134950006, 0.999908897], + ] + ), + Point3(3.04724478, 0.491451306, -0.989571061), + ), + Pose3( + Rot3( + [ + [0.968578343, -0.002544616, 0.248695527], + [0.000806130148, 0.999974526, 0.00709200332], + [-0.248707238, -0.0066686795, 0.968555721], + ] + ), + Point3(2.05737869, 0.46840177, -0.546344594), + ), + Pose3( + Rot3( + [ + [0.968827882, 0.000182770584, 0.247734722], + [-0.000558107079, 0.9999988, 0.00144484904], + [-0.24773416, -0.00153807255, 0.968826821], + ] + ), + Point3(1.14019947, 0.469674641, -0.0491053805), + ), + Pose3( + Rot3( + [ + [0.991647805, 0.00197867892, 0.128960146], + [-0.00247518407, 0.999990129, 0.00368991165], + [-0.128951572, -0.00397829284, 0.991642914], + ] + ), + Point3(0.150270471, 0.457867448, 0.103628642), + ), + Pose3( + Rot3( + [ + [0.992244594, 0.00477781876, -0.124208847], + [-0.0037682125, 0.999957938, 0.00836195891], + [0.124243574, -0.00782906317, 0.992220862], + ] + ), + Point3(-0.937954641, 0.440532658, 0.154265069), + ), + Pose3( + Rot3( + [ + [0.999591078, 0.00215462857, -0.0285137564], + [-0.00183807224, 0.999936443, 0.0111234301], + [0.028535911, -0.0110664711, 0.999531507], + ] + ), + Point3(-1.95622231, 0.448914367, -0.0859439782), + ), + Pose3( + Rot3( + [ + [0.931835342, 0.000956922238, 0.362880212], + [0.000941640753, 0.99998678, -0.00505501434], + [-0.362880252, 0.00505214382, 0.931822122], + ] + ), + Point3(-2.85557418, 0.434739285, 0.0793777177), + ), + Pose3( + Rot3( + [ + [0.781615218, -0.0109886966, 0.623664238], + [0.00516954657, 0.999924591, 0.011139446], + [-0.623739616, -0.00548270158, 0.781613084], + ] + ), + Point3(-3.67524552, 0.444074681, 0.583718622), + ), + Pose3( + Rot3( + [ + [0.521291761, 0.00264805046, 0.853374051], + [0.00659087718, 0.999952868, -0.00712898365], + [-0.853352707, 0.00934076542, 0.521249738], + ] + ), + Point3(-4.35541796, 0.413479707, 1.31179007), + ), + Pose3( + Rot3( + [ + [0.320164205, -0.00890839482, 0.947319884], + [0.00458409304, 0.999958649, 0.007854118], + [-0.947350678, 0.00182799903, 0.320191803], + ] + ), + Point3(-4.71617526, 0.476674479, 2.16502998), + ), + Pose3( + Rot3( + [ + [0.464861609, 0.0268597443, -0.884976415], + [-0.00947397841, 0.999633409, 0.0253631906], + [0.885333239, -0.00340614699, 0.464945663], + ] + ), + Point3(6.11772094, 1.63029238, 0.491786626), + ), + Pose3( + Rot3( + [ + [0.691647251, 0.0216006293, -0.721912024], + [-0.0093228132, 0.999736395, 0.020981541], + [0.722174939, -0.00778156302, 0.691666308], + ] + ), + Point3(5.46912979, 1.68759322, -0.288499782), + ), + Pose3( + Rot3( + [ + [0.921208931, 0.00622640471, -0.389018433], + [-0.00686296262, 0.999976419, -0.000246683913], + [0.389007724, 0.00289706631, 0.92122994], + ] + ), + Point3(4.70156942, 1.72186229, -0.806181015), + ), + Pose3( + Rot3( + [ + [0.822397705, 0.00276497594, 0.568906142], + [0.00804891535, 0.999831556, -0.016494662], + [-0.568855921, 0.0181442503, 0.822236923], + ] + ), + Point3(-3.51368714, 1.59619714, 0.437437437), + ), + Pose3( + Rot3( + [ + [0.726822937, -0.00545541524, 0.686803193], + [0.00913794245, 0.999956756, -0.00172754968], + [-0.686764068, 0.00753159111, 0.726841357], + ] + ), + Point3(-4.29737821, 1.61462527, 1.11537749), + ), + Pose3( + Rot3( + [ + [0.402595481, 0.00697612855, 0.915351441], + [0.0114113638, 0.999855006, -0.0126391687], + [-0.915306892, 0.0155338804, 0.4024575], + ] + ), + Point3(-4.6516433, 1.6323107, 1.96579585), + ), + ] + # from estimated cameras + unaligned_pose_dict = { + 2: Pose3( + Rot3( + [ + [-0.681949, -0.568276, 0.460444], + [0.572389, -0.0227514, 0.819667], + [-0.455321, 0.822524, 0.34079], + ] + ), + Point3(-1.52189, 0.78906, -1.60608), + ), + 4: Pose3( + Rot3( + [ + [-0.817805393, -0.575044816, 0.022755196], + [0.0478829397, -0.0285875849, 0.998443776], + [-0.573499401, 0.81762229, 0.0509139174], + ] + ), + Point3(-1.22653168, 0.686485651, -1.39294168), + ), + 3: Pose3( + Rot3( + [ + [-0.783051568, -0.571905041, 0.244448085], + [0.314861464, -0.0255673164, 0.948793218], + [-0.536369743, 0.819921299, 0.200091385], + ] + ), + Point3(-1.37620079, 0.721408674, -1.49945316), + ), + 5: Pose3( + Rot3( + [ + [-0.818916586, -0.572896131, 0.0341415873], + [0.0550548476, -0.0192038786, 0.99829864], + [-0.571265778, 0.819402974, 0.0472670839], + ] + ), + Point3(-1.06370243, 0.663084159, -1.27672831), + ), + 6: Pose3( + Rot3( + [ + [-0.798825521, -0.571995242, 0.186277293], + [0.243311017, -0.0240196245, 0.969650869], + [-0.550161372, 0.819905178, 0.158360233], + ] + ), + Point3(-0.896250742, 0.640768239, -1.16984756), + ), + 7: Pose3( + Rot3( + [ + [-0.786416666, -0.570215296, 0.237493882], + [0.305475635, -0.0248440676, 0.951875732], + [-0.536873788, 0.821119534, 0.193724669], + ] + ), + Point3(-0.740385043, 0.613956842, -1.05908579), + ), + 8: Pose3( + Rot3( + [ + [-0.806252832, -0.57019757, 0.157578877], + [0.211046715, -0.0283979846, 0.977063375], + [-0.55264424, 0.821016617, 0.143234279], + ] + ), + Point3(-0.58333517, 0.549832698, -0.9542864), + ), + 9: Pose3( + Rot3( + [ + [-0.821191354, -0.557772774, -0.120558255], + [-0.125347331, -0.0297958331, 0.991665395], + [-0.556716092, 0.829458703, -0.0454472483], + ] + ), + Point3(-0.436483039, 0.55003923, -0.850733187), + ), + 21: Pose3( + Rot3( + [ + [-0.778607603, -0.575075476, 0.251114312], + [0.334920968, -0.0424301164, 0.941290407], + [-0.53065822, 0.816999316, 0.225641247], + ] + ), + Point3(-0.736735967, 0.571415247, -0.738663611), + ), + 17: Pose3( + Rot3( + [ + [-0.818569806, -0.573904529, 0.0240221722], + [0.0512889176, -0.0313725422, 0.998190969], + [-0.572112681, 0.818321059, 0.0551155579], + ] + ), + Point3(-1.36150982, 0.724829031, -1.16055631), + ), + 18: Pose3( + Rot3( + [ + [-0.812668105, -0.582027424, 0.0285417146], + [0.0570298244, -0.0306936169, 0.997900547], + [-0.579929436, 0.812589675, 0.0581366453], + ] + ), + Point3(-1.20484771, 0.762370343, -1.05057127), + ), + 20: Pose3( + Rot3( + [ + [-0.748446406, -0.580905382, 0.319963926], + [0.416860654, -0.0368374152, 0.908223651], + [-0.515805363, 0.813137099, 0.269727429], + ] + ), + Point3(569.449421, -907.892555, -794.585647), + ), + 22: Pose3( + Rot3( + [ + [-0.826878177, -0.559495019, -0.0569017041], + [-0.0452256802, -0.0346974602, 0.99837404], + [-0.560559647, 0.828107125, 0.00338702978], + ] + ), + Point3(-0.591431172, 0.55422253, -0.654656597), + ), + 29: Pose3( + Rot3( + [ + [-0.785759779, -0.574532433, -0.229115805], + [-0.246020939, -0.049553424, 0.967996981], + [-0.567499134, 0.81698038, -0.102409921], + ] + ), + Point3(69.4916073, 240.595227, -493.278045), + ), + 23: Pose3( + Rot3( + [ + [-0.783524382, -0.548569702, -0.291823276], + [-0.316457553, -0.051878563, 0.94718701], + [-0.534737468, 0.834493797, -0.132950906], + ] + ), + Point3(-5.93496204, 41.9304933, -3.06881633), + ), + 10: Pose3( + Rot3( + [ + [-0.766833992, -0.537641809, -0.350580824], + [-0.389506676, -0.0443270797, 0.919956336], + [-0.510147213, 0.84200736, -0.175423563], + ] + ), + Point3(234.185458, 326.007989, -691.769777), + ), + 30: Pose3( + Rot3( + [ + [-0.754844165, -0.559278755, -0.342662459], + [-0.375790683, -0.0594160018, 0.92479787], + [-0.537579435, 0.826847636, -0.165321923], + ] + ), + Point3(-5.93398168, 41.9107972, -3.07385081), + ), + } + + unaligned_pose_list = [] + for i in range(32): + wTi = unaligned_pose_dict.get(i, None) + unaligned_pose_list.append(wTi) + # GT poses are the reference/target + rSe = align_poses_sim3_ignore_missing(aTi_list=poses_gt, bTi_list=unaligned_pose_list) + assert rSe.scale() >= 0 + + +def align_poses_sim3_ignore_missing(aTi_list: List[Optional[Pose3]], bTi_list: List[Optional[Pose3]]) -> Similarity3: + """Align by similarity transformation, but allow missing estimated poses in the input. + + Note: this is a wrapper for align_poses_sim3() that allows for missing poses/dropped cameras. + This is necessary, as align_poses_sim3() requires a valid pose for every input pair. + + We force SIM(3) alignment rather than SE(3) alignment. + We assume the two trajectories are of the exact same length. + + Args: + aTi_list: reference poses in frame "a" which are the targets for alignment + bTi_list: input poses which need to be aligned to frame "a" + + Returns: + aSb: Similarity(3) object that aligns the two pose graphs. + """ + assert len(aTi_list) == len(bTi_list) + + # only choose target poses for which there is a corresponding estimated pose + corresponding_aTi_list = [] + valid_camera_idxs = [] + valid_bTi_list = [] + for i, bTi in enumerate(bTi_list): + if bTi is not None: + valid_camera_idxs.append(i) + valid_bTi_list.append(bTi) + corresponding_aTi_list.append(aTi_list[i]) + + aSb = align_poses_sim3(aTi_list=corresponding_aTi_list, bTi_list=valid_bTi_list) + return aSb + + +def align_poses_sim3(aTi_list: List[Pose3], bTi_list: List[Pose3]) -> Similarity3: + """Align two pose graphs via similarity transformation. Note: poses cannot be missing/invalid. + + We force SIM(3) alignment rather than SE(3) alignment. + We assume the two trajectories are of the exact same length. + + Args: + aTi_list: reference poses in frame "a" which are the targets for alignment + bTi_list: input poses which need to be aligned to frame "a" + + Returns: + aSb: Similarity(3) object that aligns the two pose graphs. + """ + n_to_align = len(aTi_list) + assert len(aTi_list) == len(bTi_list) + assert n_to_align >= 2, "SIM(3) alignment uses at least 2 frames" + + ab_pairs = Pose3Pairs(list(zip(aTi_list, bTi_list))) + + aSb = Similarity3.Align(ab_pairs) + + if np.isnan(aSb.scale()) or aSb.scale() == 0: + # we have run into a case where points have no translation between them (i.e. panorama). + # We will first align the rotations and then align the translation by using centroids. + # TODO: handle it in GTSAM + + # align the rotations first, so that we can find the translation between the two panoramas + aSb = Similarity3(aSb.rotation(), np.zeros((3,)), 1.0) + aTi_list_rot_aligned = [aSb.transformFrom(bTi) for bTi in bTi_list] + + # fit a single translation motion to the centroid + aTi_centroid = np.array([aTi.translation() for aTi in aTi_list]).mean(axis=0) + aTi_rot_aligned_centroid = np.array([aTi.translation() for aTi in aTi_list_rot_aligned]).mean(axis=0) + + # construct the final SIM3 transform + aSb = Similarity3(aSb.rotation(), aTi_centroid - aTi_rot_aligned_centroid, 1.0) + + return aSb + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/tests/test_Triangulation.py b/python/gtsam/tests/test_Triangulation.py index 399cf019d..0a258a0af 100644 --- a/python/gtsam/tests/test_Triangulation.py +++ b/python/gtsam/tests/test_Triangulation.py @@ -6,28 +6,40 @@ All Rights Reserved See LICENSE for the license information Test Triangulation -Author: Frank Dellaert & Fan Jiang (Python) +Authors: Frank Dellaert & Fan Jiang (Python) & Sushmita Warrier & John Lambert """ import unittest +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import gtsam -from gtsam import (Cal3_S2, Cal3Bundler, CameraSetCal3_S2, - CameraSetCal3Bundler, PinholeCameraCal3_S2, - PinholeCameraCal3Bundler, Point2Vector, Point3, Pose3, - Pose3Vector, Rot3) +from gtsam import ( + Cal3_S2, + Cal3Bundler, + CameraSetCal3_S2, + CameraSetCal3Bundler, + PinholeCameraCal3_S2, + PinholeCameraCal3Bundler, + Point2, + Point2Vector, + Point3, + Pose3, + Pose3Vector, + Rot3, +) from gtsam.utils.test_case import GtsamTestCase +UPRIGHT = Rot3.Ypr(-np.pi / 2, 0.0, -np.pi / 2) -class TestVisualISAMExample(GtsamTestCase): - """ Tests for triangulation with shared and individual calibrations """ + +class TestTriangulationExample(GtsamTestCase): + """Tests for triangulation with shared and individual calibrations""" def setUp(self): - """ Set up two camera poses """ + """Set up two camera poses""" # Looking along X-axis, 1 meter above ground plane (x-y) - upright = Rot3.Ypr(-np.pi / 2, 0., -np.pi / 2) - pose1 = Pose3(upright, Point3(0, 0, 1)) + pose1 = Pose3(UPRIGHT, Point3(0, 0, 1)) # create second camera 1 meter to the right of first camera pose2 = pose1.compose(Pose3(Rot3(), Point3(1, 0, 0))) @@ -39,15 +51,24 @@ class TestVisualISAMExample(GtsamTestCase): # landmark ~5 meters infront of camera self.landmark = Point3(5, 0.5, 1.2) - def generate_measurements(self, calibration, camera_model, cal_params, camera_set=None): + def generate_measurements( + self, + calibration: Union[Cal3Bundler, Cal3_S2], + camera_model: Union[PinholeCameraCal3Bundler, PinholeCameraCal3_S2], + cal_params: Iterable[Iterable[Union[int, float]]], + camera_set: Optional[Union[CameraSetCal3Bundler, + CameraSetCal3_S2]] = None, + ) -> Tuple[Point2Vector, Union[CameraSetCal3Bundler, CameraSetCal3_S2, + List[Cal3Bundler], List[Cal3_S2]]]: """ Generate vector of measurements for given calibration and camera model. - Args: + Args: calibration: Camera calibration e.g. Cal3_S2 camera_model: Camera model e.g. PinholeCameraCal3_S2 cal_params: Iterable of camera parameters for `calibration` e.g. [K1, K2] camera_set: Cameraset object (for individual calibrations) + Returns: list of measurements and list/CameraSet object for cameras """ @@ -66,14 +87,15 @@ class TestVisualISAMExample(GtsamTestCase): return measurements, cameras - def test_TriangulationExample(self): - """ Tests triangulation with shared Cal3_S2 calibration""" + def test_TriangulationExample(self) -> None: + """Tests triangulation with shared Cal3_S2 calibration""" # Some common constants sharedCal = (1500, 1200, 0, 640, 480) - measurements, _ = self.generate_measurements(Cal3_S2, - PinholeCameraCal3_S2, - (sharedCal, sharedCal)) + measurements, _ = self.generate_measurements( + calibration=Cal3_S2, + camera_model=PinholeCameraCal3_S2, + cal_params=(sharedCal, sharedCal)) triangulated_landmark = gtsam.triangulatePoint3(self.poses, Cal3_S2(sharedCal), @@ -95,16 +117,17 @@ class TestVisualISAMExample(GtsamTestCase): self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-2) - def test_distinct_Ks(self): - """ Tests triangulation with individual Cal3_S2 calibrations """ + def test_distinct_Ks(self) -> None: + """Tests triangulation with individual Cal3_S2 calibrations""" # two camera parameters K1 = (1500, 1200, 0, 640, 480) K2 = (1600, 1300, 0, 650, 440) - measurements, cameras = self.generate_measurements(Cal3_S2, - PinholeCameraCal3_S2, - (K1, K2), - camera_set=CameraSetCal3_S2) + measurements, cameras = self.generate_measurements( + calibration=Cal3_S2, + camera_model=PinholeCameraCal3_S2, + cal_params=(K1, K2), + camera_set=CameraSetCal3_S2) triangulated_landmark = gtsam.triangulatePoint3(cameras, measurements, @@ -112,16 +135,17 @@ class TestVisualISAMExample(GtsamTestCase): optimize=True) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) - def test_distinct_Ks_Bundler(self): - """ Tests triangulation with individual Cal3Bundler calibrations""" + def test_distinct_Ks_Bundler(self) -> None: + """Tests triangulation with individual Cal3Bundler calibrations""" # two camera parameters K1 = (1500, 0, 0, 640, 480) K2 = (1600, 0, 0, 650, 440) - measurements, cameras = self.generate_measurements(Cal3Bundler, - PinholeCameraCal3Bundler, - (K1, K2), - camera_set=CameraSetCal3Bundler) + measurements, cameras = self.generate_measurements( + calibration=Cal3Bundler, + camera_model=PinholeCameraCal3Bundler, + cal_params=(K1, K2), + camera_set=CameraSetCal3Bundler) triangulated_landmark = gtsam.triangulatePoint3(cameras, measurements, @@ -129,6 +153,71 @@ class TestVisualISAMExample(GtsamTestCase): optimize=True) self.gtsamAssertEquals(self.landmark, triangulated_landmark, 1e-9) + def test_triangulation_robust_three_poses(self) -> None: + """Ensure triangulation with a robust model works.""" + sharedCal = Cal3_S2(1500, 1200, 0, 640, 480) + + # landmark ~5 meters infront of camera + landmark = Point3(5, 0.5, 1.2) + + pose1 = Pose3(UPRIGHT, Point3(0, 0, 1)) + pose2 = pose1 * Pose3(Rot3(), Point3(1, 0, 0)) + pose3 = pose1 * Pose3(Rot3.Ypr(0.1, 0.2, 0.1), Point3(0.1, -2, -0.1)) + + camera1 = PinholeCameraCal3_S2(pose1, sharedCal) + camera2 = PinholeCameraCal3_S2(pose2, sharedCal) + camera3 = PinholeCameraCal3_S2(pose3, sharedCal) + + z1: Point2 = camera1.project(landmark) + z2: Point2 = camera2.project(landmark) + z3: Point2 = camera3.project(landmark) + + poses = gtsam.Pose3Vector([pose1, pose2, pose3]) + measurements = Point2Vector([z1, z2, z3]) + + # noise free, so should give exactly the landmark + actual = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=False) + self.assertTrue(np.allclose(landmark, actual, atol=1e-2)) + + # Add outlier + measurements[0] += Point2(100, 120) # very large pixel noise! + + # now estimate does not match landmark + actual2 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=False) + # DLT is surprisingly robust, but still off (actual error is around 0.26m) + self.assertTrue(np.linalg.norm(landmark - actual2) >= 0.2) + self.assertTrue(np.linalg.norm(landmark - actual2) <= 0.5) + + # Again with nonlinear optimization + actual3 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=True) + # result from nonlinear (but non-robust optimization) is close to DLT and still off + self.assertTrue(np.allclose(actual2, actual3, atol=0.1)) + + # Again with nonlinear optimization, this time with robust loss + model = gtsam.noiseModel.Robust.Create( + gtsam.noiseModel.mEstimator.Huber.Create(1.345), + gtsam.noiseModel.Unit.Create(2)) + actual4 = gtsam.triangulatePoint3(poses, + sharedCal, + measurements, + rank_tol=1e-9, + optimize=True, + model=model) + # using the Huber loss we now have a quite small error!! nice! + self.assertTrue(np.allclose(landmark, actual4, atol=0.05)) + if __name__ == "__main__": unittest.main() diff --git a/python/gtsam/utils/plot.py b/python/gtsam/utils/plot.py index 7ea393077..5ff7fd7aa 100644 --- a/python/gtsam/utils/plot.py +++ b/python/gtsam/utils/plot.py @@ -10,8 +10,15 @@ from matplotlib import patches from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import import gtsam -from gtsam import Marginals, Point3, Pose2, Pose3, Values +from gtsam import Marginals, Point2, Point3, Pose2, Pose3, Values +# For future reference: following +# https://www.xarg.org/2018/04/how-to-plot-a-covariance-error-ellipse/ +# we have, in 2D: +# def kk(p): return math.sqrt(-2*math.log(1-p)) # k to get p probability mass +# def pp(k): return 1-math.exp(-float(k**2)/2.0) # p as a function of k +# Some values: +# k = 5 => p = 99.9996 % def set_axes_equal(fignum: int) -> None: """ @@ -108,6 +115,66 @@ def plot_covariance_ellipse_3d(axes, axes.plot_surface(x, y, z, alpha=alpha, cmap='hot') +def plot_point2_on_axes(axes, + point: Point2, + linespec: str, + P: Optional[np.ndarray] = None) -> None: + """ + Plot a 2D point on given axis `axes` with given `linespec`. + + Args: + axes (matplotlib.axes.Axes): Matplotlib axes. + point: The point to be plotted. + linespec: String representing formatting options for Matplotlib. + P: Marginal covariance matrix to plot the uncertainty of the estimation. + """ + axes.plot([point[0]], [point[1]], linespec, marker='.', markersize=10) + if P is not None: + w, v = np.linalg.eig(P) + + # 5 sigma corresponds to 99.9996%, see note above + k = 5.0 + + angle = np.arctan2(v[1, 0], v[0, 0]) + e1 = patches.Ellipse(point, + np.sqrt(w[0] * k), + np.sqrt(w[1] * k), + np.rad2deg(angle), + fill=False) + axes.add_patch(e1) + + +def plot_point2( + fignum: int, + point: Point2, + linespec: str, + P: np.ndarray = None, + axis_labels: Iterable[str] = ("X axis", "Y axis"), +) -> plt.Figure: + """ + Plot a 2D point on given figure with given `linespec`. + + Args: + fignum: Integer representing the figure number to use for plotting. + point: The point to be plotted. + linespec: String representing formatting options for Matplotlib. + P: Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels: List of axis labels to set. + + Returns: + fig: The matplotlib figure. + + """ + fig = plt.figure(fignum) + axes = fig.gca() + plot_point2_on_axes(axes, point, linespec, P) + + axes.set_xlabel(axis_labels[0]) + axes.set_ylabel(axis_labels[1]) + + return fig + + def plot_pose2_on_axes(axes, pose: Pose2, axis_length: float = 0.1, @@ -142,7 +209,7 @@ def plot_pose2_on_axes(axes, w, v = np.linalg.eig(gPp) - # k = 2.296 + # 5 sigma corresponds to 99.9996%, see note above k = 5.0 angle = np.arctan2(v[1, 0], v[0, 0]) diff --git a/python/gtsam_unstable/gtsam_unstable.tpl b/python/gtsam_unstable/gtsam_unstable.tpl index aa7ac6bdb..055fcaea7 100644 --- a/python/gtsam_unstable/gtsam_unstable.tpl +++ b/python/gtsam_unstable/gtsam_unstable.tpl @@ -40,7 +40,7 @@ PYBIND11_MODULE({module_name}, m_) {{ {wrapped_namespace} -#include "python/gtsam_unstable/specializations.h" +#include "python/gtsam_unstable/specializations/gtsam_unstable.h" }} diff --git a/python/gtsam_unstable/specializations.h b/python/gtsam_unstable/specializations/gtsam_unstable.h similarity index 100% rename from python/gtsam_unstable/specializations.h rename to python/gtsam_unstable/specializations/gtsam_unstable.h diff --git a/tests/smallExample.h b/tests/smallExample.h index 944899e70..ca9a8580f 100644 --- a/tests/smallExample.h +++ b/tests/smallExample.h @@ -679,26 +679,25 @@ inline Ordering planarOrdering(size_t N) { } /* ************************************************************************* */ -inline std::pair splitOffPlanarTree(size_t N, - const GaussianFactorGraph& original) { - auto T = boost::make_shared(), C= boost::make_shared(); +inline std::pair splitOffPlanarTree( + size_t N, const GaussianFactorGraph& original) { + GaussianFactorGraph T, C; // Add the x11 constraint to the tree - T->push_back(original[0]); + T.push_back(original[0]); // Add all horizontal constraints to the tree size_t i = 1; for (size_t x = 1; x < N; x++) - for (size_t y = 1; y <= N; y++, i++) - T->push_back(original[i]); + for (size_t y = 1; y <= N; y++, i++) T.push_back(original[i]); // Add first vertical column of constraints to T, others to C for (size_t x = 1; x <= N; x++) for (size_t y = 1; y < N; y++, i++) if (x == 1) - T->push_back(original[i]); + T.push_back(original[i]); else - C->push_back(original[i]); + C.push_back(original[i]); return std::make_pair(T, C); } diff --git a/tests/testNonlinearFactor.cpp b/tests/testNonlinearFactor.cpp index fba7949a1..27b61cf89 100644 --- a/tests/testNonlinearFactor.cpp +++ b/tests/testNonlinearFactor.cpp @@ -101,6 +101,82 @@ TEST( NonlinearFactor, NonlinearFactor ) DOUBLES_EQUAL(expected,actual,0.00000001); } +/* ************************************************************************* */ +TEST(NonlinearFactor, Weight) { + // create a values structure for the non linear factor graph + Values values; + + // Instantiate a concrete class version of a NoiseModelFactor + PriorFactor factor1(X(1), Point2(0, 0)); + values.insert(X(1), Point2(0.1, 0.1)); + + CHECK(assert_equal(1.0, factor1.weight(values))); + + // Factor with noise model + auto noise = noiseModel::Isotropic::Sigma(2, 0.2); + PriorFactor factor2(X(2), Point2(1, 1), noise); + values.insert(X(2), Point2(1.1, 1.1)); + + CHECK(assert_equal(1.0, factor2.weight(values))); + + Point2 estimate(3, 3), prior(1, 1); + double distance = (estimate - prior).norm(); + + auto gaussian = noiseModel::Isotropic::Sigma(2, 0.2); + + PriorFactor factor; + + // vector to store all the robust models in so we can test iteratively. + vector robust_models; + + // Fair noise model + auto fair = noiseModel::Robust::Create( + noiseModel::mEstimator::Fair::Create(1.3998), gaussian); + robust_models.push_back(fair); + + // Huber noise model + auto huber = noiseModel::Robust::Create( + noiseModel::mEstimator::Huber::Create(1.345), gaussian); + robust_models.push_back(huber); + + // Cauchy noise model + auto cauchy = noiseModel::Robust::Create( + noiseModel::mEstimator::Cauchy::Create(0.1), gaussian); + robust_models.push_back(cauchy); + + // Tukey noise model + auto tukey = noiseModel::Robust::Create( + noiseModel::mEstimator::Tukey::Create(4.6851), gaussian); + robust_models.push_back(tukey); + + // Welsch noise model + auto welsch = noiseModel::Robust::Create( + noiseModel::mEstimator::Welsch::Create(2.9846), gaussian); + robust_models.push_back(welsch); + + // Geman-McClure noise model + auto gm = noiseModel::Robust::Create( + noiseModel::mEstimator::GemanMcClure::Create(1.0), gaussian); + robust_models.push_back(gm); + + // DCS noise model + auto dcs = noiseModel::Robust::Create( + noiseModel::mEstimator::DCS::Create(1.0), gaussian); + robust_models.push_back(dcs); + + // L2WithDeadZone noise model + auto l2 = noiseModel::Robust::Create( + noiseModel::mEstimator::L2WithDeadZone::Create(1.0), gaussian); + robust_models.push_back(l2); + + for(auto&& model: robust_models) { + factor = PriorFactor(X(3), prior, model); + values.clear(); + values.insert(X(3), estimate); + CHECK(assert_equal(model->robust()->weight(distance), factor.weight(values))); + } +} + /* ************************************************************************* */ TEST( NonlinearFactor, linearize_f1 ) { diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index fdb080a63..05a6e7f45 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -15,6 +15,7 @@ * @brief testNonlinearFactorGraph * @author Carlos Nieto * @author Christian Potthast + * @author Frank Dellaert */ #include @@ -106,6 +107,24 @@ TEST( NonlinearFactorGraph, probPrime ) DOUBLES_EQUAL(expected,actual,0); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, ProbPrime2) { + NonlinearFactorGraph fg; + fg.emplace_shared>(1, 0.0, + noiseModel::Isotropic::Sigma(1, 1.0)); + + Values values; + values.insert(1, 1.0); + + // The prior factor squared error is: 0.5. + EXPECT_DOUBLES_EQUAL(0.5, fg.error(values), 1e-12); + + // The probability value is: exp^(-factor_error) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-factor_error) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, fg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ TEST( NonlinearFactorGraph, linearize ) { @@ -285,6 +304,7 @@ TEST(testNonlinearFactorGraph, addPrior) { EXPECT(0 != graph.error(values)); } +/* ************************************************************************* */ TEST(NonlinearFactorGraph, printErrors) { const NonlinearFactorGraph fg = createNonlinearFactorGraph(); @@ -309,6 +329,65 @@ TEST(NonlinearFactorGraph, printErrors) for (bool visit : visited) EXPECT(visit==true); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varl1[label=\"l1\"];\n" + " varx1[label=\"x1\"];\n" + " varx2[label=\"x2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + string actual = fg.dot(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot_extra) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " varl1[label=\"l1\", pos=\"0,0!\"];\n" + " varx1[label=\"x1\", pos=\"1,0!\"];\n" + " varx2[label=\"x2\", pos=\"1,1.5!\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + const Values c = createValues(); + + stringstream ss; + fg.dot(ss, c); + EXPECT(ss.str() == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ diff --git a/tests/testSubgraphPreconditioner.cpp b/tests/testSubgraphPreconditioner.cpp index 84ccc131a..eeba38b68 100644 --- a/tests/testSubgraphPreconditioner.cpp +++ b/tests/testSubgraphPreconditioner.cpp @@ -77,8 +77,8 @@ TEST(SubgraphPreconditioner, planarGraph) { DOUBLES_EQUAL(0, error(A, xtrue), 1e-9); // check zero error for xtrue // Check that xtrue is optimal - GaussianBayesNet::shared_ptr R1 = A.eliminateSequential(); - VectorValues actual = R1->optimize(); + GaussianBayesNet R1 = *A.eliminateSequential(); + VectorValues actual = R1.optimize(); EXPECT(assert_equal(xtrue, actual)); } @@ -90,14 +90,14 @@ TEST(SubgraphPreconditioner, splitOffPlanarTree) { boost::tie(A, xtrue) = planarGraph(3); // Get the spanning tree and constraints, and check their sizes - GaussianFactorGraph::shared_ptr T, C; + GaussianFactorGraph T, C; boost::tie(T, C) = splitOffPlanarTree(3, A); - LONGS_EQUAL(9, T->size()); - LONGS_EQUAL(4, C->size()); + LONGS_EQUAL(9, T.size()); + LONGS_EQUAL(4, C.size()); // Check that the tree can be solved to give the ground xtrue - GaussianBayesNet::shared_ptr R1 = T->eliminateSequential(); - VectorValues xbar = R1->optimize(); + GaussianBayesNet R1 = *T.eliminateSequential(); + VectorValues xbar = R1.optimize(); EXPECT(assert_equal(xtrue, xbar)); } @@ -110,31 +110,29 @@ TEST(SubgraphPreconditioner, system) { boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b // Get the spanning tree and remaining graph - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); // Eliminate the spanning tree to build a prior const Ordering ord = planarOrdering(N); - auto Rc1 = Ab1->eliminateSequential(ord); // R1*x-c1 - VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1 + auto Rc1 = *Ab1.eliminateSequential(ord); // R1*x-c1 + VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1 // Create Subgraph-preconditioned system - VectorValues::shared_ptr xbarShared( - new VectorValues(xbar)); // TODO: horrible - const SubgraphPreconditioner system(Ab2, Rc1, xbarShared); + const SubgraphPreconditioner system(Ab2, Rc1, xbar); // Get corresponding matrices for tests. Add dummy factors to Ab2 to make // sure it works with the ordering. - Ordering ordering = Rc1->ordering(); // not ord in general! - Ab2->add(key(1, 1), Z_2x2, Z_2x1); - Ab2->add(key(1, 2), Z_2x2, Z_2x1); - Ab2->add(key(1, 3), Z_2x2, Z_2x1); + Ordering ordering = Rc1.ordering(); // not ord in general! + Ab2.add(key(1, 1), Z_2x2, Z_2x1); + Ab2.add(key(1, 2), Z_2x2, Z_2x1); + Ab2.add(key(1, 3), Z_2x2, Z_2x1); Matrix A, A1, A2; Vector b, b1, b2; std::tie(A, b) = Ab.jacobian(ordering); - std::tie(A1, b1) = Ab1->jacobian(ordering); - std::tie(A2, b2) = Ab2->jacobian(ordering); - Matrix R1 = Rc1->matrix(ordering).first; + std::tie(A1, b1) = Ab1.jacobian(ordering); + std::tie(A2, b2) = Ab2.jacobian(ordering); + Matrix R1 = Rc1.matrix(ordering).first; Matrix Abar(13 * 2, 9 * 2); Abar.topRows(9 * 2) = Matrix::Identity(9 * 2, 9 * 2); Abar.bottomRows(8) = A2.topRows(8) * R1.inverse(); @@ -151,7 +149,7 @@ TEST(SubgraphPreconditioner, system) { y1[key(3, 3)] = Vector2(1.0, -1.0); // Check backSubstituteTranspose works with R1 - VectorValues actual = Rc1->backSubstituteTranspose(y1); + VectorValues actual = Rc1.backSubstituteTranspose(y1); Vector expected = R1.transpose().inverse() * vec(y1); EXPECT(assert_equal(expected, vec(actual))); @@ -230,7 +228,7 @@ TEST(SubgraphSolver, Solves) { system.build(Ab, keyInfo, lambda); // Create a perturbed (non-zero) RHS - const auto xbar = system.Rc1()->optimize(); // merely for use in zero below + const auto xbar = system.Rc1().optimize(); // merely for use in zero below auto values_y = VectorValues::Zero(xbar); auto it = values_y.begin(); it->second.setConstant(100); @@ -238,13 +236,13 @@ TEST(SubgraphSolver, Solves) { it->second.setConstant(-100); // Solve the VectorValues way - auto values_x = system.Rc1()->backSubstitute(values_y); + auto values_x = system.Rc1().backSubstitute(values_y); // Solve the matrix way, this really just checks BN::backSubstitute // This only works with Rc1 ordering, not with keyInfo ! // TODO(frank): why does this not work with an arbitrary ordering? - const auto ord = system.Rc1()->ordering(); - const Matrix R1 = system.Rc1()->matrix(ord).first; + const auto ord = system.Rc1().ordering(); + const Matrix R1 = system.Rc1().matrix(ord).first; auto ord_y = values_y.vector(ord); auto vector_x = R1.inverse() * ord_y; EXPECT(assert_equal(vector_x, values_x.vector(ord))); @@ -261,7 +259,7 @@ TEST(SubgraphSolver, Solves) { // Test that transposeSolve does implement x = R^{-T} y // We do this by asserting it gives same answer as backSubstituteTranspose - auto values_x2 = system.Rc1()->backSubstituteTranspose(values_y); + auto values_x2 = system.Rc1().backSubstituteTranspose(values_y); Vector solveT_x = Vector::Zero(N); system.transposeSolve(vector_y, solveT_x); EXPECT(assert_equal(values_x2.vector(ordering), solveT_x)); @@ -277,18 +275,15 @@ TEST(SubgraphPreconditioner, conjugateGradients) { boost::tie(Ab, xtrue) = planarGraph(N); // A*x-b // Get the spanning tree - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 boost::tie(Ab1, Ab2) = splitOffPlanarTree(N, Ab); // Eliminate the spanning tree to build a prior - SubgraphPreconditioner::sharedBayesNet Rc1 = - Ab1->eliminateSequential(); // R1*x-c1 - VectorValues xbar = Rc1->optimize(); // xbar = inv(R1)*c1 + GaussianBayesNet Rc1 = *Ab1.eliminateSequential(); // R1*x-c1 + VectorValues xbar = Rc1.optimize(); // xbar = inv(R1)*c1 // Create Subgraph-preconditioned system - VectorValues::shared_ptr xbarShared( - new VectorValues(xbar)); // TODO: horrible - SubgraphPreconditioner system(Ab2, Rc1, xbarShared); + SubgraphPreconditioner system(Ab2, Rc1, xbar); // Create zero config y0 and perturbed config y1 VectorValues y0 = VectorValues::Zero(xbar); diff --git a/tests/testSubgraphSolver.cpp b/tests/testSubgraphSolver.cpp index cca13c822..5d8d88775 100644 --- a/tests/testSubgraphSolver.cpp +++ b/tests/testSubgraphSolver.cpp @@ -68,10 +68,10 @@ TEST( SubgraphSolver, splitFactorGraph ) auto subgraph = builder(Ab); EXPECT_LONGS_EQUAL(9, subgraph.size()); - GaussianFactorGraph::shared_ptr Ab1, Ab2; + GaussianFactorGraph Ab1, Ab2; std::tie(Ab1, Ab2) = splitFactorGraph(Ab, subgraph); - EXPECT_LONGS_EQUAL(9, Ab1->size()); - EXPECT_LONGS_EQUAL(13, Ab2->size()); + EXPECT_LONGS_EQUAL(9, Ab1.size()); + EXPECT_LONGS_EQUAL(13, Ab2.size()); } /* ************************************************************************* */ @@ -99,12 +99,12 @@ TEST( SubgraphSolver, constructor2 ) std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b // Get the spanning tree - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); // The second constructor takes two factor graphs, so the caller can specify // the preconditioner (Ab1) and the constraints that are left out (Ab2) - SubgraphSolver solver(*Ab1, Ab2, kParameters, kOrdering); + SubgraphSolver solver(Ab1, Ab2, kParameters, kOrdering); VectorValues optimized = solver.optimize(); DOUBLES_EQUAL(0.0, error(Ab, optimized), 1e-5); } @@ -119,11 +119,11 @@ TEST( SubgraphSolver, constructor3 ) std::tie(Ab, xtrue) = example::planarGraph(N); // A*x-b // Get the spanning tree and corresponding kOrdering - GaussianFactorGraph::shared_ptr Ab1, Ab2; // A1*x-b1 and A2*x-b2 + GaussianFactorGraph Ab1, Ab2; // A1*x-b1 and A2*x-b2 std::tie(Ab1, Ab2) = example::splitOffPlanarTree(N, Ab); // The caller solves |A1*x-b1|^2 == |R1*x-c1|^2, where R1 is square UT - auto Rc1 = Ab1->eliminateSequential(); + auto Rc1 = *Ab1.eliminateSequential(); // The third constructor allows the caller to pass an already solved preconditioner Rc1_ // as a Bayes net, in addition to the "loop closing constraints" Ab2, as before diff --git a/wrap/cmake/MatlabWrap.cmake b/wrap/cmake/MatlabWrap.cmake index 083b88566..3cb058102 100644 --- a/wrap/cmake/MatlabWrap.cmake +++ b/wrap/cmake/MatlabWrap.cmake @@ -62,10 +62,10 @@ macro(find_and_configure_matlab) endmacro() # Consistent and user-friendly wrap function -function(matlab_wrap interfaceHeader linkLibraries +function(matlab_wrap interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) find_and_configure_matlab() - wrap_and_install_library("${interfaceHeader}" "${linkLibraries}" + wrap_and_install_library("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${extraMexFlags}" "${ignore_classes}") endfunction() @@ -77,6 +77,7 @@ endfunction() # Arguments: # # interfaceHeader: The relative path to the wrapper interface definition file. +# moduleName: The name of the wrapped module, e.g. gtsam # linkLibraries: Any *additional* libraries to link. Your project library # (e.g. `lba`), libraries it depends on, and any necessary MATLAB libraries will # be linked automatically. So normally, leave this empty. @@ -85,15 +86,15 @@ endfunction() # extraMexFlags: Any *additional* flags to pass to the compiler when building # the wrap code. Normally, leave this empty. # ignore_classes: List of classes to ignore in the wrapping. -function(wrap_and_install_library interfaceHeader linkLibraries +function(wrap_and_install_library interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags ignore_classes) - wrap_library_internal("${interfaceHeader}" "${linkLibraries}" + wrap_library_internal("${interfaceHeader}" "${moduleName}" "${linkLibraries}" "${extraIncludeDirs}" "${mexFlags}") - install_wrapped_library_internal("${interfaceHeader}") + install_wrapped_library_internal("${moduleName}") endfunction() # Internal function that wraps a library and compiles the wrapper -function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs +function(wrap_library_internal interfaceHeader moduleName linkLibraries extraIncludeDirs extraMexFlags) if(UNIX AND NOT APPLE) if(CMAKE_SIZEOF_VOID_P EQUAL 8) @@ -120,7 +121,6 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # Extract module name from interface header file name get_filename_component(interfaceHeader "${interfaceHeader}" ABSOLUTE) get_filename_component(modulePath "${interfaceHeader}" PATH) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) # Paths for generated files set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") @@ -136,8 +136,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs # explicit link libraries list so that the next block of code can unpack any # static libraries set(automaticDependencies "") - foreach(lib ${moduleName} ${linkLibraries}) - # message("MODULE NAME: ${moduleName}") + foreach(lib ${module} ${linkLibraries}) if(TARGET "${lib}") get_target_property(dependentLibraries ${lib} INTERFACE_LINK_LIBRARIES) # message("DEPENDENT LIBRARIES: ${dependentLibraries}") @@ -176,7 +175,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs set(otherLibraryTargets "") set(otherLibraryNontargets "") set(otherSourcesAndObjects "") - foreach(lib ${moduleName} ${linkLibraries} ${automaticDependencies}) + foreach(lib ${module} ${linkLibraries} ${automaticDependencies}) if(TARGET "${lib}") if(WRAP_MEX_BUILD_STATIC_MODULE) get_target_property(target_sources ${lib} SOURCES) @@ -250,7 +249,7 @@ function(wrap_library_internal interfaceHeader linkLibraries extraIncludeDirs COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" - ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src ${interfaceHeader} + ${PYTHON_EXECUTABLE} ${MATLAB_WRAP_SCRIPT} --src "${interfaceHeader}" --module_name ${moduleName} --out ${generated_files_path} --top_module_namespaces ${moduleName} --ignore ${ignore_classes} VERBATIM @@ -324,8 +323,8 @@ endfunction() # Internal function that installs a wrap toolbox function(install_wrapped_library_internal interfaceHeader) - get_filename_component(moduleName "${interfaceHeader}" NAME_WE) - set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${moduleName}") + get_filename_component(module "${interfaceHeader}" NAME_WE) + set(generated_files_path "${PROJECT_BINARY_DIR}/wrap/${module}") # NOTE: only installs .m and mex binary files (not .cpp) - the trailing slash # on the directory name here prevents creating the top-level module name diff --git a/wrap/cmake/PybindWrap.cmake b/wrap/cmake/PybindWrap.cmake index 2149c7195..2008bf2dd 100644 --- a/wrap/cmake/PybindWrap.cmake +++ b/wrap/cmake/PybindWrap.cmake @@ -55,15 +55,44 @@ function( set(GTWRAP_PATH_SEPARATOR ";") endif() + # Create a copy of interface_headers so we can freely manipulate it + set(interface_files ${interface_headers}) + + # Pop the main interface file so that interface_files has only submodules. + list(POP_FRONT interface_files main_interface) + # Convert .i file names to .cpp file names. - foreach(filepath ${interface_headers}) - get_filename_component(interface ${filepath} NAME) - string(REPLACE ".i" ".cpp" cpp_file ${interface}) + foreach(interface_file ${interface_files}) + # This block gets the interface file name and does the replacement + get_filename_component(interface ${interface_file} NAME_WLE) + set(cpp_file "${interface}.cpp") list(APPEND cpp_files ${cpp_file}) + + # Wrap the specific interface header + # This is done so that we can create CMake dependencies in such a way so that when changing a single .i file, + # the others don't need to be regenerated. + # NOTE: We have to use `add_custom_command` so set the dependencies correctly. + # https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes + add_custom_command( + OUTPUT ${cpp_file} + COMMAND + ${CMAKE_COMMAND} -E env + "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" + ${PYTHON_EXECUTABLE} ${PYBIND_WRAP_SCRIPT} --src "${interface_file}" + --out "${cpp_file}" --module_name ${module_name} + --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} + --template ${module_template} --is_submodule ${_WRAP_BOOST_ARG} + DEPENDS "${interface_file}" ${module_template} "${module_name}/specializations/${interface}.h" "${module_name}/preamble/${interface}.h" + VERBATIM) + endforeach() + get_filename_component(main_interface_name ${main_interface} NAME_WLE) + set(main_cpp_file "${main_interface_name}.cpp") + list(PREPEND cpp_files ${main_cpp_file}) + add_custom_command( - OUTPUT ${cpp_files} + OUTPUT ${main_cpp_file} COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${GTWRAP_PACKAGE_DIR}${GTWRAP_PATH_SEPARATOR}$ENV{PYTHONPATH}" @@ -71,23 +100,10 @@ function( --out "${generated_cpp}" --module_name ${module_name} --top_module_namespaces "${top_namespace}" --ignore ${ignore_classes} --template ${module_template} ${_WRAP_BOOST_ARG} - DEPENDS "${interface_headers}" ${module_template} + DEPENDS "${main_interface}" ${module_template} "${module_name}/specializations/${main_interface_name}.h" "${module_name}/specializations/${main_interface_name}.h" VERBATIM) - add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${cpp_files}) - - # Late dependency injection, to make sure this gets called whenever the - # interface header or the wrap library are updated. - # ~~~ - # See: https://stackoverflow.com/questions/40032593/cmake-does-not-rebuild-dependent-after-prerequisite-changes - # ~~~ - add_custom_command( - OUTPUT ${cpp_files} - DEPENDS ${interface_headers} - # @GTWRAP_SOURCE_DIR@/gtwrap/interface_parser.py - # @GTWRAP_SOURCE_DIR@/gtwrap/pybind_wrapper.py - # @GTWRAP_SOURCE_DIR@/gtwrap/template_instantiator.py - APPEND) + add_custom_target(pybind_wrap_${module_name} DEPENDS ${cpp_files}) pybind11_add_module(${target} "${cpp_files}") diff --git a/wrap/gtwrap/interface_parser/type.py b/wrap/gtwrap/interface_parser/type.py index e94db4ff2..7aacf0b81 100644 --- a/wrap/gtwrap/interface_parser/type.py +++ b/wrap/gtwrap/interface_parser/type.py @@ -53,6 +53,10 @@ class Typename: self.name = t[-1] # the name is the last element in this list self.namespaces = t[:-1] + # If the first namespace is empty string, just get rid of it. + if self.namespaces and self.namespaces[0] == '': + self.namespaces.pop(0) + if instantiations: if isinstance(instantiations, Sequence): self.instantiations = instantiations # type: ignore @@ -92,8 +96,8 @@ class Typename: else: cpp_name = self.name return '{}{}{}'.format( - "::".join(self.namespaces[idx:]), - "::" if self.namespaces[idx:] else "", + "::".join(self.namespaces), + "::" if self.namespaces else "", cpp_name, ) diff --git a/wrap/gtwrap/matlab_wrapper/mixins.py b/wrap/gtwrap/matlab_wrapper/mixins.py index 2d7c75b39..f4a7988fd 100644 --- a/wrap/gtwrap/matlab_wrapper/mixins.py +++ b/wrap/gtwrap/matlab_wrapper/mixins.py @@ -108,7 +108,7 @@ class FormatMixin: elif is_method: formatted_type_name += self.data_type_param.get(name) or name else: - formatted_type_name += name + formatted_type_name += str(name) if separator == "::": # C++ templates = [] @@ -192,10 +192,9 @@ class FormatMixin: method = '' if isinstance(static_method, parser.StaticMethod): - method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ - separator + static_method.parent.name + separator + method += static_method.parent.to_cpp() + separator - return method[2 * len(separator):] + return method def _format_global_function(self, function: Union[parser.GlobalFunction, Any], diff --git a/wrap/gtwrap/matlab_wrapper/templates.py b/wrap/gtwrap/matlab_wrapper/templates.py index 7aaf8f487..3d1306dca 100644 --- a/wrap/gtwrap/matlab_wrapper/templates.py +++ b/wrap/gtwrap/matlab_wrapper/templates.py @@ -66,7 +66,7 @@ class WrapperTemplate: mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) {{ + if(mexPutVariable("global", "gtsam_{module_name}_rttiRegistry_created", newAlreadyCreated) != 0) {{ mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); }} mxDestroyArray(newAlreadyCreated); diff --git a/wrap/gtwrap/matlab_wrapper/wrapper.py b/wrap/gtwrap/matlab_wrapper/wrapper.py index 97945f73a..42610999d 100755 --- a/wrap/gtwrap/matlab_wrapper/wrapper.py +++ b/wrap/gtwrap/matlab_wrapper/wrapper.py @@ -5,6 +5,7 @@ that Matlab's MEX compiler can use. # pylint: disable=too-many-lines, no-self-use, too-many-arguments, too-many-branches, too-many-statements +import copy import os import os.path as osp import textwrap @@ -13,6 +14,7 @@ from typing import Dict, Iterable, List, Union import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator +from gtwrap.interface_parser.function import ArgumentList from gtwrap.matlab_wrapper.mixins import CheckMixin, FormatMixin from gtwrap.matlab_wrapper.templates import WrapperTemplate @@ -137,6 +139,37 @@ class MatlabWrapper(CheckMixin, FormatMixin): """ return x + '\n' + ('' if y == '' else ' ') + y + @staticmethod + def _expand_default_arguments(method, save_backup=True): + """Recursively expand all possibilities for optional default arguments. + We create "overload" functions with fewer arguments, but since we have to "remember" what + the default arguments are for later, we make a backup. + """ + def args_copy(args): + return ArgumentList([copy.copy(arg) for arg in args.list()]) + def method_copy(method): + method2 = copy.copy(method) + method2.args = args_copy(method.args) + method2.args.backup = method.args.backup + return method2 + if save_backup: + method.args.backup = args_copy(method.args) + method = method_copy(method) + for arg in reversed(method.args.list()): + if arg.default is not None: + arg.default = None + methodWithArg = method_copy(method) + method.args.list().remove(arg) + return [ + methodWithArg, + *MatlabWrapper._expand_default_arguments(method, save_backup=False) + ] + break + assert all(arg.default is None for arg in method.args.list()), \ + 'In parsing method {:}: Arguments with default values cannot appear before ones ' \ + 'without default values.'.format(method.name) + return [method] + def _group_methods(self, methods): """Group overloaded methods together""" method_map = {} @@ -147,9 +180,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): if method_index is None: method_map[method.name] = len(method_out) - method_out.append([method]) + method_out.append(MatlabWrapper._expand_default_arguments(method)) else: - method_out[method_index].append(method) + method_out[method_index] += MatlabWrapper._expand_default_arguments(method) return method_out @@ -239,18 +272,18 @@ class MatlabWrapper(CheckMixin, FormatMixin): return var_list_wrap - def _wrap_method_check_statement(self, args): + def _wrap_method_check_statement(self, args: parser.ArgumentList): """ Wrap the given arguments into either just a varargout call or a call in an if statement that checks if the parameters are accurate. + + TODO Update this method so that default arguments are supported. """ - check_statement = '' arg_id = 1 - if check_statement == '': - check_statement = \ - 'if length(varargin) == {param_count}'.format( - param_count=len(args.list())) + param_count = len(args) + check_statement = 'if length(varargin) == {param_count}'.format( + param_count=param_count) for _, arg in enumerate(args.list()): name = arg.ctype.typename.name @@ -301,13 +334,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");), ((a), std::shared_ptr p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");) """ - params = '' body_args = '' for arg in args.list(): - if params != '': - params += ',' - if self.is_ref(arg.ctype): # and not constructor: ctype_camel = self._format_type_name(arg.ctype.typename, separator='') @@ -336,8 +365,6 @@ class MatlabWrapper(CheckMixin, FormatMixin): name=arg.name, id=arg_id)), prefix=' ') - if call_type == "": - params += "*" else: body_args += textwrap.indent(textwrap.dedent('''\ @@ -347,10 +374,29 @@ class MatlabWrapper(CheckMixin, FormatMixin): id=arg_id)), prefix=' ') - params += arg.name - arg_id += 1 + params = '' + explicit_arg_names = [arg.name for arg in args.list()] + # when returning the params list, we need to re-include the default args. + for arg in args.backup.list(): + if params != '': + params += ',' + + if (arg.default is not None) and (arg.name not in explicit_arg_names): + params += arg.default + continue + + if (not self.is_ref(arg.ctype)) and (self.is_shared_ptr(arg.ctype)) and (self.is_ptr( + arg.ctype)) and (arg.ctype.typename.name not in self.ignore_namespace): + if arg.ctype.is_shared_ptr: + call_type = arg.ctype.is_shared_ptr + else: + call_type = arg.ctype.is_ptr + if call_type == "": + params += "*" + params += arg.name + return params, body_args @staticmethod @@ -555,6 +601,8 @@ class MatlabWrapper(CheckMixin, FormatMixin): if not isinstance(ctors, Iterable): ctors = [ctors] + ctors = sum((MatlabWrapper._expand_default_arguments(ctor) for ctor in ctors), []) + methods_wrap = textwrap.indent(textwrap.dedent("""\ methods function obj = {class_name}(varargin) @@ -674,20 +722,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): def _group_class_methods(self, methods): """Group overloaded methods together""" - method_map = {} - method_out = [] - - for method in methods: - method_index = method_map.get(method.name) - - if method_index is None: - method_map[method.name] = len(method_out) - method_out.append([method]) - else: - # print("[_group_methods] Merging {} with {}".format(method_index, method.name)) - method_out[method_index].append(method) - - return method_out + return self._group_methods(methods) @classmethod def _format_varargout(cls, return_type, return_type_formatted): @@ -809,7 +844,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): for static_method in static_methods: format_name = list(static_method[0].name) - format_name[0] = format_name[0].upper() + format_name[0] = format_name[0] if static_method[0].name in self.ignore_methods: continue @@ -850,12 +885,13 @@ class MatlabWrapper(CheckMixin, FormatMixin): wrapper=self._wrapper_name(), id=self._update_wrapper_id( (namespace_name, instantiated_class, - static_overload.name, static_overload)), + static_overload.name, static_overload)), class_name=instantiated_class.name, end_statement=end_statement), - prefix=' ') + prefix=' ') - #TODO Figure out what is static_overload doing here. + # If the arguments don't match any of the checks above, + # throw an error with the class and method name. method_text += textwrap.indent(textwrap.dedent("""\ error('Arguments do not match any overload of function {class_name}.{method_name}'); """.format(class_name=class_name, @@ -1081,7 +1117,6 @@ class MatlabWrapper(CheckMixin, FormatMixin): obj_start = '' if isinstance(method, instantiator.InstantiatedMethod): - # method_name = method.original.name method_name = method.to_cpp() obj_start = 'obj->' @@ -1090,6 +1125,10 @@ class MatlabWrapper(CheckMixin, FormatMixin): # self._format_type_name(method.instantiations)) method = method.to_cpp() + elif isinstance(method, instantiator.InstantiatedStaticMethod): + method_name = self._format_static_method(method, '::') + method_name += method.original.name + elif isinstance(method, parser.GlobalFunction): method_name = self._format_global_function(method, '::') method_name += method.name @@ -1230,9 +1269,9 @@ class MatlabWrapper(CheckMixin, FormatMixin): Collector_{class_name}::iterator item; item = collector_{class_name}.find(self); if(item != collector_{class_name}.end()) {{ - delete self; collector_{class_name}.erase(item); }} + delete self; ''').format(class_name_sep=class_name_separated, class_name=class_name), prefix=' ') @@ -1250,7 +1289,7 @@ class MatlabWrapper(CheckMixin, FormatMixin): method_name = '' if is_static_method: - method_name = self._format_static_method(extra) + '.' + method_name = self._format_static_method(extra, '.') method_name += extra.name @@ -1567,23 +1606,23 @@ class MatlabWrapper(CheckMixin, FormatMixin): def wrap(self, files, path): """High level function to wrap the project.""" + content = "" modules = {} for file in files: with open(file, 'r') as f: - content = f.read() + content += f.read() - # Parse the contents of the interface file - parsed_result = parser.Module.parseString(content) - # print(parsed_result) + # Parse the contents of the interface file + parsed_result = parser.Module.parseString(content) - # Instantiate the module - module = instantiator.instantiate_namespace(parsed_result) + # Instantiate the module + module = instantiator.instantiate_namespace(parsed_result) - if module.name in modules: - modules[module. - name].content[0].content += module.content[0].content - else: - modules[module.name] = module + if module.name in modules: + modules[ + module.name].content[0].content += module.content[0].content + else: + modules[module.name] = module for module in modules.values(): # Wrap the full namespace diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index 809c69b56..1a3f10bf5 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -14,6 +14,7 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellae import re from pathlib import Path +from typing import List import gtwrap.interface_parser as parser import gtwrap.template_instantiator as instantiator @@ -46,6 +47,11 @@ class PybindWrapper: # amount of indentation to add before each function/method declaration. self.method_indent = '\n' + (' ' * 8) + # Special methods which are leveraged by ipython/jupyter notebooks + self._ipython_special_methods = [ + "svg", "png", "jpeg", "html", "javascript", "markdown", "latex" + ] + def _py_args_names(self, args): """Set the argument names in Pybind11 format.""" names = args.names() @@ -86,34 +92,99 @@ class PybindWrapper: )) return res + def _wrap_serialization(self, cpp_class): + """Helper method to add serialize, deserialize and pickle methods to the wrapped class.""" + if not cpp_class in self._serializing_classes: + self._serializing_classes.append(cpp_class) + + serialize_method = self.method_indent + \ + ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') + + deserialize_method = self.method_indent + \ + '.def("deserialize", []({class_inst} self, string serialized)' \ + '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ + .format(class_inst=cpp_class + '*') + + # Since this class supports serialization, we also add the pickle method. + pickle_method = self.method_indent + \ + ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" + + return serialize_method + deserialize_method + \ + pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) + + def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str, + args_names: List[str], args_signature_with_names: str, + py_args_names: str, prefix: str, suffix: str): + """ + Update the print method to print to the output stream and append a __repr__ method. + + Args: + ret (str): The result of the parser. + method (parser.Method): The method to be wrapped. + cpp_class (str): The C++ name of the class to which the method belongs. + args_names (List[str]): List of argument variable names passed to the method. + args_signature_with_names (str): C++ arguments containing their names and type signatures. + py_args_names (str): The pybind11 formatted version of the argument list. + prefix (str): Prefix to add to the wrapped method when writing to the cpp file. + suffix (str): Suffix to add to the wrapped method when writing to the cpp file. + + Returns: + str: The wrapped print method. + """ + # Redirect stdout - see pybind docs for why this is a good idea: + # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream + ret = ret.replace('self->print', + 'py::scoped_ostream_redirect output; self->print') + + # Make __repr__() call .print() internally + ret += '''{prefix}.def("__repr__", + [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ + gtsam::RedirectCout redirect; + self.{method_name}({method_args}); + return redirect.str(); + }}{py_args_names}){suffix}'''.format( + prefix=prefix, + cpp_class=cpp_class, + opt_comma=', ' if args_names else '', + args_signature_with_names=args_signature_with_names, + method_name=method.name, + method_args=", ".join(args_names) if args_names else '', + py_args_names=py_args_names, + suffix=suffix) + return ret + def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): + """ + Wrap the `method` for the class specified by `cpp_class`. + + Args: + method: The method to wrap. + cpp_class: The C++ name of the class to which the method belongs. + prefix: Prefix to add to the wrapped method when writing to the cpp file. + suffix: Suffix to add to the wrapped method when writing to the cpp file. + method_suffix: A string to append to the wrapped method name. + """ py_method = method.name + method_suffix cpp_method = method.to_cpp() - if cpp_method in ["serialize", "serializable"]: - if not cpp_class in self._serializing_classes: - self._serializing_classes.append(cpp_class) - serialize_method = self.method_indent + \ - ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') - deserialize_method = self.method_indent + \ - '.def("deserialize", []({class_inst} self, string serialized)' \ - '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ - .format(class_inst=cpp_class + '*') - return serialize_method + deserialize_method + args_names = method.args.names() + py_args_names = self._py_args_names(method.args) + args_signature_with_names = self._method_args_signature(method.args) - if cpp_method == "pickle": - if not cpp_class in self._serializing_classes: - raise ValueError( - "Cannot pickle a class which is not serializable") - pickle_method = self.method_indent + \ - ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast(), obj); return obj; }}))" - return pickle_method.format(cpp_class=cpp_class, - indent=self.method_indent) + # Special handling for the serialize/serializable method + if cpp_method in ["serialize", "serializable"]: + return self._wrap_serialization(cpp_class) + + # Special handling of ipython specific methods + # https://ipython.readthedocs.io/en/stable/config/integrating.html + if cpp_method in self._ipython_special_methods: + idx = self._ipython_special_methods.index(cpp_method) + py_method = f"_repr_{self._ipython_special_methods[idx]}_" # Add underscore to disambiguate if the method name matches a python keyword if py_method in self.python_keywords: @@ -125,9 +196,6 @@ class PybindWrapper: method, (parser.StaticMethod, instantiator.InstantiatedStaticMethod)) return_void = method.return_type.is_void() - args_names = method.args.names() - py_args_names = self._py_args_names(method.args) - args_signature_with_names = self._method_args_signature(method.args) caller = cpp_class + "::" if not is_method else "self->" function_call = ('{opt_return} {caller}{method_name}' @@ -158,27 +226,9 @@ class PybindWrapper: # Create __repr__ override # We allow all arguments to .print() and let the compiler handle type mismatches. if method.name == 'print': - # Redirect stdout - see pybind docs for why this is a good idea: - # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace( - 'self->print', - 'py::scoped_ostream_redirect output; self->print') - - # Make __repr__() call .print() internally - ret += '''{prefix}.def("__repr__", - [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ - gtsam::RedirectCout redirect; - self.{method_name}({method_args}); - return redirect.str(); - }}{py_args_names}){suffix}'''.format( - prefix=prefix, - cpp_class=cpp_class, - opt_comma=', ' if args_names else '', - args_signature_with_names=args_signature_with_names, - method_name=method.name, - method_args=", ".join(args_names) if args_names else '', - py_args_names=py_args_names, - suffix=suffix) + ret = self._wrap_print(ret, method, cpp_class, args_names, + args_signature_with_names, py_args_names, + prefix, suffix) return ret @@ -624,28 +674,47 @@ class PybindWrapper: submodules_init="\n".join(submodules_init), ) - def wrap(self, sources, main_output): + def wrap_submodule(self, source): """ - Wrap all the source interface files. + Wrap a list of submodule files, i.e. a set of interface files which are + in support of a larger wrapping project. + + E.g. This is used in GTSAM where we have a main gtsam.i, but various smaller .i files + which are the submodules. + The benefit of this scheme is that it reduces compute and memory usage during compilation. + + Args: + source: Interface file which forms the submodule. + """ + filename = Path(source).name + module_name = Path(source).stem + + # Read in the complete interface (.i) file + with open(source, "r") as f: + content = f.read() + # Wrap the read-in content + cc_content = self.wrap_file(content, module_name=module_name) + + # Generate the C++ code which Pybind11 will use. + with open(filename.replace(".i", ".cpp"), "w") as f: + f.write(cc_content) + + def wrap(self, sources, main_module_name): + """ + Wrap all the main interface file. Args: sources: List of all interface files. - main_output: The name for the main module. + The first file should be the main module. + main_module_name: The name for the main module. """ main_module = sources[0] + + # Get all the submodule names. submodules = [] for source in sources[1:]: - filename = Path(source).name module_name = Path(source).stem - # Read in the complete interface (.i) file - with open(source, "r") as f: - content = f.read() submodules.append(module_name) - cc_content = self.wrap_file(content, module_name=module_name) - - # Generate the C++ code which Pybind11 will use. - with open(filename.replace(".i", ".cpp"), "w") as f: - f.write(cc_content) with open(main_module, "r") as f: content = f.read() @@ -654,5 +723,5 @@ class PybindWrapper: submodules=submodules) # Generate the C++ code which Pybind11 will use. - with open(main_output, "w") as f: + with open(main_module_name, "w") as f: f.write(cc_content) diff --git a/wrap/gtwrap/template_instantiator/helpers.py b/wrap/gtwrap/template_instantiator/helpers.py index b55515dba..194c6f686 100644 --- a/wrap/gtwrap/template_instantiator/helpers.py +++ b/wrap/gtwrap/template_instantiator/helpers.py @@ -55,16 +55,14 @@ def instantiate_type( # make a deep copy so that there is no overwriting of original template params ctype = deepcopy(ctype) - # Check if the return type has template parameters + # Check if the return type has template parameters as the typename's name if ctype.typename.instantiations: for idx, instantiation in enumerate(ctype.typename.instantiations): if instantiation.name in template_typenames: template_idx = template_typenames.index(instantiation.name) - ctype.typename.instantiations[ - idx] = instantiations[ # type: ignore - template_idx] + ctype.typename.instantiations[idx].name =\ + instantiations[template_idx] - return ctype str_arg_typename = str(ctype.typename) @@ -125,9 +123,18 @@ def instantiate_type( # Case when 'This' is present in the type namespace, e.g `This::Subclass`. elif 'This' in str_arg_typename: - # Simply get the index of `This` in the namespace and replace it with the instantiated name. - namespace_idx = ctype.typename.namespaces.index('This') - ctype.typename.namespaces[namespace_idx] = cpp_typename.name + # Check if `This` is in the namespaces + if 'This' in ctype.typename.namespaces: + # Simply get the index of `This` in the namespace and + # replace it with the instantiated name. + namespace_idx = ctype.typename.namespaces.index('This') + ctype.typename.namespaces[namespace_idx] = cpp_typename.name + # Else check if it is in the template namespace, e.g vector + else: + for idx, instantiation in enumerate(ctype.typename.instantiations): + if 'This' in instantiation.namespaces: + ctype.typename.instantiations[idx].namespaces = \ + cpp_typename.namespaces + [cpp_typename.name] return ctype else: diff --git a/wrap/scripts/pybind_wrap.py b/wrap/scripts/pybind_wrap.py index c82a1d24c..577060243 100644 --- a/wrap/scripts/pybind_wrap.py +++ b/wrap/scripts/pybind_wrap.py @@ -19,7 +19,7 @@ def main(): arg_parser.add_argument("--src", type=str, required=True, - help="Input interface .i/.h file") + help="Input interface .i/.h file(s)") arg_parser.add_argument( "--module_name", type=str, @@ -31,7 +31,7 @@ def main(): "--out", type=str, required=True, - help="Name of the output pybind .cc file", + help="Name of the output pybind .cc file(s)", ) arg_parser.add_argument( "--use-boost", @@ -60,7 +60,10 @@ def main(): ) arg_parser.add_argument("--template", type=str, - help="The module template file") + help="The module template file (e.g. module.tpl).") + arg_parser.add_argument("--is_submodule", + default=False, + action="store_true") args = arg_parser.parse_args() top_module_namespaces = args.top_module_namespaces.split("::") @@ -78,9 +81,13 @@ def main(): module_template=template_content, ) - # Wrap the code and get back the cpp/cc code. - sources = args.src.split(';') - wrapper.wrap(sources, args.out) + if args.is_submodule: + wrapper.wrap_submodule(args.src) + + else: + # Wrap the code and get back the cpp/cc code. + sources = args.src.split(';') + wrapper.wrap(sources, args.out) if __name__ == "__main__": diff --git a/wrap/tests/actual/.gitignore b/wrap/tests/actual/.gitignore new file mode 100644 index 000000000..7e0359a99 --- /dev/null +++ b/wrap/tests/actual/.gitignore @@ -0,0 +1,2 @@ +./* +!.gitignore diff --git a/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m b/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m new file mode 100644 index 000000000..0ce4051af --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/GeneralSFMFactorCal3Bundler.m @@ -0,0 +1,31 @@ +%class GeneralSFMFactorCal3Bundler, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef GeneralSFMFactorCal3Bundler < handle + properties + ptr_gtsamGeneralSFMFactorCal3Bundler = 0 + end + methods + function obj = GeneralSFMFactorCal3Bundler(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + special_cases_wrapper(7, my_ptr); + else + error('Arguments do not match any overload of gtsam.GeneralSFMFactorCal3Bundler constructor'); + end + obj.ptr_gtsamGeneralSFMFactorCal3Bundler = my_ptr; + end + + function delete(obj) + special_cases_wrapper(8, obj.ptr_gtsamGeneralSFMFactorCal3Bundler); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/Point3.m b/wrap/tests/expected/matlab/+gtsam/Point3.m index 06d378ac2..b3290faf2 100644 --- a/wrap/tests/expected/matlab/+gtsam/Point3.m +++ b/wrap/tests/expected/matlab/+gtsam/Point3.m @@ -78,7 +78,7 @@ classdef Point3 < handle error('Arguments do not match any overload of function Point3.StaticFunctionRet'); end - function varargout = StaticFunction(varargin) + function varargout = staticFunction(varargin) % STATICFUNCTION usage: staticFunction() : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/+gtsam/SfmTrack.m b/wrap/tests/expected/matlab/+gtsam/SfmTrack.m new file mode 100644 index 000000000..428da2706 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/SfmTrack.m @@ -0,0 +1,31 @@ +%class SfmTrack, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef SfmTrack < handle + properties + ptr_gtsamSfmTrack = 0 + end + methods + function obj = SfmTrack(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + special_cases_wrapper(3, my_ptr); + else + error('Arguments do not match any overload of gtsam.SfmTrack constructor'); + end + obj.ptr_gtsamSfmTrack = my_ptr; + end + + function delete(obj) + special_cases_wrapper(4, obj.ptr_gtsamSfmTrack); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+gtsam/Values.m b/wrap/tests/expected/matlab/+gtsam/Values.m new file mode 100644 index 000000000..d85b24b91 --- /dev/null +++ b/wrap/tests/expected/matlab/+gtsam/Values.m @@ -0,0 +1,59 @@ +%class Values, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%Values() +%Values(Values other) +% +%-------Methods------- +%insert(size_t j, Vector vector) : returns void +%insert(size_t j, Matrix matrix) : returns void +% +classdef Values < handle + properties + ptr_gtsamValues = 0 + end + methods + function obj = Values(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + namespaces_wrapper(26, my_ptr); + elseif nargin == 0 + my_ptr = namespaces_wrapper(27); + elseif nargin == 1 && isa(varargin{1},'gtsam.Values') + my_ptr = namespaces_wrapper(28, varargin{1}); + else + error('Arguments do not match any overload of gtsam.Values constructor'); + end + obj.ptr_gtsamValues = my_ptr; + end + + function delete(obj) + namespaces_wrapper(29, obj.ptr_gtsamValues); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + function varargout = insert(this, varargin) + % INSERT usage: insert(size_t j, Vector vector) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'double') && size(varargin{2},2)==1 + namespaces_wrapper(30, this, varargin{:}); + return + end + % INSERT usage: insert(size_t j, Matrix matrix) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'double') + namespaces_wrapper(31, this, varargin{:}); + return + end + error('Arguments do not match any overload of function gtsam.Values.insert'); + end + + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/+ns2/ClassA.m b/wrap/tests/expected/matlab/+ns2/ClassA.m index 4640e7cca..71718ccba 100644 --- a/wrap/tests/expected/matlab/+ns2/ClassA.m +++ b/wrap/tests/expected/matlab/+ns2/ClassA.m @@ -74,7 +74,7 @@ classdef ClassA < handle end methods(Static = true) - function varargout = Afunction(varargin) + function varargout = afunction(varargin) % AFUNCTION usage: afunction() : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/DefaultFuncInt.m b/wrap/tests/expected/matlab/DefaultFuncInt.m new file mode 100644 index 000000000..6c9c4116b --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncInt.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncInt(varargin) + if length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') + functions_wrapper(8, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'numeric') + functions_wrapper(9, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(10, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncInt'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncObj.m b/wrap/tests/expected/matlab/DefaultFuncObj.m new file mode 100644 index 000000000..15d9ba0fa --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncObj.m @@ -0,0 +1,8 @@ +function varargout = DefaultFuncObj(varargin) + if length(varargin) == 1 && isa(varargin{1},'gtsam.KeyFormatter') + functions_wrapper(14, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(15, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncObj'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncString.m b/wrap/tests/expected/matlab/DefaultFuncString.m new file mode 100644 index 000000000..d26201c15 --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncString.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncString(varargin) + if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'char') + functions_wrapper(11, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'char') + functions_wrapper(12, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(13, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncString'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncVector.m b/wrap/tests/expected/matlab/DefaultFuncVector.m new file mode 100644 index 000000000..344533fad --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncVector.m @@ -0,0 +1,10 @@ +function varargout = DefaultFuncVector(varargin) + if length(varargin) == 2 && isa(varargin{1},'std.vectornumeric') && isa(varargin{2},'std.vectorchar') + functions_wrapper(20, varargin{:}); + elseif length(varargin) == 1 && isa(varargin{1},'std.vectornumeric') + functions_wrapper(21, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(22, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncVector'); + end diff --git a/wrap/tests/expected/matlab/DefaultFuncZero.m b/wrap/tests/expected/matlab/DefaultFuncZero.m new file mode 100644 index 000000000..0ebba2e5c --- /dev/null +++ b/wrap/tests/expected/matlab/DefaultFuncZero.m @@ -0,0 +1,12 @@ +function varargout = DefaultFuncZero(varargin) + if length(varargin) == 5 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric') && isa(varargin{5},'logical') + functions_wrapper(16, varargin{:}); + elseif length(varargin) == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'numeric') + functions_wrapper(17, varargin{:}); + elseif length(varargin) == 3 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') + functions_wrapper(18, varargin{:}); + elseif length(varargin) == 2 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') + functions_wrapper(19, varargin{:}); + else + error('Arguments do not match any overload of function DefaultFuncZero'); + end diff --git a/wrap/tests/expected/matlab/ForwardKinematics.m b/wrap/tests/expected/matlab/ForwardKinematics.m new file mode 100644 index 000000000..c2ff701c7 --- /dev/null +++ b/wrap/tests/expected/matlab/ForwardKinematics.m @@ -0,0 +1,38 @@ +%class ForwardKinematics, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%ForwardKinematics(Robot robot, string start_link_name, string end_link_name, Values joint_angles, Pose3 l2Tp) +% +classdef ForwardKinematics < handle + properties + ptr_ForwardKinematics = 0 + end + methods + function obj = ForwardKinematics(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + class_wrapper(57, my_ptr); + elseif nargin == 5 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values') && isa(varargin{5},'gtsam.Pose3') + my_ptr = class_wrapper(58, varargin{1}, varargin{2}, varargin{3}, varargin{4}, varargin{5}); + elseif nargin == 4 && isa(varargin{1},'gtdynamics.Robot') && isa(varargin{2},'char') && isa(varargin{3},'char') && isa(varargin{4},'gtsam.Values') + my_ptr = class_wrapper(59, varargin{1}, varargin{2}, varargin{3}, varargin{4}); + else + error('Arguments do not match any overload of ForwardKinematics constructor'); + end + obj.ptr_ForwardKinematics = my_ptr; + end + + function delete(obj) + class_wrapper(60, obj.ptr_ForwardKinematics); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/ForwardKinematicsFactor.m b/wrap/tests/expected/matlab/ForwardKinematicsFactor.m new file mode 100644 index 000000000..46aa41392 --- /dev/null +++ b/wrap/tests/expected/matlab/ForwardKinematicsFactor.m @@ -0,0 +1,36 @@ +%class ForwardKinematicsFactor, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +classdef ForwardKinematicsFactor < gtsam.BetweenFactor + properties + ptr_ForwardKinematicsFactor = 0 + end + methods + function obj = ForwardKinematicsFactor(varargin) + if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void'))) && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + if nargin == 2 + my_ptr = varargin{2}; + else + my_ptr = inheritance_wrapper(36, varargin{2}); + end + base_ptr = inheritance_wrapper(35, my_ptr); + else + error('Arguments do not match any overload of ForwardKinematicsFactor constructor'); + end + obj = obj@gtsam.BetweenFactorPose3(uint64(5139824614673773682), base_ptr); + obj.ptr_ForwardKinematicsFactor = my_ptr; + end + + function delete(obj) + inheritance_wrapper(37, obj.ptr_ForwardKinematicsFactor); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/FunDouble.m b/wrap/tests/expected/matlab/FunDouble.m index 78609c7f6..5f432341b 100644 --- a/wrap/tests/expected/matlab/FunDouble.m +++ b/wrap/tests/expected/matlab/FunDouble.m @@ -3,6 +3,7 @@ % %-------Methods------- %multiTemplatedMethodStringSize_t(double d, string t, size_t u) : returns Fun +%sets() : returns std::map::double> %templatedMethodString(double d, string t) : returns Fun % %-------Static Methods------- @@ -46,11 +47,21 @@ classdef FunDouble < handle error('Arguments do not match any overload of function FunDouble.multiTemplatedMethodStringSize_t'); end + function varargout = sets(this, varargin) + % SETS usage: sets() : returns std.mapdoubledouble + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + varargout{1} = class_wrapper(8, this, varargin{:}); + return + end + error('Arguments do not match any overload of function FunDouble.sets'); + end + function varargout = templatedMethodString(this, varargin) % TEMPLATEDMETHODSTRING usage: templatedMethodString(double d, string t) : returns Fun % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'double') && isa(varargin{2},'char') - varargout{1} = class_wrapper(8, this, varargin{:}); + varargout{1} = class_wrapper(9, this, varargin{:}); return end error('Arguments do not match any overload of function FunDouble.templatedMethodString'); @@ -59,22 +70,22 @@ classdef FunDouble < handle end methods(Static = true) - function varargout = StaticMethodWithThis(varargin) + function varargout = staticMethodWithThis(varargin) % STATICMETHODWITHTHIS usage: staticMethodWithThis() : returns Fundouble % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - varargout{1} = class_wrapper(9, varargin{:}); + varargout{1} = class_wrapper(10, varargin{:}); return end error('Arguments do not match any overload of function FunDouble.staticMethodWithThis'); end - function varargout = TemplatedStaticMethodInt(varargin) + function varargout = templatedStaticMethodInt(varargin) % TEMPLATEDSTATICMETHODINT usage: templatedStaticMethodInt(int m) : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(10, varargin{:}); + varargout{1} = class_wrapper(11, varargin{:}); return end diff --git a/wrap/tests/expected/matlab/FunRange.m b/wrap/tests/expected/matlab/FunRange.m index 1d1a6f7b8..52ee78aa2 100644 --- a/wrap/tests/expected/matlab/FunRange.m +++ b/wrap/tests/expected/matlab/FunRange.m @@ -52,7 +52,7 @@ classdef FunRange < handle end methods(Static = true) - function varargout = Create(varargin) + function varargout = create(varargin) % CREATE usage: create() : returns FunRange % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 diff --git a/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m b/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m index 863d30ee8..ebf263bcb 100644 --- a/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m +++ b/wrap/tests/expected/matlab/MultipleTemplatesIntDouble.m @@ -9,7 +9,7 @@ classdef MultipleTemplatesIntDouble < handle function obj = MultipleTemplatesIntDouble(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(50, my_ptr); + class_wrapper(53, my_ptr); else error('Arguments do not match any overload of MultipleTemplatesIntDouble constructor'); end @@ -17,7 +17,7 @@ classdef MultipleTemplatesIntDouble < handle end function delete(obj) - class_wrapper(51, obj.ptr_MultipleTemplatesIntDouble); + class_wrapper(54, obj.ptr_MultipleTemplatesIntDouble); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m b/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m index b7f1fdac5..02290f032 100644 --- a/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m +++ b/wrap/tests/expected/matlab/MultipleTemplatesIntFloat.m @@ -9,7 +9,7 @@ classdef MultipleTemplatesIntFloat < handle function obj = MultipleTemplatesIntFloat(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(52, my_ptr); + class_wrapper(55, my_ptr); else error('Arguments do not match any overload of MultipleTemplatesIntFloat constructor'); end @@ -17,7 +17,7 @@ classdef MultipleTemplatesIntFloat < handle end function delete(obj) - class_wrapper(53, obj.ptr_MultipleTemplatesIntFloat); + class_wrapper(56, obj.ptr_MultipleTemplatesIntFloat); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MyFactorPosePoint2.m b/wrap/tests/expected/matlab/MyFactorPosePoint2.m index 7634ae2cb..7457fe749 100644 --- a/wrap/tests/expected/matlab/MyFactorPosePoint2.m +++ b/wrap/tests/expected/matlab/MyFactorPosePoint2.m @@ -15,9 +15,9 @@ classdef MyFactorPosePoint2 < handle function obj = MyFactorPosePoint2(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(63, my_ptr); + class_wrapper(67, my_ptr); elseif nargin == 4 && isa(varargin{1},'numeric') && isa(varargin{2},'numeric') && isa(varargin{3},'double') && isa(varargin{4},'gtsam.noiseModel.Base') - my_ptr = class_wrapper(64, varargin{1}, varargin{2}, varargin{3}, varargin{4}); + my_ptr = class_wrapper(68, varargin{1}, varargin{2}, varargin{3}, varargin{4}); else error('Arguments do not match any overload of MyFactorPosePoint2 constructor'); end @@ -25,7 +25,7 @@ classdef MyFactorPosePoint2 < handle end function delete(obj) - class_wrapper(65, obj.ptr_MyFactorPosePoint2); + class_wrapper(69, obj.ptr_MyFactorPosePoint2); end function display(obj), obj.print(''); end @@ -36,7 +36,19 @@ classdef MyFactorPosePoint2 < handle % PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter') - class_wrapper(66, this, varargin{:}); + class_wrapper(70, this, varargin{:}); + return + end + % PRINT usage: print(string s) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 && isa(varargin{1},'char') + class_wrapper(71, this, varargin{:}); + return + end + % PRINT usage: print() : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + class_wrapper(72, this, varargin{:}); return end error('Arguments do not match any overload of function MyFactorPosePoint2.print'); diff --git a/wrap/tests/expected/matlab/MyVector12.m b/wrap/tests/expected/matlab/MyVector12.m index 291d0d71b..53e554a10 100644 --- a/wrap/tests/expected/matlab/MyVector12.m +++ b/wrap/tests/expected/matlab/MyVector12.m @@ -12,9 +12,9 @@ classdef MyVector12 < handle function obj = MyVector12(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(47, my_ptr); + class_wrapper(50, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(48); + my_ptr = class_wrapper(51); else error('Arguments do not match any overload of MyVector12 constructor'); end @@ -22,7 +22,7 @@ classdef MyVector12 < handle end function delete(obj) - class_wrapper(49, obj.ptr_MyVector12); + class_wrapper(52, obj.ptr_MyVector12); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/MyVector3.m b/wrap/tests/expected/matlab/MyVector3.m index 3051dc2e2..0f6ea84ab 100644 --- a/wrap/tests/expected/matlab/MyVector3.m +++ b/wrap/tests/expected/matlab/MyVector3.m @@ -12,9 +12,9 @@ classdef MyVector3 < handle function obj = MyVector3(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(44, my_ptr); + class_wrapper(47, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(45); + my_ptr = class_wrapper(48); else error('Arguments do not match any overload of MyVector3 constructor'); end @@ -22,7 +22,7 @@ classdef MyVector3 < handle end function delete(obj) - class_wrapper(46, obj.ptr_MyVector3); + class_wrapper(49, obj.ptr_MyVector3); end function display(obj), obj.print(''); end diff --git a/wrap/tests/expected/matlab/PrimitiveRefDouble.m b/wrap/tests/expected/matlab/PrimitiveRefDouble.m index dd0a6d2da..e1039e567 100644 --- a/wrap/tests/expected/matlab/PrimitiveRefDouble.m +++ b/wrap/tests/expected/matlab/PrimitiveRefDouble.m @@ -19,9 +19,9 @@ classdef PrimitiveRefDouble < handle function obj = PrimitiveRefDouble(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(40, my_ptr); + class_wrapper(43, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(41); + my_ptr = class_wrapper(44); else error('Arguments do not match any overload of PrimitiveRefDouble constructor'); end @@ -29,7 +29,7 @@ classdef PrimitiveRefDouble < handle end function delete(obj) - class_wrapper(42, obj.ptr_PrimitiveRefDouble); + class_wrapper(45, obj.ptr_PrimitiveRefDouble); end function display(obj), obj.print(''); end @@ -43,7 +43,7 @@ classdef PrimitiveRefDouble < handle % BRUTAL usage: Brutal(double t) : returns PrimitiveRefdouble % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(43, varargin{:}); + varargout{1} = class_wrapper(46, varargin{:}); return end diff --git a/wrap/tests/expected/matlab/ScopedTemplateResult.m b/wrap/tests/expected/matlab/ScopedTemplateResult.m new file mode 100644 index 000000000..8cb8ed7d0 --- /dev/null +++ b/wrap/tests/expected/matlab/ScopedTemplateResult.m @@ -0,0 +1,36 @@ +%class ScopedTemplateResult, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%ScopedTemplateResult(Result::Value arg) +% +classdef ScopedTemplateResult < handle + properties + ptr_ScopedTemplateResult = 0 + end + methods + function obj = ScopedTemplateResult(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + template_wrapper(6, my_ptr); + elseif nargin == 1 && isa(varargin{1},'Result::Value') + my_ptr = template_wrapper(7, varargin{1}); + else + error('Arguments do not match any overload of ScopedTemplateResult constructor'); + end + obj.ptr_ScopedTemplateResult = my_ptr; + end + + function delete(obj) + template_wrapper(8, obj.ptr_ScopedTemplateResult); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/TemplatedConstructor.m b/wrap/tests/expected/matlab/TemplatedConstructor.m new file mode 100644 index 000000000..70beb26ce --- /dev/null +++ b/wrap/tests/expected/matlab/TemplatedConstructor.m @@ -0,0 +1,45 @@ +%class TemplatedConstructor, see Doxygen page for details +%at https://gtsam.org/doxygen/ +% +%-------Constructors------- +%TemplatedConstructor() +%TemplatedConstructor(string arg) +%TemplatedConstructor(int arg) +%TemplatedConstructor(double arg) +% +classdef TemplatedConstructor < handle + properties + ptr_TemplatedConstructor = 0 + end + methods + function obj = TemplatedConstructor(varargin) + if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) + my_ptr = varargin{2}; + template_wrapper(0, my_ptr); + elseif nargin == 0 + my_ptr = template_wrapper(1); + elseif nargin == 1 && isa(varargin{1},'char') + my_ptr = template_wrapper(2, varargin{1}); + elseif nargin == 1 && isa(varargin{1},'numeric') + my_ptr = template_wrapper(3, varargin{1}); + elseif nargin == 1 && isa(varargin{1},'double') + my_ptr = template_wrapper(4, varargin{1}); + else + error('Arguments do not match any overload of TemplatedConstructor constructor'); + end + obj.ptr_TemplatedConstructor = my_ptr; + end + + function delete(obj) + template_wrapper(5, obj.ptr_TemplatedConstructor); + end + + function display(obj), obj.print(''); end + %DISPLAY Calls print on the object + function disp(obj), obj.display; end + %DISP Calls print on the object + end + + methods(Static = true) + end +end diff --git a/wrap/tests/expected/matlab/TemplatedFunctionRot3.m b/wrap/tests/expected/matlab/TemplatedFunctionRot3.m index 4216201b4..eb5cb4abe 100644 --- a/wrap/tests/expected/matlab/TemplatedFunctionRot3.m +++ b/wrap/tests/expected/matlab/TemplatedFunctionRot3.m @@ -1,6 +1,6 @@ function varargout = TemplatedFunctionRot3(varargin) if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3') - functions_wrapper(14, varargin{:}); + functions_wrapper(25, varargin{:}); else error('Arguments do not match any overload of function TemplatedFunctionRot3'); end diff --git a/wrap/tests/expected/matlab/Test.m b/wrap/tests/expected/matlab/Test.m index 8569120c5..66ba4721c 100644 --- a/wrap/tests/expected/matlab/Test.m +++ b/wrap/tests/expected/matlab/Test.m @@ -11,6 +11,7 @@ %create_ptrs() : returns pair< Test, Test > %get_container() : returns std::vector %lambda() : returns void +%markdown(KeyFormatter keyFormatter) : returns string %print() : returns void %return_Point2Ptr(bool value) : returns Point2 %return_Test(Test value) : returns Test @@ -40,11 +41,11 @@ classdef Test < handle function obj = Test(varargin) if nargin == 2 && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682) my_ptr = varargin{2}; - class_wrapper(11, my_ptr); + class_wrapper(12, my_ptr); elseif nargin == 0 - my_ptr = class_wrapper(12); + my_ptr = class_wrapper(13); elseif nargin == 2 && isa(varargin{1},'double') && isa(varargin{2},'double') - my_ptr = class_wrapper(13, varargin{1}, varargin{2}); + my_ptr = class_wrapper(14, varargin{1}, varargin{2}); else error('Arguments do not match any overload of Test constructor'); end @@ -52,7 +53,7 @@ classdef Test < handle end function delete(obj) - class_wrapper(14, obj.ptr_Test); + class_wrapper(15, obj.ptr_Test); end function display(obj), obj.print(''); end @@ -63,7 +64,7 @@ classdef Test < handle % ARG_EIGENCONSTREF usage: arg_EigenConstRef(Matrix value) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - class_wrapper(15, this, varargin{:}); + class_wrapper(16, this, varargin{:}); return end error('Arguments do not match any overload of function Test.arg_EigenConstRef'); @@ -73,7 +74,7 @@ classdef Test < handle % CREATE_MIXEDPTRS usage: create_MixedPtrs() : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - [ varargout{1} varargout{2} ] = class_wrapper(16, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(17, this, varargin{:}); return end error('Arguments do not match any overload of function Test.create_MixedPtrs'); @@ -83,7 +84,7 @@ classdef Test < handle % CREATE_PTRS usage: create_ptrs() : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - [ varargout{1} varargout{2} ] = class_wrapper(17, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(18, this, varargin{:}); return end error('Arguments do not match any overload of function Test.create_ptrs'); @@ -93,7 +94,7 @@ classdef Test < handle % GET_CONTAINER usage: get_container() : returns std.vectorTest % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - varargout{1} = class_wrapper(18, this, varargin{:}); + varargout{1} = class_wrapper(19, this, varargin{:}); return end error('Arguments do not match any overload of function Test.get_container'); @@ -103,17 +104,33 @@ classdef Test < handle % LAMBDA usage: lambda() : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - class_wrapper(19, this, varargin{:}); + class_wrapper(20, this, varargin{:}); return end error('Arguments do not match any overload of function Test.lambda'); end + function varargout = markdown(this, varargin) + % MARKDOWN usage: markdown(KeyFormatter keyFormatter) : returns string + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 1 && isa(varargin{1},'gtsam.KeyFormatter') + varargout{1} = class_wrapper(21, this, varargin{:}); + return + end + % MARKDOWN usage: markdown() : returns string + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 0 + varargout{1} = class_wrapper(22, this, varargin{:}); + return + end + error('Arguments do not match any overload of function Test.markdown'); + end + function varargout = print(this, varargin) % PRINT usage: print() : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 0 - class_wrapper(20, this, varargin{:}); + class_wrapper(23, this, varargin{:}); return end error('Arguments do not match any overload of function Test.print'); @@ -123,7 +140,7 @@ classdef Test < handle % RETURN_POINT2PTR usage: return_Point2Ptr(bool value) : returns Point2 % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'logical') - varargout{1} = class_wrapper(21, this, varargin{:}); + varargout{1} = class_wrapper(24, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_Point2Ptr'); @@ -133,7 +150,7 @@ classdef Test < handle % RETURN_TEST usage: return_Test(Test value) : returns Test % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(22, this, varargin{:}); + varargout{1} = class_wrapper(25, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_Test'); @@ -143,7 +160,7 @@ classdef Test < handle % RETURN_TESTPTR usage: return_TestPtr(Test value) : returns Test % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(23, this, varargin{:}); + varargout{1} = class_wrapper(26, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_TestPtr'); @@ -153,7 +170,7 @@ classdef Test < handle % RETURN_BOOL usage: return_bool(bool value) : returns bool % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'logical') - varargout{1} = class_wrapper(24, this, varargin{:}); + varargout{1} = class_wrapper(27, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_bool'); @@ -163,7 +180,7 @@ classdef Test < handle % RETURN_DOUBLE usage: return_double(double value) : returns double % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(25, this, varargin{:}); + varargout{1} = class_wrapper(28, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_double'); @@ -173,7 +190,7 @@ classdef Test < handle % RETURN_FIELD usage: return_field(Test t) : returns bool % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'Test') - varargout{1} = class_wrapper(26, this, varargin{:}); + varargout{1} = class_wrapper(29, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_field'); @@ -183,7 +200,7 @@ classdef Test < handle % RETURN_INT usage: return_int(int value) : returns int % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(27, this, varargin{:}); + varargout{1} = class_wrapper(30, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_int'); @@ -193,7 +210,7 @@ classdef Test < handle % RETURN_MATRIX1 usage: return_matrix1(Matrix value) : returns Matrix % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(28, this, varargin{:}); + varargout{1} = class_wrapper(31, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_matrix1'); @@ -203,7 +220,7 @@ classdef Test < handle % RETURN_MATRIX2 usage: return_matrix2(Matrix value) : returns Matrix % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') - varargout{1} = class_wrapper(29, this, varargin{:}); + varargout{1} = class_wrapper(32, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_matrix2'); @@ -213,13 +230,13 @@ classdef Test < handle % RETURN_PAIR usage: return_pair(Vector v, Matrix A) : returns pair< Vector, Matrix > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'double') && size(varargin{1},2)==1 && isa(varargin{2},'double') - [ varargout{1} varargout{2} ] = class_wrapper(30, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(33, this, varargin{:}); return end % RETURN_PAIR usage: return_pair(Vector v) : returns pair< Vector, Matrix > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - [ varargout{1} varargout{2} ] = class_wrapper(31, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(34, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_pair'); @@ -229,7 +246,7 @@ classdef Test < handle % RETURN_PTRS usage: return_ptrs(Test p1, Test p2) : returns pair< Test, Test > % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 2 && isa(varargin{1},'Test') && isa(varargin{2},'Test') - [ varargout{1} varargout{2} ] = class_wrapper(32, this, varargin{:}); + [ varargout{1} varargout{2} ] = class_wrapper(35, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_ptrs'); @@ -239,7 +256,7 @@ classdef Test < handle % RETURN_SIZE_T usage: return_size_t(size_t value) : returns size_t % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'numeric') - varargout{1} = class_wrapper(33, this, varargin{:}); + varargout{1} = class_wrapper(36, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_size_t'); @@ -249,7 +266,7 @@ classdef Test < handle % RETURN_STRING usage: return_string(string value) : returns string % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'char') - varargout{1} = class_wrapper(34, this, varargin{:}); + varargout{1} = class_wrapper(37, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_string'); @@ -259,7 +276,7 @@ classdef Test < handle % RETURN_VECTOR1 usage: return_vector1(Vector value) : returns Vector % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - varargout{1} = class_wrapper(35, this, varargin{:}); + varargout{1} = class_wrapper(38, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_vector1'); @@ -269,7 +286,7 @@ classdef Test < handle % RETURN_VECTOR2 usage: return_vector2(Vector value) : returns Vector % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'double') && size(varargin{1},2)==1 - varargout{1} = class_wrapper(36, this, varargin{:}); + varargout{1} = class_wrapper(39, this, varargin{:}); return end error('Arguments do not match any overload of function Test.return_vector2'); @@ -279,19 +296,19 @@ classdef Test < handle % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(37, this, varargin{:}); + class_wrapper(40, this, varargin{:}); return end % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(38, this, varargin{:}); + class_wrapper(41, this, varargin{:}); return end % SET_CONTAINER usage: set_container(vector container) : returns void % Doxygen can be found at https://gtsam.org/doxygen/ if length(varargin) == 1 && isa(varargin{1},'std.vectorTest') - class_wrapper(39, this, varargin{:}); + class_wrapper(42, this, varargin{:}); return end error('Arguments do not match any overload of function Test.set_container'); diff --git a/wrap/tests/expected/matlab/class_wrapper.cpp b/wrap/tests/expected/matlab/class_wrapper.cpp index df6cb3307..03a25c358 100644 --- a/wrap/tests/expected/matlab/class_wrapper.cpp +++ b/wrap/tests/expected/matlab/class_wrapper.cpp @@ -145,7 +145,7 @@ void _class_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_class_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -180,9 +180,9 @@ void FunRange_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxA Collector_FunRange::iterator item; item = collector_FunRange.find(self); if(item != collector_FunRange.end()) { - delete self; collector_FunRange.erase(item); } + delete self; } void FunRange_range_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -216,9 +216,9 @@ void FunDouble_deconstructor_6(int nargout, mxArray *out[], int nargin, const mx Collector_FunDouble::iterator item; item = collector_FunDouble.find(self); if(item != collector_FunDouble.end()) { - delete self; collector_FunDouble.erase(item); } + delete self; } void FunDouble_multiTemplatedMethod_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -231,7 +231,14 @@ void FunDouble_multiTemplatedMethod_7(int nargout, mxArray *out[], int nargin, c out[0] = wrap_shared_ptr(boost::make_shared>(obj->multiTemplatedMethod(d,t,u)),"Fun", false); } -void FunDouble_templatedMethod_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_sets_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("sets",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_FunDouble"); + out[0] = wrap_shared_ptr(boost::make_shared::double>>(obj->sets()),"std.mapdoubledouble", false); +} + +void FunDouble_templatedMethod_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("templatedMethodString",nargout,nargin-1,2); auto obj = unwrap_shared_ptr>(in[0], "ptr_FunDouble"); @@ -240,20 +247,20 @@ void FunDouble_templatedMethod_8(int nargout, mxArray *out[], int nargin, const out[0] = wrap_shared_ptr(boost::make_shared>(obj->templatedMethod(d,t)),"Fun", false); } -void FunDouble_staticMethodWithThis_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_staticMethodWithThis_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("FunDouble.staticMethodWithThis",nargout,nargin,0); + checkArguments("Fun.staticMethodWithThis",nargout,nargin,0); out[0] = wrap_shared_ptr(boost::make_shared>(Fun::staticMethodWithThis()),"Fundouble", false); } -void FunDouble_templatedStaticMethodInt_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void FunDouble_templatedStaticMethodInt_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("FunDouble.templatedStaticMethodInt",nargout,nargin,1); + checkArguments("Fun.templatedStaticMethodInt",nargout,nargin,1); int m = unwrap< int >(in[0]); - out[0] = wrap< double >(Fun::templatedStaticMethodInt(m)); + out[0] = wrap< double >(Fun::templatedStaticMethod(m)); } -void Test_collectorInsertAndMakeBase_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_collectorInsertAndMakeBase_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -262,7 +269,7 @@ void Test_collectorInsertAndMakeBase_11(int nargout, mxArray *out[], int nargin, collector_Test.insert(self); } -void Test_constructor_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -273,7 +280,7 @@ void Test_constructor_12(int nargout, mxArray *out[], int nargin, const mxArray *reinterpret_cast (mxGetData(out[0])) = self; } -void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_constructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -286,7 +293,7 @@ void Test_constructor_13(int nargout, mxArray *out[], int nargin, const mxArray *reinterpret_cast (mxGetData(out[0])) = self; } -void Test_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_deconstructor_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_Test",nargout,nargin,1); @@ -294,12 +301,12 @@ void Test_deconstructor_14(int nargout, mxArray *out[], int nargin, const mxArra Collector_Test::iterator item; item = collector_Test.find(self); if(item != collector_Test.end()) { - delete self; collector_Test.erase(item); } + delete self; } -void Test_arg_EigenConstRef_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_arg_EigenConstRef_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("arg_EigenConstRef",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -307,7 +314,7 @@ void Test_arg_EigenConstRef_15(int nargout, mxArray *out[], int nargin, const mx obj->arg_EigenConstRef(value); } -void Test_create_MixedPtrs_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_create_MixedPtrs_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("create_MixedPtrs",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -316,7 +323,7 @@ void Test_create_MixedPtrs_16(int nargout, mxArray *out[], int nargin, const mxA out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_create_ptrs_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_create_ptrs_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("create_ptrs",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -325,28 +332,43 @@ void Test_create_ptrs_17(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_get_container_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_get_container_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("get_container",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); out[0] = wrap_shared_ptr(boost::make_shared>(obj->get_container()),"std.vectorTest", false); } -void Test_lambda_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_lambda_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("lambda",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); obj->lambda(); } -void Test_print_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_markdown_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("markdown",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); + gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[1], "ptr_gtsamKeyFormatter"); + out[0] = wrap< string >(obj->markdown(keyFormatter)); +} + +void Test_markdown_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("markdown",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); + out[0] = wrap< string >(obj->markdown(gtsam::DefaultKeyFormatter)); +} + +void Test_print_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,0); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); obj->print(); } -void Test_return_Point2Ptr_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_Point2Ptr_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_Point2Ptr",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -357,7 +379,7 @@ void Test_return_Point2Ptr_21(int nargout, mxArray *out[], int nargin, const mxA } } -void Test_return_Test_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_Test_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_Test",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -365,7 +387,7 @@ void Test_return_Test_22(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap_shared_ptr(boost::make_shared(obj->return_Test(value)),"Test", false); } -void Test_return_TestPtr_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_TestPtr_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_TestPtr",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -373,7 +395,7 @@ void Test_return_TestPtr_23(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap_shared_ptr(obj->return_TestPtr(value),"Test", false); } -void Test_return_bool_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_bool_27(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_bool",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -381,7 +403,7 @@ void Test_return_bool_24(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap< bool >(obj->return_bool(value)); } -void Test_return_double_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_double_28(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_double",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -389,7 +411,7 @@ void Test_return_double_25(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< double >(obj->return_double(value)); } -void Test_return_field_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_field_29(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_field",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -397,7 +419,7 @@ void Test_return_field_26(int nargout, mxArray *out[], int nargin, const mxArray out[0] = wrap< bool >(obj->return_field(t)); } -void Test_return_int_27(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_int_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_int",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -405,7 +427,7 @@ void Test_return_int_27(int nargout, mxArray *out[], int nargin, const mxArray * out[0] = wrap< int >(obj->return_int(value)); } -void Test_return_matrix1_28(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_matrix1_31(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_matrix1",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -413,7 +435,7 @@ void Test_return_matrix1_28(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Matrix >(obj->return_matrix1(value)); } -void Test_return_matrix2_29(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_matrix2_32(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_matrix2",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -421,7 +443,7 @@ void Test_return_matrix2_29(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Matrix >(obj->return_matrix2(value)); } -void Test_return_pair_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_pair_33(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_pair",nargout,nargin-1,2); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -432,7 +454,7 @@ void Test_return_pair_30(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap< Matrix >(pairResult.second); } -void Test_return_pair_31(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_pair_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_pair",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -442,7 +464,7 @@ void Test_return_pair_31(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap< Matrix >(pairResult.second); } -void Test_return_ptrs_32(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_ptrs_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_ptrs",nargout,nargin-1,2); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -453,7 +475,7 @@ void Test_return_ptrs_32(int nargout, mxArray *out[], int nargin, const mxArray out[1] = wrap_shared_ptr(pairResult.second,"Test", false); } -void Test_return_size_t_33(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_size_t_36(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_size_t",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -461,7 +483,7 @@ void Test_return_size_t_33(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< size_t >(obj->return_size_t(value)); } -void Test_return_string_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_string_37(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_string",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -469,7 +491,7 @@ void Test_return_string_34(int nargout, mxArray *out[], int nargin, const mxArra out[0] = wrap< string >(obj->return_string(value)); } -void Test_return_vector1_35(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_vector1_38(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_vector1",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -477,7 +499,7 @@ void Test_return_vector1_35(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Vector >(obj->return_vector1(value)); } -void Test_return_vector2_36(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_return_vector2_39(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("return_vector2",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -485,7 +507,7 @@ void Test_return_vector2_36(int nargout, mxArray *out[], int nargin, const mxArr out[0] = wrap< Vector >(obj->return_vector2(value)); } -void Test_set_container_37(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_40(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -493,7 +515,7 @@ void Test_set_container_37(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void Test_set_container_38(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_41(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -501,7 +523,7 @@ void Test_set_container_38(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void Test_set_container_39(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void Test_set_container_42(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("set_container",nargout,nargin-1,1); auto obj = unwrap_shared_ptr(in[0], "ptr_Test"); @@ -509,7 +531,7 @@ void Test_set_container_39(int nargout, mxArray *out[], int nargin, const mxArra obj->set_container(*container); } -void PrimitiveRefDouble_collectorInsertAndMakeBase_40(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_collectorInsertAndMakeBase_43(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -518,7 +540,7 @@ void PrimitiveRefDouble_collectorInsertAndMakeBase_40(int nargout, mxArray *out[ collector_PrimitiveRefDouble.insert(self); } -void PrimitiveRefDouble_constructor_41(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_constructor_44(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -529,7 +551,7 @@ void PrimitiveRefDouble_constructor_41(int nargout, mxArray *out[], int nargin, *reinterpret_cast (mxGetData(out[0])) = self; } -void PrimitiveRefDouble_deconstructor_42(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_deconstructor_45(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_PrimitiveRefDouble",nargout,nargin,1); @@ -537,19 +559,19 @@ void PrimitiveRefDouble_deconstructor_42(int nargout, mxArray *out[], int nargin Collector_PrimitiveRefDouble::iterator item; item = collector_PrimitiveRefDouble.find(self); if(item != collector_PrimitiveRefDouble.end()) { - delete self; collector_PrimitiveRefDouble.erase(item); } + delete self; } -void PrimitiveRefDouble_Brutal_43(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void PrimitiveRefDouble_Brutal_46(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("PrimitiveRefDouble.Brutal",nargout,nargin,1); + checkArguments("PrimitiveRef.Brutal",nargout,nargin,1); double t = unwrap< double >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(PrimitiveRef::Brutal(t)),"PrimitiveRefdouble", false); } -void MyVector3_collectorInsertAndMakeBase_44(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -558,7 +580,7 @@ void MyVector3_collectorInsertAndMakeBase_44(int nargout, mxArray *out[], int na collector_MyVector3.insert(self); } -void MyVector3_constructor_45(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_constructor_48(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -569,7 +591,7 @@ void MyVector3_constructor_45(int nargout, mxArray *out[], int nargin, const mxA *reinterpret_cast (mxGetData(out[0])) = self; } -void MyVector3_deconstructor_46(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector3_deconstructor_49(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyVector3",nargout,nargin,1); @@ -577,12 +599,12 @@ void MyVector3_deconstructor_46(int nargout, mxArray *out[], int nargin, const m Collector_MyVector3::iterator item; item = collector_MyVector3.find(self); if(item != collector_MyVector3.end()) { - delete self; collector_MyVector3.erase(item); } + delete self; } -void MyVector12_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_collectorInsertAndMakeBase_50(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -591,7 +613,7 @@ void MyVector12_collectorInsertAndMakeBase_47(int nargout, mxArray *out[], int n collector_MyVector12.insert(self); } -void MyVector12_constructor_48(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_constructor_51(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -602,7 +624,7 @@ void MyVector12_constructor_48(int nargout, mxArray *out[], int nargin, const mx *reinterpret_cast (mxGetData(out[0])) = self; } -void MyVector12_deconstructor_49(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyVector12_deconstructor_52(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyVector12",nargout,nargin,1); @@ -610,12 +632,12 @@ void MyVector12_deconstructor_49(int nargout, mxArray *out[], int nargin, const Collector_MyVector12::iterator item; item = collector_MyVector12.find(self); if(item != collector_MyVector12.end()) { - delete self; collector_MyVector12.erase(item); } + delete self; } -void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_53(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -624,7 +646,7 @@ void MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(int nargout, mxArr collector_MultipleTemplatesIntDouble.insert(self); } -void MultipleTemplatesIntDouble_deconstructor_51(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntDouble_deconstructor_54(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MultipleTemplatesIntDouble",nargout,nargin,1); @@ -632,12 +654,12 @@ void MultipleTemplatesIntDouble_deconstructor_51(int nargout, mxArray *out[], in Collector_MultipleTemplatesIntDouble::iterator item; item = collector_MultipleTemplatesIntDouble.find(self); if(item != collector_MultipleTemplatesIntDouble.end()) { - delete self; collector_MultipleTemplatesIntDouble.erase(item); } + delete self; } -void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_55(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -646,7 +668,7 @@ void MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(int nargout, mxArra collector_MultipleTemplatesIntFloat.insert(self); } -void MultipleTemplatesIntFloat_deconstructor_53(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MultipleTemplatesIntFloat_deconstructor_56(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MultipleTemplatesIntFloat",nargout,nargin,1); @@ -654,12 +676,12 @@ void MultipleTemplatesIntFloat_deconstructor_53(int nargout, mxArray *out[], int Collector_MultipleTemplatesIntFloat::iterator item; item = collector_MultipleTemplatesIntFloat.find(self); if(item != collector_MultipleTemplatesIntFloat.end()) { - delete self; collector_MultipleTemplatesIntFloat.erase(item); } + delete self; } -void ForwardKinematics_collectorInsertAndMakeBase_54(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -668,7 +690,7 @@ void ForwardKinematics_collectorInsertAndMakeBase_54(int nargout, mxArray *out[] collector_ForwardKinematics.insert(self); } -void ForwardKinematics_constructor_55(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_constructor_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -684,7 +706,22 @@ void ForwardKinematics_constructor_55(int nargout, mxArray *out[], int nargin, c *reinterpret_cast (mxGetData(out[0])) = self; } -void ForwardKinematics_deconstructor_56(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void ForwardKinematics_constructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + mexAtExit(&_deleteAllObjects); + typedef boost::shared_ptr Shared; + + gtdynamics::Robot& robot = *unwrap_shared_ptr< gtdynamics::Robot >(in[0], "ptr_gtdynamicsRobot"); + string& start_link_name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + string& end_link_name = *unwrap_shared_ptr< string >(in[2], "ptr_string"); + gtsam::Values& joint_angles = *unwrap_shared_ptr< gtsam::Values >(in[3], "ptr_gtsamValues"); + Shared *self = new Shared(new ForwardKinematics(robot,start_link_name,end_link_name,joint_angles,gtsam::Pose3())); + collector_ForwardKinematics.insert(self); + out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); + *reinterpret_cast (mxGetData(out[0])) = self; +} + +void ForwardKinematics_deconstructor_60(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_ForwardKinematics",nargout,nargin,1); @@ -692,12 +729,12 @@ void ForwardKinematics_deconstructor_56(int nargout, mxArray *out[], int nargin, Collector_ForwardKinematics::iterator item; item = collector_ForwardKinematics.find(self); if(item != collector_ForwardKinematics.end()) { - delete self; collector_ForwardKinematics.erase(item); } + delete self; } -void TemplatedConstructor_collectorInsertAndMakeBase_57(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_collectorInsertAndMakeBase_61(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -706,7 +743,7 @@ void TemplatedConstructor_collectorInsertAndMakeBase_57(int nargout, mxArray *ou collector_TemplatedConstructor.insert(self); } -void TemplatedConstructor_constructor_58(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -717,7 +754,7 @@ void TemplatedConstructor_constructor_58(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_59(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_63(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -729,7 +766,7 @@ void TemplatedConstructor_constructor_59(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_60(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_64(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -741,7 +778,7 @@ void TemplatedConstructor_constructor_60(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_constructor_61(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_constructor_65(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr Shared; @@ -753,7 +790,7 @@ void TemplatedConstructor_constructor_61(int nargout, mxArray *out[], int nargin *reinterpret_cast (mxGetData(out[0])) = self; } -void TemplatedConstructor_deconstructor_62(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void TemplatedConstructor_deconstructor_66(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr Shared; checkArguments("delete_TemplatedConstructor",nargout,nargin,1); @@ -761,12 +798,12 @@ void TemplatedConstructor_deconstructor_62(int nargout, mxArray *out[], int narg Collector_TemplatedConstructor::iterator item; item = collector_TemplatedConstructor.find(self); if(item != collector_TemplatedConstructor.end()) { - delete self; collector_TemplatedConstructor.erase(item); } + delete self; } -void MyFactorPosePoint2_collectorInsertAndMakeBase_63(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_collectorInsertAndMakeBase_67(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -775,7 +812,7 @@ void MyFactorPosePoint2_collectorInsertAndMakeBase_63(int nargout, mxArray *out[ collector_MyFactorPosePoint2.insert(self); } -void MyFactorPosePoint2_constructor_64(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_constructor_68(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { mexAtExit(&_deleteAllObjects); typedef boost::shared_ptr> Shared; @@ -790,7 +827,7 @@ void MyFactorPosePoint2_constructor_64(int nargout, mxArray *out[], int nargin, *reinterpret_cast (mxGetData(out[0])) = self; } -void MyFactorPosePoint2_deconstructor_65(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_deconstructor_69(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { typedef boost::shared_ptr> Shared; checkArguments("delete_MyFactorPosePoint2",nargout,nargin,1); @@ -798,12 +835,12 @@ void MyFactorPosePoint2_deconstructor_65(int nargout, mxArray *out[], int nargin Collector_MyFactorPosePoint2::iterator item; item = collector_MyFactorPosePoint2.find(self); if(item != collector_MyFactorPosePoint2.end()) { - delete self; collector_MyFactorPosePoint2.erase(item); } + delete self; } -void MyFactorPosePoint2_print_66(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void MyFactorPosePoint2_print_70(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("print",nargout,nargin-1,2); auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); @@ -812,6 +849,21 @@ void MyFactorPosePoint2_print_66(int nargout, mxArray *out[], int nargin, const obj->print(s,keyFormatter); } +void MyFactorPosePoint2_print_71(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("print",nargout,nargin-1,1); + auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); + string& s = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + obj->print(s,gtsam::DefaultKeyFormatter); +} + +void MyFactorPosePoint2_print_72(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("print",nargout,nargin-1,0); + auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); + obj->print("factor: ",gtsam::DefaultKeyFormatter); +} + void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { @@ -849,181 +901,199 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) FunDouble_multiTemplatedMethod_7(nargout, out, nargin-1, in+1); break; case 8: - FunDouble_templatedMethod_8(nargout, out, nargin-1, in+1); + FunDouble_sets_8(nargout, out, nargin-1, in+1); break; case 9: - FunDouble_staticMethodWithThis_9(nargout, out, nargin-1, in+1); + FunDouble_templatedMethod_9(nargout, out, nargin-1, in+1); break; case 10: - FunDouble_templatedStaticMethodInt_10(nargout, out, nargin-1, in+1); + FunDouble_staticMethodWithThis_10(nargout, out, nargin-1, in+1); break; case 11: - Test_collectorInsertAndMakeBase_11(nargout, out, nargin-1, in+1); + FunDouble_templatedStaticMethodInt_11(nargout, out, nargin-1, in+1); break; case 12: - Test_constructor_12(nargout, out, nargin-1, in+1); + Test_collectorInsertAndMakeBase_12(nargout, out, nargin-1, in+1); break; case 13: Test_constructor_13(nargout, out, nargin-1, in+1); break; case 14: - Test_deconstructor_14(nargout, out, nargin-1, in+1); + Test_constructor_14(nargout, out, nargin-1, in+1); break; case 15: - Test_arg_EigenConstRef_15(nargout, out, nargin-1, in+1); + Test_deconstructor_15(nargout, out, nargin-1, in+1); break; case 16: - Test_create_MixedPtrs_16(nargout, out, nargin-1, in+1); + Test_arg_EigenConstRef_16(nargout, out, nargin-1, in+1); break; case 17: - Test_create_ptrs_17(nargout, out, nargin-1, in+1); + Test_create_MixedPtrs_17(nargout, out, nargin-1, in+1); break; case 18: - Test_get_container_18(nargout, out, nargin-1, in+1); + Test_create_ptrs_18(nargout, out, nargin-1, in+1); break; case 19: - Test_lambda_19(nargout, out, nargin-1, in+1); + Test_get_container_19(nargout, out, nargin-1, in+1); break; case 20: - Test_print_20(nargout, out, nargin-1, in+1); + Test_lambda_20(nargout, out, nargin-1, in+1); break; case 21: - Test_return_Point2Ptr_21(nargout, out, nargin-1, in+1); + Test_markdown_21(nargout, out, nargin-1, in+1); break; case 22: - Test_return_Test_22(nargout, out, nargin-1, in+1); + Test_markdown_22(nargout, out, nargin-1, in+1); break; case 23: - Test_return_TestPtr_23(nargout, out, nargin-1, in+1); + Test_print_23(nargout, out, nargin-1, in+1); break; case 24: - Test_return_bool_24(nargout, out, nargin-1, in+1); + Test_return_Point2Ptr_24(nargout, out, nargin-1, in+1); break; case 25: - Test_return_double_25(nargout, out, nargin-1, in+1); + Test_return_Test_25(nargout, out, nargin-1, in+1); break; case 26: - Test_return_field_26(nargout, out, nargin-1, in+1); + Test_return_TestPtr_26(nargout, out, nargin-1, in+1); break; case 27: - Test_return_int_27(nargout, out, nargin-1, in+1); + Test_return_bool_27(nargout, out, nargin-1, in+1); break; case 28: - Test_return_matrix1_28(nargout, out, nargin-1, in+1); + Test_return_double_28(nargout, out, nargin-1, in+1); break; case 29: - Test_return_matrix2_29(nargout, out, nargin-1, in+1); + Test_return_field_29(nargout, out, nargin-1, in+1); break; case 30: - Test_return_pair_30(nargout, out, nargin-1, in+1); + Test_return_int_30(nargout, out, nargin-1, in+1); break; case 31: - Test_return_pair_31(nargout, out, nargin-1, in+1); + Test_return_matrix1_31(nargout, out, nargin-1, in+1); break; case 32: - Test_return_ptrs_32(nargout, out, nargin-1, in+1); + Test_return_matrix2_32(nargout, out, nargin-1, in+1); break; case 33: - Test_return_size_t_33(nargout, out, nargin-1, in+1); + Test_return_pair_33(nargout, out, nargin-1, in+1); break; case 34: - Test_return_string_34(nargout, out, nargin-1, in+1); + Test_return_pair_34(nargout, out, nargin-1, in+1); break; case 35: - Test_return_vector1_35(nargout, out, nargin-1, in+1); + Test_return_ptrs_35(nargout, out, nargin-1, in+1); break; case 36: - Test_return_vector2_36(nargout, out, nargin-1, in+1); + Test_return_size_t_36(nargout, out, nargin-1, in+1); break; case 37: - Test_set_container_37(nargout, out, nargin-1, in+1); + Test_return_string_37(nargout, out, nargin-1, in+1); break; case 38: - Test_set_container_38(nargout, out, nargin-1, in+1); + Test_return_vector1_38(nargout, out, nargin-1, in+1); break; case 39: - Test_set_container_39(nargout, out, nargin-1, in+1); + Test_return_vector2_39(nargout, out, nargin-1, in+1); break; case 40: - PrimitiveRefDouble_collectorInsertAndMakeBase_40(nargout, out, nargin-1, in+1); + Test_set_container_40(nargout, out, nargin-1, in+1); break; case 41: - PrimitiveRefDouble_constructor_41(nargout, out, nargin-1, in+1); + Test_set_container_41(nargout, out, nargin-1, in+1); break; case 42: - PrimitiveRefDouble_deconstructor_42(nargout, out, nargin-1, in+1); + Test_set_container_42(nargout, out, nargin-1, in+1); break; case 43: - PrimitiveRefDouble_Brutal_43(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_collectorInsertAndMakeBase_43(nargout, out, nargin-1, in+1); break; case 44: - MyVector3_collectorInsertAndMakeBase_44(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_constructor_44(nargout, out, nargin-1, in+1); break; case 45: - MyVector3_constructor_45(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_deconstructor_45(nargout, out, nargin-1, in+1); break; case 46: - MyVector3_deconstructor_46(nargout, out, nargin-1, in+1); + PrimitiveRefDouble_Brutal_46(nargout, out, nargin-1, in+1); break; case 47: - MyVector12_collectorInsertAndMakeBase_47(nargout, out, nargin-1, in+1); + MyVector3_collectorInsertAndMakeBase_47(nargout, out, nargin-1, in+1); break; case 48: - MyVector12_constructor_48(nargout, out, nargin-1, in+1); + MyVector3_constructor_48(nargout, out, nargin-1, in+1); break; case 49: - MyVector12_deconstructor_49(nargout, out, nargin-1, in+1); + MyVector3_deconstructor_49(nargout, out, nargin-1, in+1); break; case 50: - MultipleTemplatesIntDouble_collectorInsertAndMakeBase_50(nargout, out, nargin-1, in+1); + MyVector12_collectorInsertAndMakeBase_50(nargout, out, nargin-1, in+1); break; case 51: - MultipleTemplatesIntDouble_deconstructor_51(nargout, out, nargin-1, in+1); + MyVector12_constructor_51(nargout, out, nargin-1, in+1); break; case 52: - MultipleTemplatesIntFloat_collectorInsertAndMakeBase_52(nargout, out, nargin-1, in+1); + MyVector12_deconstructor_52(nargout, out, nargin-1, in+1); break; case 53: - MultipleTemplatesIntFloat_deconstructor_53(nargout, out, nargin-1, in+1); + MultipleTemplatesIntDouble_collectorInsertAndMakeBase_53(nargout, out, nargin-1, in+1); break; case 54: - ForwardKinematics_collectorInsertAndMakeBase_54(nargout, out, nargin-1, in+1); + MultipleTemplatesIntDouble_deconstructor_54(nargout, out, nargin-1, in+1); break; case 55: - ForwardKinematics_constructor_55(nargout, out, nargin-1, in+1); + MultipleTemplatesIntFloat_collectorInsertAndMakeBase_55(nargout, out, nargin-1, in+1); break; case 56: - ForwardKinematics_deconstructor_56(nargout, out, nargin-1, in+1); + MultipleTemplatesIntFloat_deconstructor_56(nargout, out, nargin-1, in+1); break; case 57: - TemplatedConstructor_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1); + ForwardKinematics_collectorInsertAndMakeBase_57(nargout, out, nargin-1, in+1); break; case 58: - TemplatedConstructor_constructor_58(nargout, out, nargin-1, in+1); + ForwardKinematics_constructor_58(nargout, out, nargin-1, in+1); break; case 59: - TemplatedConstructor_constructor_59(nargout, out, nargin-1, in+1); + ForwardKinematics_constructor_59(nargout, out, nargin-1, in+1); break; case 60: - TemplatedConstructor_constructor_60(nargout, out, nargin-1, in+1); + ForwardKinematics_deconstructor_60(nargout, out, nargin-1, in+1); break; case 61: - TemplatedConstructor_constructor_61(nargout, out, nargin-1, in+1); + TemplatedConstructor_collectorInsertAndMakeBase_61(nargout, out, nargin-1, in+1); break; case 62: - TemplatedConstructor_deconstructor_62(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_62(nargout, out, nargin-1, in+1); break; case 63: - MyFactorPosePoint2_collectorInsertAndMakeBase_63(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_63(nargout, out, nargin-1, in+1); break; case 64: - MyFactorPosePoint2_constructor_64(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_64(nargout, out, nargin-1, in+1); break; case 65: - MyFactorPosePoint2_deconstructor_65(nargout, out, nargin-1, in+1); + TemplatedConstructor_constructor_65(nargout, out, nargin-1, in+1); break; case 66: - MyFactorPosePoint2_print_66(nargout, out, nargin-1, in+1); + TemplatedConstructor_deconstructor_66(nargout, out, nargin-1, in+1); + break; + case 67: + MyFactorPosePoint2_collectorInsertAndMakeBase_67(nargout, out, nargin-1, in+1); + break; + case 68: + MyFactorPosePoint2_constructor_68(nargout, out, nargin-1, in+1); + break; + case 69: + MyFactorPosePoint2_deconstructor_69(nargout, out, nargin-1, in+1); + break; + case 70: + MyFactorPosePoint2_print_70(nargout, out, nargin-1, in+1); + break; + case 71: + MyFactorPosePoint2_print_71(nargout, out, nargin-1, in+1); + break; + case 72: + MyFactorPosePoint2_print_72(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/wrap/tests/expected/matlab/functions_wrapper.cpp b/wrap/tests/expected/matlab/functions_wrapper.cpp index d0f0f8ca6..17b5fb494 100644 --- a/wrap/tests/expected/matlab/functions_wrapper.cpp +++ b/wrap/tests/expected/matlab/functions_wrapper.cpp @@ -51,7 +51,7 @@ void _functions_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_functions_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -130,43 +130,110 @@ void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in int b = unwrap< int >(in[1]); DefaultFuncInt(a,b); } -void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncInt_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncInt",nargout,nargin,1); + int a = unwrap< int >(in[0]); + DefaultFuncInt(a,0); +} +void DefaultFuncInt_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncInt",nargout,nargin,0); + DefaultFuncInt(123,0); +} +void DefaultFuncString_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncString",nargout,nargin,2); string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); string& name = *unwrap_shared_ptr< string >(in[1], "ptr_string"); DefaultFuncString(s,name); } -void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncString_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncString",nargout,nargin,1); + string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); + DefaultFuncString(s,""); +} +void DefaultFuncString_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncString",nargout,nargin,0); + DefaultFuncString("hello",""); +} +void DefaultFuncObj_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncObj",nargout,nargin,1); gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[0], "ptr_gtsamKeyFormatter"); DefaultFuncObj(keyFormatter); } -void DefaultFuncZero_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncObj_15(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncObj",nargout,nargin,0); + DefaultFuncObj(gtsam::DefaultKeyFormatter); +} +void DefaultFuncZero_16(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncZero",nargout,nargin,5); int a = unwrap< int >(in[0]); int b = unwrap< int >(in[1]); double c = unwrap< double >(in[2]); - bool d = unwrap< bool >(in[3]); + int d = unwrap< int >(in[3]); bool e = unwrap< bool >(in[4]); DefaultFuncZero(a,b,c,d,e); } -void DefaultFuncVector_12(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncZero_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,4); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + double c = unwrap< double >(in[2]); + int d = unwrap< int >(in[3]); + DefaultFuncZero(a,b,c,d,false); +} +void DefaultFuncZero_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,3); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + double c = unwrap< double >(in[2]); + DefaultFuncZero(a,b,c,0,false); +} +void DefaultFuncZero_19(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncZero",nargout,nargin,2); + int a = unwrap< int >(in[0]); + int b = unwrap< int >(in[1]); + DefaultFuncZero(a,b,0.0,0,false); +} +void DefaultFuncVector_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("DefaultFuncVector",nargout,nargin,2); std::vector& i = *unwrap_shared_ptr< std::vector >(in[0], "ptr_stdvectorint"); std::vector& s = *unwrap_shared_ptr< std::vector >(in[1], "ptr_stdvectorstring"); DefaultFuncVector(i,s); } -void setPose_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncVector_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncVector",nargout,nargin,1); + std::vector& i = *unwrap_shared_ptr< std::vector >(in[0], "ptr_stdvectorint"); + DefaultFuncVector(i,{"borglab", "gtsam"}); +} +void DefaultFuncVector_22(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncVector",nargout,nargin,0); + DefaultFuncVector({1, 2, 3},{"borglab", "gtsam"}); +} +void setPose_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("setPose",nargout,nargin,1); gtsam::Pose3& pose = *unwrap_shared_ptr< gtsam::Pose3 >(in[0], "ptr_gtsamPose3"); setPose(pose); } -void TemplatedFunctionRot3_14(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void setPose_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("setPose",nargout,nargin,0); + setPose(gtsam::Pose3()); +} +void TemplatedFunctionRot3_25(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("TemplatedFunctionRot3",nargout,nargin,1); gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3"); @@ -212,22 +279,55 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) DefaultFuncInt_8(nargout, out, nargin-1, in+1); break; case 9: - DefaultFuncString_9(nargout, out, nargin-1, in+1); + DefaultFuncInt_9(nargout, out, nargin-1, in+1); break; case 10: - DefaultFuncObj_10(nargout, out, nargin-1, in+1); + DefaultFuncInt_10(nargout, out, nargin-1, in+1); break; case 11: - DefaultFuncZero_11(nargout, out, nargin-1, in+1); + DefaultFuncString_11(nargout, out, nargin-1, in+1); break; case 12: - DefaultFuncVector_12(nargout, out, nargin-1, in+1); + DefaultFuncString_12(nargout, out, nargin-1, in+1); break; case 13: - setPose_13(nargout, out, nargin-1, in+1); + DefaultFuncString_13(nargout, out, nargin-1, in+1); break; case 14: - TemplatedFunctionRot3_14(nargout, out, nargin-1, in+1); + DefaultFuncObj_14(nargout, out, nargin-1, in+1); + break; + case 15: + DefaultFuncObj_15(nargout, out, nargin-1, in+1); + break; + case 16: + DefaultFuncZero_16(nargout, out, nargin-1, in+1); + break; + case 17: + DefaultFuncZero_17(nargout, out, nargin-1, in+1); + break; + case 18: + DefaultFuncZero_18(nargout, out, nargin-1, in+1); + break; + case 19: + DefaultFuncZero_19(nargout, out, nargin-1, in+1); + break; + case 20: + DefaultFuncVector_20(nargout, out, nargin-1, in+1); + break; + case 21: + DefaultFuncVector_21(nargout, out, nargin-1, in+1); + break; + case 22: + DefaultFuncVector_22(nargout, out, nargin-1, in+1); + break; + case 23: + setPose_23(nargout, out, nargin-1, in+1); + break; + case 24: + setPose_24(nargout, out, nargin-1, in+1); + break; + case 25: + TemplatedFunctionRot3_25(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/wrap/tests/expected/matlab/geometry_wrapper.cpp b/wrap/tests/expected/matlab/geometry_wrapper.cpp index 81631390c..ee1f04359 100644 --- a/wrap/tests/expected/matlab/geometry_wrapper.cpp +++ b/wrap/tests/expected/matlab/geometry_wrapper.cpp @@ -118,9 +118,9 @@ void gtsamPoint2_deconstructor_3(int nargout, mxArray *out[], int nargin, const Collector_gtsamPoint2::iterator item; item = collector_gtsamPoint2.find(self); if(item != collector_gtsamPoint2.end()) { - delete self; collector_gtsamPoint2.erase(item); } + delete self; } void gtsamPoint2_argChar_4(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -262,9 +262,9 @@ void gtsamPoint3_deconstructor_20(int nargout, mxArray *out[], int nargin, const Collector_gtsamPoint3::iterator item; item = collector_gtsamPoint3.find(self); if(item != collector_gtsamPoint3.end()) { - delete self; collector_gtsamPoint3.erase(item); } + delete self; } void gtsamPoint3_norm_21(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -286,14 +286,14 @@ void gtsamPoint3_string_serialize_22(int nargout, mxArray *out[], int nargin, co } void gtsamPoint3_StaticFunctionRet_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("gtsamPoint3.StaticFunctionRet",nargout,nargin,1); + checkArguments("gtsam::Point3.StaticFunctionRet",nargout,nargin,1); double z = unwrap< double >(in[0]); out[0] = wrap< Point3 >(gtsam::Point3::StaticFunctionRet(z)); } void gtsamPoint3_staticFunction_24(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("gtsamPoint3.staticFunction",nargout,nargin,0); + checkArguments("gtsam::Point3.staticFunction",nargout,nargin,0); out[0] = wrap< double >(gtsam::Point3::staticFunction()); } diff --git a/wrap/tests/expected/matlab/inheritance_wrapper.cpp b/wrap/tests/expected/matlab/inheritance_wrapper.cpp index 8e61ac8c6..0cf17eedd 100644 --- a/wrap/tests/expected/matlab/inheritance_wrapper.cpp +++ b/wrap/tests/expected/matlab/inheritance_wrapper.cpp @@ -88,7 +88,7 @@ void _inheritance_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_inheritance_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -121,9 +121,9 @@ void MyBase_deconstructor_2(int nargout, mxArray *out[], int nargin, const mxArr Collector_MyBase::iterator item; item = collector_MyBase.find(self); if(item != collector_MyBase.end()) { - delete self; collector_MyBase.erase(item); } + delete self; } void MyTemplatePoint2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -171,9 +171,9 @@ void MyTemplatePoint2_deconstructor_6(int nargout, mxArray *out[], int nargin, c Collector_MyTemplatePoint2::iterator item; item = collector_MyTemplatePoint2.find(self); if(item != collector_MyTemplatePoint2.end()) { - delete self; collector_MyTemplatePoint2.erase(item); } + delete self; } void MyTemplatePoint2_accept_T_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -289,7 +289,7 @@ void MyTemplatePoint2_templatedMethod_17(int nargout, mxArray *out[], int nargin void MyTemplatePoint2_Level_18(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("MyTemplatePoint2.Level",nargout,nargin,1); + checkArguments("MyTemplate.Level",nargout,nargin,1); Point2 K = unwrap< Point2 >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplatePoint2", false); } @@ -339,9 +339,9 @@ void MyTemplateMatrix_deconstructor_22(int nargout, mxArray *out[], int nargin, Collector_MyTemplateMatrix::iterator item; item = collector_MyTemplateMatrix.find(self); if(item != collector_MyTemplateMatrix.end()) { - delete self; collector_MyTemplateMatrix.erase(item); } + delete self; } void MyTemplateMatrix_accept_T_23(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -457,7 +457,7 @@ void MyTemplateMatrix_templatedMethod_33(int nargout, mxArray *out[], int nargin void MyTemplateMatrix_Level_34(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("MyTemplateMatrix.Level",nargout,nargin,1); + checkArguments("MyTemplate.Level",nargout,nargin,1); Matrix K = unwrap< Matrix >(in[0]); out[0] = wrap_shared_ptr(boost::make_shared>(MyTemplate::Level(K)),"MyTemplateMatrix", false); } @@ -492,9 +492,9 @@ void ForwardKinematicsFactor_deconstructor_37(int nargout, mxArray *out[], int n Collector_ForwardKinematicsFactor::iterator item; item = collector_ForwardKinematicsFactor.find(self); if(item != collector_ForwardKinematicsFactor.end()) { - delete self; collector_ForwardKinematicsFactor.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/multiple_files_wrapper.cpp b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp index 66ab7ff73..864ae75d6 100644 --- a/wrap/tests/expected/matlab/multiple_files_wrapper.cpp +++ b/wrap/tests/expected/matlab/multiple_files_wrapper.cpp @@ -75,7 +75,7 @@ void _multiple_files_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_multiple_files_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -110,9 +110,9 @@ void gtsamClass1_deconstructor_2(int nargout, mxArray *out[], int nargin, const Collector_gtsamClass1::iterator item; item = collector_gtsamClass1.find(self); if(item != collector_gtsamClass1.end()) { - delete self; collector_gtsamClass1.erase(item); } + delete self; } void gtsamClass2_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -143,9 +143,9 @@ void gtsamClass2_deconstructor_5(int nargout, mxArray *out[], int nargin, const Collector_gtsamClass2::iterator item; item = collector_gtsamClass2.find(self); if(item != collector_gtsamClass2.end()) { - delete self; collector_gtsamClass2.erase(item); } + delete self; } void gtsamClassA_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -176,9 +176,9 @@ void gtsamClassA_deconstructor_8(int nargout, mxArray *out[], int nargin, const Collector_gtsamClassA::iterator item; item = collector_gtsamClassA.find(self); if(item != collector_gtsamClassA.end()) { - delete self; collector_gtsamClassA.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/namespaces_wrapper.cpp b/wrap/tests/expected/matlab/namespaces_wrapper.cpp index 604ede5da..b2fe3eed6 100644 --- a/wrap/tests/expected/matlab/namespaces_wrapper.cpp +++ b/wrap/tests/expected/matlab/namespaces_wrapper.cpp @@ -112,7 +112,7 @@ void _namespaces_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_namespaces_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -147,9 +147,9 @@ void ns1ClassA_deconstructor_2(int nargout, mxArray *out[], int nargin, const mx Collector_ns1ClassA::iterator item; item = collector_ns1ClassA.find(self); if(item != collector_ns1ClassA.end()) { - delete self; collector_ns1ClassA.erase(item); } + delete self; } void ns1ClassB_collectorInsertAndMakeBase_3(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -180,9 +180,9 @@ void ns1ClassB_deconstructor_5(int nargout, mxArray *out[], int nargin, const mx Collector_ns1ClassB::iterator item; item = collector_ns1ClassB.find(self); if(item != collector_ns1ClassB.end()) { - delete self; collector_ns1ClassB.erase(item); } + delete self; } void aGlobalFunction_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -218,9 +218,9 @@ void ns2ClassA_deconstructor_9(int nargout, mxArray *out[], int nargin, const mx Collector_ns2ClassA::iterator item; item = collector_ns2ClassA.find(self); if(item != collector_ns2ClassA.end()) { - delete self; collector_ns2ClassA.erase(item); } + delete self; } void ns2ClassA_memberFunction_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -248,7 +248,7 @@ void ns2ClassA_nsReturn_12(int nargout, mxArray *out[], int nargin, const mxArra void ns2ClassA_afunction_13(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { - checkArguments("ns2ClassA.afunction",nargout,nargin,0); + checkArguments("ns2::ClassA.afunction",nargout,nargin,0); out[0] = wrap< double >(ns2::ClassA::afunction()); } @@ -280,9 +280,9 @@ void ns2ns3ClassB_deconstructor_16(int nargout, mxArray *out[], int nargin, cons Collector_ns2ns3ClassB::iterator item; item = collector_ns2ns3ClassB.find(self); if(item != collector_ns2ns3ClassB.end()) { - delete self; collector_ns2ns3ClassB.erase(item); } + delete self; } void ns2ClassC_collectorInsertAndMakeBase_17(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -313,9 +313,9 @@ void ns2ClassC_deconstructor_19(int nargout, mxArray *out[], int nargin, const m Collector_ns2ClassC::iterator item; item = collector_ns2ClassC.find(self); if(item != collector_ns2ClassC.end()) { - delete self; collector_ns2ClassC.erase(item); } + delete self; } void aGlobalFunction_20(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -364,9 +364,9 @@ void ClassD_deconstructor_25(int nargout, mxArray *out[], int nargin, const mxAr Collector_ClassD::iterator item; item = collector_ClassD.find(self); if(item != collector_ClassD.end()) { - delete self; collector_ClassD.erase(item); } + delete self; } void gtsamValues_collectorInsertAndMakeBase_26(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -409,9 +409,9 @@ void gtsamValues_deconstructor_29(int nargout, mxArray *out[], int nargin, const Collector_gtsamValues::iterator item; item = collector_gtsamValues.find(self); if(item != collector_gtsamValues.end()) { - delete self; collector_gtsamValues.erase(item); } + delete self; } void gtsamValues_insert_30(int nargout, mxArray *out[], int nargin, const mxArray *in[]) diff --git a/wrap/tests/expected/matlab/setPose.m b/wrap/tests/expected/matlab/setPose.m new file mode 100644 index 000000000..d45cc5692 --- /dev/null +++ b/wrap/tests/expected/matlab/setPose.m @@ -0,0 +1,8 @@ +function varargout = setPose(varargin) + if length(varargin) == 1 && isa(varargin{1},'gtsam.Pose3') + functions_wrapper(23, varargin{:}); + elseif length(varargin) == 0 + functions_wrapper(24, varargin{:}); + else + error('Arguments do not match any overload of function setPose'); + end diff --git a/wrap/tests/expected/matlab/special_cases_wrapper.cpp b/wrap/tests/expected/matlab/special_cases_wrapper.cpp index 69abbf73b..c6704c20f 100644 --- a/wrap/tests/expected/matlab/special_cases_wrapper.cpp +++ b/wrap/tests/expected/matlab/special_cases_wrapper.cpp @@ -84,7 +84,7 @@ void _special_cases_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_special_cases_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -108,9 +108,9 @@ void gtsamNonlinearFactorGraph_deconstructor_1(int nargout, mxArray *out[], int Collector_gtsamNonlinearFactorGraph::iterator item; item = collector_gtsamNonlinearFactorGraph.find(self); if(item != collector_gtsamNonlinearFactorGraph.end()) { - delete self; collector_gtsamNonlinearFactorGraph.erase(item); } + delete self; } void gtsamNonlinearFactorGraph_addPrior_2(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -140,9 +140,9 @@ void gtsamSfmTrack_deconstructor_4(int nargout, mxArray *out[], int nargin, cons Collector_gtsamSfmTrack::iterator item; item = collector_gtsamSfmTrack.find(self); if(item != collector_gtsamSfmTrack.end()) { - delete self; collector_gtsamSfmTrack.erase(item); } + delete self; } void gtsamPinholeCameraCal3Bundler_collectorInsertAndMakeBase_5(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -162,9 +162,9 @@ void gtsamPinholeCameraCal3Bundler_deconstructor_6(int nargout, mxArray *out[], Collector_gtsamPinholeCameraCal3Bundler::iterator item; item = collector_gtsamPinholeCameraCal3Bundler.find(self); if(item != collector_gtsamPinholeCameraCal3Bundler.end()) { - delete self; collector_gtsamPinholeCameraCal3Bundler.erase(item); } + delete self; } void gtsamGeneralSFMFactorCal3Bundler_collectorInsertAndMakeBase_7(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -184,9 +184,9 @@ void gtsamGeneralSFMFactorCal3Bundler_deconstructor_8(int nargout, mxArray *out[ Collector_gtsamGeneralSFMFactorCal3Bundler::iterator item; item = collector_gtsamGeneralSFMFactorCal3Bundler.find(self); if(item != collector_gtsamGeneralSFMFactorCal3Bundler.end()) { - delete self; collector_gtsamGeneralSFMFactorCal3Bundler.erase(item); } + delete self; } diff --git a/wrap/tests/expected/matlab/template_wrapper.cpp b/wrap/tests/expected/matlab/template_wrapper.cpp index 532bdd57a..a0b1aaa7e 100644 --- a/wrap/tests/expected/matlab/template_wrapper.cpp +++ b/wrap/tests/expected/matlab/template_wrapper.cpp @@ -67,7 +67,7 @@ void _template_RTTIRegister() { mxDestroyArray(registry); mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); - if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) { + if(mexPutVariable("global", "gtsam_template_rttiRegistry_created", newAlreadyCreated) != 0) { mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); } mxDestroyArray(newAlreadyCreated); @@ -138,9 +138,9 @@ void TemplatedConstructor_deconstructor_5(int nargout, mxArray *out[], int nargi Collector_TemplatedConstructor::iterator item; item = collector_TemplatedConstructor.find(self); if(item != collector_TemplatedConstructor.end()) { - delete self; collector_TemplatedConstructor.erase(item); } + delete self; } void ScopedTemplateResult_collectorInsertAndMakeBase_6(int nargout, mxArray *out[], int nargin, const mxArray *in[]) @@ -172,9 +172,9 @@ void ScopedTemplateResult_deconstructor_8(int nargout, mxArray *out[], int nargi Collector_ScopedTemplateResult::iterator item; item = collector_ScopedTemplateResult.find(self); if(item != collector_ScopedTemplateResult.end()) { - delete self; collector_ScopedTemplateResult.erase(item); } + delete self; } diff --git a/wrap/tests/expected/python/class_pybind.cpp b/wrap/tests/expected/python/class_pybind.cpp index a54d9135b..fd5398912 100644 --- a/wrap/tests/expected/python/class_pybind.cpp +++ b/wrap/tests/expected/python/class_pybind.cpp @@ -31,6 +31,7 @@ PYBIND11_MODULE(class_py, m_) { py::class_, std::shared_ptr>>(m_, "FunDouble") .def("templatedMethodString",[](Fun* self, double d, string t){return self->templatedMethod(d, t);}, py::arg("d"), py::arg("t")) .def("multiTemplatedMethodStringSize_t",[](Fun* self, double d, string t, size_t u){return self->multiTemplatedMethod(d, t, u);}, py::arg("d"), py::arg("t"), py::arg("u")) + .def("sets",[](Fun* self){return self->sets();}) .def_static("staticMethodWithThis",[](){return Fun::staticMethodWithThis();}) .def_static("templatedStaticMethodInt",[](const int& m){return Fun::templatedStaticMethod(m);}, py::arg("m")); @@ -68,6 +69,7 @@ PYBIND11_MODULE(class_py, m_) { .def("set_container",[](Test* self, std::vector> container){ self->set_container(container);}, py::arg("container")) .def("set_container",[](Test* self, std::vector container){ self->set_container(container);}, py::arg("container")) .def("get_container",[](Test* self){return self->get_container();}) + .def("_repr_markdown_",[](Test* self, const gtsam::KeyFormatter& keyFormatter){return self->markdown(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter) .def_readwrite("model_ptr", &Test::model_ptr); py::class_, std::shared_ptr>>(m_, "PrimitiveRefDouble") diff --git a/wrap/tests/expected/python/functions_pybind.cpp b/wrap/tests/expected/python/functions_pybind.cpp index bee95ec03..8511b5d3c 100644 --- a/wrap/tests/expected/python/functions_pybind.cpp +++ b/wrap/tests/expected/python/functions_pybind.cpp @@ -33,7 +33,7 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("DefaultFuncInt",[](int a, int b){ ::DefaultFuncInt(a, b);}, py::arg("a") = 123, py::arg("b") = 0); m_.def("DefaultFuncString",[](const string& s, const string& name){ ::DefaultFuncString(s, name);}, py::arg("s") = "hello", py::arg("name") = ""); m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); - m_.def("DefaultFuncZero",[](int a, int b, double c, bool d, bool e){ ::DefaultFuncZero(a, b, c, d, e);}, py::arg("a") = 0, py::arg("b"), py::arg("c") = 0.0, py::arg("d") = false, py::arg("e")); + m_.def("DefaultFuncZero",[](int a, int b, double c, int d, bool e){ ::DefaultFuncZero(a, b, c, d, e);}, py::arg("a"), py::arg("b"), py::arg("c") = 0.0, py::arg("d") = 0, py::arg("e") = false); m_.def("DefaultFuncVector",[](const std::vector& i, const std::vector& s){ ::DefaultFuncVector(i, s);}, py::arg("i") = {1, 2, 3}, py::arg("s") = {"borglab", "gtsam"}); m_.def("setPose",[](const gtsam::Pose3& pose){ ::setPose(pose);}, py::arg("pose") = gtsam::Pose3()); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); diff --git a/wrap/tests/fixtures/class.i b/wrap/tests/fixtures/class.i index 40a565506..f38c27411 100644 --- a/wrap/tests/fixtures/class.i +++ b/wrap/tests/fixtures/class.i @@ -18,6 +18,8 @@ class Fun { template This multiTemplatedMethod(double d, T t, U u); + + std::map sets(); }; @@ -75,6 +77,10 @@ class Test { void set_container(std::vector container); std::vector get_container() const; + // special ipython method + string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + // comments at the end! // even more comments at the end! diff --git a/wrap/tests/fixtures/functions.i b/wrap/tests/fixtures/functions.i index 0852a3e1e..7f3c83332 100644 --- a/wrap/tests/fixtures/functions.i +++ b/wrap/tests/fixtures/functions.i @@ -31,7 +31,7 @@ typedef TemplatedFunction TemplatedFunctionRot3; void DefaultFuncInt(int a = 123, int b = 0); void DefaultFuncString(const string& s = "hello", const string& name = ""); void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); -void DefaultFuncZero(int a = 0, int b, double c = 0.0, bool d = false, bool e); +void DefaultFuncZero(int a, int b, double c = 0.0, int d = 0, bool e = false); void DefaultFuncVector(const std::vector &i = {1, 2, 3}, const std::vector &s = {"borglab", "gtsam"}); // Test for non-trivial default constructor diff --git a/wrap/tests/fixtures/geometry.i b/wrap/tests/fixtures/geometry.i index a7b900f80..e1460666c 100644 --- a/wrap/tests/fixtures/geometry.i +++ b/wrap/tests/fixtures/geometry.i @@ -24,9 +24,6 @@ class Point2 { VectorNotEigen vectorConfusion(); void serializable() const; // Sets flag and creates export, but does not make serialization functions - - // enable pickling in python - void pickle() const; }; #include @@ -40,9 +37,6 @@ class Point3 { // enabling serialization functionality void serialize() const; // Just triggers a flag internally and removes actual function - - // enable pickling in python - void pickle() const; }; } diff --git a/wrap/tests/test_interface_parser.py b/wrap/tests/test_interface_parser.py index 49165328c..2603e9db4 100644 --- a/wrap/tests/test_interface_parser.py +++ b/wrap/tests/test_interface_parser.py @@ -657,8 +657,6 @@ class TestInterfaceParser(unittest.TestCase): int globalVar; """) - # print("module: ", module) - # print(dir(module.content[0].name)) self.assertEqual(["one", "Global", "globalVar"], [x.name for x in module.content]) self.assertEqual(["two", "two_dummy", "two", "oneVar"], diff --git a/wrap/tests/test_matlab_wrapper.py b/wrap/tests/test_matlab_wrapper.py index 34940d62e..43fedf7aa 100644 --- a/wrap/tests/test_matlab_wrapper.py +++ b/wrap/tests/test_matlab_wrapper.py @@ -92,10 +92,19 @@ class TestWrap(unittest.TestCase): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'functions_wrapper.cpp', 'aGlobalFunction.m', 'load2D.m', + 'functions_wrapper.cpp', + 'aGlobalFunction.m', + 'load2D.m', 'MultiTemplatedFunctionDoubleSize_tDouble.m', 'MultiTemplatedFunctionStringSize_tDouble.m', - 'overloadedGlobalFunction.m', 'TemplatedFunctionRot3.m' + 'overloadedGlobalFunction.m', + 'TemplatedFunctionRot3.m', + 'DefaultFuncInt.m', + 'DefaultFuncObj.m', + 'DefaultFuncString.m', + 'DefaultFuncVector.m', + 'DefaultFuncZero.m', + 'setPose.m', ] for file in files: @@ -115,10 +124,17 @@ class TestWrap(unittest.TestCase): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'class_wrapper.cpp', 'FunDouble.m', 'FunRange.m', - 'MultipleTemplatesIntDouble.m', 'MultipleTemplatesIntFloat.m', - 'MyFactorPosePoint2.m', 'MyVector3.m', 'MyVector12.m', - 'PrimitiveRefDouble.m', 'Test.m' + 'class_wrapper.cpp', + 'FunDouble.m', + 'FunRange.m', + 'MultipleTemplatesIntDouble.m', + 'MultipleTemplatesIntFloat.m', + 'MyFactorPosePoint2.m', + 'MyVector3.m', + 'MyVector12.m', + 'PrimitiveRefDouble.m', + 'Test.m', + 'ForwardKinematics.m', ] for file in files: @@ -137,7 +153,10 @@ class TestWrap(unittest.TestCase): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) - files = ['template_wrapper.cpp'] + files = [ + 'template_wrapper.cpp', 'ScopedTemplateResult.m', + 'TemplatedConstructor.m' + ] for file in files: actual = osp.join(self.MATLAB_ACTUAL_DIR, file) @@ -155,8 +174,11 @@ class TestWrap(unittest.TestCase): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'inheritance_wrapper.cpp', 'MyBase.m', 'MyTemplateMatrix.m', - 'MyTemplatePoint2.m' + 'inheritance_wrapper.cpp', + 'MyBase.m', + 'MyTemplateMatrix.m', + 'MyTemplatePoint2.m', + 'ForwardKinematicsFactor.m', ] for file in files: @@ -178,10 +200,17 @@ class TestWrap(unittest.TestCase): wrapper.wrap([file], path=self.MATLAB_ACTUAL_DIR) files = [ - 'namespaces_wrapper.cpp', '+ns1/aGlobalFunction.m', - '+ns1/ClassA.m', '+ns1/ClassB.m', '+ns2/+ns3/ClassB.m', - '+ns2/aGlobalFunction.m', '+ns2/ClassA.m', '+ns2/ClassC.m', - '+ns2/overloadedGlobalFunction.m', 'ClassD.m' + 'namespaces_wrapper.cpp', + '+ns1/aGlobalFunction.m', + '+ns1/ClassA.m', + '+ns1/ClassB.m', + '+ns2/+ns3/ClassB.m', + '+ns2/aGlobalFunction.m', + '+ns2/ClassA.m', + '+ns2/ClassC.m', + '+ns2/overloadedGlobalFunction.m', + 'ClassD.m', + '+gtsam/Values.m', ] for file in files: @@ -203,8 +232,10 @@ class TestWrap(unittest.TestCase): files = [ 'special_cases_wrapper.cpp', - '+gtsam/PinholeCameraCal3Bundler.m', + '+gtsam/GeneralSFMFactorCal3Bundler.m', '+gtsam/NonlinearFactorGraph.m', + '+gtsam/PinholeCameraCal3Bundler.m', + '+gtsam/SfmTrack.m', ] for file in files: