Merge branch 'develop' into hybrid-tablefactor-3

release/4.3a0
Varun Agrawal 2023-07-23 17:21:38 -04:00
commit 381c33c6d4
62 changed files with 1262 additions and 814 deletions

View File

@ -1,323 +0,0 @@
# The MIT License (MIT)
#
# Copyright (c) 2015 Justus Calvin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# FindTBB
# -------
#
# Find TBB include directories and libraries.
#
# Usage:
#
# find_package(TBB [major[.minor]] [EXACT]
# [QUIET] [REQUIRED]
# [[COMPONENTS] [components...]]
# [OPTIONAL_COMPONENTS components...])
#
# where the allowed components are tbbmalloc and tbb_preview. Users may modify
# the behavior of this module with the following variables:
#
# * TBB_ROOT_DIR - The base directory the of TBB installation.
# * TBB_INCLUDE_DIR - The directory that contains the TBB headers files.
# * TBB_LIBRARY - The directory that contains the TBB library files.
# * TBB_<library>_LIBRARY - The path of the TBB the corresponding TBB library.
# These libraries, if specified, override the
# corresponding library search results, where <library>
# may be tbb, tbb_debug, tbbmalloc, tbbmalloc_debug,
# tbb_preview, or tbb_preview_debug.
# * TBB_USE_DEBUG_BUILD - The debug version of tbb libraries, if present, will
# be used instead of the release version.
#
# Users may modify the behavior of this module with the following environment
# variables:
#
# * TBB_INSTALL_DIR
# * TBBROOT
# * LIBRARY_PATH
#
# This module will set the following variables:
#
# * TBB_FOUND - Set to false, or undefined, if we havent found, or
# dont want to use TBB.
# * TBB_<component>_FOUND - If False, optional <component> part of TBB sytem is
# not available.
# * TBB_VERSION - The full version string
# * TBB_VERSION_MAJOR - The major version
# * TBB_VERSION_MINOR - The minor version
# * TBB_INTERFACE_VERSION - The interface version number defined in
# tbb/tbb_stddef.h.
# * TBB_<library>_LIBRARY_RELEASE - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
# * TBB_<library>_LIBRARY_DEGUG - The path of the TBB release version of
# <library>, where <library> may be tbb, tbb_debug,
# tbbmalloc, tbbmalloc_debug, tbb_preview, or
# tbb_preview_debug.
#
# The following varibles should be used to build and link with TBB:
#
# * TBB_INCLUDE_DIRS - The include directory for TBB.
# * TBB_LIBRARIES - The libraries to link against to use TBB.
# * TBB_LIBRARIES_RELEASE - The release libraries to link against to use TBB.
# * TBB_LIBRARIES_DEBUG - The debug libraries to link against to use TBB.
# * TBB_DEFINITIONS - Definitions to use when compiling code that uses
# TBB.
# * TBB_DEFINITIONS_RELEASE - Definitions to use when compiling release code that
# uses TBB.
# * TBB_DEFINITIONS_DEBUG - Definitions to use when compiling debug code that
# uses TBB.
#
# This module will also create the "tbb" target that may be used when building
# executables and libraries.
include(FindPackageHandleStandardArgs)
if(NOT TBB_FOUND)
##################################
# Check the build type
##################################
if(NOT DEFINED TBB_USE_DEBUG_BUILD)
# Set build type to RELEASE by default for optimization.
set(TBB_BUILD_TYPE RELEASE)
elseif(TBB_USE_DEBUG_BUILD)
set(TBB_BUILD_TYPE DEBUG)
else()
set(TBB_BUILD_TYPE RELEASE)
endif()
##################################
# Set the TBB search directories
##################################
# Define search paths based on user input and environment variables
set(TBB_SEARCH_DIR ${TBB_ROOT_DIR} $ENV{TBB_INSTALL_DIR} $ENV{TBBROOT})
# Define the search directories based on the current platform
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(TBB_DEFAULT_SEARCH_DIR "C:/Program Files/Intel/TBB"
"C:/Program Files (x86)/Intel/TBB")
# Set the target architecture
if(CMAKE_SIZEOF_VOID_P EQUAL 8)
set(TBB_ARCHITECTURE "intel64")
else()
set(TBB_ARCHITECTURE "ia32")
endif()
# Set the TBB search library path search suffix based on the version of VC
if(WINDOWS_STORE)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11_ui")
elseif(MSVC14)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc14")
elseif(MSVC12)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc12")
elseif(MSVC11)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc11")
elseif(MSVC10)
set(TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc10")
endif()
# Add the library path search suffix for the VC independent version of TBB
list(APPEND TBB_LIB_PATH_SUFFIX "lib/${TBB_ARCHITECTURE}/vc_mt")
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
# OS X
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb"
"/usr/local/opt/tbb")
# TODO: Check to see which C++ library is being used by the compiler.
if(NOT ${CMAKE_SYSTEM_VERSION} VERSION_LESS 13.0)
# The default C++ library on OS X 10.9 and later is libc++
set(TBB_LIB_PATH_SUFFIX "lib/libc++" "lib")
else()
set(TBB_LIB_PATH_SUFFIX "lib")
endif()
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Linux
set(TBB_DEFAULT_SEARCH_DIR "/opt/intel/tbb")
# TODO: Check compiler version to see the suffix should be <arch>/gcc4.1 or
# <arch>/gcc4.1. For now, assume that the compiler is more recent than
# gcc 4.4.x or later.
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
set(TBB_LIB_PATH_SUFFIX "lib/intel64/gcc4.4")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^i.86$")
set(TBB_LIB_PATH_SUFFIX "lib/ia32/gcc4.4")
endif()
endif()
##################################
# Find the TBB include dir
##################################
find_path(TBB_INCLUDE_DIRS tbb/tbb.h
HINTS ${TBB_INCLUDE_DIR} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR}
PATH_SUFFIXES include)
##################################
# Set version strings
##################################
if(TBB_INCLUDE_DIRS)
set(_tbb_version_file_prior_to_tbb_2021_1 "${TBB_INCLUDE_DIRS}/tbb/tbb_stddef.h")
set(_tbb_version_file_after_tbb_2021_1 "${TBB_INCLUDE_DIRS}/oneapi/tbb/version.h")
if (EXISTS "${_tbb_version_file_prior_to_tbb_2021_1}")
file(READ "${_tbb_version_file_prior_to_tbb_2021_1}" _tbb_version_file )
elseif (EXISTS "${_tbb_version_file_after_tbb_2021_1}")
file(READ "${_tbb_version_file_after_tbb_2021_1}" _tbb_version_file )
else()
message(FATAL_ERROR "Found TBB installation: ${TBB_INCLUDE_DIRS} "
"missing version header.")
endif()
string(REGEX REPLACE ".*#define TBB_VERSION_MAJOR ([0-9]+).*" "\\1"
TBB_VERSION_MAJOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_VERSION_MINOR ([0-9]+).*" "\\1"
TBB_VERSION_MINOR "${_tbb_version_file}")
string(REGEX REPLACE ".*#define TBB_INTERFACE_VERSION ([0-9]+).*" "\\1"
TBB_INTERFACE_VERSION "${_tbb_version_file}")
set(TBB_VERSION "${TBB_VERSION_MAJOR}.${TBB_VERSION_MINOR}")
endif()
##################################
# Find TBB components
##################################
if(TBB_VERSION VERSION_LESS 4.3)
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc tbb)
else()
set(TBB_SEARCH_COMPOMPONENTS tbb_preview tbbmalloc_proxy tbbmalloc tbb)
endif()
# Find each component
foreach(_comp ${TBB_SEARCH_COMPOMPONENTS})
if(";${TBB_FIND_COMPONENTS};tbb;" MATCHES ";${_comp};")
# Search for the libraries
find_library(TBB_${_comp}_LIBRARY_RELEASE ${_comp}
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
find_library(TBB_${_comp}_LIBRARY_DEBUG ${_comp}_debug
HINTS ${TBB_LIBRARY} ${TBB_SEARCH_DIR}
PATHS ${TBB_DEFAULT_SEARCH_DIR} ENV LIBRARY_PATH
PATH_SUFFIXES ${TBB_LIB_PATH_SUFFIX})
if(TBB_${_comp}_LIBRARY_DEBUG)
list(APPEND TBB_LIBRARIES_DEBUG "${TBB_${_comp}_LIBRARY_DEBUG}")
endif()
if(TBB_${_comp}_LIBRARY_RELEASE)
list(APPEND TBB_LIBRARIES_RELEASE "${TBB_${_comp}_LIBRARY_RELEASE}")
endif()
if(TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE} AND NOT TBB_${_comp}_LIBRARY)
set(TBB_${_comp}_LIBRARY "${TBB_${_comp}_LIBRARY_${TBB_BUILD_TYPE}}")
endif()
if(TBB_${_comp}_LIBRARY AND EXISTS "${TBB_${_comp}_LIBRARY}")
set(TBB_${_comp}_FOUND TRUE)
else()
set(TBB_${_comp}_FOUND FALSE)
endif()
# Mark internal variables as advanced
mark_as_advanced(TBB_${_comp}_LIBRARY_RELEASE)
mark_as_advanced(TBB_${_comp}_LIBRARY_DEBUG)
mark_as_advanced(TBB_${_comp}_LIBRARY)
endif()
endforeach()
##################################
# Set compile flags and libraries
##################################
set(TBB_DEFINITIONS_RELEASE "")
set(TBB_DEFINITIONS_DEBUG "-DTBB_USE_DEBUG=1")
if(TBB_LIBRARIES_${TBB_BUILD_TYPE})
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_${TBB_BUILD_TYPE}}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_${TBB_BUILD_TYPE}}")
elseif(TBB_LIBRARIES_RELEASE)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_RELEASE}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_RELEASE}")
elseif(TBB_LIBRARIES_DEBUG)
set(TBB_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}")
set(TBB_LIBRARIES "${TBB_LIBRARIES_DEBUG}")
endif()
find_package_handle_standard_args(TBB
REQUIRED_VARS TBB_INCLUDE_DIRS TBB_LIBRARIES
HANDLE_COMPONENTS
VERSION_VAR TBB_VERSION)
##################################
# Create targets
##################################
if(NOT CMAKE_VERSION VERSION_LESS 3.0 AND TBB_FOUND)
# Start fix to support different targets for tbb, tbbmalloc, etc.
# (Jose Luis Blanco, Jan 2019)
# Iterate over tbb, tbbmalloc, etc.
foreach(libname ${TBB_SEARCH_COMPOMPONENTS})
if ((NOT TBB_${libname}_LIBRARY_RELEASE) AND (NOT TBB_${libname}_LIBRARY_DEBUG))
continue()
endif()
add_library(${libname} SHARED IMPORTED)
set_target_properties(${libname} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${TBB_INCLUDE_DIRS}
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
if(TBB_${libname}_LIBRARY_RELEASE AND TBB_${libname}_LIBRARY_DEBUG)
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "$<$<OR:$<CONFIG:Debug>,$<CONFIG:RelWithDebInfo>>:TBB_USE_DEBUG=1>"
IMPORTED_LOCATION_DEBUG ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELWITHDEBINFO ${TBB_${libname}_LIBRARY_DEBUG}
IMPORTED_LOCATION_RELEASE ${TBB_${libname}_LIBRARY_RELEASE}
IMPORTED_LOCATION_MINSIZEREL ${TBB_${libname}_LIBRARY_RELEASE}
)
elseif(TBB_${libname}_LIBRARY_RELEASE)
set_target_properties(${libname} PROPERTIES IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_RELEASE})
else()
set_target_properties(${libname} PROPERTIES
INTERFACE_COMPILE_DEFINITIONS "${TBB_DEFINITIONS_DEBUG}"
IMPORTED_LOCATION ${TBB_${libname}_LIBRARY_DEBUG}
)
endif()
endforeach()
# End of fix to support different targets
endif()
mark_as_advanced(TBB_INCLUDE_DIRS TBB_LIBRARIES)
unset(TBB_ARCHITECTURE)
unset(TBB_BUILD_TYPE)
unset(TBB_LIB_PATH_SUFFIX)
unset(TBB_DEFAULT_SEARCH_DIR)
endif()

View File

@ -14,7 +14,7 @@ if (GTSAM_WITH_TBB)
endif() endif()
# all definitions and link requisites will go via imported targets: # all definitions and link requisites will go via imported targets:
# tbb & tbbmalloc # tbb & tbbmalloc
list(APPEND GTSAM_ADDITIONAL_LIBRARIES tbb tbbmalloc) list(APPEND GTSAM_ADDITIONAL_LIBRARIES TBB::tbb TBB::tbbmalloc)
else() else()
set(GTSAM_USE_TBB 0) # This will go into config.h set(GTSAM_USE_TBB 0) # This will go into config.h
endif() endif()

View File

@ -22,13 +22,6 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept_check.hpp>
#include <boost/concept/requires.hpp>
#include <boost/type_traits/is_base_of.hpp>
#include <boost/static_assert.hpp>
#endif
#include <utility> #include <utility>
namespace gtsam { namespace gtsam {

View File

@ -19,9 +19,10 @@
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <algorithm>
#include <cassert>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cassert>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <map> #include <map>

View File

@ -19,16 +19,11 @@
#pragma once #pragma once
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <gtsam/dllexport.h> #include <gtsam/dllexport.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/concept/assert.hpp>
#include <boost/range/concepts.hpp>
#endif
#include <gtsam/config.h> // for GTSAM_USE_TBB
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <exception> #include <exception>
#include <string> #include <string>

View File

@ -29,9 +29,9 @@
namespace gtsam { namespace gtsam {
/** /**
* Algebraic Decision Trees fix the range to double * An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar.
* TODO: consider eliminating this class altogether? * TODO(dellaert): consider eliminating this class altogether?
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -81,20 +81,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2) AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {} : Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */ /**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2) double y2)
: Base(labelC, y1, y2) {} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) { const std::vector<double>& ys) {
this->root_ = this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) { const std::string& table) {
@ -109,7 +151,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) { : Base(nullptr) {

View File

@ -93,7 +93,8 @@ namespace gtsam {
/// print /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
} }
/** Write graphviz format to stream `os`. */ /** Write graphviz format to stream `os`. */
@ -626,7 +627,7 @@ namespace gtsam {
// B=1 // B=1
// A=0: 3 // A=0: 3
// A=1: 4 // A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root. // exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest. // However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y> template<typename L, typename Y>
@ -827,6 +828,16 @@ namespace gtsam {
return total; return total;
} }
/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}
/****************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>

View File

@ -39,9 +39,23 @@
namespace gtsam { namespace gtsam {
/** /**
* Decision Tree * @brief a decision tree is a function from assignments to values.
* L = label for variables * @tparam L label for variables
* Y = function range (any algebra), e.g., bool, int, double * @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -136,7 +150,8 @@ namespace gtsam {
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities, /**
* Internal recursive function to create from keys, cardinalities,
* and Y values * and Y values
*/ */
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
@ -167,7 +182,13 @@ namespace gtsam {
/** Create a constant */ /** Create a constant */
explicit DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` /**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */ /** Allow Label+Cardinality for convenience */
@ -299,6 +320,42 @@ namespace gtsam {
/// Return the number of leaves in the tree. /// Return the number of leaves in the tree.
size_t nrLeaves() const; size_t nrLeaves() const;
/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;
/** /**
* @brief Fold a binary function over the tree, returning accumulator. * @brief Fold a binary function over the tree, returning accumulator.
* *

View File

@ -117,6 +117,14 @@ namespace gtsam {
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {

View File

@ -59,11 +59,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */ /** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table); const std::vector<double>& table);
/** Constructor from string */ /**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization /// Single-key specialization

View File

@ -58,6 +58,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
//** evaluate conditional probability of subtree for given DiscreteValues */ //** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const; double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
}; };
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -41,16 +41,30 @@ class DiscreteJunctionTree;
/** /**
* @brief Main elimination function for DiscreteFactorGraph. * @brief Main elimination function for DiscreteFactorGraph.
* *
* @param factors * @param factors The factor graph to eliminate.
* @param keys * @param frontalKeys An ordering for which variables to eliminate.
* @return GTSAM_EXPORT * @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete * @ingroup discrete
*/ */
GTSAM_EXPORT std::pair<std::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr> GTSAM_EXPORT
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph> template<> struct EliminationTraits<DiscreteFactorGraph>
{ {
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@ -60,12 +74,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<std::shared_ptr<ConditionalType>, static std::pair<std::shared_ptr<ConditionalType>,
std::shared_ptr<FactorType> > std::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); return EliminateDiscrete(factors, keys);
} }
/// The default ordering generation function /// The default ordering generation function
static Ordering DefaultOrderingFunc( static Ordering DefaultOrderingFunc(
const FactorGraphType& graph, const FactorGraphType& graph,
@ -74,7 +90,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
} }
}; };
/* ************************************************************************* */
/** /**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor * Factor == DiscreteFactor
@ -108,8 +123,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Implicit copy/downcast constructor to override explicit template container /** Implicit copy/downcast constructor to override explicit template container
* constructor */ * constructor */
template <class DERIVEDFACTOR> template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -227,10 +242,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/// traits /// traits
template <> template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

View File

@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
} }

View File

@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @} /// @}
}; };
/// Free version of CartesianProduct.
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
return DiscreteValues::CartesianProduct(keys);
}
/// Free version of markdown. /// Free version of markdown.
std::string markdown(const DiscreteValues& values, std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,

View File

@ -17,6 +17,8 @@ class DiscreteKeys {
}; };
// DiscreteValues is added in specializations/discrete.h as a std::map // DiscreteValues is added in specializations/discrete.h as a std::map
std::vector<gtsam::DiscreteValues> cartesianProduct(
const gtsam::DiscreteKeys& keys);
string markdown( string markdown(
const gtsam::DiscreteValues& values, const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
@ -31,27 +33,30 @@ string html(const gtsam::DiscreteValues& values,
std::map<gtsam::Key, std::vector<std::string>> names); std::map<gtsam::Key, std::vector<std::string>> names);
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
class DiscreteFactor { virtual class DiscreteFactor : gtsam::Factor {
void print(string s = "DiscreteFactor\n", void print(string s = "DiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
}; };
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor : gtsam::DiscreteFactor { virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor(); DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKey& key, DecisionTreeFactor(const gtsam::DiscreteKey& key,
const std::vector<double>& spec); const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys,
const std::vector<double>& table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys,
const std::vector<double>& table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table); DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c); DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n", void print(string s = "DecisionTreeFactor\n",
@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const; size_t cardinality(gtsam::Key j) const;
@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const;
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -203,10 +211,16 @@ class DiscreteBayesTreeClique {
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const; const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const; bool isRoot() const;
size_t nrChildren() const;
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
void print(string s = "DiscreteBayesTreeClique",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printSignature( void printSignature(
const string& s = "Clique: ", const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
}; };
class DiscreteBayesTree { class DiscreteBayesTree {
@ -220,6 +234,9 @@ class DiscreteBayesTree {
bool empty() const; bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const; const DiscreteBayesTreeClique* operator[](size_t j) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
string dot(const gtsam::KeyFormatter& keyFormatter = string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, void saveGraph(string s,
@ -242,9 +259,9 @@ class DiscreteBayesTree {
class DiscreteLookupTable : gtsam::DiscreteConditional{ class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials); const gtsam::DecisionTreeFactor::ADT& potentials);
void print( void print(string s = "Discrete Lookup Table: ",
const std::string& s = "Discrete Lookup Table: ", const gtsam::KeyFormatter& keyFormatter =
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
}; };
@ -263,6 +280,14 @@ class DiscreteLookupDAG {
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
class DiscreteFactorGraph { class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
@ -277,6 +302,7 @@ class DiscreteFactorGraph {
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec); void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec); void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
@ -290,25 +316,46 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct(); gtsam::DiscreteBayesNet sumProduct(
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct(
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(); gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering); eliminatePartialSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering); eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -328,4 +375,41 @@ class DiscreteFactorGraph {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscreteEliminationTree.h>
class DiscreteEliminationTree {
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/discrete/DiscreteJunctionTree.h>
class DiscreteCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::DiscreteFactorGraph factors;
const gtsam::DiscreteCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class DiscreteJunctionTree {
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -25,6 +25,7 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Symbol.h>
#include <iomanip> #include <iomanip>
@ -75,6 +76,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ************************************************************************** */
// Test char labels and int range
/* ************************************************************************** */
// Create a decision stump one one variable 'a' with values 10 and 20.
TEST(DecisionTree, Constructor) {
DecisionTree<char, int> tree('a', 10, 20);
// Evaluate the tree on an assignment to the variable.
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ************************************************************************** */ /* ************************************************************************** */
@ -118,18 +132,47 @@ struct Ring {
static inline int mul(const int& a, const int& b) { return a * b; } static inline int mul(const int& a, const int& b) { return a * b; }
}; };
/* ************************************************************************** */
// Check that creating decision trees respects key order.
TEST(DecisionTree, ConstructorOrder) {
// Create labels
string A("A"), B("B");
const std::vector<int> ys1 = {1, 2, 3, 4};
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
const std::vector<int> ys2 = {1, 3, 2, 4};
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
// Both trees will be the same, tree is order from high to low labels.
// Choice(B)
// 0 Choice(A)
// 0 0 Leaf 1
// 0 1 Leaf 2
// 1 Choice(A)
// 1 0 Leaf 3
// 1 1 Leaf 4
EXPECT(tree2.equals(tree1));
// Check the values are as expected by calling the () operator:
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, Example) {
// Create labels // Create labels
string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
// create a value // Create assignments using brace initialization:
Assignment<string> x00, x01, x10, x11; Assignment<string> x00{{A, 0}, {B, 0}};
x00[A] = 0, x00[B] = 0; Assignment<string> x01{{A, 0}, {B, 1}};
x01[A] = 0, x01[B] = 1; Assignment<string> x10{{A, 1}, {B, 0}};
x10[A] = 1, x10[B] = 0; Assignment<string> x11{{A, 1}, {B, 1}};
x11[A] = 1, x11[B] = 1;
// empty // empty
DT empty; DT empty;
@ -241,8 +284,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
StringBoolTree f2(f1, bool_of_int); StringBoolTree f2(f1, bool_of_int);
// Check a value // Check a value
Assignment<string> x00; Assignment<string> x00 {{A, 0}, {B, 0}};
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
@ -266,10 +308,11 @@ TEST(DecisionTree, ConvertBoth) {
// Check some values // Check some values
Assignment<Label> x00, x01, x10, x11; Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0; x00 = {{X, 0}, {Y, 0}};
x01[X] = 0, x01[Y] = 1; x01 = {{X, 0}, {Y, 1}};
x10[X] = 1, x10[Y] = 0; x10 = {{X, 1}, {Y, 0}};
x11[X] = 1, x11[Y] = 1; x11 = {{X, 1}, {Y, 1}};
EXPECT(!f2(x00)); EXPECT(!f2(x00));
EXPECT(!f2(x01)); EXPECT(!f2(x01));
EXPECT(f2(x10)); EXPECT(f2(x10));

View File

@ -27,6 +27,18 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DecisionTreeFactor, constructors) TEST( DecisionTreeFactor, constructors)
{ {
@ -41,21 +53,18 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3,f3.size());
DiscreteValues values; DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
values[0] = 1; // x EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
values[1] = 2; // y EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
values[2] = 1; // z EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
// Construct from DiscreteConditional // Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4"); DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional); DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -16,23 +16,24 @@
*/ */
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesNet.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
using namespace std;
using namespace gtsam; using namespace gtsam;
static constexpr bool debug = false; static constexpr bool debug = false;
/* ************************************************************************* */ /* ************************************************************************* */
struct TestFixture { struct TestFixture {
vector<DiscreteKey> keys; DiscreteKeys keys;
std::vector<DiscreteValues> assignments;
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
std::shared_ptr<DiscreteBayesTree> bayesTree; std::shared_ptr<DiscreteBayesTree> bayesTree;
@ -47,6 +48,9 @@ struct TestFixture {
keys.push_back(key_i); keys.push_back(key_i);
} }
// Enumerate all assignments.
assignments = DiscreteValues::CartesianProduct(keys);
// Create thin-tree Bayesnet. // Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3"); bayesNet.add(keys[14] % "1/3");
@ -74,9 +78,9 @@ struct TestFixture {
}; };
/* ************************************************************************* */ /* ************************************************************************* */
// Check that BN and BT give the same answer on all configurations
TEST(DiscreteBayesTree, ThinTree) { TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self; TestFixture self;
const auto& keys = self.keys;
if (debug) { if (debug) {
GTSAM_PRINT(self.bayesNet); GTSAM_PRINT(self.bayesNet);
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
} }
auto R = self.bayesTree->roots().front(); for (const auto& x : self.assignments) {
// Check whether BN and BT give the same answer on all configurations
auto allPosbValues = DiscreteValues::CartesianProduct(
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double expected = self.bayesNet.evaluate(x); double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x); double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9); DOUBLES_EQUAL(expected, actual, 1e-9);
} }
}
// Calculate all some marginals for DiscreteValues==all1 /* ************************************************************************* */
Vector marginals = Vector::Zero(15); // Check calculation of separator marginals
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, TEST(DiscreteBayesTree, SeparatorMarginals) {
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, TestFixture self;
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; // Calculate some marginals for DiscreteValues==all1
for (size_t i = 0; i < allPosbValues.size(); ++i) { double marginal_14 = 0, joint_8_12 = 0;
DiscreteValues x = allPosbValues[i]; for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x); double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[8] && x[12]) joint_8_12 += px; if (x[8] && x[12]) joint_8_12 += px;
if (x[2]) { if (x[14]) marginal_14 += px;
if (x[8]) joint82 += px; }
if (x[1]) joint12 += px; DiscreteValues all1 = self.assignments.back();
}
if (x[4]) { // check separator marginal P(S0)
if (x[2]) joint24 += px; auto clique = (*self.bayesTree)[0];
if (x[5]) joint45 += px; DiscreteFactorGraph separatorMarginal0 =
if (x[6]) joint46 += px; clique->separatorMarginal(EliminateDiscrete);
if (x[11]) joint_4_11 += px; DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
}
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
}
/* ************************************************************************* */
// Check shortcuts in the tree
TEST(DiscreteBayesTree, Shortcuts) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
if (x[11] && x[13]) { if (x[11] && x[13]) {
joint_11_13 += px; joint_11_13 += px;
if (x[8] && x[12]) joint_8_11_12_13 += px; if (x[8] && x[12]) joint_8_11_12_13 += px;
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
} }
} }
} }
DiscreteValues all1 = allPosbValues.back(); DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0) auto R = self.bayesTree->roots().front();
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root // check shortcut P(S9||R) to root
clique = (*self.bayesTree)[9]; auto clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size()); LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
shortcut.print("shortcut:"); shortcut.print("shortcut:");
} }
} }
}
/* ************************************************************************* */
// Check all marginals
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
}
// Check all marginals // Check all marginals
DiscreteFactor::shared_ptr marginalFactor; DiscreteValues all1 = self.assignments.back();
for (size_t i = 0; i < 15; i++) { for (size_t i = 0; i < 15; i++) {
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1); double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9); DOUBLES_EQUAL(marginals[i], actual, 1e-9);
} }
}
/* ************************************************************************* */
// Check a number of joint marginals.
TEST(DiscreteBayesTree, Joints) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
}
// regression tests:
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
DiscreteValues all1 = self.assignments.back();
DiscreteBayesNet::shared_ptr actualJoint; DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2) // Check joint P(8, 2)
@ -240,8 +285,8 @@ TEST(DiscreteBayesTree, ThinTree) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) { TEST(DiscreteBayesTree, Dot) {
const TestFixture self; TestFixture self;
string actual = self.bayesTree->dot(); std::string actual = self.bayesTree->dot();
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n" "0[label=\"13, 11, 6, 7\"];\n"
@ -268,6 +313,62 @@ TEST(DiscreteBayesTree, Dot) {
"}"); "}");
} }
/* ************************************************************************* */
// Check that we can have a multi-frontal lookup table
TEST(DiscreteBayesTree, Lookup) {
using gtsam::symbol_shorthand::A;
using gtsam::symbol_shorthand::X;
// Make a small planning-like graph: 3 states, 2 actions
DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
// Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10;
std::vector<double> table{
r, 0, 0, 0, r, 0, // x1 = 0
0, r, 0, 0, 0, r, // x1 = 1
0, 0, r, 0, 0, r // x1 = 2
};
graph.add(DiscreteKeys{x1, a1, x2}, table);
graph.add(DiscreteKeys{x2, a2, x3}, table);
// eliminate for MPE (maximum probable explanation).
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
// Check that the lookup table is correct
EXPECT_LONGS_EQUAL(2, lookup->size());
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
// check that sum is 1.0 (not 100, as we now normalize)
DiscreteValues empty;
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
auto sum_x2 = lookup_a2_x3->sum(2);
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
EXPECT_DOUBLES_EQUAL(1.0, (*sum_x2)({{X(2),1}}), 1e-9);
EXPECT_DOUBLES_EQUAL(2.0, (*sum_x2)({{X(2),2}}), 1e-9);
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
// And that the non-zero rewards are for
// x2 a2 x3 == 1 1 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 0 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 1 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -597,6 +597,25 @@ TEST(Rot3, quaternion) {
EXPECT(assert_equal(expected2, actual2)); EXPECT(assert_equal(expected2, actual2));
} }
/* ************************************************************************* */
TEST(Rot3, ConvertQuaternion) {
Eigen::Quaterniond eigenQuaternion;
eigenQuaternion.w() = 1.0;
eigenQuaternion.x() = 2.0;
eigenQuaternion.y() = 3.0;
eigenQuaternion.z() = 4.0;
EXPECT_DOUBLES_EQUAL(1, eigenQuaternion.w(), 1e-9);
EXPECT_DOUBLES_EQUAL(2, eigenQuaternion.x(), 1e-9);
EXPECT_DOUBLES_EQUAL(3, eigenQuaternion.y(), 1e-9);
EXPECT_DOUBLES_EQUAL(4, eigenQuaternion.z(), 1e-9);
Rot3 R(eigenQuaternion);
EXPECT_DOUBLES_EQUAL(1, R.toQuaternion().w(), 1e-9);
EXPECT_DOUBLES_EQUAL(2, R.toQuaternion().x(), 1e-9);
EXPECT_DOUBLES_EQUAL(3, R.toQuaternion().y(), 1e-9);
EXPECT_DOUBLES_EQUAL(4, R.toQuaternion().z(), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
Matrix Cayley(const Matrix& A) { Matrix Cayley(const Matrix& A) {
Matrix::Index n = A.cols(); Matrix::Index n = A.cols();

View File

@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
}
}
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
}
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) { size_t maxNrLeaves) {
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); // Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;
std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); auto conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); discreteProbs = discreteProbs * (*conditional->asDiscrete());
// Convert pointer from conditional to factor Ordering conditional_keys(conditional->frontals());
auto discreteTree = discrete_frontals += conditional_keys;
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); discrete_factor_idxs.push_back(i);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet
this->at(i) = conditional;
} }
} }
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
return prunedDiscreteProbs;
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys DecisionTreeFactor prunedDiscreteProbs =
gttic_(HybridBayesNet_PruneDiscreteConditionals); this->pruneDiscreteConditionals(maxNrLeaves);
DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
gttic_(HybridBayesNet_UpdateDiscreteConditionals); /* To prune, we visitWith every leaf in the GaussianMixture.
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
* *

View File

@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
VectorValues optimize(const DiscreteValues &assignment) const; VectorValues optimize(const DiscreteValues &assignment) const;
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;
/** /**
* @brief Sample from an incomplete BayesNet, given missing variables. * @brief Sample from an incomplete BayesNet, given missing variables.
* *
@ -222,12 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
private: private:
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Prune all the discrete conditionals.
* *
* @param prunedDiscreteProbs * @param maxNrLeaves
*/ */
void updateDiscreteConditionals( DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
const DecisionTreeFactor &prunedDiscreteProbs);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>

View File

@ -17,7 +17,6 @@
* @date January, 2023 * @date January, 2023
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const { std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys; std::set<DiscreteKey> keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) { for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key); keys.insert(key);
} }
@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
for (const Key& key : p->continuousKeys()) { for (const Key& key : p->continuousKeys()) {
keys.insert(key); keys.insert(key);
} }
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
} }
} }
return keys; return keys;

View File

@ -48,8 +48,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// #define HYBRID_TIMING
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): in C++20, we can use std::visit. // TODO(dellaert): in C++20, we can use std::visit.
continue; continue;
} }
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &f : factors) { for (auto &f : factors) {
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) { if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(dtf); dfg.push_back(df);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here. // TODO(dellaert): is this correct? If so explain here.
@ -231,17 +229,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
#ifdef HYBRID_TIMING
gttic_(hybrid_eliminate);
#endif
gttic_(hybrid_continuous_eliminate);
auto result = EliminatePreferCholesky(graph, frontalKeys); auto result = EliminatePreferCholesky(graph, frontalKeys);
gttoc_(hybrid_continuous_eliminate);
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return result; return result;
}; };
@ -359,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// When the number of assignments is large we may encounter stack overflows. // When the number of assignments is large we may encounter stack overflows.
// However this is also the case with iSAM2, so no pressure :) // However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination // Check the factors:
// TODO(dellaert): just check the factors:
// 1. if all factors are discrete, then we can do discrete elimination: // 1. if all factors are discrete, then we can do discrete elimination:
// 2. if all factors are continuous, then we can do continuous elimination: // 2. if all factors are continuous, then we can do continuous elimination:
// 3. if not, we do hybrid elimination: // 3. if not, we do hybrid elimination:
// First, identify the separator keys, i.e. all keys that are not frontal. bool only_discrete = true, only_continuous = true;
KeySet separatorKeys;
for (auto &&factor : factors) { for (auto &&factor : factors) {
separatorKeys.insert(factor->begin(), factor->end()); if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
} if (hybrid_factor->isDiscrete()) {
// remove frontals from separator only_continuous = false;
for (auto &k : frontalKeys) { } else if (hybrid_factor->isContinuous()) {
separatorKeys.erase(k); only_discrete = false;
} } else if (hybrid_factor->isHybrid()) {
only_continuous = false;
// Build a map from keys to DiscreteKeys only_discrete = false;
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); }
} else if (auto cont_factor =
// Fill in discrete frontals and continuous frontals. std::dynamic_pointer_cast<GaussianFactor>(factor)) {
std::set<DiscreteKey> discreteFrontals; only_discrete = false;
KeySet continuousFrontals; } else if (auto discrete_factor =
for (auto &k : frontalKeys) { std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { only_continuous = false;
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousFrontals.insert(k);
} }
} }
// Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet;
KeyVector continuousSeparator;
for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.push_back(k);
}
}
// Check if we have any continuous keys:
const bool discrete_only =
continuousFrontals.empty() && continuousSeparator.empty();
// NOTE: We should really defer the product here because of pruning // NOTE: We should really defer the product here because of pruning
if (discrete_only) { if (only_discrete) {
// Case 1: we are only dealing with discrete // Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys); return discreteElimination(factors, frontalKeys);
} else if (mapFromKeyToDiscreteKey.empty()) { } else if (only_continuous) {
// Case 2: we are only dealing with continuous // Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys); return continuousElimination(factors, frontalKeys);
} else { } else {
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
// Find all the keys in the set of continuous keys
// which are not in the frontal keys. This is our continuous separator.
KeyVector continuousSeparator;
auto continuousKeySet = factors.continuousKeySet();
std::set_difference(
continuousKeySet.begin(), continuousKeySet.end(),
frontalKeysSet.begin(), frontalKeysSet.end(),
std::inserter(continuousSeparator, continuousSeparator.begin()));
// Similarly for the discrete separator.
KeySet discreteSeparatorSet;
std::set<DiscreteKey> discreteSeparator;
auto discreteKeySet = factors.discreteKeySet();
std::set_difference(
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
frontalKeysSet.end(),
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
// Convert from set of keys to set of DiscreteKeys
auto discreteKeyMap = factors.discreteKeyMap();
for (auto key : discreteSeparatorSet) {
discreteSeparator.insert(discreteKeyMap.at(key));
}
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet); discreteSeparator);
} }
} }
@ -440,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Add the gaussian factor error to every leaf of the error tree. // Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else { } else {

View File

@ -40,6 +40,7 @@ class HybridEliminationTree;
class HybridBayesTree; class HybridBayesTree;
class HybridJunctionTree; class HybridJunctionTree;
class DecisionTreeFactor; class DecisionTreeFactor;
class TableFactor;
class JacobianFactor; class JacobianFactor;
class HybridValues; class HybridValues;

View File

@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
Data rootData(0); Data rootData(0);
rootData.junctionTreeNode = rootData.junctionTreeNode =
std::make_shared<typename Base::Node>(); // Make a dummy node to gather std::make_shared<typename Base::Node>(); // Make a dummy node to gather
// the junction tree roots // the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPost); Data::ConstructorTraversalVisitorPost);

View File

@ -17,6 +17,7 @@
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) { } else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
linearFG->push_back(gf); linearFG->push_back(gf);
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization. // If discrete-only: doesn't need linearization.
linearFG->push_back(f); linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {

View File

@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
addConditionals(graph, hybridBayesNet_, ordering); addConditionals(graph, hybridBayesNet_, ordering);
// Eliminate. // Eliminate.
auto bayesNetFragment = graph.eliminateSequential(ordering); HybridBayesNet::shared_ptr bayesNetFragment =
graph.eliminateSequential(ordering);
/// Prune /// Prune
if (maxNrLeaves) { if (maxNrLeaves) {
@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
HybridGaussianFactorGraph graph(originalGraph); HybridGaussianFactorGraph graph(originalGraph);
HybridBayesNet hybridBayesNet(originalHybridBayesNet); HybridBayesNet hybridBayesNet(originalHybridBayesNet);
// If we are not at the first iteration, means we have conditionals to add. // If hybridBayesNet is not empty,
// it means we have conditionals to add to the factor graph.
if (!hybridBayesNet.empty()) { if (!hybridBayesNet.empty()) {
// We add all relevant conditional mixtures on the last continuous variable // We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph // in the previous `hybridBayesNet` to the graph

View File

@ -35,14 +35,11 @@ class HybridValues {
}; };
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
virtual class HybridFactor { virtual class HybridFactor : gtsam::Factor {
void print(string s = "HybridFactor\n", void print(string s = "HybridFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const; bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
// Standard interface: // Standard interface:
double error(const gtsam::HybridValues &values) const; double error(const gtsam::HybridValues &values) const;

View File

@ -202,31 +202,16 @@ struct Switching {
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
* E.g. if K=4, we want M0, M1 and M2. * E.g. if K=4, we want M0, M1 and M2.
* *
* @param fg The nonlinear factor graph to which the mode chain is added. * @param fg The factor graph to which the mode chain is added.
*/ */
void addModeChain(HybridNonlinearFactorGraph *fg, template <typename FACTORGRAPH>
void addModeChain(FACTORGRAPH *fg,
std::string discrete_transition_prob = "1/2 3/2") { std::string discrete_transition_prob = "1/2 3/2") {
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1"); fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
for (size_t k = 0; k < K - 2; k++) { for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]}; auto parents = {modes[k]};
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents, fg->template emplace_shared<DiscreteConditional>(
discrete_transition_prob); modes[k + 1], parents, discrete_transition_prob);
}
}
/**
* @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
* E.g. if K=4, we want M0, M1 and M2.
*
* @param fg The gaussian factor graph to which the mode chain is added.
*/
void addModeChain(HybridGaussianFactorGraph *fg,
std::string discrete_transition_prob = "1/2 3/2") {
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]};
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
discrete_transition_prob);
} }
} }
}; };

View File

@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
std::string expected = std::string expected =
R"(Hybrid [x1 x2; 1]{ R"(Hybrid [x1 x2; 1]{
Choice(1) Choice(1)
0 Leaf : 0 Leaf [1] :
A[x1] = [ A[x1] = [
0; 0;
0 0
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
b = [ 0 0 ] b = [ 0 0 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x1] = [ A[x1] = [
0; 0;
0 0

View File

@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
// Regression test on pruned logProbability tree // Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves); AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
logProbability += logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues); posterior->at(4)->asDiscrete()->logProbability(hybridValues);
// Regression
double density = exp(logProbability); double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density,
1.6078460548731697 * actualTree(discrete_values), 1e-6);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9); 1e-9);
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(7, posterior->size()); EXPECT_LONGS_EQUAL(7, posterior->size());
size_t maxNrLeaves = 3; size_t maxNrLeaves = 3;
auto discreteConditionals = posterior->discreteConditionals(); DiscreteConditional discreteConditionals;
for (auto&& conditional : *posterior) {
if (conditional->isDiscrete()) {
discreteConditionals =
discreteConditionals * (*conditional->asDiscrete());
}
}
const DecisionTreeFactor::shared_ptr prunedDecisionTree = const DecisionTreeFactor::shared_ptr prunedDecisionTree =
std::make_shared<DecisionTreeFactor>( std::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves)); discreteConditionals.prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves()); prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); // regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
// Prune! // Prune!
posterior->prune(maxNrLeaves); posterior->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals // Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment, auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double { double probability) -> double {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
if (prunedDecisionTree->operator()(choices) == 0) { if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else { } else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
1e-9); 1e-9);
} }
return 0.0; return 0.0;

View File

@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) { for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f); auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor); assert(discreteFactor);
dfg.push_back(discreteFactor); dfg.push_back(discreteFactor);
} }

View File

@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
EXPECT(assert_equal(expected_continuous, result)); EXPECT(assert_equal(expected_continuous, result));
} }
/****************************************************************************/
// Test approximate inference with an additional pruning step.
TEST(HybridEstimation, ISAM) {
size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
// Ground truth discrete seq
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
// Switching example of robot moving in 1D
// with given measurements and equal mode priors.
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
HybridNonlinearISAM isam;
HybridNonlinearFactorGraph graph;
Values initial;
// gttic_(Estimation);
// Add the X(0) prior
graph.push_back(switching.nonlinearFactorGraph.at(0));
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
HybridGaussianFactorGraph linearized;
for (size_t k = 1; k < K; k++) {
// Motion Model
graph.push_back(switching.nonlinearFactorGraph.at(k));
// Measurement
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
isam.update(graph, initial, 3);
// isam.bayesTree().print("\n\n");
graph.resize(0);
initial.clear();
}
Values result = isam.estimate();
DiscreteValues assignment = isam.assignment();
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, assignment));
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/** /**
* @brief A function to get a specific 1D robot motion problem as a linearized * @brief A function to get a specific 1D robot motion problem as a linearized
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous * factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous

View File

@ -18,7 +18,9 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/nonlinear/PriorFactor.h> #include <gtsam/nonlinear/PriorFactor.h>
using namespace std; using namespace std;
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
HybridFactorGraph fg; HybridFactorGraph fg;
} }
/* ************************************************************************* */
// Test if methods to get keys work as expected.
TEST(HybridFactorGraph, Keys) {
HybridGaussianFactorGraph hfg;
// Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
// Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
// Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
KeySet expected_continuous{X(0), X(1)};
EXPECT(
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
KeySet expected_discrete{M(1)};
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Test resulting posterior Bayes net has correct size: // Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size()); EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): this test fails - no idea why. // Ratio test
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
} }

View File

@ -492,7 +492,7 @@ factor 0:
factor 1: factor 1:
Hybrid [x0 x1; m0]{ Hybrid [x0 x1; m0]{
Choice(m0) Choice(m0)
0 Leaf : 0 Leaf [1] :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
factor 2: factor 2:
Hybrid [x1 x2; m1]{ Hybrid [x1 x2; m1]{
Choice(m1) Choice(m1)
0 Leaf : 0 Leaf [1] :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -550,16 +550,16 @@ factor 4:
b = [ -10 ] b = [ -10 ]
No noise model No noise model
factor 5: P( m0 ): factor 5: P( m0 ):
Leaf 0.5 Leaf [2] 0.5
factor 6: P( m1 | m0 ): factor 6: P( m1 | m0 ):
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf 0.33333333 0 0 Leaf [1] 0.33333333
0 1 Leaf 0.6 0 1 Leaf [1] 0.6
1 Choice(m0) 1 Choice(m0)
1 0 Leaf 0.66666667 1 0 Leaf [1] 0.66666667
1 1 Leaf 0.4 1 1 Leaf [1] 0.4
)"; )";
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
@ -570,13 +570,13 @@ size: 3
conditional 0: Hybrid P( x0 | x1 m0) conditional 0: Hybrid P( x0 | x1 m0)
Discrete Keys = (m0, 2), Discrete Keys = (m0, 2),
Choice(m0) Choice(m0)
0 Leaf p(x0 | x1) 0 Leaf [1] p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.85087 ] d = [ -9.85087 ]
No noise model No noise model
1 Leaf p(x0 | x1) 1 Leaf [1] p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.95037 ] d = [ -9.95037 ]
@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x1 | x2) 0 0 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.99901 ] d = [ -9.99901 ]
No noise model No noise model
0 1 Leaf p(x1 | x2) 0 1 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.90098 ] d = [ -9.90098 ]
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x1 | x2) 1 0 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10.098 ] d = [ -10.098 ]
No noise model No noise model
1 1 Leaf p(x1 | x2) 1 1 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10 ] d = [ -10 ]
@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x2) 0 0 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1489 ] d = [ -10.1489 ]
mean: 1 elements mean: 1 elements
x2: -1.0099 x2: -1.0099
No noise model No noise model
0 1 Leaf p(x2) 0 1 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1479 ] d = [ -10.1479 ]
mean: 1 elements mean: 1 elements
@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x2) 1 0 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0504 ] d = [ -10.0504 ]
mean: 1 elements mean: 1 elements
x2: -1.0001 x2: -1.0001
No noise model No noise model
1 1 Leaf p(x2) 1 1 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0494 ] d = [ -10.0494 ]
mean: 1 elements mean: 1 elements

View File

@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
R"(Hybrid [x1 x2; 1] R"(Hybrid [x1 x2; 1]
MixtureFactor MixtureFactor
Choice(1) Choice(1)
0 Leaf Nonlinear factor on 2 keys 0 Leaf [1] Nonlinear factor on 2 keys
1 Leaf Nonlinear factor on 2 keys 1 Leaf [1] Nonlinear factor on 2 keys
)"; )";
EXPECT(assert_print_equal(expected, mixtureFactor)); EXPECT(assert_print_equal(expected, mixtureFactor));
} }

View File

@ -140,9 +140,15 @@ namespace gtsam {
/** Access the conditional */ /** Access the conditional */
const sharedConditional& conditional() const { return conditional_; } const sharedConditional& conditional() const { return conditional_; }
/** is this the root of a Bayes tree ? */ /// Return true if this clique is the root of a Bayes tree.
inline bool isRoot() const { return parent_.expired(); } inline bool isRoot() const { return parent_.expired(); }
/// Return the number of children.
size_t nrChildren() const { return children.size(); }
/// Return the child at index i.
const derived_ptr operator[](size_t i) const { return children.at(i); }
/** The size of subtree rooted at this clique, i.e., nr of Cliques */ /** The size of subtree rooted at this clique, i.e., nr of Cliques */
size_t treeSize() const; size_t treeSize() const;

View File

@ -49,7 +49,7 @@ class ClusterTree {
virtual ~Cluster() {} virtual ~Cluster() {}
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(children[i]); return *(children.at(i));
} }
/// Construct from factors associated with a single key /// Construct from factors associated with a single key
@ -161,7 +161,7 @@ class ClusterTree {
} }
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(roots_[i]); return *(roots_.at(i));
} }
/// @} /// @}

View File

@ -74,8 +74,9 @@ namespace gtsam {
EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering); EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering);
const auto [bayesNet, factorGraph] = etree.eliminate(function); const auto [bayesNet, factorGraph] = etree.eliminate(function);
// If any factors are remaining, the ordering was incomplete // If any factors are remaining, the ordering was incomplete
if(!factorGraph->empty()) if(!factorGraph->empty()) {
throw InconsistentEliminationRequested(); throw InconsistentEliminationRequested(factorGraph->keys());
}
// Return the Bayes net // Return the Bayes net
return bayesNet; return bayesNet;
} }
@ -136,8 +137,9 @@ namespace gtsam {
JunctionTreeType junctionTree(etree); JunctionTreeType junctionTree(etree);
const auto [bayesTree, factorGraph] = junctionTree.eliminate(function); const auto [bayesTree, factorGraph] = junctionTree.eliminate(function);
// If any factors are remaining, the ordering was incomplete // If any factors are remaining, the ordering was incomplete
if(!factorGraph->empty()) if(!factorGraph->empty()) {
throw InconsistentEliminationRequested(); throw InconsistentEliminationRequested(factorGraph->keys());
}
// Return the Bayes tree // Return the Bayes tree
return bayesTree; return bayesTree;
} }

View File

@ -51,12 +51,12 @@ namespace gtsam {
* algorithms. Any factor graph holding eliminateable factors can derive from this class to * algorithms. Any factor graph holding eliminateable factors can derive from this class to
* expose functions for computing marginals, conditional marginals, doing multifrontal and * expose functions for computing marginals, conditional marginals, doing multifrontal and
* sequential elimination, etc. */ * sequential elimination, etc. */
template<class FACTORGRAPH> template<class FACTOR_GRAPH>
class EliminateableFactorGraph class EliminateableFactorGraph
{ {
private: private:
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class. typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
// Base factor type stored in this graph (private because derived classes will get this from // Base factor type stored in this graph (private because derived classes will get this from
// their FactorGraph base class) // their FactorGraph base class)
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType; typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;

View File

@ -29,10 +29,6 @@
#include <Eigen/Core> // for Eigen::aligned_allocator #include <Eigen/Core> // for Eigen::aligned_allocator
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
#include <boost/serialization/nvp.hpp> #include <boost/serialization/nvp.hpp>
#include <boost/serialization/vector.hpp> #include <boost/serialization/vector.hpp>
@ -53,45 +49,6 @@ class BayesTree;
class HybridValues; class HybridValues;
/** Helper */
template <class C>
class CRefCallPushBack {
C& obj;
public:
explicit CRefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class RefCallPushBack {
C& obj;
public:
explicit RefCallPushBack(C& obj) : obj(obj) {}
template <typename A>
void operator()(A& a) {
obj.push_back(a);
}
};
/** Helper */
template <class C>
class CRefCallAddCopy {
C& obj;
public:
explicit CRefCallAddCopy(C& obj) : obj(obj) {}
template <typename A>
void operator()(const A& a) {
obj.addCopy(a);
}
};
/** /**
* A factor graph is a bipartite graph with factor nodes connected to variable * A factor graph is a bipartite graph with factor nodes connected to variable
* nodes. In this class, however, only factor nodes are kept around. * nodes. In this class, however, only factor nodes are kept around.
@ -215,17 +172,26 @@ class FactorGraph {
push_back(factor); push_back(factor);
} }
#ifdef GTSAM_USE_BOOST_FEATURES /// Append factor to factor graph
/// `+=` works well with boost::assign list inserter.
template <class DERIVEDFACTOR> template <class DERIVEDFACTOR>
typename std::enable_if< typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value,
std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type&
boost::assign::list_inserter<RefCallPushBack<This>>>::type
operator+=(std::shared_ptr<DERIVEDFACTOR> factor) { operator+=(std::shared_ptr<DERIVEDFACTOR> factor) {
return boost::assign::make_list_inserter(RefCallPushBack<This>(*this))( push_back(factor);
factor); return *this;
}
/**
* @brief Overload comma operator to allow for append chaining.
*
* E.g. fg += factor1, factor2, ...
*/
template <class DERIVEDFACTOR>
typename std::enable_if<std::is_base_of<FactorType, DERIVEDFACTOR>::value, This>::type& operator,(
std::shared_ptr<DERIVEDFACTOR> factor) {
push_back(factor);
return *this;
} }
#endif
/// @} /// @}
/// @name Adding via iterators /// @name Adding via iterators
@ -276,18 +242,15 @@ class FactorGraph {
push_back(factorOrContainer); push_back(factorOrContainer);
} }
#ifdef GTSAM_USE_BOOST_FEATURES
/** /**
* Add a factor or container of factors, including STL collections, * Add a factor or container of factors, including STL collections,
* BayesTrees, etc. * BayesTrees, etc.
*/ */
template <class FACTOR_OR_CONTAINER> template <class FACTOR_OR_CONTAINER>
boost::assign::list_inserter<CRefCallPushBack<This>> operator+=( This& operator+=(const FACTOR_OR_CONTAINER& factorOrContainer) {
const FACTOR_OR_CONTAINER& factorOrContainer) { push_back(factorOrContainer);
return boost::assign::make_list_inserter(CRefCallPushBack<This>(*this))( return *this;
factorOrContainer);
} }
#endif
/// @} /// @}
/// @name Specialized versions /// @name Specialized versions

View File

@ -281,6 +281,18 @@ void Ordering::print(const std::string& str,
cout.flush(); cout.flush();
} }
/* ************************************************************************* */
Ordering::This& Ordering::operator+=(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */
Ordering::This& Ordering::operator,(Key key) {
this->push_back(key);
return *this;
}
/* ************************************************************************* */ /* ************************************************************************* */
Ordering::This& Ordering::operator+=(KeyVector& keys) { Ordering::This& Ordering::operator+=(KeyVector& keys) {
this->insert(this->end(), keys.begin(), keys.end()); this->insert(this->end(), keys.begin(), keys.end());

View File

@ -25,10 +25,6 @@
#include <gtsam/inference/MetisIndex.h> #include <gtsam/inference/MetisIndex.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#ifdef GTSAM_USE_BOOST_FEATURES
#include <boost/assign/list_inserter.hpp>
#endif
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
@ -61,15 +57,13 @@ public:
Base(keys.begin(), keys.end()) { Base(keys.begin(), keys.end()) {
} }
#ifdef GTSAM_USE_BOOST_FEATURES /// Add new variables to the ordering as
/// Add new variables to the ordering as ordering += key1, key2, ... Equivalent to calling /// `ordering += key1, key2, ...`.
/// push_back. This& operator+=(Key key);
boost::assign::list_inserter<boost::assign_detail::call_push_back<This> > operator+=(
Key key) { /// Overloading the comma operator allows for chaining appends
return boost::assign::make_list_inserter( // e.g. keys += key1, key2
boost::assign_detail::call_push_back<This>(*this))(key); This& operator,(Key key);
}
#endif
/** /**
* @brief Append new keys to the ordering as `ordering += keys`. * @brief Append new keys to the ordering as `ordering += keys`.

View File

@ -104,6 +104,7 @@ class Ordering {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
Ordering(); Ordering();
Ordering(const gtsam::Ordering& other); Ordering(const gtsam::Ordering& other);
Ordering(const std::vector<size_t>& keys);
template < template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
@ -148,7 +149,7 @@ class Ordering {
// Standard interface // Standard interface
size_t size() const; size_t size() const;
size_t at(size_t key) const; size_t at(size_t i) const;
void push_back(size_t key); void push_back(size_t key);
// enabling serialization functionality // enabling serialization functionality
@ -194,4 +195,15 @@ class VariableIndex {
size_t nEntries() const; size_t nEntries() const;
}; };
#include <gtsam/inference/Factor.h>
virtual class Factor {
void print(string s = "Factor\n", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s = "") const;
bool equals(const gtsam::Factor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,60 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file inferenceExceptions.cpp
* @brief Exceptions that may be thrown by inference algorithms
* @author Richard Roberts, Varun Agrawal
* @date Apr 25, 2013
*/
#include <gtsam/inference/inferenceExceptions.h>
#include <sstream>
namespace gtsam {
InconsistentEliminationRequested::InconsistentEliminationRequested(
const KeySet& keys, const KeyFormatter& key_formatter)
: keys_(keys.begin(), keys.end()), keyFormatter(key_formatter) {}
const char* InconsistentEliminationRequested::what() const noexcept {
// Format keys for printing
std::stringstream sstr;
size_t nrKeysToDisplay = std::min(size_t(4), keys_.size());
for (size_t i = 0; i < nrKeysToDisplay; i++) {
sstr << keyFormatter(keys_.at(i));
if (i < nrKeysToDisplay - 1) {
sstr << ", ";
}
}
if (keys_.size() > nrKeysToDisplay) {
sstr << ", ... (total " << keys_.size() << " keys)";
}
sstr << ".";
std::string keys = sstr.str();
std::string msg =
"An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n";
msg += ("Leftover keys after elimination: " + keys);
// `new` to allocate memory on heap instead of stack
return (new std::string(msg))->c_str();
}
} // namespace gtsam

View File

@ -12,30 +12,35 @@
/** /**
* @file inferenceExceptions.h * @file inferenceExceptions.h
* @brief Exceptions that may be thrown by inference algorithms * @brief Exceptions that may be thrown by inference algorithms
* @author Richard Roberts * @author Richard Roberts, Varun Agrawal
* @date Apr 25, 2013 * @date Apr 25, 2013
*/ */
#pragma once #pragma once
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h>
#include <exception> #include <exception>
namespace gtsam { namespace gtsam {
/** An inference algorithm was called with inconsistent arguments. The factor graph, ordering, or /** An inference algorithm was called with inconsistent arguments. The factor
* variable index were inconsistent with each other, or a full elimination routine was called * graph, ordering, or variable index were inconsistent with each other, or a
* with an ordering that does not include all of the variables. */ * full elimination routine was called with an ordering that does not include
class InconsistentEliminationRequested : public std::exception { * all of the variables. */
public: class InconsistentEliminationRequested : public std::exception {
InconsistentEliminationRequested() noexcept {} KeyVector keys_;
~InconsistentEliminationRequested() noexcept override {} const KeyFormatter& keyFormatter = DefaultKeyFormatter;
const char* what() const noexcept override {
return
"An inference algorithm was called with inconsistent arguments. The\n"
"factor graph, ordering, or variable index were inconsistent with each\n"
"other, or a full elimination routine was called with an ordering that\n"
"does not include all of the variables.";
}
};
} public:
InconsistentEliminationRequested() noexcept {}
InconsistentEliminationRequested(
const KeySet& keys,
const KeyFormatter& key_formatter = DefaultKeyFormatter);
~InconsistentEliminationRequested() noexcept override {}
const char* what() const noexcept override;
};
} // namespace gtsam

View File

@ -196,6 +196,20 @@ TEST(Ordering, csr_format_3) {
EXPECT(adjExpected == adjAcutal); EXPECT(adjExpected == adjAcutal);
} }
/* ************************************************************************* */
TEST(Ordering, AppendKey) {
using symbol_shorthand::X;
Ordering actual;
actual += X(0);
Ordering expected1{X(0)};
EXPECT(assert_equal(expected1, actual));
actual += X(1), X(2), X(3);
Ordering expected2{X(0), X(1), X(2), X(3)};
EXPECT(assert_equal(expected2, actual));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(Ordering, AppendVector) { TEST(Ordering, AppendVector) {
using symbol_shorthand::X; using symbol_shorthand::X;

View File

@ -99,7 +99,7 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const { void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
cout << s << " p("; cout << (s.empty() ? "" : s + " ") << "p(";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << (nrFrontals() > 1 ? " " : ""); cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
} }

View File

@ -261,8 +261,7 @@ class VectorValues {
}; };
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
virtual class GaussianFactor { virtual class GaussianFactor : gtsam::Factor {
gtsam::KeyVector keys() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
@ -273,8 +272,6 @@ virtual class GaussianFactor {
Matrix information() const; Matrix information() const;
Matrix augmentedJacobian() const; Matrix augmentedJacobian() const;
pair<Matrix, Vector> jacobian() const; pair<Matrix, Vector> jacobian() const;
size_t size() const;
bool empty() const;
}; };
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
@ -301,10 +298,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor {
//Testable //Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
gtsam::KeyVector& keys() const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
size_t size() const;
Vector unweighted_error(const gtsam::VectorValues& c) const; Vector unweighted_error(const gtsam::VectorValues& c) const;
Vector error_vector(const gtsam::VectorValues& c) const; Vector error_vector(const gtsam::VectorValues& c) const;
double error(const gtsam::VectorValues& c) const; double error(const gtsam::VectorValues& c) const;
@ -346,10 +340,8 @@ virtual class HessianFactor : gtsam::GaussianFactor {
HessianFactor(const gtsam::GaussianFactorGraph& factors); HessianFactor(const gtsam::GaussianFactorGraph& factors);
//Testable //Testable
size_t size() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
double error(const gtsam::VectorValues& c) const; double error(const gtsam::VectorValues& c) const;

View File

@ -21,6 +21,7 @@
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/VariableSlots.h> #include <gtsam/inference/VariableSlots.h>
#include <gtsam/inference/VariableIndex.h> #include <gtsam/inference/VariableIndex.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
@ -70,6 +71,28 @@ TEST(GaussianFactorGraph, initialization) {
EQUALITY(expectedIJS, actualIJS); EQUALITY(expectedIJS, actualIJS);
} }
/* ************************************************************************* */
TEST(GaussianFactorGraph, Append) {
// Create empty graph
GaussianFactorGraph fg;
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
auto f1 =
make_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2), unit2);
auto f2 = make_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
Vector2(2.0, -1.0), unit2);
auto f3 = make_shared<JacobianFactor>(0, -5 * I_2x2, 2, 5 * I_2x2,
Vector2(0.0, 1.0), unit2);
fg += f1;
fg += f2;
EXPECT_LONGS_EQUAL(2, fg.size());
fg = GaussianFactorGraph();
fg += f1, f2, f3;
EXPECT_LONGS_EQUAL(3, fg.size());
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(GaussianFactorGraph, sparseJacobian) { TEST(GaussianFactorGraph, sparseJacobian) {
// Create factor graph: // Create factor graph:
@ -435,6 +458,64 @@ TEST(GaussianFactorGraph, ProbPrime) {
EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12);
} }
TEST(GaussianFactorGraph, InconsistentEliminationMessage) {
// Create empty graph
GaussianFactorGraph fg;
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
using gtsam::symbol_shorthand::X;
fg.emplace_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2),
unit2);
fg.emplace_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
Vector2(2.0, -1.0), unit2);
fg.emplace_shared<JacobianFactor>(1, -5 * I_2x2, 2, 5 * I_2x2,
Vector2(-1.0, 1.5), unit2);
fg.emplace_shared<JacobianFactor>(2, -5 * I_2x2, X(3), 5 * I_2x2,
Vector2(-1.0, 1.5), unit2);
Ordering ordering{0, 1};
try {
fg.eliminateSequential(ordering);
} catch (const exception& exc) {
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n"
"Leftover keys after elimination: 2, x3.";
EXPECT(expected_exception_message == exc.what());
}
// Test large number of keys
fg = GaussianFactorGraph();
for (size_t i = 0; i < 1000; i++) {
fg.emplace_shared<JacobianFactor>(i, -I_2x2, i + 1, I_2x2,
Vector2(2.0, -1.0), unit2);
}
try {
fg.eliminateSequential(ordering);
} catch (const exception& exc) {
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n"
"Leftover keys after elimination: 2, 3, 4, 5, ... (total 999 keys).";
EXPECT(expected_exception_message == exc.what());
}
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -109,13 +109,10 @@ class NonlinearFactorGraph {
}; };
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>
virtual class NonlinearFactor { virtual class NonlinearFactor : gtsam::Factor {
// Factor base class // Factor base class
size_t size() const;
gtsam::KeyVector keys() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
// NonlinearFactor // NonlinearFactor
bool equals(const gtsam::NonlinearFactor& other, double tol) const; bool equals(const gtsam::NonlinearFactor& other, double tol) const;
double error(const gtsam::Values& c) const; double error(const gtsam::Values& c) const;

View File

@ -894,6 +894,9 @@ template <size_t d>
std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate, std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate,
size_t pMin, size_t pMin,
size_t pMax) const { size_t pMax) const {
if (pMin < d) {
throw std::runtime_error("pMin is smaller than the base dimension d");
}
Values Qstar; Values Qstar;
Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin! Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin!
for (size_t p = pMin; p <= pMax; p++) { for (size_t p = pMin; p <= pMax; p++) {

View File

@ -415,6 +415,20 @@ TEST(ShonanAveraging3, PriorWeights) {
auto result = shonan.run(initial, 3, 3); auto result = shonan.run(initial, 3, 3);
EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4); EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4);
} }
/* ************************************************************************* */
// Check a small graph created using binary measurements
TEST(ShonanAveraging3, BinaryMeasurements) {
std::vector<BinaryMeasurement<Rot3>> measurements;
auto unit3 = noiseModel::Unit::Create(3);
measurements.emplace_back(0, 1, Rot3::Yaw(M_PI_2), unit3);
measurements.emplace_back(1, 2, Rot3::Yaw(M_PI_2), unit3);
ShonanAveraging3 shonan(measurements);
Values initial = shonan.initializeRandomly();
auto result = shonan.run(initial, 3, 5);
EXPECT_DOUBLES_EQUAL(0.0, shonan.cost(result.first), 1e-4);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -65,4 +65,6 @@ namespace gtsam {
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree); SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using SymbolicCluster = SymbolicJunctionTree::Cluster;
} }

View File

@ -4,7 +4,7 @@
namespace gtsam { namespace gtsam {
#include <gtsam/symbolic/SymbolicFactor.h> #include <gtsam/symbolic/SymbolicFactor.h>
virtual class SymbolicFactor { virtual class SymbolicFactor : gtsam::Factor {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
SymbolicFactor(const gtsam::SymbolicFactor& f); SymbolicFactor(const gtsam::SymbolicFactor& f);
SymbolicFactor(); SymbolicFactor();
@ -18,12 +18,10 @@ virtual class SymbolicFactor {
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js); static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
// From Factor // From Factor
size_t size() const;
void print(string s = "SymbolicFactor", void print(string s = "SymbolicFactor",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicFactor& other, double tol) const; bool equals(const gtsam::SymbolicFactor& other, double tol) const;
gtsam::KeyVector keys();
}; };
#include <gtsam/symbolic/SymbolicFactorGraph.h> #include <gtsam/symbolic/SymbolicFactorGraph.h>
@ -139,7 +137,60 @@ class SymbolicBayesNet {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const; const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicEliminationTree.h>
class SymbolicEliminationTree {
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/symbolic/SymbolicJunctionTree.h>
class SymbolicCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::SymbolicFactorGraph factors;
const gtsam::SymbolicCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class SymbolicJunctionTree {
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::SymbolicCluster& operator[](size_t i) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h> #include <gtsam/symbolic/SymbolicBayesTree.h>
class SymbolicBayesTreeClique {
SymbolicBayesTreeClique();
SymbolicBayesTreeClique(const gtsam::SymbolicConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter);
const gtsam::SymbolicConditional* conditional() const;
bool isRoot() const;
gtsam::SymbolicBayesTreeClique* parent() const;
size_t treeSize() const;
size_t numCachedSeparatorMarginals() const;
void deleteCachedShortcuts();
};
class SymbolicBayesTree { class SymbolicBayesTree {
// Constructors // Constructors
SymbolicBayesTree(); SymbolicBayesTree();
@ -151,9 +202,14 @@ class SymbolicBayesTree {
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const; bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
// Standard Interface // Standard Interface
// size_t findParentClique(const gtsam::IndexVector& parents) const; bool empty() const;
size_t size(); size_t size() const;
void saveGraph(string s) const;
const gtsam::SymbolicBayesTreeClique* operator[](size_t j) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void clear(); void clear();
void deleteCachedShortcuts(); void deleteCachedShortcuts();
size_t numCachedSeparatorMarginals() const; size_t numCachedSeparatorMarginals() const;
@ -161,28 +217,9 @@ class SymbolicBayesTree {
gtsam::SymbolicConditional* marginalFactor(size_t key) const; gtsam::SymbolicConditional* marginalFactor(size_t key) const;
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const; gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const; gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
};
class SymbolicBayesTreeClique { string dot(const gtsam::KeyFormatter& keyFormatter =
SymbolicBayesTreeClique(); gtsam::DefaultKeyFormatter) const;
// SymbolicBayesTreeClique(gtsam::sharedConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
size_t numCachedSeparatorMarginals() const;
// gtsam::sharedConditional* conditional() const;
bool isRoot() const;
size_t treeSize() const;
gtsam::SymbolicBayesTreeClique* parent() const;
// // TODO: need wrapped versions graphs, BayesNet
// BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function)
// const; FactorGraph<FactorType> marginal(derived_ptr root, Eliminate
// function) const; FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr
// root, Eliminate function) const;
//
void deleteCachedShortcuts();
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -181,7 +181,7 @@ TEST(QPSolver, iterate) {
QPSolver::State state(currentSolution, VectorValues(), workingSet, false, QPSolver::State state(currentSolution, VectorValues(), workingSet, false,
100); 100);
int it = 0; // int it = 0;
while (!state.converged) { while (!state.converged) {
state = solver.iterate(state); state = solver.iterate(state);
// These checks will fail because the expected solutions obtained from // These checks will fail because the expected solutions obtained from
@ -190,7 +190,7 @@ TEST(QPSolver, iterate) {
// do not recompute dual variables after every step!!! // do not recompute dual variables after every step!!!
// CHECK(assert_equal(expected[it], state.values, 1e-10)); // CHECK(assert_equal(expected[it], state.values, 1e-10));
// CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10)); // CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10));
it++; // it++;
} }
CHECK(assert_equal(expected[3], state.values, 1e-10)); CHECK(assert_equal(expected[3], state.values, 1e-10));

View File

@ -26,7 +26,13 @@ class TestDecisionTreeFactor(GtsamTestCase):
self.B = (5, 2) self.B = (5, 2)
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
def test_from_floats(self):
"""Test whether we can construct a factor from floats."""
actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
self.gtsamAssertEquals(actual, self.factor)
def test_enumerate(self): def test_enumerate(self):
"""Test whether we can enumerate the factor."""
actual = self.factor.enumerate() actual = self.factor.enumerate()
_, values = zip(*actual) _, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

View File

@ -13,10 +13,15 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, import numpy as np
DiscreteConditional, DiscreteFactorGraph, Ordering) from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteValues, Ordering)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -27,7 +32,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
# Define DiscreteKey pairs. # Define DiscreteKey pairs.
keys = [(j, 2) for j in range(15)] keys = [(j, 2) for j in range(15)]
# Create thin-tree Bayesnet. # Create thin-tree Bayes net.
bayesNet = DiscreteBayesNet() bayesNet = DiscreteBayesNet()
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
@ -65,15 +70,91 @@ class TestDiscreteBayesNet(GtsamTestCase):
# bayesTree[key].printSignature() # bayesTree[key].printSignature()
# bayesTree.saveGraph("test_DiscreteBayesTree.dot") # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
# The root is P( 8 12 14), we can retrieve it by key: # The root is P( 8 12 14), we can retrieve it by key:
root = bayesTree[8] root = bayesTree[8]
self.assertIsInstance(root, DiscreteBayesTreeClique) self.assertIsInstance(root, DiscreteBayesTreeClique)
self.assertTrue(root.isRoot()) self.assertTrue(root.isRoot())
self.assertIsInstance(root.conditional(), DiscreteConditional) self.assertIsInstance(root.conditional(), DiscreteConditional)
# Test all methods in DiscreteBayesTree
self.gtsamAssertEquals(bayesTree, bayesTree)
# Check value at 0
zero_values = DiscreteValues()
for j in range(15):
zero_values[j] = 0
value_at_zeros = bayesTree.evaluate(zero_values)
self.assertAlmostEqual(value_at_zeros, 0.0)
# Check value at max
values_star = factorGraph.optimize()
max_value = bayesTree.evaluate(values_star)
self.assertAlmostEqual(max_value, 0.002548)
# Check operator sugar
max_value = bayesTree(values_star)
self.assertAlmostEqual(max_value, 0.002548)
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
def test_discrete_bayes_tree_lookup(self):
"""Check that we can have a multi-frontal lookup table."""
# Make a small planning-like graph: 3 states, 2 actions
graph = DiscreteFactorGraph()
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
a1, a2 = (A(1), 2), (A(2), 2)
# Constraint on start and goal
graph.add([x1], np.array([1, 0, 0]))
graph.add([x3], np.array([0, 0, 1]))
# Should I stay or should I go?
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
r = 10
table = np.array([
r, 0, 0, 0, r, 0, # x1 = 0
0, r, 0, 0, 0, r, # x1 = 1
0, 0, r, 0, 0, r # x1 = 2
])
graph.add([x1, a1, x2], table)
graph.add([x2, a2, x3], table)
# Eliminate for MPE (maximum probable explanation).
ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# Check that the lookup table is correct
assert lookup.size() == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert lookup_x1_a1_x2.nrFrontals() == 3
# Check that sum is 1.0 (not 100, as we now normalize to prevent underflow)
empty = gtsam.DiscreteValues()
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
values = DiscreteValues()
values[X(1)] = 0
values[A(1)] = 1
values[X(2)] = 1
self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
values = DiscreteValues()
values[X(2)] = 0
self.assertAlmostEqual(sum_x2(values), 0)
values[X(2)] = 1
self.assertAlmostEqual(sum_x2(values), 1.0) # not 10, as we normalize
values[X(2)] = 2
self.assertAlmostEqual(sum_x2(values), 2.0) # not 20, as we normalize
assert lookup_a2_x3.nrFrontals() == 2
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
values = DiscreteValues()
values[X(2)] = 1
values[A(2)] = 1
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,14 +10,16 @@ Author: Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
import math
import unittest import unittest
import numpy as np import numpy as np
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (BetweenFactorPose2, LevenbergMarquardtParams, Pose2, Rot2, from gtsam import (BetweenFactorPose2, BetweenFactorPose3,
ShonanAveraging2, ShonanAveraging3, BinaryMeasurementRot3, LevenbergMarquardtParams, Pose2,
Pose3, Rot2, Rot3, ShonanAveraging2, ShonanAveraging3,
ShonanAveragingParameters2, ShonanAveragingParameters3) ShonanAveragingParameters2, ShonanAveragingParameters3)
DEFAULT_PARAMS = ShonanAveragingParameters3( DEFAULT_PARAMS = ShonanAveragingParameters3(
@ -197,6 +199,19 @@ class TestShonanAveraging(GtsamTestCase):
expected_thetas_deg = np.array([0.0, 90.0, 0.0]) expected_thetas_deg = np.array([0.0, 90.0, 0.0])
np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1) np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1)
def test_measurements3(self):
"""Create from Measurements."""
measurements = []
unit3 = gtsam.noiseModel.Unit.Create(3)
m01 = BinaryMeasurementRot3(0, 1, Rot3.Yaw(math.radians(90)), unit3)
m12 = BinaryMeasurementRot3(1, 2, Rot3.Yaw(math.radians(90)), unit3)
measurements.append(m01)
measurements.append(m12)
obj = ShonanAveraging3(measurements)
self.assertIsInstance(obj, ShonanAveraging3)
initial = obj.initializeRandomly()
_, cost = obj.run(initial, min_p=3, max_p=5)
self.assertAlmostEqual(cost, 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -84,7 +84,7 @@ class TestVisualISAMExample(GtsamTestCase):
values.insert(key, v) values.insert(key, v)
self.assertAlmostEqual(isam.error(values), 34212421.14731998) self.assertAlmostEqual(isam.error(values), 34212421.14732)
def test_isam2_update(self): def test_isam2_update(self):
""" """